瀏覽代碼

Refs #35030 -- Added more tests for @user_passes_test decorator.

Mariusz Felisiak 1 年之前
父節點
當前提交
c4df2a7776
共有 1 個文件被更改,包括 39 次插入1 次删除
  1. 39 1
      tests/auth_tests/test_decorators.py

+ 39 - 1
tests/auth_tests/test_decorators.py

@@ -1,6 +1,10 @@
 from django.conf import settings
 from django.contrib.auth import models
-from django.contrib.auth.decorators import login_required, permission_required
+from django.contrib.auth.decorators import (
+    login_required,
+    permission_required,
+    user_passes_test,
+)
 from django.core.exceptions import PermissionDenied
 from django.http import HttpResponse
 from django.test import TestCase, override_settings
@@ -142,3 +146,37 @@ class PermissionsRequiredDecoratorTest(TestCase):
         request.user = self.user
         with self.assertRaises(PermissionDenied):
             a_view(request)
+
+
+class UserPassesTestDecoratorTest(TestCase):
+    factory = RequestFactory()
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.user_pass = models.User.objects.create(username="joe", password="qwerty")
+        cls.user_deny = models.User.objects.create(username="jim", password="qwerty")
+        models.Group.objects.create(name="Joe group")
+        # Add permissions auth.add_customuser and auth.change_customuser
+        perms = models.Permission.objects.filter(
+            codename__in=("add_customuser", "change_customuser")
+        )
+        cls.user_pass.user_permissions.add(*perms)
+
+    def test_decorator(self):
+        def sync_test_func(user):
+            return bool(
+                models.Group.objects.filter(name__istartswith=user.username).exists()
+            )
+
+        @user_passes_test(sync_test_func)
+        def sync_view(request):
+            return HttpResponse()
+
+        request = self.factory.get("/rand")
+        request.user = self.user_pass
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 200)
+
+        request.user = self.user_deny
+        response = sync_view(request)
+        self.assertEqual(response.status_code, 302)