浏览代码

Fixed #35030 -- Made django.contrib.auth decorators to work with async functions.

Dingning 1 年之前
父节点
当前提交
549320946d
共有 4 个文件被更改,包括 288 次插入19 次删除
  1. 61 19
      django/contrib/auth/decorators.py
  2. 5 0
      docs/releases/5.1.txt
  3. 13 0
      docs/topics/auth/default.txt
  4. 209 0
      tests/auth_tests/test_decorators.py

+ 61 - 19
django/contrib/auth/decorators.py

@@ -1,6 +1,9 @@
+import asyncio
 from functools import wraps
 from functools import wraps
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 
 
+from asgiref.sync import async_to_sync, sync_to_async
+
 from django.conf import settings
 from django.conf import settings
 from django.contrib.auth import REDIRECT_FIELD_NAME
 from django.contrib.auth import REDIRECT_FIELD_NAME
 from django.core.exceptions import PermissionDenied
 from django.core.exceptions import PermissionDenied
@@ -17,10 +20,7 @@ def user_passes_test(
     """
     """
 
 
     def decorator(view_func):
     def decorator(view_func):
-        @wraps(view_func)
-        def _wrapper_view(request, *args, **kwargs):
-            if test_func(request.user):
-                return view_func(request, *args, **kwargs)
+        def _redirect_to_login(request):
             path = request.build_absolute_uri()
             path = request.build_absolute_uri()
             resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
             resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
             # If the login url is the same scheme and net location then just
             # If the login url is the same scheme and net location then just
@@ -35,7 +35,32 @@ def user_passes_test(
 
 
             return redirect_to_login(path, resolved_login_url, redirect_field_name)
             return redirect_to_login(path, resolved_login_url, redirect_field_name)
 
 
-        return _wrapper_view
+        if asyncio.iscoroutinefunction(view_func):
+
+            async def _view_wrapper(request, *args, **kwargs):
+                auser = await request.auser()
+                if asyncio.iscoroutinefunction(test_func):
+                    test_pass = await test_func(auser)
+                else:
+                    test_pass = await sync_to_async(test_func)(auser)
+
+                if test_pass:
+                    return await view_func(request, *args, **kwargs)
+                return _redirect_to_login(request)
+
+        else:
+
+            def _view_wrapper(request, *args, **kwargs):
+                if asyncio.iscoroutinefunction(test_func):
+                    test_pass = async_to_sync(test_func)(request.user)
+                else:
+                    test_pass = test_func(request.user)
+
+                if test_pass:
+                    return view_func(request, *args, **kwargs)
+                return _redirect_to_login(request)
+
+        return wraps(view_func)(_view_wrapper)
 
 
     return decorator
     return decorator
 
 
@@ -64,19 +89,36 @@ def permission_required(perm, login_url=None, raise_exception=False):
     If the raise_exception parameter is given the PermissionDenied exception
     If the raise_exception parameter is given the PermissionDenied exception
     is raised.
     is raised.
     """
     """
+    if isinstance(perm, str):
+        perms = (perm,)
+    else:
+        perms = perm
+
+    def decorator(view_func):
+        if asyncio.iscoroutinefunction(view_func):
+
+            async def check_perms(user):
+                # First check if the user has the permission (even anon users).
+                if await sync_to_async(user.has_perms)(perms):
+                    return True
+                # In case the 403 handler should be called raise the exception.
+                if raise_exception:
+                    raise PermissionDenied
+                # As the last resort, show the login form.
+                return False
 
 
-    def check_perms(user):
-        if isinstance(perm, str):
-            perms = (perm,)
         else:
         else:
-            perms = perm
-        # First check if the user has the permission (even anon users)
-        if user.has_perms(perms):
-            return True
-        # In case the 403 handler should be called raise the exception
-        if raise_exception:
-            raise PermissionDenied
-        # As the last resort, show the login form
-        return False
-
-    return user_passes_test(check_perms, login_url=login_url)
+
+            def check_perms(user):
+                # First check if the user has the permission (even anon users).
+                if user.has_perms(perms):
+                    return True
+                # In case the 403 handler should be called raise the exception.
+                if raise_exception:
+                    raise PermissionDenied
+                # As the last resort, show the login form.
+                return False
+
+        return user_passes_test(check_perms, login_url=login_url)(view_func)
+
+    return decorator

+ 5 - 0
docs/releases/5.1.txt

@@ -52,6 +52,11 @@ Minor features
   form save. This is now available in the admin when visiting the user creation
   form save. This is now available in the admin when visiting the user creation
   and password change pages.
   and password change pages.
 
 
+* :func:`~.django.contrib.auth.decorators.login_required`,
+  :func:`~.django.contrib.auth.decorators.permission_required`, and
+  :func:`~.django.contrib.auth.decorators.user_passes_test` decorators now
+  support wrapping asynchronous view functions.
+
 :mod:`django.contrib.contenttypes`
 :mod:`django.contrib.contenttypes`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 13 - 0
docs/topics/auth/default.txt

@@ -617,6 +617,10 @@ The ``login_required`` decorator
     :func:`django.contrib.admin.views.decorators.staff_member_required`
     :func:`django.contrib.admin.views.decorators.staff_member_required`
     decorator a useful alternative to ``login_required()``.
     decorator a useful alternative to ``login_required()``.
 
 
+.. versionchanged:: 5.1
+
+    Support for wrapping asynchronous view functions was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 .. currentmodule:: django.contrib.auth.mixins
 
 
 The ``LoginRequiredMixin`` mixin
 The ``LoginRequiredMixin`` mixin
@@ -714,6 +718,11 @@ email in the desired domain and if not, redirects to the login page::
         @user_passes_test(email_check, login_url="/login/")
         @user_passes_test(email_check, login_url="/login/")
         def my_view(request): ...
         def my_view(request): ...
 
 
+    .. versionchanged:: 5.1
+
+        Support for wrapping asynchronous view functions and using asynchronous
+        test callables was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 .. currentmodule:: django.contrib.auth.mixins
 
 
 .. class:: UserPassesTestMixin
 .. class:: UserPassesTestMixin
@@ -818,6 +827,10 @@ The ``permission_required`` decorator
     ``redirect_authenticated_user=True`` and the logged-in user doesn't have
     ``redirect_authenticated_user=True`` and the logged-in user doesn't have
     all of the required permissions.
     all of the required permissions.
 
 
+.. versionchanged:: 5.1
+
+    Support for wrapping asynchronous view functions was added.
+
 .. currentmodule:: django.contrib.auth.mixins
 .. currentmodule:: django.contrib.auth.mixins
 
 
 The ``PermissionRequiredMixin`` mixin
 The ``PermissionRequiredMixin`` mixin

+ 209 - 0
tests/auth_tests/test_decorators.py

@@ -1,3 +1,7 @@
+from asyncio import iscoroutinefunction
+
+from asgiref.sync import sync_to_async
+
 from django.conf import settings
 from django.conf import settings
 from django.contrib.auth import models
 from django.contrib.auth import models
 from django.contrib.auth.decorators import (
 from django.contrib.auth.decorators import (
@@ -19,6 +23,22 @@ class LoginRequiredTestCase(AuthViewsTestCase):
     Tests the login_required decorators
     Tests the login_required decorators
     """
     """
 
 
+    factory = RequestFactory()
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = login_required(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = login_required(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_callable(self):
     def test_callable(self):
         """
         """
         login_required is assignable to callable objects.
         login_required is assignable to callable objects.
@@ -63,6 +83,35 @@ class LoginRequiredTestCase(AuthViewsTestCase):
             view_url="/login_required_login_url/", login_url="/somewhere/"
             view_url="/login_required_login_url/", login_url="/somewhere/"
         )
         )
 
 
+    async def test_login_required_async_view(self, login_url=None):
+        async def async_view(request):
+            return HttpResponse()
+
+        async def auser_anonymous():
+            return models.AnonymousUser()
+
+        async def auser():
+            return self.u1
+
+        if login_url is None:
+            async_view = login_required(async_view)
+            login_url = settings.LOGIN_URL
+        else:
+            async_view = login_required(async_view, login_url=login_url)
+
+        request = self.factory.get("/rand")
+        request.auser = auser_anonymous
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+        self.assertIn(login_url, response.url)
+
+        request.auser = auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_login_required_next_url_async_view(self):
+        await self.test_login_required_async_view(login_url="/somewhere/")
+
 
 
 class PermissionsRequiredDecoratorTest(TestCase):
 class PermissionsRequiredDecoratorTest(TestCase):
     """
     """
@@ -80,6 +129,24 @@ class PermissionsRequiredDecoratorTest(TestCase):
         )
         )
         cls.user.user_permissions.add(*perms)
         cls.user.user_permissions.add(*perms)
 
 
+    @classmethod
+    async def auser(cls):
+        return cls.user
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = permission_required([])(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = permission_required([])(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_many_permissions_pass(self):
     def test_many_permissions_pass(self):
         @permission_required(
         @permission_required(
             ["auth_tests.add_customuser", "auth_tests.change_customuser"]
             ["auth_tests.add_customuser", "auth_tests.change_customuser"]
@@ -147,6 +214,73 @@ class PermissionsRequiredDecoratorTest(TestCase):
         with self.assertRaises(PermissionDenied):
         with self.assertRaises(PermissionDenied):
             a_view(request)
             a_view(request)
 
 
+    async def test_many_permissions_pass_async_view(self):
+        @permission_required(
+            ["auth_tests.add_customuser", "auth_tests.change_customuser"]
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_many_permissions_in_set_pass_async_view(self):
+        @permission_required(
+            {"auth_tests.add_customuser", "auth_tests.change_customuser"}
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_single_permission_pass_async_view(self):
+        @permission_required("auth_tests.add_customuser")
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+    async def test_permissioned_denied_redirect_async_view(self):
+        @permission_required(
+            [
+                "auth_tests.add_customuser",
+                "auth_tests.change_customuser",
+                "nonexistent-permission",
+            ]
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_permissioned_denied_exception_raised_async_view(self):
+        @permission_required(
+            [
+                "auth_tests.add_customuser",
+                "auth_tests.change_customuser",
+                "nonexistent-permission",
+            ],
+            raise_exception=True,
+        )
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser
+        with self.assertRaises(PermissionDenied):
+            await async_view(request)
+
 
 
 class UserPassesTestDecoratorTest(TestCase):
 class UserPassesTestDecoratorTest(TestCase):
     factory = RequestFactory()
     factory = RequestFactory()
@@ -162,6 +296,28 @@ class UserPassesTestDecoratorTest(TestCase):
         )
         )
         cls.user_pass.user_permissions.add(*perms)
         cls.user_pass.user_permissions.add(*perms)
 
 
+    @classmethod
+    async def auser_pass(cls):
+        return cls.user_pass
+
+    @classmethod
+    async def auser_deny(cls):
+        return cls.user_deny
+
+    def test_wrapped_sync_function_is_not_coroutine_function(self):
+        def sync_view(request):
+            return HttpResponse()
+
+        wrapped_view = user_passes_test(lambda user: True)(sync_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), False)
+
+    def test_wrapped_async_function_is_coroutine_function(self):
+        async def async_view(request):
+            return HttpResponse()
+
+        wrapped_view = user_passes_test(lambda user: True)(async_view)
+        self.assertIs(iscoroutinefunction(wrapped_view), True)
+
     def test_decorator(self):
     def test_decorator(self):
         def sync_test_func(user):
         def sync_test_func(user):
             return bool(
             return bool(
@@ -180,3 +336,56 @@ class UserPassesTestDecoratorTest(TestCase):
         request.user = self.user_deny
         request.user = self.user_deny
         response = sync_view(request)
         response = sync_view(request)
         self.assertEqual(response.status_code, 302)
         self.assertEqual(response.status_code, 302)
+
+    def test_decorator_async_test_func(self):
+        async def async_test_func(user):
+            return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+
+        @user_passes_test(async_test_func)
+        def sync_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.user = self.user_pass
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.user = self.user_deny
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_decorator_async_view(self):
+        def sync_test_func(user):
+            return bool(
+                models.Group.objects.filter(name__istartswith=user.username).exists()
+            )
+
+        @user_passes_test(sync_test_func)
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser_pass
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.auser = self.auser_deny
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)
+
+    async def test_decorator_async_view_async_test_func(self):
+        async def async_test_func(user):
+            return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
+
+        @user_passes_test(async_test_func)
+        async def async_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.auser = self.auser_pass
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.auser = self.auser_deny
+        response = await async_view(request)
+        self.assertEqual(response.status_code, 302)