瀏覽代碼

Fixed #33738 -- Allowed handling ASGI http.disconnect in long-lived requests.

th3nn3ss 2 年之前
父節點
當前提交
1d1ddffc27
共有 5 個文件被更改,包括 157 次插入3 次删除
  1. 38 3
      django/core/handlers/asgi.py
  2. 7 0
      docs/releases/5.0.txt
  3. 20 0
      docs/topics/async.txt
  4. 84 0
      tests/asgi/tests.py
  5. 8 0
      tests/asgi/urls.py

+ 38 - 3
django/core/handlers/asgi.py

@@ -1,3 +1,4 @@
+import asyncio
 import logging
 import sys
 import tempfile
@@ -177,15 +178,49 @@ class ASGIHandler(base.BaseHandler):
             body_file.close()
             await self.send_response(error_response, send)
             return
-        # Get the response, using the async mode of BaseHandler.
+        # Try to catch a disconnect while getting response.
+        tasks = [
+            asyncio.create_task(self.run_get_response(request)),
+            asyncio.create_task(self.listen_for_disconnect(receive)),
+        ]
+        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+        done, pending = done.pop(), pending.pop()
+        # Allow views to handle cancellation.
+        pending.cancel()
+        try:
+            await pending
+        except asyncio.CancelledError:
+            # Task re-raised the CancelledError as expected.
+            pass
+        try:
+            response = done.result()
+        except RequestAborted:
+            body_file.close()
+            return
+        except AssertionError:
+            body_file.close()
+            raise
+        # Send the response.
+        await self.send_response(response, send)
+
+    async def listen_for_disconnect(self, receive):
+        """Listen for disconnect from the client."""
+        message = await receive()
+        if message["type"] == "http.disconnect":
+            raise RequestAborted()
+        # This should never happen.
+        assert False, "Invalid ASGI message after request body: %s" % message["type"]
+
+    async def run_get_response(self, request):
+        """Get async response."""
+        # Use the async mode of BaseHandler.
         response = await self.get_response_async(request)
         response._handler_class = self.__class__
         # Increase chunk size on file responses (ASGI servers handles low-level
         # chunking).
         if isinstance(response, FileResponse):
             response.block_size = self.chunk_size
-        # Send the response.
-        await self.send_response(response, send)
+        return response
 
     async def read_body(self, receive):
         """Reads an HTTP body from an ASGI connection."""

+ 7 - 0
docs/releases/5.0.txt

@@ -192,6 +192,13 @@ Minor features
 
 * ...
 
+Asynchronous views
+~~~~~~~~~~~~~~~~~~
+
+* Under ASGI, ``http.disconnect`` events are now handled. This allows views to
+  perform any necessary cleanup if a client disconnects before the response is
+  generated. See :ref:`async-handling-disconnect` for more details.
+
 Cache
 ~~~~~
 

+ 20 - 0
docs/topics/async.txt

@@ -136,6 +136,26 @@ a purely synchronous codebase under ASGI because the request-handling code is
 still all running asynchronously. In general you will only want to enable ASGI
 mode if you have asynchronous code in your project.
 
+.. _async-handling-disconnect:
+
+Handling disconnects
+--------------------
+
+.. versionadded:: 5.0
+
+For long-lived requests, a client may disconnect before the view returns a
+response. In this case, an ``asyncio.CancelledError`` will be raised in the
+view. You can catch this error and handle it if you need to perform any
+cleanup::
+
+    async def my_view(request):
+        try:
+            # Do some work
+            ...
+        except asyncio.CancelledError:
+            # Handle disconnect
+            raise
+
 .. _async-safety:
 
 Async safety

+ 84 - 0
tests/asgi/tests.py

@@ -7,8 +7,10 @@ from asgiref.testing import ApplicationCommunicator
 
 from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
 from django.core.asgi import get_asgi_application
+from django.core.handlers.asgi import ASGIHandler, ASGIRequest
 from django.core.signals import request_finished, request_started
 from django.db import close_old_connections
