123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- import asyncio
- import os
- from unittest import mock
- from asgiref.sync import async_to_sync, iscoroutinefunction
- from django.core.cache import DEFAULT_CACHE_ALIAS, caches
- from django.core.exceptions import ImproperlyConfigured, SynchronousOnlyOperation
- from django.http import HttpResponse, HttpResponseNotAllowed
- from django.test import RequestFactory, SimpleTestCase
- from django.utils.asyncio import async_unsafe
- from django.views.generic.base import View
- from .models import SimpleModel
- class CacheTest(SimpleTestCase):
- def test_caches_local(self):
- @async_to_sync
- async def async_cache():
- return caches[DEFAULT_CACHE_ALIAS]
- cache_1 = async_cache()
- cache_2 = async_cache()
- self.assertIs(cache_1, cache_2)
- class DatabaseConnectionTest(SimpleTestCase):
- """A database connection cannot be used in an async context."""
- async def test_get_async_connection(self):
- with self.assertRaises(SynchronousOnlyOperation):
- list(SimpleModel.objects.all())
- class AsyncUnsafeTest(SimpleTestCase):
- """
- async_unsafe decorator should work correctly and returns the correct
- message.
- """
- @async_unsafe
- def dangerous_method(self):
- return True
- async def test_async_unsafe(self):
- # async_unsafe decorator catches bad access and returns the right
- # message.
- msg = (
- "You cannot call this from an async context - use a thread or "
- "sync_to_async."
- )
- with self.assertRaisesMessage(SynchronousOnlyOperation, msg):
- self.dangerous_method()
- @mock.patch.dict(os.environ, {"DJANGO_ALLOW_ASYNC_UNSAFE": "true"})
- @async_to_sync # mock.patch() is not async-aware.
- async def test_async_unsafe_suppressed(self):
- # Decorator doesn't trigger check when the environment variable to
- # suppress it is set.
- try:
- self.dangerous_method()
- except SynchronousOnlyOperation:
- self.fail("SynchronousOnlyOperation should not be raised.")
- class SyncView(View):
- def get(self, request, *args, **kwargs):
- return HttpResponse("Hello (sync) world!")
- class AsyncView(View):
- async def get(self, request, *args, **kwargs):
- return HttpResponse("Hello (async) world!")
- class ViewTests(SimpleTestCase):
- def test_views_are_correctly_marked(self):
- tests = [
- (SyncView, False),
- (AsyncView, True),
- ]
- for view_cls, is_async in tests:
- with self.subTest(view_cls=view_cls, is_async=is_async):
- self.assertIs(view_cls.view_is_async, is_async)
- callback = view_cls.as_view()
- self.assertIs(iscoroutinefunction(callback), is_async)
- def test_mixed_views_raise_error(self):
- class MixedView(View):
- def get(self, request, *args, **kwargs):
- return HttpResponse("Hello (mixed) world!")
- async def post(self, request, *args, **kwargs):
- return HttpResponse("Hello (mixed) world!")
- msg = (
- f"{MixedView.__qualname__} HTTP handlers must either be all sync or all "
- "async."
- )
- with self.assertRaisesMessage(ImproperlyConfigured, msg):
- MixedView.as_view()
- def test_options_handler_responds_correctly(self):
- tests = [
- (SyncView, False),
- (AsyncView, True),
- ]
- for view_cls, is_coroutine in tests:
- with self.subTest(view_cls=view_cls, is_coroutine=is_coroutine):
- instance = view_cls()
- response = instance.options(None)
- self.assertIs(
- asyncio.iscoroutine(response),
- is_coroutine,
- )
- if is_coroutine:
- response = asyncio.run(response)
- self.assertIsInstance(response, HttpResponse)
- def test_http_method_not_allowed_responds_correctly(self):
- request_factory = RequestFactory()
- tests = [
- (SyncView, False),
- (AsyncView, True),
- ]
- for view_cls, is_coroutine in tests:
- with self.subTest(view_cls=view_cls, is_coroutine=is_coroutine):
- instance = view_cls()
- response = instance.http_method_not_allowed(request_factory.post("/"))
- self.assertIs(
- asyncio.iscoroutine(response),
- is_coroutine,
- )
- if is_coroutine:
- response = asyncio.run(response)
- self.assertIsInstance(response, HttpResponseNotAllowed)
- def test_base_view_class_is_sync(self):
- """
- View and by extension any subclasses that don't define handlers are
- sync.
- """
- self.assertIs(View.view_is_async, False)
|