123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import threading
- from asgiref.sync import async_to_sync, iscoroutinefunction
- from django.contrib.admindocs.middleware import XViewMiddleware
- from django.contrib.auth.middleware import (
- AuthenticationMiddleware,
- RemoteUserMiddleware,
- )
- from django.contrib.flatpages.middleware import FlatpageFallbackMiddleware
- from django.contrib.messages.middleware import MessageMiddleware
- from django.contrib.redirects.middleware import RedirectFallbackMiddleware
- from django.contrib.sessions.middleware import SessionMiddleware
- from django.contrib.sites.middleware import CurrentSiteMiddleware
- from django.db import connection
- from django.http.request import HttpRequest
- from django.http.response import HttpResponse
- from django.middleware.cache import (
- CacheMiddleware,
- FetchFromCacheMiddleware,
- UpdateCacheMiddleware,
- )
- from django.middleware.clickjacking import XFrameOptionsMiddleware
- from django.middleware.common import BrokenLinkEmailsMiddleware, CommonMiddleware
- from django.middleware.csrf import CsrfViewMiddleware
- from django.middleware.gzip import GZipMiddleware
- from django.middleware.http import ConditionalGetMiddleware
- from django.middleware.locale import LocaleMiddleware
- from django.middleware.security import SecurityMiddleware
- from django.test import SimpleTestCase
- from django.utils.deprecation import MiddlewareMixin
- class MiddlewareMixinTests(SimpleTestCase):
- middlewares = [
- AuthenticationMiddleware,
- BrokenLinkEmailsMiddleware,
- CacheMiddleware,
- CommonMiddleware,
- ConditionalGetMiddleware,
- CsrfViewMiddleware,
- CurrentSiteMiddleware,
- FetchFromCacheMiddleware,
- FlatpageFallbackMiddleware,
- GZipMiddleware,
- LocaleMiddleware,
- MessageMiddleware,
- RedirectFallbackMiddleware,
- RemoteUserMiddleware,
- SecurityMiddleware,
- SessionMiddleware,
- UpdateCacheMiddleware,
- XFrameOptionsMiddleware,
- XViewMiddleware,
- ]
- def test_repr(self):
- class GetResponse:
- def __call__(self):
- return HttpResponse()
- def get_response():
- return HttpResponse()
- self.assertEqual(
- repr(MiddlewareMixin(GetResponse())),
- "<MiddlewareMixin get_response=GetResponse>",
- )
- self.assertEqual(
- repr(MiddlewareMixin(get_response)),
- "<MiddlewareMixin get_response="
- "MiddlewareMixinTests.test_repr.<locals>.get_response>",
- )
- self.assertEqual(
- repr(CsrfViewMiddleware(GetResponse())),
- "<CsrfViewMiddleware get_response=GetResponse>",
- )
- self.assertEqual(
- repr(CsrfViewMiddleware(get_response)),
- "<CsrfViewMiddleware get_response="
- "MiddlewareMixinTests.test_repr.<locals>.get_response>",
- )
- def test_passing_explicit_none(self):
- msg = "get_response must be provided."
- for middleware in self.middlewares:
- with self.subTest(middleware=middleware):
- with self.assertRaisesMessage(ValueError, msg):
- middleware(None)
- def test_coroutine(self):
- async def async_get_response(request):
- return HttpResponse()
- def sync_get_response(request):
- return HttpResponse()
- for middleware in self.middlewares:
- with self.subTest(middleware=middleware.__qualname__):
- # Middleware appears as coroutine if get_function is
- # a coroutine.
- middleware_instance = middleware(async_get_response)
- self.assertIs(iscoroutinefunction(middleware_instance), True)
- # Middleware doesn't appear as coroutine if get_function is not
- # a coroutine.
- middleware_instance = middleware(sync_get_response)
- self.assertIs(iscoroutinefunction(middleware_instance), False)
- def test_sync_to_async_uses_base_thread_and_connection(self):
- """
- The process_request() and process_response() hooks must be called with
- the sync_to_async thread_sensitive flag enabled, so that database
- operations use the correct thread and connection.
- """
- def request_lifecycle():
- """Fake request_started/request_finished."""
- return (threading.get_ident(), id(connection))
- async def get_response(self):
- return HttpResponse()
- class SimpleMiddleWare(MiddlewareMixin):
- def process_request(self, request):
- request.thread_and_connection = request_lifecycle()
- def process_response(self, request, response):
- response.thread_and_connection = request_lifecycle()
- return response
- threads_and_connections = []
- threads_and_connections.append(request_lifecycle())
- request = HttpRequest()
- response = async_to_sync(SimpleMiddleWare(get_response))(request)
- threads_and_connections.append(request.thread_and_connection)
- threads_and_connections.append(response.thread_and_connection)
- threads_and_connections.append(request_lifecycle())
- self.assertEqual(len(threads_and_connections), 4)
- self.assertEqual(len(set(threads_and_connections)), 1)
|