|
@@ -9,7 +9,10 @@ from importlib import import_module
|
|
|
from io import BytesIO
|
|
|
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
|
|
|
|
|
|
+from asgiref.sync import sync_to_async
|
|
|
+
|
|
|
from django.conf import settings
|
|
|
+from django.core.handlers.asgi import ASGIRequest
|
|
|
from django.core.handlers.base import BaseHandler
|
|
|
from django.core.handlers.wsgi import WSGIRequest
|
|
|
from django.core.serializers.json import DjangoJSONEncoder
|
|
@@ -157,6 +160,52 @@ class ClientHandler(BaseHandler):
|
|
|
return response
|
|
|
|
|
|
|
|
|
+class AsyncClientHandler(BaseHandler):
|
|
|
+ """An async version of ClientHandler."""
|
|
|
+ def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
|
|
+ self.enforce_csrf_checks = enforce_csrf_checks
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ async def __call__(self, scope):
|
|
|
+ # Set up middleware if needed. We couldn't do this earlier, because
|
|
|
+ # settings weren't available.
|
|
|
+ if self._middleware_chain is None:
|
|
|
+ self.load_middleware(is_async=True)
|
|
|
+ # Extract body file from the scope, if provided.
|
|
|
+ if '_body_file' in scope:
|
|
|
+ body_file = scope.pop('_body_file')
|
|
|
+ else:
|
|
|
+ body_file = FakePayload('')
|
|
|
+
|
|
|
+ request_started.disconnect(close_old_connections)
|
|
|
+ await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
|
|
|
+ request_started.connect(close_old_connections)
|
|
|
+ request = ASGIRequest(scope, body_file)
|
|
|
+ # Sneaky little hack so that we can easily get round
|
|
|
+ # CsrfViewMiddleware. This makes life easier, and is probably required
|
|
|
+ # for backwards compatibility with external tests against admin views.
|
|
|
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
|
|
+ # Request goes through middleware.
|
|
|
+ response = await self.get_response_async(request)
|
|
|
+ # Simulate behaviors of most Web servers.
|
|
|
+ conditional_content_removal(request, response)
|
|
|
+ # Attach the originating ASGI request to the response so that it could
|
|
|
+ # be later retrieved.
|
|
|
+ response.asgi_request = request
|
|
|
+ # Emulate a server by calling the close method on completion.
|
|
|
+ if response.streaming:
|
|
|
+ response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
|
|
|
+ response.streaming_content,
|
|
|
+ response.close,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ request_finished.disconnect(close_old_connections)
|
|
|
+ # Will fire request_finished.
|
|
|
+ await sync_to_async(response.close)()
|
|
|
+ request_finished.connect(close_old_connections)
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
|
|
|
"""
|
|
|
Store templates and contexts that are rendered.
|
|
@@ -421,7 +470,194 @@ class RequestFactory:
|
|
|
return self.request(**r)
|
|
|
|
|
|
|
|
|
-class Client(RequestFactory):
|
|
|
+class AsyncRequestFactory(RequestFactory):
|
|
|
+ """
|
|
|
+ Class that lets you create mock ASGI-like Request objects for use in
|
|
|
+ testing. Usage:
|
|
|
+
|
|
|
+ rf = AsyncRequestFactory()
|
|
|
+ get_request = await rf.get('/hello/')
|
|
|
+ post_request = await rf.post('/submit/', {'foo': 'bar'})
|
|
|
+
|
|
|
+ Once you have a request object you can pass it to any view function,
|
|
|
+ including synchronous ones. The reason we have a separate class here is:
|
|
|
+ a) this makes ASGIRequest subclasses, and
|
|
|
+ b) AsyncTestClient can subclass it.
|
|
|
+ """
|
|
|
+ def _base_scope(self, **request):
|
|
|
+ """The base scope for a request."""
|
|
|
+ # This is a minimal valid ASGI scope, plus:
|
|
|
+ # - headers['cookie'] for cookie support,
|
|
|
+ # - 'client' often useful, see #8551.
|
|
|
+ scope = {
|
|
|
+ 'asgi': {'version': '3.0'},
|
|
|
+ 'type': 'http',
|
|
|
+ 'http_version': '1.1',
|
|
|
+ 'client': ['127.0.0.1', 0],
|
|
|
+ 'server': ('testserver', '80'),
|
|
|
+ 'scheme': 'http',
|
|
|
+ 'method': 'GET',
|
|
|
+ 'headers': [],
|
|
|
+ **self.defaults,
|
|
|
+ **request,
|
|
|
+ }
|
|
|
+ scope['headers'].append((
|
|
|
+ b'cookie',
|
|
|
+ b'; '.join(sorted(
|
|
|
+ ('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
|
|
|
+ for morsel in self.cookies.values()
|
|
|
+ )),
|
|
|
+ ))
|
|
|
+ return scope
|
|
|
+
|
|
|
+ def request(self, **request):
|
|
|
+ """Construct a generic request object."""
|
|
|
+ # This is synchronous, which means all methods on this class are.
|
|
|
+ # AsyncClient, however, has an async request function, which makes all
|
|
|
+ # its methods async.
|
|
|
+ if '_body_file' in request:
|
|
|
+ body_file = request.pop('_body_file')
|
|
|
+ else:
|
|
|
+ body_file = FakePayload('')
|
|
|
+ return ASGIRequest(self._base_scope(**request), body_file)
|
|
|
+
|
|
|
+ def generic(
|
|
|
+ self, method, path, data='', content_type='application/octet-stream',
|
|
|
+ secure=False, **extra,
|
|
|
+ ):
|
|
|
+ """Construct an arbitrary HTTP request."""
|
|
|
+ parsed = urlparse(str(path)) # path can be lazy.
|
|
|
+ data = force_bytes(data, settings.DEFAULT_CHARSET)
|
|
|
+ s = {
|
|
|
+ 'method': method,
|
|
|
+ 'path': self._get_path(parsed),
|
|
|
+ 'server': ('127.0.0.1', '443' if secure else '80'),
|
|
|
+ 'scheme': 'https' if secure else 'http',
|
|
|
+ 'headers': [(b'host', b'testserver')],
|
|
|
+ }
|
|
|
+ if data:
|
|
|
+ s['headers'].extend([
|
|
|
+ (b'content-length', bytes(len(data))),
|
|
|
+ (b'content-type', content_type.encode('ascii')),
|
|
|
+ ])
|
|
|
+ s['_body_file'] = FakePayload(data)
|
|
|
+ s.update(extra)
|
|
|
+ # If QUERY_STRING is absent or empty, we want to extract it from the
|
|
|
+ # URL.
|
|
|
+ if not s.get('query_string'):
|
|
|
+ s['query_string'] = parsed[4]
|
|
|
+ return self.request(**s)
|
|
|
+
|
|
|
+
|
|
|
+class ClientMixin:
|
|
|
+ """
|
|
|
+ Mixin with common methods between Client and AsyncClient.
|
|
|
+ """
|
|
|
+ def store_exc_info(self, **kwargs):
|
|
|
+ """Store exceptions when they are generated by a view."""
|
|
|
+ self.exc_info = sys.exc_info()
|
|
|
+
|
|
|
+ def check_exception(self, response):
|
|
|
+ """
|
|
|
+ Look for a signaled exception, clear the current context exception
|
|
|
+ data, re-raise the signaled exception, and clear the signaled exception
|
|
|
+ from the local cache.
|
|
|
+ """
|
|
|
+ response.exc_info = self.exc_info
|
|
|
+ if self.exc_info:
|
|
|
+ _, exc_value, _ = self.exc_info
|
|
|
+ self.exc_info = None
|
|
|
+ if self.raise_request_exception:
|
|
|
+ raise exc_value
|
|
|
+
|
|
|
+ @property
|
|
|
+ def session(self):
|
|
|
+ """Return the current session variables."""
|
|
|
+ engine = import_module(settings.SESSION_ENGINE)
|
|
|
+ cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
|
+ if cookie:
|
|
|
+ return engine.SessionStore(cookie.value)
|
|
|
+ session = engine.SessionStore()
|
|
|
+ session.save()
|
|
|
+ self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
|
|
+ return session
|
|
|
+
|
|
|
+ def login(self, **credentials):
|
|
|
+ """
|
|
|
+ Set the Factory to appear as if it has successfully logged into a site.
|
|
|
+
|
|
|
+ Return True if login is possible or False if the provided credentials
|
|
|
+ are incorrect.
|
|
|
+ """
|
|
|
+ from django.contrib.auth import authenticate
|
|
|
+ user = authenticate(**credentials)
|
|
|
+ if user:
|
|
|
+ self._login(user)
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def force_login(self, user, backend=None):
|
|
|
+ def get_backend():
|
|
|
+ from django.contrib.auth import load_backend
|
|
|
+ for backend_path in settings.AUTHENTICATION_BACKENDS:
|
|
|
+ backend = load_backend(backend_path)
|
|
|
+ if hasattr(backend, 'get_user'):
|
|
|
+ return backend_path
|
|
|
+
|
|
|
+ if backend is None:
|
|
|
+ backend = get_backend()
|
|
|
+ user.backend = backend
|
|
|
+ self._login(user, backend)
|
|
|
+
|
|
|
+ def _login(self, user, backend=None):
|
|
|
+ from django.contrib.auth import login
|
|
|
+ # Create a fake request to store login details.
|
|
|
+ request = HttpRequest()
|
|
|
+ if self.session:
|
|
|
+ request.session = self.session
|
|
|
+ else:
|
|
|
+ engine = import_module(settings.SESSION_ENGINE)
|
|
|
+ request.session = engine.SessionStore()
|
|
|
+ login(request, user, backend)
|
|
|
+ # Save the session values.
|
|
|
+ request.session.save()
|
|
|
+ # Set the cookie to represent the session.
|
|
|
+ session_cookie = settings.SESSION_COOKIE_NAME
|
|
|
+ self.cookies[session_cookie] = request.session.session_key
|
|
|
+ cookie_data = {
|
|
|
+ 'max-age': None,
|
|
|
+ 'path': '/',
|
|
|
+ 'domain': settings.SESSION_COOKIE_DOMAIN,
|
|
|
+ 'secure': settings.SESSION_COOKIE_SECURE or None,
|
|
|
+ 'expires': None,
|
|
|
+ }
|
|
|
+ self.cookies[session_cookie].update(cookie_data)
|
|
|
+
|
|
|
+ def logout(self):
|
|
|
+ """Log out the user by removing the cookies and session object."""
|
|
|
+ from django.contrib.auth import get_user, logout
|
|
|
+ request = HttpRequest()
|
|
|
+ if self.session:
|
|
|
+ request.session = self.session
|
|
|
+ request.user = get_user(request)
|
|
|
+ else:
|
|
|
+ engine = import_module(settings.SESSION_ENGINE)
|
|
|
+ request.session = engine.SessionStore()
|
|
|
+ logout(request)
|
|
|
+ self.cookies = SimpleCookie()
|
|
|
+
|
|
|
+ def _parse_json(self, response, **extra):
|
|
|
+ if not hasattr(response, '_json'):
|
|
|
+ if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
|
|
+ raise ValueError(
|
|
|
+ 'Content-Type header is "%s", not "application/json"'
|
|
|
+ % response.get('Content-Type')
|
|
|
+ )
|
|
|
+ response._json = json.loads(response.content.decode(response.charset), **extra)
|
|
|
+ return response._json
|
|
|
+
|
|
|
+
|
|
|
+class Client(ClientMixin, RequestFactory):
|
|
|
"""
|
|
|
A class that can act as a client for testing purposes.
|
|
|
|
|
@@ -446,23 +682,6 @@ class Client(RequestFactory):
|
|
|
self.exc_info = None
|
|
|
self.extra = None
|
|
|
|
|
|
- def store_exc_info(self, **kwargs):
|
|
|
- """Store exceptions when they are generated by a view."""
|
|
|
- self.exc_info = sys.exc_info()
|
|
|
-
|
|
|
- @property
|
|
|
- def session(self):
|
|
|
- """Return the current session variables."""
|
|
|
- engine = import_module(settings.SESSION_ENGINE)
|
|
|
- cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
|
|
- if cookie:
|
|
|
- return engine.SessionStore(cookie.value)
|
|
|
-
|
|
|
- session = engine.SessionStore()
|
|
|
- session.save()
|
|
|
- self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
|
|
- return session
|
|
|
-
|
|
|
def request(self, **request):
|
|
|
"""
|
|
|
The master request method. Compose the environment dictionary and pass
|
|
@@ -486,15 +705,8 @@ class Client(RequestFactory):
|
|
|
finally:
|
|
|
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
|
|
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
|
|
- # Look for a signaled exception, clear the current context exception
|
|
|
- # data, then re-raise the signaled exception. Also clear the signaled
|
|
|
- # exception from the local cache.
|
|
|
- response.exc_info = self.exc_info
|
|
|
- if self.exc_info:
|
|
|
- _, exc_value, _ = self.exc_info
|
|
|
- self.exc_info = None
|
|
|
- if self.raise_request_exception:
|
|
|
- raise exc_value
|
|
|
+ # Check for signaled exceptions.
|
|
|
+ self.check_exception(response)
|
|
|
# Save the client and request that stimulated the response.
|
|
|
response.client = self
|
|
|
response.request = request
|
|
@@ -583,85 +795,6 @@ class Client(RequestFactory):
|
|
|
response = self._handle_redirects(response, data=data, **extra)
|
|
|
return response
|
|
|
|
|
|
- def login(self, **credentials):
|
|
|
- """
|
|
|
- Set the Factory to appear as if it has successfully logged into a site.
|
|
|
-
|
|
|
- Return True if login is possible; False if the provided credentials
|
|
|
- are incorrect.
|
|
|
- """
|
|
|
- from django.contrib.auth import authenticate
|
|
|
- user = authenticate(**credentials)
|
|
|
- if user:
|
|
|
- self._login(user)
|
|
|
- return True
|
|
|
- else:
|
|
|
- return False
|
|
|
-
|
|
|
- def force_login(self, user, backend=None):
|
|
|
- def get_backend():
|
|
|
- from django.contrib.auth import load_backend
|
|
|
- for backend_path in settings.AUTHENTICATION_BACKENDS:
|
|
|
- backend = load_backend(backend_path)
|
|
|
- if hasattr(backend, 'get_user'):
|
|
|
- return backend_path
|
|
|
- if backend is None:
|
|
|
- backend = get_backend()
|
|
|
- user.backend = backend
|
|
|
- self._login(user, backend)
|
|
|
-
|
|
|
- def _login(self, user, backend=None):
|
|
|
- from django.contrib.auth import login
|
|
|
- engine = import_module(settings.SESSION_ENGINE)
|
|
|
-
|
|
|
- # Create a fake request to store login details.
|
|
|
- request = HttpRequest()
|
|
|
-
|
|
|
- if self.session:
|
|
|
- request.session = self.session
|
|
|
- else:
|
|
|
- request.session = engine.SessionStore()
|
|
|
- login(request, user, backend)
|
|
|
-
|
|
|
- # Save the session values.
|
|
|
- request.session.save()
|
|
|
-
|
|
|
- # Set the cookie to represent the session.
|
|
|
- session_cookie = settings.SESSION_COOKIE_NAME
|
|
|
- self.cookies[session_cookie] = request.session.session_key
|
|
|
- cookie_data = {
|
|
|
- 'max-age': None,
|
|
|
- 'path': '/',
|
|
|
- 'domain': settings.SESSION_COOKIE_DOMAIN,
|
|
|
- 'secure': settings.SESSION_COOKIE_SECURE or None,
|
|
|
- 'expires': None,
|
|
|
- }
|
|
|
- self.cookies[session_cookie].update(cookie_data)
|
|
|
-
|
|
|
- def logout(self):
|
|
|
- """Log out the user by removing the cookies and session object."""
|
|
|
- from django.contrib.auth import get_user, logout
|
|
|
-
|
|
|
- request = HttpRequest()
|
|
|
- engine = import_module(settings.SESSION_ENGINE)
|
|
|
- if self.session:
|
|
|
- request.session = self.session
|
|
|
- request.user = get_user(request)
|
|
|
- else:
|
|
|
- request.session = engine.SessionStore()
|
|
|
- logout(request)
|
|
|
- self.cookies = SimpleCookie()
|
|
|
-
|
|
|
- def _parse_json(self, response, **extra):
|
|
|
- if not hasattr(response, '_json'):
|
|
|
- if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
|
|
- raise ValueError(
|
|
|
- 'Content-Type header is "{}", not "application/json"'
|
|
|
- .format(response.get('Content-Type'))
|
|
|
- )
|
|
|
- response._json = json.loads(response.content.decode(response.charset), **extra)
|
|
|
- return response._json
|
|
|
-
|
|
|
def _handle_redirects(self, response, data='', content_type='', **extra):
|
|
|
"""
|
|
|
Follow any redirects by requesting responses from the server using GET.
|
|
@@ -714,3 +847,66 @@ class Client(RequestFactory):
|
|
|
raise RedirectCycleError("Too many redirects.", last_response=response)
|
|
|
|
|
|
return response
|
|
|
+
|
|
|
+
|
|
|
+class AsyncClient(ClientMixin, AsyncRequestFactory):
|
|
|
+ """
|
|
|
+ An async version of Client that creates ASGIRequests and calls through an
|
|
|
+ async request path.
|
|
|
+
|
|
|
+ Does not currently support "follow" on its methods.
|
|
|
+ """
|
|
|
+ def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
|
|
+ super().__init__(**defaults)
|
|
|
+ self.handler = AsyncClientHandler(enforce_csrf_checks)
|
|
|
+ self.raise_request_exception = raise_request_exception
|
|
|
+ self.exc_info = None
|
|
|
+ self.extra = None
|
|
|
+
|
|
|
+ async def request(self, **request):
|
|
|
+ """
|
|
|
+ The master request method. Compose the scope dictionary and pass to the
|
|
|
+ handler, return the result of the handler. Assume defaults for the
|
|
|
+ query environment, which can be overridden using the arguments to the
|
|
|
+ request.
|
|
|
+ """
|
|
|
+ if 'follow' in request:
|
|
|
+ raise NotImplementedError(
|
|
|
+ 'AsyncClient request methods do not accept the follow '
|
|
|
+ 'parameter.'
|
|
|
+ )
|
|
|
+ scope = self._base_scope(**request)
|
|
|
+ # Curry a data dictionary into an instance of the template renderer
|
|
|
+ # callback function.
|
|
|
+ data = {}
|
|
|
+ on_template_render = partial(store_rendered_templates, data)
|
|
|
+ signal_uid = 'template-render-%s' % id(request)
|
|
|
+ signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
|
|
+ # Capture exceptions created by the handler.
|
|
|
+ exception_uid = 'request-exception-%s' % id(request)
|
|
|
+ got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
|
|
+ try:
|
|
|
+ response = await self.handler(scope)
|
|
|
+ finally:
|
|
|
+ signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
|
|
+ got_request_exception.disconnect(dispatch_uid=exception_uid)
|
|
|
+ # Check for signaled exceptions.
|
|
|
+ self.check_exception(response)
|
|
|
+ # Save the client and request that stimulated the response.
|
|
|
+ response.client = self
|
|
|
+ response.request = request
|
|
|
+ # Add any rendered template detail to the response.
|
|
|
+ response.templates = data.get('templates', [])
|
|
|
+ response.context = data.get('context')
|
|
|
+ response.json = partial(self._parse_json, response)
|
|
|
+ # Attach the ResolverMatch instance to the response.
|
|
|
+ response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
|
|
|
+ # Flatten a single context. Not really necessary anymore thanks to the
|
|
|
+ # __getattr__ flattening in ContextList, but has some edge case
|
|
|
+ # backwards compatibility implications.
|
|
|
+ if response.context and len(response.context) == 1:
|
|
|
+ response.context = response.context[0]
|
|
|
+ # Update persistent cookie data.
|
|
|
+ if response.cookies:
|
|
|
+ self.cookies.update(response.cookies)
|
|
|
+ return response
|