فهرست منبع

Fixed #34752 -- Fixed handling ASGI http.disconnect for streaming responses.

Sam Toyer 1 سال پیش
والد
کامیت
64cea1e48f
5فایلهای تغییر یافته به همراه197 افزوده شده و 23 حذف شده
  1. 31 20
      django/core/handlers/asgi.py
  2. 30 0
      docs/ref/request-response.txt
  3. 3 0
      docs/topics/async.txt
  4. 119 2
      tests/asgi/tests.py
  5. 14 1
      tests/asgi/urls.py

+ 31 - 20
django/core/handlers/asgi.py

@@ -187,30 +187,41 @@ class ASGIHandler(base.BaseHandler):
             body_file.close()
             await self.send_response(error_response, send)
             return
+
+        async def process_request(request, send):
+            response = await self.run_get_response(request)
+            await self.send_response(response, send)
+
         # Try to catch a disconnect while getting response.
         tasks = [
-            asyncio.create_task(self.run_get_response(request)),
+            # Check the status of these tasks and (optionally) terminate them
+            # in this order. The listen_for_disconnect() task goes first
+            # because it should not raise unexpected errors that would prevent
+            # us from cancelling process_request().
             asyncio.create_task(self.listen_for_disconnect(receive)),
+            asyncio.create_task(process_request(request, send)),
         ]
-        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
-        done, pending = done.pop(), pending.pop()
-        # Allow views to handle cancellation.
-        pending.cancel()
-        try:
-            await pending
-        except asyncio.CancelledError:
-            # Task re-raised the CancelledError as expected.
-            pass
-        try:
-            response = done.result()
-        except RequestAborted:
-            body_file.close()
-            return
-        except AssertionError:
-            body_file.close()
-            raise
-        # Send the response.
-        await self.send_response(response, send)
+        await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+        # Now wait on both tasks (they may have both finished by now).
+        for task in tasks:
+            if task.done():
+                try:
+                    task.result()
+                except RequestAborted:
+                    # Ignore client disconnects.
+                    pass
+                except AssertionError:
+                    body_file.close()
+                    raise
+            else:
+                # Allow views to handle cancellation.
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    # Task re-raised the CancelledError as expected.
+                    pass
+        body_file.close()
 
     async def listen_for_disconnect(self, receive):
         """Listen for disconnect from the client."""

+ 30 - 0
docs/ref/request-response.txt

@@ -1282,6 +1282,36 @@ Attributes
     This is useful for middleware needing to wrap
     :attr:`StreamingHttpResponse.streaming_content`.
 
+.. _request-response-streaming-disconnect:
+
+Handling disconnects
+--------------------
+
+.. versionadded:: 5.0
+
+If the client disconnects during a streaming response, Django will cancel the
+coroutine that is handling the response. If you want to clean up resources
+manually, you can do so by catching the ``asyncio.CancelledError``::
+
+    async def streaming_response():
+        try:
+            # Do some work here
+            async for chunk in my_streaming_iterator():
+                yield chunk
+        except asyncio.CancelledError:
+            # Handle disconnect
+            ...
+            raise
+
+
+    async def my_streaming_view(request):
+        return StreamingHttpResponse(streaming_response())
+
+This example only shows how to handle client disconnection while the response
+is streaming. If you perform long-running operations in your view before
+returning the ``StreamingHttpResponse`` object, then you may also want to
+:ref:`handle disconnections in the view <async-handling-disconnect>` itself.
+
 ``FileResponse`` objects
 ========================
 

+ 3 - 0
docs/topics/async.txt

@@ -197,6 +197,9 @@ cleanup::
             # Handle disconnect
             raise
 
+You can also :ref:`handle client disconnects in streaming responses
+<request-response-streaming-disconnect>`.
+
 .. _async-safety:
 
 Async safety

+ 119 - 2
tests/asgi/tests.py

@@ -10,7 +10,7 @@ from django.core.asgi import get_asgi_application
 from django.core.handlers.asgi import ASGIHandler, ASGIRequest
 from django.core.signals import request_finished, request_started
 from django.db import close_old_connections
-from django.http import HttpResponse
+from django.http import HttpResponse, StreamingHttpResponse
 from django.test import (
     AsyncRequestFactory,
     SimpleTestCase,
@@ -237,6 +237,31 @@ class ASGITest(SimpleTestCase):
         with self.assertRaises(asyncio.TimeoutError):
             await communicator.receive_output()
 
+    async def test_disconnect_both_return(self):
+        # Force both the disconnect listener and the task that sends the
+        # response to finish at the same time.
+        application = get_asgi_application()
+        scope = self.async_request_factory._base_scope(path="/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request", "body": b"some body"})
+        # Fetch response headers (this yields to asyncio and causes
+        # ASGHandler.send_response() to dump the body of the response in the
+        # queue).
+        await communicator.receive_output()
+        # Fetch response body (there's already some data queued up, so this
+        # doesn't actually yield to the event loop, it just succeeds
+        # instantly).
+        await communicator.receive_output()
+        # Send disconnect at the same time that response finishes (this just
+        # puts some info in a queue, it doesn't have to yield to the event
+        # loop).
+        await communicator.send_input({"type": "http.disconnect"})
+        # Waiting for the communicator _does_ yield to the event loop, since
+        # ASGIHandler.send_response() is still waiting to do response.close().
+        # It so happens that there are enough remaining yield points in both
+        # tasks that they both finish while the loop is running.
+        await communicator.wait()
+
     async def test_disconnect_with_body(self):
         application = get_asgi_application()
         scope = self.async_request_factory._base_scope(path="/")
