Browse Source

Refs #31949 -- Made make_middleware_decorator to work with async functions.

Ben Lomax 1 year ago
parent
commit
74f7deec9e

+ 46 - 11
django/utils/decorators.py

@@ -2,6 +2,8 @@
 
 from functools import partial, update_wrapper, wraps
 
+from asgiref.sync import iscoroutinefunction
+
 
 class classonlymethod(classmethod):
     def __get__(self, instance, cls=None):
@@ -120,8 +122,7 @@ def make_middleware_decorator(middleware_class):
         def _decorator(view_func):
             middleware = middleware_class(view_func, *m_args, **m_kwargs)
 
-            @wraps(view_func)
-            def _wrapper_view(request, *args, **kwargs):
+            def _pre_process_request(request, *args, **kwargs):
                 if hasattr(middleware, "process_request"):
                     result = middleware.process_request(request)
                     if result is not None:
@@ -130,14 +131,16 @@ def make_middleware_decorator(middleware_class):
                     result = middleware.process_view(request, view_func, args, kwargs)
                     if result is not None:
                         return result
-                try:
-                    response = view_func(request, *args, **kwargs)
-                except Exception as e:
-                    if hasattr(middleware, "process_exception"):
-                        result = middleware.process_exception(request, e)
-                        if result is not None:
-                            return result
-                    raise
+                return None
+
+            def _process_exception(request, exception):
+                if hasattr(middleware, "process_exception"):
+                    result = middleware.process_exception(request, exception)
+                    if result is not None:
+                        return result
+                raise
+
+            def _post_process_request(request, response):
                 if hasattr(response, "render") and callable(response.render):
                     if hasattr(middleware, "process_template_response"):
                         response = middleware.process_template_response(
@@ -156,7 +159,39 @@ def make_middleware_decorator(middleware_class):
                         return middleware.process_response(request, response)
                 return response
 
-            return _wrapper_view
+            if iscoroutinefunction(view_func):
+
+                async def _view_wrapper(request, *args, **kwargs):
+                    result = _pre_process_request(request, *args, **kwargs)
+                    if result is not None:
+                        return result
+
+                    try:
+                        response = await view_func(request, *args, **kwargs)
+                    except Exception as e:
+                        result = _process_exception(request, e)
+                        if result is not None:
+                            return result
+
+                    return _post_process_request(request, response)
+
+            else:
+
+                def _view_wrapper(request, *args, **kwargs):
+                    result = _pre_process_request(request, *args, **kwargs)
+                    if result is not None:
+                        return result
+
+                    try:
+                        response = view_func(request, *args, **kwargs)
+                    except Exception as e:
+                        result = _process_exception(request, e)
+                        if result is not None:
+                            return result
+
+                    return _post_process_request(request, response)
+
+            return wraps(view_func)(_view_wrapper)
 
         return _decorator
 

+ 12 - 0
docs/ref/csrf.txt

@@ -170,6 +170,10 @@ class-based views<decorating-class-based-views>`.
             # ...
             return render(request, "a_template.html", c)
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. function:: requires_csrf_token(view)
 
     Normally the :ttag:`csrf_token` template tag will not work if
@@ -190,10 +194,18 @@ class-based views<decorating-class-based-views>`.
             # ...
             return render(request, "a_template.html", c)
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. function:: ensure_csrf_cookie(view)
 
     This decorator forces a view to send the CSRF cookie.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 Settings
 ========
 

+ 5 - 0
docs/releases/5.0.txt

@@ -322,9 +322,14 @@ Decorators
   * :func:`~django.views.decorators.cache.never_cache`
   * :func:`~django.views.decorators.common.no_append_slash`
   * :func:`~django.views.decorators.csrf.csrf_exempt`
+  * :func:`~django.views.decorators.csrf.csrf_protect`
+  * :func:`~django.views.decorators.csrf.ensure_csrf_cookie`
+  * :func:`~django.views.decorators.csrf.requires_csrf_token`
   * :func:`~django.views.decorators.debug.sensitive_variables`
   * :func:`~django.views.decorators.debug.sensitive_post_parameters`
+  * :func:`~django.views.decorators.gzip.gzip_page`
   * :func:`~django.views.decorators.http.condition`
+  * ``conditional_page()``
   * :func:`~django.views.decorators.http.etag`
   * :func:`~django.views.decorators.http.last_modified`
   * :func:`~django.views.decorators.http.require_http_methods`

+ 5 - 0
docs/topics/async.txt

@@ -85,9 +85,14 @@ view functions:
 * :func:`~django.views.decorators.cache.never_cache`
 * :func:`~django.views.decorators.common.no_append_slash`
 * :func:`~django.views.decorators.csrf.csrf_exempt`
+* :func:`~django.views.decorators.csrf.csrf_protect`
+* :func:`~django.views.decorators.csrf.ensure_csrf_cookie`
+* :func:`~django.views.decorators.csrf.requires_csrf_token`
 * :func:`~django.views.decorators.debug.sensitive_variables`
 * :func:`~django.views.decorators.debug.sensitive_post_parameters`
+* :func:`~django.views.decorators.gzip.gzip_page`
 * :func:`~django.views.decorators.http.condition`
+* ``conditional_page()``
 * :func:`~django.views.decorators.http.etag`
 * :func:`~django.views.decorators.http.last_modified`
 * :func:`~django.views.decorators.http.require_http_methods`

+ 4 - 0
docs/topics/http/decorators.txt

@@ -105,6 +105,10 @@ compression on a per-view basis.
     It sets the ``Vary`` header accordingly, so that caches will base their
     storage on the ``Accept-Encoding`` header.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. module:: django.views.decorators.vary
 
 Vary headers

+ 87 - 0
tests/decorators/test_csrf.py

@@ -24,6 +24,20 @@ class CsrfTestMixin:
 
 
 class CsrfProtectTests(CsrfTestMixin, SimpleTestCase):
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = csrf_protect(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = csrf_protect(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_csrf_protect_decorator(self):
         @csrf_protect
         def sync_view(request):
@@ -39,8 +53,37 @@ class CsrfProtectTests(CsrfTestMixin, SimpleTestCase):
             response = sync_view(request)
             self.assertEqual(response.status_code, 403)
 
+    async def test_csrf_protect_decorator_async_view(self):
+        @csrf_protect
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.get_request()
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+        self.assertIs(request.csrf_processing_done, True)
+
+        with self.assertLogs("django.security.csrf", "WARNING"):
+            request = self.get_request(token=None)
+            response = await async_view(request)
+            self.assertEqual(response.status_code, 403)
+
 
 class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase):
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = requires_csrf_token(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = requires_csrf_token(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_requires_csrf_token_decorator(self):
         @requires_csrf_token
         def sync_view(request):
@@ -56,8 +99,37 @@ class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase):
             response = sync_view(request)
             self.assertEqual(response.status_code, 200)
 
+    async def test_requires_csrf_token_decorator_async_view(self):
+        @requires_csrf_token
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.get_request()
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+        self.assertIs(request.csrf_processing_done, True)
+
+        with self.assertNoLogs("django.security.csrf", "WARNING"):
+            request = self.get_request(token=None)
+            response = await async_view(request)
+            self.assertEqual(response.status_code, 200)
+
 
 class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase):
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = ensure_csrf_cookie(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = ensure_csrf_cookie(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_ensure_csrf_cookie_decorator(self):
         @ensure_csrf_cookie
         def sync_view(request):
@@ -73,6 +145,21 @@ class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase):
             response = sync_view(request)
             self.assertEqual(response.status_code, 200)
 
+    async def test_ensure_csrf_cookie_decorator_async_view(self):
+        @ensure_csrf_cookie
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.get_request()
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+        self.assertIs(request.csrf_processing_done, True)
+
+        with self.assertNoLogs("django.security.csrf", "WARNING"):
+            request = self.get_request(token=None)
+            response = await async_view(request)
+            self.assertEqual(response.status_code, 200)
+
 
 class CsrfExemptTests(SimpleTestCase):
     def test_wrapped_sync_function_is_not_coroutine_function(self):

+ 27 - 0
tests/decorators/test_gzip.py

@@ -1,3 +1,5 @@
+from asgiref.sync import iscoroutinefunction
+
 from django.http import HttpRequest, HttpResponse
 from django.test import SimpleTestCase
 from django.views.decorators.gzip import gzip_page
@@ -7,6 +9,20 @@ class GzipPageTests(SimpleTestCase):
     # Gzip ignores content that is too short.
     content = "Content " * 100
 
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = gzip_page(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = gzip_page(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_gzip_page_decorator(self):
         @gzip_page
         def sync_view(request):
@@ -17,3 +33,14 @@ class GzipPageTests(SimpleTestCase):
         response = sync_view(request)
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.get("Content-Encoding"), "gzip")
+
+    async def test_gzip_page_decorator_async_view(self):
+        @gzip_page
+        async def async_view(request):
+            return HttpResponse(content=self.content)
+
+        request = HttpRequest()
+        request.META["HTTP_ACCEPT_ENCODING"] = "gzip"
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.get("Content-Encoding"), "gzip")

+ 28 - 0
tests/decorators/test_http.py

@@ -163,6 +163,20 @@ class ConditionDecoratorTest(SimpleTestCase):
 
 
 class ConditionalPageTests(SimpleTestCase):
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = conditional_page(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = conditional_page(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_conditional_page_decorator_successful(self):
         @conditional_page
         def sync_view(request):
@@ -176,3 +190,17 @@ class ConditionalPageTests(SimpleTestCase):
         response = sync_view(request)
         self.assertEqual(response.status_code, 200)
         self.assertIsNotNone(response.get("Etag"))
+
+    async def test_conditional_page_decorator_successful_async_view(self):
+        @conditional_page
+        async def async_view(request):
+            response = HttpResponse()
+            response.content = b"test"
+            response["Cache-Control"] = "public"
+            return response
+
+        request = HttpRequest()
+        request.method = "GET"
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+        self.assertIsNotNone(response.get("Etag"))