Browse Source

Fixed #32596 -- Added CsrfViewMiddleware._check_referer().

This encapsulates CsrfViewMiddleware's referer logic into a method and
updates existing tests to check the "seam" introduced by the refactor,
when doing so would improve the test.
Chris Jerdonek 4 years ago
parent
commit
71179a6124
2 changed files with 77 additions and 44 deletions
  1. 52 43
      django/middleware/csrf.py
  2. 25 1
      tests/csrf_tests/tests.py

+ 52 - 43
django/middleware/csrf.py

@@ -132,6 +132,11 @@ def _compare_masked_tokens(request_csrf_token, csrf_token):
     )
 
 
+class RejectRequest(Exception):
+    def __init__(self, reason):
+        self.reason = reason
+
+
 class CsrfViewMiddleware(MiddlewareMixin):
     """
     Require a present and correct csrfmiddlewaretoken for POST requests that
@@ -251,6 +256,50 @@ class CsrfViewMiddleware(MiddlewareMixin):
             for host in self.allowed_origin_subdomains.get(request_scheme, ())
         )
 
+    def _check_referer(self, request):
+        referer = request.META.get('HTTP_REFERER')
+        if referer is None:
+            raise RejectRequest(REASON_NO_REFERER)
+
+        try:
+            referer = urlparse(referer)
+        except ValueError:
+            raise RejectRequest(REASON_MALFORMED_REFERER)
+
+        # Make sure we have a valid URL for Referer.
+        if '' in (referer.scheme, referer.netloc):
+            raise RejectRequest(REASON_MALFORMED_REFERER)
+
+        # Ensure that our Referer is also secure.
+        if referer.scheme != 'https':
+            raise RejectRequest(REASON_INSECURE_REFERER)
+
+        good_referer = (
+            settings.SESSION_COOKIE_DOMAIN
+            if settings.CSRF_USE_SESSIONS
+            else settings.CSRF_COOKIE_DOMAIN
+        )
+        if good_referer is None:
+            # If no cookie domain is configured, allow matching the current
+            # host:port exactly if it's permitted by ALLOWED_HOSTS.
+            try:
+                # request.get_host() includes the port.
+                good_referer = request.get_host()
+            except DisallowedHost:
+                pass
+        else:
+            server_port = request.get_port()
+            if server_port not in ('443', '80'):
+                good_referer = '%s:%s' % (good_referer, server_port)
+
+        # Create an iterable of all acceptable HTTP referers.
+        good_hosts = self.csrf_trusted_origins_hosts
+        if good_referer is not None:
+            good_hosts = (*good_hosts, good_referer)
+
+        if not any(is_same_domain(referer.netloc, host) for host in good_hosts):
+            raise RejectRequest(REASON_BAD_REFERER % referer.geturl())
+
     def process_request(self, request):
         csrf_token = self._get_token(request)
         if csrf_token is not None:
@@ -300,50 +349,10 @@ class CsrfViewMiddleware(MiddlewareMixin):
                 # Barth et al. found that the Referer header is missing for
                 # same-domain requests in only about 0.2% of cases or less, so
                 # we can use strict Referer checking.
-                referer = request.META.get('HTTP_REFERER')
-                if referer is None:
-                    return self._reject(request, REASON_NO_REFERER)
-
                 try:
-                    referer = urlparse(referer)
-                except ValueError:
-                    return self._reject(request, REASON_MALFORMED_REFERER)
-
-                # Make sure we have a valid URL for Referer.
-                if '' in (referer.scheme, referer.netloc):
-                    return self._reject(request, REASON_MALFORMED_REFERER)
-
-                # Ensure that our Referer is also secure.
-                if referer.scheme != 'https':
-                    return self._reject(request, REASON_INSECURE_REFERER)
-
-                good_referer = (
-                    settings.SESSION_COOKIE_DOMAIN
-                    if settings.CSRF_USE_SESSIONS
-                    else settings.CSRF_COOKIE_DOMAIN
-                )
-                if good_referer is None:
-                    # If no cookie domain is configured, allow matching the
-                    # current host:port exactly if it's permitted by
-                    # ALLOWED_HOSTS.
-                    try:
-                        # request.get_host() includes the port.
-                        good_referer = request.get_host()
-                    except DisallowedHost:
-                        pass
-                else:
-                    server_port = request.get_port()
-                    if server_port not in ('443', '80'):
-                        good_referer = '%s:%s' % (good_referer, server_port)
-
-                # Create an iterable of all acceptable HTTP referers.
-                good_hosts = self.csrf_trusted_origins_hosts
-                if good_referer is not None:
-                    good_hosts = (*good_hosts, good_referer)
-
-                if not any(is_same_domain(referer.netloc, host) for host in good_hosts):
-                    reason = REASON_BAD_REFERER % referer.geturl()
-                    return self._reject(request, reason)
+                    self._check_referer(request)
+                except RejectRequest as exc:
+                    return self._reject(request, exc.reason)
 
             # Access csrf_token via self._get_token() as rotate_token() may
             # have been called by an authentication middleware during the

+ 25 - 1
tests/csrf_tests/tests.py

@@ -6,7 +6,7 @@ from django.core.exceptions import ImproperlyConfigured
 from django.http import HttpRequest, HttpResponse
 from django.middleware.csrf import (
     CSRF_SESSION_KEY, CSRF_TOKEN_LENGTH, REASON_BAD_ORIGIN, REASON_BAD_TOKEN,
-    REASON_NO_CSRF_COOKIE, CsrfViewMiddleware,
+    REASON_NO_CSRF_COOKIE, CsrfViewMiddleware, RejectRequest,
     _compare_masked_tokens as equivalent_tokens, get_token,
 )
 from django.test import SimpleTestCase, override_settings
@@ -305,12 +305,17 @@ class CsrfViewMiddlewareTestMixin:
             status_code=403,
         )
 
+    def _check_referer_rejects(self, mw, req):
+        with self.assertRaises(RejectRequest):
+            mw._check_referer(req)
+
     @override_settings(DEBUG=True)
     def test_https_no_referer(self):
         """A POST HTTPS request with a missing referer is rejected."""
         req = self._get_POST_request_with_token()
         req._is_secure_override = True
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(
             response,
@@ -329,6 +334,12 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_REFERER'] = 'https://www.evil.org/somepage'
         req.META['SERVER_PORT'] = '443'
         mw = CsrfViewMiddleware(token_view)
+        expected = (
+            'Referer checking failed - https://www.evil.org/somepage does not '
+            'match any trusted origins.'
+        )
+        with self.assertRaisesMessage(RejectRequest, expected):
+            mw._check_referer(req)
         response = mw.process_view(req, token_view, (), {})
         self.assertEqual(response.status_code, 403)
 
@@ -338,6 +349,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = '@malformed'
         req.META['HTTP_ORIGIN'] = 'https://www.evil.org'
         mw = CsrfViewMiddleware(token_view)
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, token_view, (), {})
         self.assertEqual(response.status_code, 403)
 
@@ -351,6 +363,7 @@ class CsrfViewMiddlewareTestMixin:
         req._is_secure_override = True
         req.META['HTTP_REFERER'] = 'http://http://www.example.com/'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(
             response,
@@ -359,28 +372,33 @@ class CsrfViewMiddlewareTestMixin:
         )
         # Empty
         req.META['HTTP_REFERER'] = ''
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(response, malformed_referer_msg, status_code=403)
         # Non-ASCII
         req.META['HTTP_REFERER'] = 'ØBöIß'
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(response, malformed_referer_msg, status_code=403)
         # missing scheme
         # >>> urlparse('//example.com/')
         # ParseResult(scheme='', netloc='example.com', path='/', params='', query='', fragment='')
         req.META['HTTP_REFERER'] = '//example.com/'
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(response, malformed_referer_msg, status_code=403)
         # missing netloc
         # >>> urlparse('https://')
         # ParseResult(scheme='https', netloc='', path='', params='', query='', fragment='')
         req.META['HTTP_REFERER'] = 'https://'
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(response, malformed_referer_msg, status_code=403)
         # Invalid URL
         # >>> urlparse('https://[')
         # ValueError: Invalid IPv6 URL
         req.META['HTTP_REFERER'] = 'https://['
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(response, malformed_referer_msg, status_code=403)
 
@@ -562,6 +580,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = 'www.example.com'
         req.META['HTTP_ORIGIN'] = 'https://www.evil.org'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         self.assertIs(mw._origin_verified(req), False)
         with self.assertLogs('django.security.csrf', 'WARNING') as cm:
             response = mw.process_view(req, post_form_view, (), {})
@@ -576,6 +595,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = 'www.example.com'
         req.META['HTTP_ORIGIN'] = 'null'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         self.assertIs(mw._origin_verified(req), False)
         with self.assertLogs('django.security.csrf', 'WARNING') as cm:
             response = mw.process_view(req, post_form_view, (), {})
@@ -591,6 +611,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = 'www.example.com'
         req.META['HTTP_ORIGIN'] = 'http://example.com'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         self.assertIs(mw._origin_verified(req), False)
         with self.assertLogs('django.security.csrf', 'WARNING') as cm:
             response = mw.process_view(req, post_form_view, (), {})
@@ -617,6 +638,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = 'www.example.com'
         req.META['HTTP_ORIGIN'] = 'http://foo.example.com'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         self.assertIs(mw._origin_verified(req), False)
         with self.assertLogs('django.security.csrf', 'WARNING') as cm:
             response = mw.process_view(req, post_form_view, (), {})
@@ -639,6 +661,7 @@ class CsrfViewMiddlewareTestMixin:
         req.META['HTTP_HOST'] = 'www.example.com'
         req.META['HTTP_ORIGIN'] = 'https://['
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         self.assertIs(mw._origin_verified(req), False)
         with self.assertLogs('django.security.csrf', 'WARNING') as cm:
             response = mw.process_view(req, post_form_view, (), {})
@@ -867,6 +890,7 @@ class CsrfViewMiddlewareTests(CsrfViewMiddlewareTestMixin, SimpleTestCase):
         req.META['HTTP_REFERER'] = 'http://example.com/'
         req.META['SERVER_PORT'] = '443'
         mw = CsrfViewMiddleware(post_form_view)
+        self._check_referer_rejects(mw, req)
         response = mw.process_view(req, post_form_view, (), {})
         self.assertContains(
             response,