@@ -254,7 +279,7 @@ class ASGITest(SimpleTestCase):
         await communicator.send_input({"type": "http.not_a_real_message"})
         msg = "Invalid ASGI message after request body: http.not_a_real_message"
         with self.assertRaisesMessage(AssertionError, msg):
-            await communicator.receive_output()
+            await communicator.wait()
 
     async def test_delayed_disconnect_with_body(self):
         application = get_asgi_application()
@@ -402,3 +427,95 @@ class ASGITest(SimpleTestCase):
             await communicator.receive_output()
         await communicator.wait()
         self.assertIs(view_did_cancel, True)
+
+    async def test_asyncio_streaming_cancel_error(self):
+        # Similar to test_asyncio_cancel_error(), but during a streaming
+        # response.
+        view_did_cancel = False
+
+        async def streaming_response():
+            nonlocal view_did_cancel
+            try:
+                await asyncio.sleep(0.2)
+                yield b"Hello World!"
+            except asyncio.CancelledError:
+                # Set the flag.
+                view_did_cancel = True
+                raise
+
+        async def view(request):
+            return StreamingHttpResponse(streaming_response())
+
+        class TestASGIRequest(ASGIRequest):
+            urlconf = (path("cancel/", view),)
+
+        class TestASGIHandler(ASGIHandler):
+            request_class = TestASGIRequest
+
+        # With no disconnect, the request cycle should complete in the same
+        # manner as the non-streaming response.
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(path="/cancel/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        response_start = await communicator.receive_output()
+        self.assertEqual(response_start["type"], "http.response.start")
+        self.assertEqual(response_start["status"], 200)
+        response_body = await communicator.receive_output()
+        self.assertEqual(response_body["type"], "http.response.body")
+        self.assertEqual(response_body["body"], b"Hello World!")
+        await communicator.wait()
+        self.assertIs(view_did_cancel, False)
+
+        # Request cycle with a disconnect.
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(path="/cancel/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        response_start = await communicator.receive_output()
+        # Fetch the start of response so streaming can begin
+        self.assertEqual(response_start["type"], "http.response.start")
+        self.assertEqual(response_start["status"], 200)
+        await asyncio.sleep(0.1)
+        # Now disconnect the client.
+        await communicator.send_input({"type": "http.disconnect"})
+        # This time the handler should not send a response.
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output()
+        await communicator.wait()
+        self.assertIs(view_did_cancel, True)
+
+    async def test_streaming(self):
+        scope = self.async_request_factory._base_scope(
+            path="/streaming/", query_string=b"sleep=0.001"
+        )
+        application = get_asgi_application()
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        # Fetch http.response.start.
+        await communicator.receive_output(timeout=1)
+        # Fetch the 'first' and 'last'.
+        first_response = await communicator.receive_output(timeout=1)
+        self.assertEqual(first_response["body"], b"first\n")
+        second_response = await communicator.receive_output(timeout=1)
+        self.assertEqual(second_response["body"], b"last\n")
+        # Fetch the rest of the response so that coroutines are cleaned up.
+        await communicator.receive_output(timeout=1)
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output(timeout=1)
+
+    async def test_streaming_disconnect(self):
+        scope = self.async_request_factory._base_scope(
+            path="/streaming/", query_string=b"sleep=0.1"
+        )
+        application = get_asgi_application()
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        await communicator.receive_output(timeout=1)
+        first_response = await communicator.receive_output(timeout=1)
+        self.assertEqual(first_response["body"], b"first\n")
+        # Disconnect the client.
+        await communicator.send_input({"type": "http.disconnect"})
+        # 'last\n' isn't sent.
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output(timeout=0.2)

+ 14 - 1
tests/asgi/urls.py

@@ -1,7 +1,8 @@
+import asyncio
 import threading
 import time
 
-from django.http import FileResponse, HttpResponse
+from django.http import FileResponse, HttpResponse, StreamingHttpResponse
 from django.urls import path
 from django.views.decorators.csrf import csrf_exempt
 
@@ -44,6 +45,17 @@ sync_waiter.lock = threading.Lock()
 sync_waiter.barrier = threading.Barrier(2)
 
 
+async def streaming_inner(sleep_time):
+    yield b"first\n"
+    await asyncio.sleep(sleep_time)
+    yield b"last\n"
+
+
+async def streaming_view(request):
+    sleep_time = float(request.GET["sleep"])
+    return StreamingHttpResponse(streaming_inner(sleep_time))
+
+
 test_filename = __file__
 
 
@@ -54,4 +66,5 @@ urlpatterns = [
     path("post/", post_echo),
     path("wait/", sync_waiter),
     path("delayed_hello/", hello_with_delay),
+    path("streaming/", streaming_view),
 ]