+from django.http import HttpResponse
 from django.test import (
     AsyncRequestFactory,
     SimpleTestCase,
@@ -16,6 +18,7 @@ from django.test import (
     modify_settings,
     override_settings,
 )
+from django.urls import path
 from django.utils.http import http_date
 
 from .urls import sync_waiter, test_filename
@@ -234,6 +237,34 @@ class ASGITest(SimpleTestCase):
         with self.assertRaises(asyncio.TimeoutError):
             await communicator.receive_output()
 
+    async def test_disconnect_with_body(self):
+        application = get_asgi_application()
+        scope = self.async_request_factory._base_scope(path="/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request", "body": b"some body"})
+        await communicator.send_input({"type": "http.disconnect"})
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output()
+
+    async def test_assert_in_listen_for_disconnect(self):
+        application = get_asgi_application()
+        scope = self.async_request_factory._base_scope(path="/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        await communicator.send_input({"type": "http.not_a_real_message"})
+        msg = "Invalid ASGI message after request body: http.not_a_real_message"
+        with self.assertRaisesMessage(AssertionError, msg):
+            await communicator.receive_output()
+
+    async def test_delayed_disconnect_with_body(self):
+        application = get_asgi_application()
+        scope = self.async_request_factory._base_scope(path="/delayed_hello/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request", "body": b"some body"})
+        await communicator.send_input({"type": "http.disconnect"})
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output()
+
     async def test_wrong_connection_type(self):
         application = get_asgi_application()
         scope = self.async_request_factory._base_scope(path="/", type="other")
@@ -318,3 +349,56 @@ class ASGITest(SimpleTestCase):
         self.assertEqual(len(sync_waiter.active_threads), 2)
 
         sync_waiter.active_threads.clear()
+
+    async def test_asyncio_cancel_error(self):
+        # Flag to check if the view was cancelled.
+        view_did_cancel = False
+
+        # A view that will listen for the cancelled error.
+        async def view(request):
+            nonlocal view_did_cancel
+            try:
+                await asyncio.sleep(0.2)
+                return HttpResponse("Hello World!")
+            except asyncio.CancelledError:
+                # Set the flag.
+                view_did_cancel = True
+                raise
+
+        # Request class to use the view.
+        class TestASGIRequest(ASGIRequest):
+            urlconf = (path("cancel/", view),)
+
+        # Handler to use request class.
+        class TestASGIHandler(ASGIHandler):
+            request_class = TestASGIRequest
+
+        # Request cycle should complete since no disconnect was sent.
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(path="/cancel/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        response_start = await communicator.receive_output()
+        self.assertEqual(response_start["type"], "http.response.start")
+        self.assertEqual(response_start["status"], 200)
+        response_body = await communicator.receive_output()
+        self.assertEqual(response_body["type"], "http.response.body")
+        self.assertEqual(response_body["body"], b"Hello World!")
+        # Give response.close() time to finish.
+        await communicator.wait()
+        self.assertIs(view_did_cancel, False)
+
+        # Request cycle with a disconnect before the view can respond.
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(path="/cancel/")
+        communicator = ApplicationCommunicator(application, scope)
+        await communicator.send_input({"type": "http.request"})
+        # Let the view actually start.
+        await asyncio.sleep(0.1)
+        # Disconnect the client.
+        await communicator.send_input({"type": "http.disconnect"})
+        # The handler should not send a response.
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output()
+        await communicator.wait()
+        self.assertIs(view_did_cancel, True)

+ 8 - 0
tests/asgi/urls.py

@@ -1,4 +1,5 @@
 import threading
+import time
 
 from django.http import FileResponse, HttpResponse
 from django.urls import path
@@ -10,6 +11,12 @@ def hello(request):
     return HttpResponse("Hello %s!" % name)
 
 
+def hello_with_delay(request):
+    name = request.GET.get("name") or "World"
+    time.sleep(1)
+    return HttpResponse(f"Hello {name}!")
+
+
 def hello_meta(request):
     return HttpResponse(
         "From %s" % request.META.get("HTTP_REFERER") or "",
@@ -46,4 +53,5 @@ urlpatterns = [
     path("meta/", hello_meta),
     path("post/", post_echo),
     path("wait/", sync_waiter),
+    path("delayed_hello/", hello_with_delay),
 ]