Browse Source

Refs #31949 -- Made http decorators to work with async functions.

th3nn3ss 1 year ago
parent
commit
3152f9de47

+ 59 - 19
django/views/decorators/http.py

@@ -4,6 +4,8 @@ Decorators for views based on HTTP headers.
 import datetime
 from functools import wraps
 
+from asgiref.sync import iscoroutinefunction
+
 from django.http import HttpResponseNotAllowed
 from django.middleware.http import ConditionalGetMiddleware
 from django.utils import timezone
@@ -28,19 +30,37 @@ def require_http_methods(request_method_list):
     """
 
     def decorator(func):
-        @wraps(func)
-        def inner(request, *args, **kwargs):
-            if request.method not in request_method_list:
-                response = HttpResponseNotAllowed(request_method_list)
-                log_response(
-                    "Method Not Allowed (%s): %s",
-                    request.method,
-                    request.path,
-                    response=response,
-                    request=request,
-                )
-                return response
-            return func(request, *args, **kwargs)
+        if iscoroutinefunction(func):
+
+            @wraps(func)
+            async def inner(request, *args, **kwargs):
+                if request.method not in request_method_list:
+                    response = HttpResponseNotAllowed(request_method_list)
+                    log_response(
+                        "Method Not Allowed (%s): %s",
+                        request.method,
+                        request.path,
+                        response=response,
+                        request=request,
+                    )
+                    return response
+                return await func(request, *args, **kwargs)
+
+        else:
+
+            @wraps(func)
+            def inner(request, *args, **kwargs):
+                if request.method not in request_method_list:
+                    response = HttpResponseNotAllowed(request_method_list)
+                    log_response(
+                        "Method Not Allowed (%s): %s",
+                        request.method,
+                        request.path,
+                        response=response,
+                        request=request,
+                    )
+                    return response
+                return func(request, *args, **kwargs)
 
         return inner
 
@@ -83,8 +103,7 @@ def condition(etag_func=None, last_modified_func=None):
     """
 
     def decorator(func):
-        @wraps(func)
-        def inner(request, *args, **kwargs):
+        def _pre_process_request(request, *args, **kwargs):
             # Compute values (if any) for the requested resource.
             res_last_modified = None
             if last_modified_func:
@@ -100,10 +119,9 @@ def condition(etag_func=None, last_modified_func=None):
                 etag=res_etag,
                 last_modified=res_last_modified,
             )
+            return response, res_etag, res_last_modified
 
-            if response is None:
-                response = func(request, *args, **kwargs)
-
+        def _post_process_request(request, response, res_etag, res_last_modified):
             # Set relevant headers on the response if they don't already exist
             # and if the request method is safe.
             if request.method in ("GET", "HEAD"):
@@ -112,7 +130,29 @@ def condition(etag_func=None, last_modified_func=None):
                 if res_etag:
                     response.headers.setdefault("ETag", res_etag)
 
-            return response
+        if iscoroutinefunction(func):
+
+            @wraps(func)
+            async def inner(request, *args, **kwargs):
+                response, res_etag, res_last_modified = _pre_process_request(
+                    request, *args, **kwargs
+                )
+                if response is None:
+                    response = await func(request, *args, **kwargs)
+                _post_process_request(request, response, res_etag, res_last_modified)
+                return response
+
+        else:
+
+            @wraps(func)
+            def inner(request, *args, **kwargs):
+                response, res_etag, res_last_modified = _pre_process_request(
+                    request, *args, **kwargs
+                )
+                if response is None:
+                    response = func(request, *args, **kwargs)
+                _post_process_request(request, response, res_etag, res_last_modified)
+                return response
 
         return inner
 

+ 7 - 0
docs/releases/5.0.txt

@@ -243,6 +243,13 @@ Decorators
   * :func:`~django.views.decorators.common.no_append_slash`
   * :func:`~django.views.decorators.debug.sensitive_variables`
   * :func:`~django.views.decorators.debug.sensitive_post_parameters`
+  * :func:`~django.views.decorators.http.condition`
+  * :func:`~django.views.decorators.http.etag`
+  * :func:`~django.views.decorators.http.last_modified`
+  * :func:`~django.views.decorators.http.require_http_methods`
+  * :func:`~django.views.decorators.http.require_GET`
+  * :func:`~django.views.decorators.http.require_POST`
+  * :func:`~django.views.decorators.http.require_safe`
   * ``xframe_options_deny()``
   * ``xframe_options_sameorigin()``
   * ``xframe_options_exempt()``

+ 7 - 0
docs/topics/async.txt

@@ -84,6 +84,13 @@ view functions:
 * :func:`~django.views.decorators.cache.cache_control`
 * :func:`~django.views.decorators.cache.never_cache`
 * :func:`~django.views.decorators.common.no_append_slash`
+* :func:`~django.views.decorators.http.condition`
+* :func:`~django.views.decorators.http.etag`
+* :func:`~django.views.decorators.http.last_modified`
+* :func:`~django.views.decorators.http.require_http_methods`
+* :func:`~django.views.decorators.http.require_GET`
+* :func:`~django.views.decorators.http.require_POST`
+* :func:`~django.views.decorators.http.require_safe`
 * ``xframe_options_deny()``
 * ``xframe_options_sameorigin()``
 * ``xframe_options_exempt()``

