asgi.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import asyncio
  2. import logging
  3. import sys
  4. import tempfile
  5. import traceback
  6. from contextlib import aclosing
  7. from asgiref.sync import ThreadSensitiveContext, sync_to_async
  8. from django.conf import settings
  9. from django.core import signals
  10. from django.core.exceptions import RequestAborted, RequestDataTooBig
  11. from django.core.handlers import base
  12. from django.http import (
  13. FileResponse,
  14. HttpRequest,
  15. HttpResponse,
  16. HttpResponseBadRequest,
  17. HttpResponseServerError,
  18. QueryDict,
  19. parse_cookie,
  20. )
  21. from django.urls import set_script_prefix
  22. from django.utils.functional import cached_property
  23. logger = logging.getLogger("django.request")
  24. def get_script_prefix(scope):
  25. """
  26. Return the script prefix to use from either the scope or a setting.
  27. """
  28. if settings.FORCE_SCRIPT_NAME:
  29. return settings.FORCE_SCRIPT_NAME
  30. return scope.get("root_path", "") or ""
  31. class ASGIRequest(HttpRequest):
  32. """
  33. Custom request subclass that decodes from an ASGI-standard request dict
  34. and wraps request body handling.
  35. """
  36. # Number of seconds until a Request gives up on trying to read a request
  37. # body and aborts.
  38. body_receive_timeout = 60
  39. def __init__(self, scope, body_file):
  40. self.scope = scope
  41. self._post_parse_error = False
  42. self._read_started = False
  43. self.resolver_match = None
  44. self.script_name = get_script_prefix(scope)
  45. if self.script_name:
  46. # TODO: Better is-prefix checking, slash handling?
  47. self.path_info = scope["path"].removeprefix(self.script_name)
  48. else:
  49. self.path_info = scope["path"]
  50. # The Django path is different from ASGI scope path args, it should
  51. # combine with script name.
  52. if self.script_name:
  53. self.path = "%s/%s" % (
  54. self.script_name.rstrip("/"),
  55. self.path_info.replace("/", "", 1),
  56. )
  57. else:
  58. self.path = scope["path"]
  59. # HTTP basics.
  60. self.method = self.scope["method"].upper()
  61. # Ensure query string is encoded correctly.
  62. query_string = self.scope.get("query_string", "")
  63. if isinstance(query_string, bytes):
  64. query_string = query_string.decode()
  65. self.META = {
  66. "REQUEST_METHOD": self.method,
  67. "QUERY_STRING": query_string,
  68. "SCRIPT_NAME": self.script_name,
  69. "PATH_INFO": self.path_info,
  70. # WSGI-expecting code will need these for a while
  71. "wsgi.multithread": True,
  72. "wsgi.multiprocess": True,
  73. }
  74. if self.scope.get("client"):
  75. self.META["REMOTE_ADDR"] = self.scope["client"][0]
  76. self.META["REMOTE_HOST"] = self.META["REMOTE_ADDR"]
  77. self.META["REMOTE_PORT"] = self.scope["client"][1]
  78. if self.scope.get("server"):
  79. self.META["SERVER_NAME"] = self.scope["server"][0]
  80. self.META["SERVER_PORT"] = str(self.scope["server"][1])
  81. else:
  82. self.META["SERVER_NAME"] = "unknown"
  83. self.META["SERVER_PORT"] = "0"
  84. # Headers go into META.
  85. for name, value in self.scope.get("headers", []):
  86. name = name.decode("latin1")
  87. if name == "content-length":
  88. corrected_name = "CONTENT_LENGTH"
  89. elif name == "content-type":
  90. corrected_name = "CONTENT_TYPE"
  91. else:
  92. corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
  93. # HTTP/2 say only ASCII chars are allowed in headers, but decode
  94. # latin1 just in case.
  95. value = value.decode("latin1")
  96. if corrected_name in self.META:
  97. value = self.META[corrected_name] + "," + value
  98. self.META[corrected_name] = value
  99. # Pull out request encoding, if provided.
  100. self._set_content_type_params(self.META)
  101. # Directly assign the body file to be our stream.
  102. self._stream = body_file
  103. # Other bits.
  104. self.resolver_match = None
  105. @cached_property
  106. def GET(self):
  107. return QueryDict(self.META["QUERY_STRING"])
  108. def _get_scheme(self):
  109. return self.scope.get("scheme") or super()._get_scheme()
  110. def _get_post(self):
  111. if not hasattr(self, "_post"):
  112. self._load_post_and_files()
  113. return self._post
  114. def _set_post(self, post):
  115. self._post = post
  116. def _get_files(self):
  117. if not hasattr(self, "_files"):
  118. self._load_post_and_files()
  119. return self._files
  120. POST = property(_get_post, _set_post)
  121. FILES = property(_get_files)
  122. @cached_property
  123. def COOKIES(self):
  124. return parse_cookie(self.META.get("HTTP_COOKIE", ""))
  125. def close(self):
  126. super().close()
  127. self._stream.close()
  128. class ASGIHandler(base.BaseHandler):
  129. """Handler for ASGI requests."""
  130. request_class = ASGIRequest
  131. # Size to chunk response bodies into for multiple response messages.
  132. chunk_size = 2**16
  133. def __init__(self):
  134. super().__init__()
  135. self.load_middleware(is_async=True)
  136. async def __call__(self, scope, receive, send):
  137. """
  138. Async entrypoint - parses the request and hands off to get_response.
  139. """
  140. # Serve only HTTP connections.
  141. # FIXME: Allow to override this.
  142. if scope["type"] != "http":
  143. raise ValueError(
  144. "Django can only handle ASGI/HTTP connections, not %s." % scope["type"]
  145. )
  146. async with ThreadSensitiveContext():
  147. await self.handle(scope, receive, send)
  148. async def handle(self, scope, receive, send):
  149. """
  150. Handles the ASGI request. Called via the __call__ method.
  151. """
  152. # Receive the HTTP request body as a stream object.
  153. try:
  154. body_file = await self.read_body(receive)
  155. except RequestAborted:
  156. return
  157. # Request is complete and can be served.
  158. set_script_prefix(get_script_prefix(scope))
  159. await signals.request_started.asend(sender=self.__class__, scope=scope)
  160. # Get the request and check for basic issues.
  161. request, error_response = self.create_request(scope, body_file)
  162. if request is None:
  163. body_file.close()
  164. await self.send_response(error_response, send)
  165. return
  166. async def process_request(request, send):
  167. response = await self.run_get_response(request)
  168. await self.send_response(response, send)
  169. # Try to catch a disconnect while getting response.
  170. tasks = [
  171. # Check the status of these tasks and (optionally) terminate them
  172. # in this order. The listen_for_disconnect() task goes first
  173. # because it should not raise unexpected errors that would prevent
  174. # us from cancelling process_request().
  175. asyncio.create_task(self.listen_for_disconnect(receive)),
  176. asyncio.create_task(process_request(request, send)),
  177. ]
  178. await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
  179. # Now wait on both tasks (they may have both finished by now).
  180. for task in tasks:
  181. if task.done():
  182. try:
  183. task.result()
  184. except RequestAborted:
  185. # Ignore client disconnects.
  186. pass
  187. except AssertionError:
  188. body_file.close()
  189. raise
  190. else:
  191. # Allow views to handle cancellation.
  192. task.cancel()
  193. try:
  194. await task
  195. except asyncio.CancelledError:
  196. # Task re-raised the CancelledError as expected.
  197. pass
  198. body_file.close()
  199. async def listen_for_disconnect(self, receive):
  200. """Listen for disconnect from the client."""
  201. message = await receive()
  202. if message["type"] == "http.disconnect":
  203. raise RequestAborted()
  204. # This should never happen.
  205. assert False, "Invalid ASGI message after request body: %s" % message["type"]
  206. async def run_get_response(self, request):
  207. """Get async response."""
  208. # Use the async mode of BaseHandler.
  209. response = await self.get_response_async(request)
  210. response._handler_class = self.__class__
  211. # Increase chunk size on file responses (ASGI servers handles low-level
  212. # chunking).
  213. if isinstance(response, FileResponse):
  214. response.block_size = self.chunk_size
  215. return response
  216. async def read_body(self, receive):
  217. """Reads an HTTP body from an ASGI connection."""
  218. # Use the tempfile that auto rolls-over to a disk file as it fills up.
  219. body_file = tempfile.SpooledTemporaryFile(
  220. max_size=settings.FILE_UPLOAD_MAX_MEMORY_SIZE, mode="w+b"
  221. )
  222. while True:
  223. message = await receive()
  224. if message["type"] == "http.disconnect":
  225. body_file.close()
  226. # Early client disconnect.
  227. raise RequestAborted()
  228. # Add a body chunk from the message, if provided.
  229. if "body" in message:
  230. body_file.write(message["body"])
  231. # Quit out if that's the end.
  232. if not message.get("more_body", False):
  233. break
  234. body_file.seek(0)
  235. return body_file
  236. def create_request(self, scope, body_file):
  237. """
  238. Create the Request object and returns either (request, None) or
  239. (None, response) if there is an error response.
  240. """
  241. try:
  242. return self.request_class(scope, body_file), None
  243. except UnicodeDecodeError:
  244. logger.warning(
  245. "Bad Request (UnicodeDecodeError)",
  246. exc_info=sys.exc_info(),
  247. extra={"status_code": 400},
  248. )
  249. return None, HttpResponseBadRequest()
  250. except RequestDataTooBig:
  251. return None, HttpResponse("413 Payload too large", status=413)
  252. def handle_uncaught_exception(self, request, resolver, exc_info):
  253. """Last-chance handler for exceptions."""
  254. # There's no WSGI server to catch the exception further up
  255. # if this fails, so translate it into a plain text response.
  256. try:
  257. return super().handle_uncaught_exception(request, resolver, exc_info)
  258. except Exception:
  259. return HttpResponseServerError(
  260. traceback.format_exc() if settings.DEBUG else "Internal Server Error",
  261. content_type="text/plain",
  262. )
  263. async def send_response(self, response, send):
  264. """Encode and send a response out over ASGI."""
  265. # Collect cookies into headers. Have to preserve header case as there
  266. # are some non-RFC compliant clients that require e.g. Content-Type.
  267. response_headers = []
  268. for header, value in response.items():
  269. if isinstance(header, str):
  270. header = header.encode("ascii")
  271. if isinstance(value, str):
  272. value = value.encode("latin1")
  273. response_headers.append((bytes(header), bytes(value)))
  274. for c in response.cookies.values():
  275. response_headers.append(
  276. (b"Set-Cookie", c.output(header="").encode("ascii").strip())
  277. )
  278. # Initial response message.
  279. await send(
  280. {
  281. "type": "http.response.start",
  282. "status": response.status_code,
  283. "headers": response_headers,
  284. }
  285. )
  286. # Streaming responses need to be pinned to their iterator.
  287. if response.streaming:
  288. # - Consume via `__aiter__` and not `streaming_content` directly, to
  289. # allow mapping of a sync iterator.
  290. # - Use aclosing() when consuming aiter.
  291. # See https://github.com/python/cpython/commit/6e8dcda
  292. async with aclosing(aiter(response)) as content:
  293. async for part in content:
  294. for chunk, _ in self.chunk_bytes(part):
  295. await send(
  296. {
  297. "type": "http.response.body",
  298. "body": chunk,
  299. # Ignore "more" as there may be more parts; instead,
  300. # use an empty final closing message with False.
  301. "more_body": True,
  302. }
  303. )
  304. # Final closing message.
  305. await send({"type": "http.response.body"})
  306. # Other responses just need chunking.
  307. else:
  308. # Yield chunks of response.
  309. for chunk, last in self.chunk_bytes(response.content):
  310. await send(
  311. {
  312. "type": "http.response.body",
  313. "body": chunk,
  314. "more_body": not last,
  315. }
  316. )
  317. await sync_to_async(response.close, thread_sensitive=True)()
  318. @classmethod
  319. def chunk_bytes(cls, data):
  320. """
  321. Chunks some data up so it can be sent in reasonable size messages.
  322. Yields (chunk, last_chunk) tuples.
  323. """
  324. position = 0
  325. if not data:
  326. yield data, True
  327. return
  328. while position < len(data):
  329. yield (
  330. data[position : position + cls.chunk_size],
  331. (position + cls.chunk_size) >= len(data),
  332. )
  333. position += cls.chunk_size