Browse Source

Fixed #34757 -- Added support for following redirects to AsyncClient.

Olivier Tabone 1 year ago
parent
commit
3f8dbe267d
4 changed files with 252 additions and 12 deletions
  1. 231 7
      django/test/client.py
  2. 2 0
      docs/releases/5.0.txt
  3. 4 1
      docs/topics/testing/tools.txt
  4. 15 4
      tests/test_client/tests.py

+ 231 - 7
django/test/client.py

@@ -705,9 +705,6 @@ class AsyncRequestFactory(RequestFactory):
                 ]
             )
             s["_body_file"] = FakePayload(data)
-        follow = extra.pop("follow", None)
-        if follow is not None:
-            s["follow"] = follow
         if query_string := extra.pop("QUERY_STRING", None):
             s["query_string"] = query_string
         if headers:
@@ -1296,10 +1293,6 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
         query environment, which can be overridden using the arguments to the
         request.
         """
-        if "follow" in request:
-            raise NotImplementedError(
-                "AsyncClient request methods do not accept the follow parameter."
-            )
         scope = self._base_scope(**request)
         # Curry a data dictionary into an instance of the template renderer
         # callback function.
@@ -1338,3 +1331,234 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
         if response.cookies:
             self.cookies.update(response.cookies)
         return response
+
+    async def get(
+        self,
+        path,
+        data=None,
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Request a response from the server using GET."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().get(
+            path, data=data, secure=secure, headers=headers, **extra
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, headers=headers, **extra
+            )
+        return response
+
+    async def post(
+        self,
+        path,
+        data=None,
+        content_type=MULTIPART_CONTENT,
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Request a response from the server using POST."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().post(
+            path,
+            data=data,
+            content_type=content_type,
+            secure=secure,
+            headers=headers,
+            **extra,
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, content_type=content_type, headers=headers, **extra
+            )
+        return response
+
+    async def head(
+        self,
+        path,
+        data=None,
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Request a response from the server using HEAD."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().head(
+            path, data=data, secure=secure, headers=headers, **extra
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, headers=headers, **extra
+            )
+        return response
+
+    async def options(
+        self,
+        path,
+        data="",
+        content_type="application/octet-stream",
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Request a response from the server using OPTIONS."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().options(
+            path,
+            data=data,
+            content_type=content_type,
+            secure=secure,
+            headers=headers,
+            **extra,
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, content_type=content_type, headers=headers, **extra
+            )
+        return response
+
+    async def put(
+        self,
+        path,
+        data="",
+        content_type="application/octet-stream",
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Send a resource to the server using PUT."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().put(
+            path,
+            data=data,
+            content_type=content_type,
+            secure=secure,
+            headers=headers,
+            **extra,
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, content_type=content_type, headers=headers, **extra
+            )
+        return response
+
+    async def patch(
+        self,
+        path,
+        data="",
+        content_type="application/octet-stream",
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Send a resource to the server using PATCH."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().patch(
+            path,
+            data=data,
+            content_type=content_type,
+            secure=secure,
+            headers=headers,
+            **extra,
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, content_type=content_type, headers=headers, **extra
+            )
+        return response
+
+    async def delete(
+        self,
+        path,
+        data="",
+        content_type="application/octet-stream",
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Send a DELETE request to the server."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().delete(
+            path,
+            data=data,
+            content_type=content_type,
+            secure=secure,
+            headers=headers,
+            **extra,
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, content_type=content_type, headers=headers, **extra
+            )
+        return response
+
+    async def trace(
+        self,
+        path,
+        data="",
+        follow=False,
+        secure=False,
+        *,
+        headers=None,
+        **extra,
+    ):
+        """Send a TRACE request to the server."""
+        self.extra = extra
+        self.headers = headers
+        response = await super().trace(
+            path, data=data, secure=secure, headers=headers, **extra
+        )
+        if follow:
+            response = await self._ahandle_redirects(
+                response, data=data, headers=headers, **extra
+            )
+        return response
+
+    async def _ahandle_redirects(
+        self,
+        response,
+        data="",
+        content_type="",
+        headers=None,
+        **extra,
+    ):
+        """
+        Follow any redirects by requesting responses from the server using GET.
+        """
+        response.redirect_chain = []
+        while response.status_code in REDIRECT_STATUS_CODES:
+            redirect_chain = response.redirect_chain
+            response = await self._follow_redirect(
+                response,
+                data=data,
+                content_type=content_type,
+                headers=headers,
+                **extra,
+            )
+            response.redirect_chain = redirect_chain
+            self._ensure_redirects_not_cyclic(response)
+        return response

+ 2 - 0
docs/releases/5.0.txt

@@ -433,6 +433,8 @@ Tests
   :meth:`~django.test.Client.aforce_login`, and
   :meth:`~django.test.Client.alogout`.
 
+* :class:`~django.test.AsyncClient` now supports the ``follow`` parameter.
+
 URLs
 ~~~~
 

+ 4 - 1
docs/topics/testing/tools.txt

@@ -2032,7 +2032,6 @@ test client, with the following exceptions:
 
 * In the initialization, arbitrary keyword arguments in ``defaults`` are added
   directly into the ASGI scope.
-* The ``follow`` parameter is not supported.
 * Headers passed as ``extra`` keyword arguments should not have the ``HTTP_``
   prefix required by the synchronous client (see :meth:`Client.get`). For
   example, here is how to set an HTTP ``Accept`` header:
@@ -2046,6 +2045,10 @@ test client, with the following exceptions:
 
     The ``headers`` parameter was added.
 
+.. versionchanged:: 5.0
+
+    Support for the ``follow`` parameter was added to the ``AsyncClient``.
+
 Using ``AsyncClient`` any method that makes a request must be awaited::
 
     async def test_my_thing(self):

+ 15 - 4
tests/test_client/tests.py

@@ -1135,8 +1135,11 @@ class AsyncClientTest(TestCase):
         response = await self.async_client.get("/middleware_urlconf_view/")
         self.assertEqual(response.resolver_match.url_name, "middleware_urlconf_view")
 
-    async def test_follow_parameter_not_implemented(self):
-        msg = "AsyncClient request methods do not accept the follow parameter."
+    async def test_redirect(self):
+        response = await self.async_client.get("/redirect_view/")
+        self.assertEqual(response.status_code, 302)
+
+    async def test_follow_redirect(self):
         tests = (
             "get",
             "post",
@@ -1150,8 +1153,16 @@ class AsyncClientTest(TestCase):
         for method_name in tests:
             with self.subTest(method=method_name):
                 method = getattr(self.async_client, method_name)
-                with self.assertRaisesMessage(NotImplementedError, msg):
-                    await method("/redirect_view/", follow=True)
+                response = await method("/redirect_view/", follow=True)
+                self.assertEqual(response.status_code, 200)
+                self.assertEqual(response.resolver_match.url_name, "get_view")
+
+    async def test_follow_double_redirect(self):
+        response = await self.async_client.get("/double_redirect_view/", follow=True)
+        self.assertRedirects(
+            response, "/get_view/", status_code=302, target_status_code=200
+        )
+        self.assertEqual(len(response.redirect_chain), 2)
 
     async def test_get_data(self):
         response = await self.async_client.get("/get_view/", {"var": "val"})