+ 20 - 0
docs/topics/http/decorators.txt

@@ -33,14 +33,26 @@ a :class:`django.http.HttpResponseNotAllowed` if the conditions are not met.
 
     Note that request methods should be in uppercase.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. function:: require_GET()
 
     Decorator to require that a view only accepts the GET method.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. function:: require_POST()
 
     Decorator to require that a view only accepts the POST method.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. function:: require_safe()
 
     Decorator to require that a view only accepts the GET and HEAD methods.
@@ -55,6 +67,10 @@ a :class:`django.http.HttpResponseNotAllowed` if the conditions are not met.
         such as link checkers, rely on HEAD requests, you might prefer
         using ``require_safe`` instead of ``require_GET``.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 Conditional view processing
 ===========================
 
@@ -71,6 +87,10 @@ control caching behavior on particular views.
     headers; see
     :doc:`conditional view processing </topics/conditional-view-processing>`.
 
+    .. versionchanged:: 5.0
+
+        Support for wrapping asynchronous view functions was added.
+
 .. module:: django.views.decorators.gzip
 
 GZip compression

+ 87 - 0
tests/decorators/test_http.py

@@ -1,11 +1,27 @@
 import datetime
 
+from asgiref.sync import iscoroutinefunction
+
 from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed
 from django.test import SimpleTestCase
 from django.views.decorators.http import condition, require_http_methods, require_safe
 
 
 class RequireHttpMethodsTest(SimpleTestCase):
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = require_http_methods(["GET"])(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = require_http_methods(["GET"])(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_require_http_methods_methods(self):
         @require_http_methods(["GET", "PUT"])
         def my_view(request):
@@ -23,6 +39,23 @@ class RequireHttpMethodsTest(SimpleTestCase):
         request.method = "DELETE"
         self.assertIsInstance(my_view(request), HttpResponseNotAllowed)
 
+    async def test_require_http_methods_methods_async_view(self):
+        @require_http_methods(["GET", "PUT"])
+        async def my_view(request):
+            return HttpResponse("OK")
+
+        request = HttpRequest()
+        request.method = "GET"
+        self.assertIsInstance(await my_view(request), HttpResponse)
+        request.method = "PUT"
+        self.assertIsInstance(await my_view(request), HttpResponse)
+        request.method = "HEAD"
+        self.assertIsInstance(await my_view(request), HttpResponseNotAllowed)
+        request.method = "POST"
+        self.assertIsInstance(await my_view(request), HttpResponseNotAllowed)
+        request.method = "DELETE"
+        self.assertIsInstance(await my_view(request), HttpResponseNotAllowed)
+
 
 class RequireSafeDecoratorTest(SimpleTestCase):
     def test_require_safe_accepts_only_safe_methods(self):
@@ -42,6 +75,23 @@ class RequireSafeDecoratorTest(SimpleTestCase):
         request.method = "DELETE"
         self.assertIsInstance(my_safe_view(request), HttpResponseNotAllowed)
 
+    async def test_require_safe_accepts_only_safe_methods_async_view(self):
+        @require_safe
+        async def async_view(request):
+            return HttpResponse("OK")
+
+        request = HttpRequest()
+        request.method = "GET"
+        self.assertIsInstance(await async_view(request), HttpResponse)
+        request.method = "HEAD"
+        self.assertIsInstance(await async_view(request), HttpResponse)
+        request.method = "POST"
+        self.assertIsInstance(await async_view(request), HttpResponseNotAllowed)
+        request.method = "PUT"
+        self.assertIsInstance(await async_view(request), HttpResponseNotAllowed)
+        request.method = "DELETE"
+        self.assertIsInstance(await async_view(request), HttpResponseNotAllowed)
+
 
 class ConditionDecoratorTest(SimpleTestCase):
     def etag_func(request, *args, **kwargs):
@@ -50,6 +100,24 @@ class ConditionDecoratorTest(SimpleTestCase):
     def latest_entry(request, *args, **kwargs):
         return datetime.datetime(2023, 1, 2, 23, 21, 47)
 
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = condition(
+            etag_func=self.etag_func, last_modified_func=self.latest_entry
+        )(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = condition(
+            etag_func=self.etag_func, last_modified_func=self.latest_entry
+        )(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_condition_decorator(self):
         @condition(
             etag_func=self.etag_func,
@@ -68,3 +136,22 @@ class ConditionDecoratorTest(SimpleTestCase):
             response.headers["Last-Modified"],
             "Mon, 02 Jan 2023 23:21:47 GMT",
         )
+
+    async def test_condition_decorator_async_view(self):
+        @condition(
+            etag_func=self.etag_func,
+            last_modified_func=self.latest_entry,
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = HttpRequest()
+        request.method = "GET"
+        response = await async_view(request)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.headers["ETag"], '"b4246ffc4f62314ca13147c9d4f76974"')
+        self.assertEqual(
+            response.headers["Last-Modified"],
+            "Mon, 02 Jan 2023 23:21:47 GMT",
+        )