test_decorators.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. from asyncio import iscoroutinefunction
  2. from asgiref.sync import sync_to_async
  3. from django.conf import settings
  4. from django.contrib.auth import models
  5. from django.contrib.auth.decorators import (
  6. login_required,
  7. permission_required,
  8. user_passes_test,
  9. )
  10. from django.core.exceptions import PermissionDenied
  11. from django.http import HttpResponse
  12. from django.test import TestCase, override_settings
  13. from django.test.client import RequestFactory
  14. from .test_views import AuthViewsTestCase
  15. @override_settings(ROOT_URLCONF="auth_tests.urls")
  16. class LoginRequiredTestCase(AuthViewsTestCase):
  17. """
  18. Tests the login_required decorators
  19. """
  20. factory = RequestFactory()
  21. def test_wrapped_sync_function_is_not_coroutine_function(self):
  22. def sync_view(request):
  23. return HttpResponse()
  24. wrapped_view = login_required(sync_view)
  25. self.assertIs(iscoroutinefunction(wrapped_view), False)
  26. def test_wrapped_async_function_is_coroutine_function(self):
  27. async def async_view(request):
  28. return HttpResponse()
  29. wrapped_view = login_required(async_view)
  30. self.assertIs(iscoroutinefunction(wrapped_view), True)
  31. def test_callable(self):
  32. """
  33. login_required is assignable to callable objects.
  34. """
  35. class CallableView:
  36. def __call__(self, *args, **kwargs):
  37. pass
  38. login_required(CallableView())
  39. def test_view(self):
  40. """
  41. login_required is assignable to normal views.
  42. """
  43. def normal_view(request):
  44. pass
  45. login_required(normal_view)
  46. def test_login_required(self, view_url="/login_required/", login_url=None):
  47. """
  48. login_required works on a simple view wrapped in a login_required
  49. decorator.
  50. """
  51. if login_url is None:
  52. login_url = settings.LOGIN_URL
  53. response = self.client.get(view_url)
  54. self.assertEqual(response.status_code, 302)
  55. self.assertIn(login_url, response.url)
  56. self.login()
  57. response = self.client.get(view_url)
  58. self.assertEqual(response.status_code, 200)
  59. def test_login_required_next_url(self):
  60. """
  61. login_required works on a simple view wrapped in a login_required
  62. decorator with a login_url set.
  63. """
  64. self.test_login_required(
  65. view_url="/login_required_login_url/", login_url="/somewhere/"
  66. )
  67. async def test_login_required_async_view(self, login_url=None):
  68. async def async_view(request):
  69. return HttpResponse()
  70. async def auser_anonymous():
  71. return models.AnonymousUser()
  72. async def auser():
  73. return self.u1
  74. if login_url is None:
  75. async_view = login_required(async_view)
  76. login_url = settings.LOGIN_URL
  77. else:
  78. async_view = login_required(async_view, login_url=login_url)
  79. request = self.factory.get("/rand")
  80. request.auser = auser_anonymous
  81. response = await async_view(request)
  82. self.assertEqual(response.status_code, 302)
  83. self.assertIn(login_url, response.url)
  84. request.auser = auser
  85. response = await async_view(request)
  86. self.assertEqual(response.status_code, 200)
  87. async def test_login_required_next_url_async_view(self):
  88. await self.test_login_required_async_view(login_url="/somewhere/")
  89. class PermissionsRequiredDecoratorTest(TestCase):
  90. """
  91. Tests for the permission_required decorator
  92. """
  93. factory = RequestFactory()
  94. @classmethod
  95. def setUpTestData(cls):
  96. cls.user = models.User.objects.create(username="joe", password="qwerty")
  97. # Add permissions auth.add_customuser and auth.change_customuser
  98. perms = models.Permission.objects.filter(
  99. codename__in=("add_customuser", "change_customuser")
  100. )
  101. cls.user.user_permissions.add(*perms)
  102. @classmethod
  103. async def auser(cls):
  104. return cls.user
  105. def test_wrapped_sync_function_is_not_coroutine_function(self):
  106. def sync_view(request):
  107. return HttpResponse()
  108. wrapped_view = permission_required([])(sync_view)
  109. self.assertIs(iscoroutinefunction(wrapped_view), False)
  110. def test_wrapped_async_function_is_coroutine_function(self):
  111. async def async_view(request):
  112. return HttpResponse()
  113. wrapped_view = permission_required([])(async_view)
  114. self.assertIs(iscoroutinefunction(wrapped_view), True)
  115. def test_many_permissions_pass(self):
  116. @permission_required(
  117. ["auth_tests.add_customuser", "auth_tests.change_customuser"]
  118. )
  119. def a_view(request):
  120. return HttpResponse()
  121. request = self.factory.get("/rand")
  122. request.user = self.user
  123. resp = a_view(request)
  124. self.assertEqual(resp.status_code, 200)
  125. def test_many_permissions_in_set_pass(self):
  126. @permission_required(
  127. {"auth_tests.add_customuser", "auth_tests.change_customuser"}
  128. )
  129. def a_view(request):
  130. return HttpResponse()
  131. request = self.factory.get("/rand")
  132. request.user = self.user
  133. resp = a_view(request)
  134. self.assertEqual(resp.status_code, 200)
  135. def test_single_permission_pass(self):
  136. @permission_required("auth_tests.add_customuser")
  137. def a_view(request):
  138. return HttpResponse()
  139. request = self.factory.get("/rand")
  140. request.user = self.user
  141. resp = a_view(request)
  142. self.assertEqual(resp.status_code, 200)
  143. def test_permissioned_denied_redirect(self):
  144. @permission_required(
  145. [
  146. "auth_tests.add_customuser",
  147. "auth_tests.change_customuser",
  148. "nonexistent-permission",
  149. ]
  150. )
  151. def a_view(request):
  152. return HttpResponse()
  153. request = self.factory.get("/rand")
  154. request.user = self.user
  155. resp = a_view(request)
  156. self.assertEqual(resp.status_code, 302)
  157. def test_permissioned_denied_exception_raised(self):
  158. @permission_required(
  159. [
  160. "auth_tests.add_customuser",
  161. "auth_tests.change_customuser",
  162. "nonexistent-permission",
  163. ],
  164. raise_exception=True,
  165. )
  166. def a_view(request):
  167. return HttpResponse()
  168. request = self.factory.get("/rand")
  169. request.user = self.user
  170. with self.assertRaises(PermissionDenied):
  171. a_view(request)
  172. async def test_many_permissions_pass_async_view(self):
  173. @permission_required(
  174. ["auth_tests.add_customuser", "auth_tests.change_customuser"]
  175. )
  176. async def async_view(request):
  177. return HttpResponse()
  178. request = self.factory.get("/rand")
  179. request.auser = self.auser
  180. response = await async_view(request)
  181. self.assertEqual(response.status_code, 200)
  182. async def test_many_permissions_in_set_pass_async_view(self):
  183. @permission_required(
  184. {"auth_tests.add_customuser", "auth_tests.change_customuser"}
  185. )
  186. async def async_view(request):
  187. return HttpResponse()
  188. request = self.factory.get("/rand")
  189. request.auser = self.auser
  190. response = await async_view(request)
  191. self.assertEqual(response.status_code, 200)
  192. async def test_single_permission_pass_async_view(self):
  193. @permission_required("auth_tests.add_customuser")
  194. async def async_view(request):
  195. return HttpResponse()
  196. request = self.factory.get("/rand")
  197. request.auser = self.auser
  198. response = await async_view(request)
  199. self.assertEqual(response.status_code, 200)
  200. async def test_permissioned_denied_redirect_async_view(self):
  201. @permission_required(
  202. [
  203. "auth_tests.add_customuser",
  204. "auth_tests.change_customuser",
  205. "nonexistent-permission",
  206. ]
  207. )
  208. async def async_view(request):
  209. return HttpResponse()
  210. request = self.factory.get("/rand")
  211. request.auser = self.auser
  212. response = await async_view(request)
  213. self.assertEqual(response.status_code, 302)
  214. async def test_permissioned_denied_exception_raised_async_view(self):
  215. @permission_required(
  216. [
  217. "auth_tests.add_customuser",
  218. "auth_tests.change_customuser",
  219. "nonexistent-permission",
  220. ],
  221. raise_exception=True,
  222. )
  223. async def async_view(request):
  224. return HttpResponse()
  225. request = self.factory.get("/rand")
  226. request.auser = self.auser
  227. with self.assertRaises(PermissionDenied):
  228. await async_view(request)
  229. class UserPassesTestDecoratorTest(TestCase):
  230. factory = RequestFactory()
  231. @classmethod
  232. def setUpTestData(cls):
  233. cls.user_pass = models.User.objects.create(username="joe", password="qwerty")
  234. cls.user_deny = models.User.objects.create(username="jim", password="qwerty")
  235. models.Group.objects.create(name="Joe group")
  236. # Add permissions auth.add_customuser and auth.change_customuser
  237. perms = models.Permission.objects.filter(
  238. codename__in=("add_customuser", "change_customuser")
  239. )
  240. cls.user_pass.user_permissions.add(*perms)
  241. @classmethod
  242. async def auser_pass(cls):
  243. return cls.user_pass
  244. @classmethod
  245. async def auser_deny(cls):
  246. return cls.user_deny
  247. def test_wrapped_sync_function_is_not_coroutine_function(self):
  248. def sync_view(request):
  249. return HttpResponse()
  250. wrapped_view = user_passes_test(lambda user: True)(sync_view)
  251. self.assertIs(iscoroutinefunction(wrapped_view), False)
  252. def test_wrapped_async_function_is_coroutine_function(self):
  253. async def async_view(request):
  254. return HttpResponse()
  255. wrapped_view = user_passes_test(lambda user: True)(async_view)
  256. self.assertIs(iscoroutinefunction(wrapped_view), True)
  257. def test_decorator(self):
  258. def sync_test_func(user):
  259. return bool(
  260. models.Group.objects.filter(name__istartswith=user.username).exists()
  261. )
  262. @user_passes_test(sync_test_func)
  263. def sync_view(request):
  264. return HttpResponse()
  265. request = self.factory.get("/rand")
  266. request.user = self.user_pass
  267. response = sync_view(request)
  268. self.assertEqual(response.status_code, 200)
  269. request.user = self.user_deny
  270. response = sync_view(request)
  271. self.assertEqual(response.status_code, 302)
  272. def test_decorator_async_test_func(self):
  273. async def async_test_func(user):
  274. return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
  275. @user_passes_test(async_test_func)
  276. def sync_view(request):
  277. return HttpResponse()
  278. request = self.factory.get("/rand")
  279. request.user = self.user_pass
  280. response = sync_view(request)
  281. self.assertEqual(response.status_code, 200)
  282. request.user = self.user_deny
  283. response = sync_view(request)
  284. self.assertEqual(response.status_code, 302)
  285. async def test_decorator_async_view(self):
  286. def sync_test_func(user):
  287. return bool(
  288. models.Group.objects.filter(name__istartswith=user.username).exists()
  289. )
  290. @user_passes_test(sync_test_func)
  291. async def async_view(request):
  292. return HttpResponse()
  293. request = self.factory.get("/rand")
  294. request.auser = self.auser_pass
  295. response = await async_view(request)
  296. self.assertEqual(response.status_code, 200)
  297. request.auser = self.auser_deny
  298. response = await async_view(request)
  299. self.assertEqual(response.status_code, 302)
  300. async def test_decorator_async_view_async_test_func(self):
  301. async def async_test_func(user):
  302. return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
  303. @user_passes_test(async_test_func)
  304. async def async_view(request):
  305. return HttpResponse()
  306. request = self.factory.get("/rand")
  307. request.auser = self.auser_pass
  308. response = await async_view(request)
  309. self.assertEqual(response.status_code, 200)
  310. request.auser = self.auser_deny
  311. response = await async_view(request)
  312. self.assertEqual(response.status_code, 302)