Browse Source

[5.0.x] Fixed #35059 -- Ensured that ASGIHandler always sends the request_finished signal.

Prior to this work, when async tasks that process the request are cancelled due
to receiving an early "http.disconnect" ASGI message, the request_finished
signal was not being sent, potentially leading to resource leaks (such as
database connections).

This branch ensures that the request_finished signal is sent even in the case
of early termination of the response.

Regression in 64cea1e48f285ea2162c669208d95188b32bbc82.

Co-authored-by: Natalia <124304+nessita@users.noreply.github.com>
Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>

Backport of 11393ab1316f973c5fbb534305750740d909b4e4 from main
James Thorniley 1 năm trước cách đây
mục cha
commit
f1fbd061ac
3 tập tin đã thay đổi với 153 bổ sung3 xóa
  1. 16 2
      django/core/handlers/asgi.py
  2. 4 0
      docs/releases/5.0.2.txt
  3. 133 1
      tests/asgi/tests.py

+ 16 - 2
django/core/handlers/asgi.py

@@ -186,11 +186,18 @@ class ASGIHandler(base.BaseHandler):
         if request is None:
             body_file.close()
             await self.send_response(error_response, send)
+            await sync_to_async(error_response.close)()
             return
 
         async def process_request(request, send):
             response = await self.run_get_response(request)
-            await self.send_response(response, send)
+            try:
+                await self.send_response(response, send)
+            except asyncio.CancelledError:
+                # Client disconnected during send_response (ignore exception).
+                pass
+
+            return response
 
         # Try to catch a disconnect while getting response.
         tasks = [
@@ -221,6 +228,14 @@ class ASGIHandler(base.BaseHandler):
                 except asyncio.CancelledError:
                     # Task re-raised the CancelledError as expected.
                     pass
+
+        try:
+            response = tasks[1].result()
+        except asyncio.CancelledError:
+            await signals.request_finished.asend(sender=self.__class__)
+        else:
+            await sync_to_async(response.close)()
+
         body_file.close()
 
     async def listen_for_disconnect(self, receive):
@@ -346,7 +361,6 @@ class ASGIHandler(base.BaseHandler):
                         "more_body": not last,
                     }
                 )
-        await sync_to_async(response.close, thread_sensitive=True)()
 
     @classmethod
     def chunk_bytes(cls, data):

+ 4 - 0
docs/releases/5.0.2.txt

@@ -28,3 +28,7 @@ Bugfixes
 * Fixed a regression in Django 5.0 that caused a crash of the ``dumpdata``
   management command when a base queryset used ``prefetch_related()``
   (:ticket:`35159`).
+
+* Fixed a regression in Django 5.0 that caused the ``request_finished`` signal to
+  sometimes not be fired when running Django through an ASGI server, resulting
+  in potential resource leaks (:ticket:`35059`).

+ 133 - 1
tests/asgi/tests.py

@@ -1,12 +1,15 @@
 import asyncio
 import sys
 import threading
+import time
 from pathlib import Path
 
+from asgiref.sync import sync_to_async
 from asgiref.testing import ApplicationCommunicator
 
 from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
 from django.core.asgi import get_asgi_application
+from django.core.exceptions import RequestDataTooBig
 from django.core.handlers.asgi import ASGIHandler, ASGIRequest
 from django.core.signals import request_finished, request_started
 from django.db import close_old_connections
@@ -20,6 +23,7 @@ from django.test import (
 )
 from django.urls import path
 from django.utils.http import http_date
+from django.views.decorators.csrf import csrf_exempt
 
 from .urls import sync_waiter, test_filename
 
@@ -207,6 +211,96 @@ class ASGITest(SimpleTestCase):
         self.assertEqual(response_body["type"], "http.response.body")
         self.assertEqual(response_body["body"], b"Echo!")
 
