Przeglądaj źródła

Fixed #35030 -- Made django.contrib.auth decorators to work with async functions.

Dingning 1 rok temu
rodzic
commit
549320946d

+ 61 - 19
django/contrib/auth/decorators.py

@@ -1,6 +1,9 @@
+import asyncio
 from functools import wraps
 from urllib.parse import urlparse
 
+from asgiref.sync import async_to_sync, sync_to_async
+
 from django.conf import settings
 from django.contrib.auth import REDIRECT_FIELD_NAME
 from django.core.exceptions import PermissionDenied
@@ -17,10 +20,7 @@ def user_passes_test(
     """
 
     def decorator(view_func):
-        @wraps(view_func)
-        def _wrapper_view(request, *args, **kwargs):
-            if test_func(request.user):
-                return view_func(request, *args, **kwargs)
+        def _redirect_to_login(request):
             path = request.build_absolute_uri()
             resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
             # If the login url is the same scheme and net location then just
@@ -35,7 +35,32 @@ def user_passes_test(
 
             return redirect_to_login(path, resolved_login_url, redirect_field_name)
 
-        return _wrapper_view
+        if asyncio.iscoroutinefunction(view_func):
+
+            async def _view_wrapper(request, *args, **kwargs):
+                auser = await request.auser()
+                if asyncio.iscoroutinefunction(test_func):
+                    test_pass = await test_func(auser)
+                else:
+                    test_pass = await sync_to_async(test_func)(auser)
+
+                if test_pass:
+                    return await view_func(request, *args, **kwargs)
+                return _redirect_to_login(request)
+
+        else:
+
+            def _view_wrapper(request, *args, **kwargs):
+                if asyncio.iscoroutinefunction(test_func):
+                    test_pass = async_to_sync(test_func)(request.user)
+                else:
+                    test_pass = test_func(request.user)
+
+                if test_pass:
+                    return view_func(request, *args, **kwargs)
+                return _redirect_to_login(request)
+
+        return wraps(view_func)(_view_wrapper)
 
     return decorator
 
@@ -64,19 +89,36 @@ def permission_required(perm, login_url=None, raise_exception=False):
     If the raise_exception parameter is given the PermissionDenied exception
     is raised.
     """
+    if isinstance(perm, str):
+        perms = (perm,)
+    else:
+        perms = perm
+
+    def decorator(view_func):
+        if asyncio.iscoroutinefunction(view_func):
+
+            async def check_perms(user):
+                # First check if the user has the permission (even anon users).
+                if await sync_to_async(user.has_perms)(perms):
+                    return True
+                # In case the 403 handler should be called raise the exception.
+                if raise_exception:
+                    raise PermissionDenied
+                # As the last resort, show the login form.
+                return False
 
-    def check_perms(user):
-        if isinstance(perm, str):
-            perms = (perm,)
         else:
-            perms = perm
-        # First check if the user has the permission (even anon users)
-        if user.has_perms(perms):
-            return True
-        # In case the 403 handler should be called raise the exception
-        if raise_exception:
-            raise PermissionDenied
-        # As the last resort, show the login form
-        return False
-
-    return user_passes_test(check_perms, login_url=login_url)
+
+            def check_perms(user):
+                # First check if the user has the permission (even anon users).
+                if user.has_perms(perms):
+                    return True
+                # In case the 403 handler should be called raise the exception.
+                if raise_exception:
+                    raise PermissionDenied
+                # As the last resort, show the login form.
+                return False
+
+        return user_passes_test(check_perms, login_url=login_url)(view_func)
+
+    return decorator

+ 5 - 0
docs/releases/5.1.txt

@@ -52,6 +52,11 @@ Minor features
   form save. This is now available in the admin when visiting the user creation
   and password change pages.
 
+* :func:`~.django.contrib.auth.decorators.login_required`,
+  :func:`~.django.contrib.auth.decorators.permission_required`, and
+  :func:`~.django.contrib.auth.decorators.user_passes_test` decorators now
+  support wrapping asynchronous view functions.
+
 :mod:`django.contrib.contenttypes`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 13 - 0
docs/topics/auth/default.txt

@@ -617,6 +617,10 @@ The ``login_required`` decorator
     :func:`django.contrib.admin.views.decorators.staff_member_required`
     decorator a useful alternative to ``login_required()``.
 
+.. versionchanged:: 5.1
+
+    Support for wrapping asynchronous view functions was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 
 The ``LoginRequiredMixin`` mixin
@@ -714,6 +718,11 @@ email in the desired domain and if not, redirects to the login page::
         @user_passes_test(email_check, login_url="/login/")
         def my_view(request): ...
 
+    .. versionchanged:: 5.1
+
+        Support for wrapping asynchronous view functions and using asynchronous
+        test callables was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 
 .. class:: UserPassesTestMixin
@@ -818,6 +827,10 @@ The ``permission_required`` decorator
     ``redirect_authenticated_user=True`` and the logged-in user doesn't have
     all of the required permissions.
 
+.. versionchanged:: 5.1
+
+    Support for wrapping asynchronous view functions was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 
 The ``PermissionRequiredMixin`` mixin

+ 209 - 0
tests/auth_tests/test_decorators.py

@@ -1,3 +1,7 @@
+from asyncio import iscoroutinefunction
+
+from asgiref.sync import sync_to_async
+
 from django.conf import settings
 from django.contrib.auth import models
 from django.contrib.auth.decorators import (
@@ -19,6 +23,22 @@ class LoginRequiredTestCase(AuthViewsTestCase):
     Tests the login_required decorators
     """
 
+    factory = RequestFactory()
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = login_required(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = login_required(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_callable(self):
         """
         login_required is assignable to callable objects.
@@ -63,6 +83,35 @@ class LoginRequiredTestCase(AuthViewsTestCase):
             view_url="/login_required_login_url/", login_url="/somewhere/"
         )
 
+    async def test_login_required_async_view(self, login_url=None):
+        async def async_view(request):
+            return HttpResponse()
+
+        async def auser_anonymous():
+            return models.AnonymousUser()
+
+        async def auser():
+            return self.u1
+
+        if login_url is None:
+            async_view = login_required(async_view)
+            login_url = settings.LOGIN_URL
+        else:
+            async_view = login_required(async_view, login_url=login_url)
+
+        request = self.factory.get("/rand")
+        request.auser = auser_anonymous
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+        self.assertIn(login_url, response.url)
+
+        request.auser = auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_login_required_next_url_async_view(self):
+        await self.test_login_required_async_view(login_url="/somewhere/")
+
 
 class PermissionsRequiredDecoratorTest(TestCase):
     """
@@ -80,6 +129,24 @@ class PermissionsRequiredDecoratorTest(TestCase):
         )
         cls.user.user_permissions.add(*perms)
 
+    @classmethod
+    async def auser(cls):
+        return cls.user
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = permission_required([])(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = permission_required([])(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_many_permissions_pass(self):
         @permission_required(
             ["auth_tests.add_customuser", "auth_tests.change_customuser"]
@@ -147,6 +214,73 @@ class PermissionsRequiredDecoratorTest(TestCase):
         with self.assertRaises(PermissionDenied):
             a_view(request)
 
+    async def test_many_permissions_pass_async_view(self):
+        @permission_required(
+            ["auth_tests.add_customuser", "auth_tests.change_customuser"]
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_many_permissions_in_set_pass_async_view(self):
+        @permission_required(
+            {"auth_tests.add_customuser", "auth_tests.change_customuser"}
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_single_permission_pass_async_view(self):
+        @permission_required("auth_tests.add_customuser")
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_permissioned_denied_redirect_async_view(self):
+        @permission_required(
+            [
+                "auth_tests.add_customuser",
+                "auth_tests.change_customuser",
+                "nonexistent-permission",
+            ]
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_permissioned_denied_exception_raised_async_view(self):
+        @permission_required(
+            [
+                "auth_tests.add_customuser",
+                "auth_tests.change_customuser",
+                "nonexistent-permission",
+            ],
+            raise_exception=True,
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        with self.assertRaises(PermissionDenied):
+            await async_view(request)
+
 
 class UserPassesTestDecoratorTest(TestCase):
     factory = RequestFactory()
@@ -162,6 +296,28 @@ class UserPassesTestDecoratorTest(TestCase):
         )
         cls.user_pass.user_permissions.add(*perms)
 
+    @classmethod
+    async def auser_pass(cls):
+        return cls.user_pass
+
+    @classmethod
+    async def auser_deny(cls):
+        return cls.user_deny
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = user_passes_test(lambda user: True)(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = user_passes_test(lambda user: True)(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_decorator(self):
         def sync_test_func(user):
             return bool(
@@ -180,3 +336,56 @@ class UserPassesTestDecoratorTest(TestCase):
         request.user = self.user_deny
         response = sync_view(request)
         self.assertEqual(response.status_code, 302)
+
+    def test_decorator_async_test_func(self):
+        async def async_test_func(user):
+            return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+
+        @user_passes_test(async_test_func)
+        def sync_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.user = self.user_pass
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.user = self.user_deny
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_decorator_async_view(self):
+        def sync_test_func(user):
+            return bool(
+                models.Group.objects.filter(name__istartswith=user.username).exists()
+            )
+
+        @user_passes_test(sync_test_func)
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser_pass
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.auser = self.auser_deny
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_decorator_async_view_async_test_func(self):
+        async def async_test_func(user):
+            return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+
+        @user_passes_test(async_test_func)
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser_pass
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.auser = self.auser_deny
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)