+    async def test_create_request_error(self):
+        # Track request_finished signal.
+        signal_handler = SignalHandler()
+        request_finished.connect(signal_handler)
+        self.addCleanup(request_finished.disconnect, signal_handler)
+
+        # Request class that always fails creation with RequestDataTooBig.
+        class TestASGIRequest(ASGIRequest):
+
+            def __init__(self, scope, body_file):
+                super().__init__(scope, body_file)
+                raise RequestDataTooBig()
+
+        # Handler to use the custom request class.
+        class TestASGIHandler(ASGIHandler):
+            request_class = TestASGIRequest
+
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(path="/not-important/")
+        communicator = ApplicationCommunicator(application, scope)
+
+        # Initiate request.
+        await communicator.send_input({"type": "http.request"})
+        # Give response.close() time to finish.
+        await communicator.wait()
+
+        self.assertEqual(len(signal_handler.calls), 1)
+        self.assertNotEqual(
+            signal_handler.calls[0]["thread"], threading.current_thread()
+        )
+
+    async def test_cancel_post_request_with_sync_processing(self):
+        """
+        The request.body object should be available and readable in view
+        code, even if the ASGIHandler cancels processing part way through.
+        """
+        loop = asyncio.get_event_loop()
+        # Events to monitor the view processing from the parent test code.
+        view_started_event = asyncio.Event()
+        view_finished_event = asyncio.Event()
+        # Record received request body or exceptions raised in the test view
+        outcome = []
+
+        # This view will run in a new thread because it is wrapped in
+        # sync_to_async. The view consumes the POST body data after a short
+        # delay. The test will cancel the request using http.disconnect during
+        # the delay, but because this is a sync view the code runs to
+        # completion. There should be no exceptions raised inside the view
+        # code.
+        @csrf_exempt
+        @sync_to_async
+        def post_view(request):
+            try:
+                loop.call_soon_threadsafe(view_started_event.set)
+                time.sleep(0.1)
+                # Do something to read request.body after pause
+                outcome.append({"request_body": request.body})
+                return HttpResponse("ok")
+            except Exception as e:
+                outcome.append({"exception": e})
+            finally:
+                loop.call_soon_threadsafe(view_finished_event.set)
+
+        # Request class to use the view.
+        class TestASGIRequest(ASGIRequest):
+            urlconf = (path("post/", post_view),)
+
+        # Handler to use request class.
+        class TestASGIHandler(ASGIHandler):
+            request_class = TestASGIRequest
+
+        application = TestASGIHandler()
+        scope = self.async_request_factory._base_scope(
+            method="POST",
+            path="/post/",
+        )
+        communicator = ApplicationCommunicator(application, scope)
+
+        await communicator.send_input({"type": "http.request", "body": b"Body data!"})
+
+        # Wait until the view code has started, then send http.disconnect.
+        await view_started_event.wait()
+        await communicator.send_input({"type": "http.disconnect"})
+        # Wait until view code has finished.
+        await view_finished_event.wait()
+        with self.assertRaises(asyncio.TimeoutError):
+            await communicator.receive_output()
+
+        self.assertEqual(outcome, [{"request_body": b"Body data!"}])
+
     async def test_untouched_request_body_gets_closed(self):
         application = get_asgi_application()
         scope = self.async_request_factory._base_scope(method="POST", path="/post/")
@@ -347,7 +441,9 @@ class ASGITest(SimpleTestCase):
         # AsyncToSync should have executed the signals in the same thread.
         self.assertEqual(len(signal_handler.calls), 2)
         request_started_call, request_finished_call = signal_handler.calls
-        self.assertEqual(request_started_call["thread"], request_finished_call["thread"])
+        self.assertEqual(
+            request_started_call["thread"], request_finished_call["thread"]
+        )
 
     async def test_concurrent_async_uses_multiple_thread_pools(self):
         sync_waiter.active_threads.clear()
@@ -383,6 +479,10 @@ class ASGITest(SimpleTestCase):
     async def test_asyncio_cancel_error(self):
         # Flag to check if the view was cancelled.
         view_did_cancel = False
+        # Track request_finished signal.
+        signal_handler = SignalHandler()
+        request_finished.connect(signal_handler)
+        self.addCleanup(request_finished.disconnect, signal_handler)
 
         # A view that will listen for the cancelled error.
         async def view(request):
@@ -417,6 +517,13 @@ class ASGITest(SimpleTestCase):
         # Give response.close() time to finish.
         await communicator.wait()
         self.assertIs(view_did_cancel, False)
+        # Exactly one call to request_finished handler.
+        self.assertEqual(len(signal_handler.calls), 1)
+        handler_call = signal_handler.calls.pop()
+        # It was NOT on the async thread.
+        self.assertNotEqual(handler_call["thread"], threading.current_thread())
+        # The signal sender is the handler class.
+        self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
 
         # Request cycle with a disconnect before the view can respond.
         application = TestASGIHandler()
@@ -432,11 +539,22 @@ class ASGITest(SimpleTestCase):
             await communicator.receive_output()
         await communicator.wait()
         self.assertIs(view_did_cancel, True)
+        # Exactly one call to request_finished handler.
+        self.assertEqual(len(signal_handler.calls), 1)
+        handler_call = signal_handler.calls.pop()
+        # It was NOT on the async thread.
+        self.assertNotEqual(handler_call["thread"], threading.current_thread())
+        # The signal sender is the handler class.
+        self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
 
     async def test_asyncio_streaming_cancel_error(self):
         # Similar to test_asyncio_cancel_error(), but during a streaming
         # response.
         view_did_cancel = False
+        # Track request_finished signals.
+        signal_handler = SignalHandler()
+        request_finished.connect(signal_handler)
+        self.addCleanup(request_finished.disconnect, signal_handler)
 
         async def streaming_response():
             nonlocal view_did_cancel
@@ -471,6 +589,13 @@ class ASGITest(SimpleTestCase):
         self.assertEqual(response_body["body"], b"Hello World!")
         await communicator.wait()
         self.assertIs(view_did_cancel, False)
+        # Exactly one call to request_finished handler.
+        self.assertEqual(len(signal_handler.calls), 1)
+        handler_call = signal_handler.calls.pop()
+        # It was NOT on the async thread.
+        self.assertNotEqual(handler_call["thread"], threading.current_thread())
+        # The signal sender is the handler class.
+        self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
 
         # Request cycle with a disconnect.
         application = TestASGIHandler()
@@ -489,6 +614,13 @@ class ASGITest(SimpleTestCase):
             await communicator.receive_output()
         await communicator.wait()
         self.assertIs(view_did_cancel, True)
+        # Exactly one call to request_finished handler.
+        self.assertEqual(len(signal_handler.calls), 1)
+        handler_call = signal_handler.calls.pop()
+        # It was NOT on the async thread.
+        self.assertNotEqual(handler_call["thread"], threading.current_thread())
+        # The signal sender is the handler class.
+        self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})
 
     async def test_streaming(self):
         scope = self.async_request_factory._base_scope(