Pārlūkot izejas kodu

Fixed #20147 -- Added HttpRequest.headers.

Santiago Basulto 6 gadi atpakaļ
vecāks
revīzija
4fc35a9c3e

+ 29 - 1
django/http/request.py

@@ -12,7 +12,9 @@ from django.core.exceptions import (
 )
 from django.core.files import uploadhandler
 from django.http.multipartparser import MultiPartParser, MultiPartParserError
-from django.utils.datastructures import ImmutableList, MultiValueDict
+from django.utils.datastructures import (
+    CaseInsensitiveMapping, ImmutableList, MultiValueDict,
+)
 from django.utils.deprecation import RemovedInDjango30Warning
 from django.utils.encoding import escape_uri_path, iri_to_uri
 from django.utils.functional import cached_property
@@ -65,6 +67,10 @@ class HttpRequest:
             return '<%s>' % self.__class__.__name__
         return '<%s: %s %r>' % (self.__class__.__name__, self.method, self.get_full_path())
 
+    @cached_property
+    def headers(self):
+        return HttpHeaders(self.META)
+
     def _get_raw_host(self):
         """
         Return the HTTP host using the environment or request headers. Skip
@@ -359,6 +365,28 @@ class HttpRequest:
         return list(self)
 
 
+class HttpHeaders(CaseInsensitiveMapping):
+    HTTP_PREFIX = 'HTTP_'
+    # PEP 333 gives two headers which aren't prepended with HTTP_.
+    UNPREFIXED_HEADERS = {'CONTENT_TYPE', 'CONTENT_LENGTH'}
+
+    def __init__(self, environ):
+        headers = {}
+        for header, value in environ.items():
+            name = self.parse_header_name(header)
+            if name:
+                headers[name] = value
+        super().__init__(headers)
+
+    @classmethod
+    def parse_header_name(cls, header):
+        if header.startswith(cls.HTTP_PREFIX):
+            header = header[len(cls.HTTP_PREFIX):]
+        elif header not in cls.UNPREFIXED_HEADERS:
+            return None
+        return header.replace('_', '-').title()
+
+
 class QueryDict(MultiValueDict):
     """
     A specialized MultiValueDict which represents a query string.

+ 59 - 0
django/utils/datastructures.py

@@ -1,5 +1,6 @@
 import copy
 from collections import OrderedDict
+from collections.abc import Mapping
 
 
 class OrderedSet:
@@ -280,3 +281,61 @@ class DictWrapper(dict):
         if use_func:
             return self.func(value)
         return value
+
+
+def _destruct_iterable_mapping_values(data):
+    for i, elem in enumerate(data):
+        if len(elem) != 2:
+            raise ValueError(
+                'dictionary update sequence element #{} has '
+                'length {}; 2 is required.'.format(i, len(elem))
+            )
+        if not isinstance(elem[0], str):
+            raise ValueError('Element key %r invalid, only strings are allowed' % elem[0])
+        yield tuple(elem)
+
+
+class CaseInsensitiveMapping(Mapping):
+    """
+    Mapping allowing case-insensitive key lookups. Original case of keys is
+    preserved for iteration and string representation.
+
+    Example::
+
+        >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
+        >>> ci_map['Name']
+        Jane
+        >>> ci_map['NAME']
+        Jane
+        >>> ci_map['name']
+        Jane
+        >>> ci_map  # original case preserved
+        {'name': 'Jane'}
+    """
+
+    def __init__(self, data):
+        if not isinstance(data, Mapping):
+            data = {k: v for k, v in _destruct_iterable_mapping_values(data)}
+        self._store = {k.lower(): (k, v) for k, v in data.items()}
+
+    def __getitem__(self, key):
+        return self._store[key.lower()][1]
+
+    def __len__(self):
+        return len(self._store)
+
+    def __eq__(self, other):
+        return isinstance(other, Mapping) and {
+            k.lower(): v for k, v in self.items()
+        } == {
+            k.lower(): v for k, v in other.items()
+        }
+
+    def __iter__(self):
+        return (original_key for original_key, value in self._store.values())
+
+    def __repr__(self):
+        return repr({key: value for key, value in self._store.values()})
+
+    def copy(self):
+        return self

+ 32 - 0
docs/ref/request-response.txt

@@ -167,6 +167,38 @@ All attributes should be considered read-only, unless stated otherwise.
     underscores in WSGI environment variables. It matches the behavior of
     Web servers like Nginx and Apache 2.4+.
 
+    :attr:`HttpRequest.headers` is a simpler way to access all HTTP-prefixd
+    headers, plus ``CONTENT_LENGTH`` and ``CONTENT_TYPE``.
+
+.. attribute:: HttpRequest.headers
+
+    .. versionadded:: 2.2
+
+    A case insensitive, dict-like object that provides access to all
+    HTTP-prefixed headers (plus ``Content-Length`` and ``Content-Type``) from
+    the request.
+
+    The name of each header is stylized with title-casing (e.g. ``User-Agent``)
+    when it's displayed. You can access headers case-insensitively::
+
+        >>> request.headers
+        {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6', ...}
+
+        >>> 'User-Agent' in request.headers
+        True
+        >>> 'user-agent' in request.headers
+        True
+
+        >>> request.headers['User-Agent']
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+        >>> request.headers['user-agent']
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+
+        >>> request.headers.get('User-Agent')
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+        >>> request.headers.get('user-agent')
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+
 .. attribute:: HttpRequest.resolver_match
 
     An instance of :class:`~django.urls.ResolverMatch` representing the

+ 2 - 1
docs/releases/2.2.txt

@@ -266,7 +266,8 @@ Models
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
-* ...
+* Added :attr:`.HttpRequest.headers` to allow simple access to a request's
+  headers.
 
 Serialization
 ~~~~~~~~~~~~~

+ 83 - 1
tests/requests/tests.py

@@ -5,7 +5,7 @@ from urllib.parse import urlencode
 from django.core.exceptions import DisallowedHost
 from django.core.handlers.wsgi import LimitedStream, WSGIRequest
 from django.http import HttpRequest, RawPostDataException, UnreadablePostError
-from django.http.request import split_domain_port
+from django.http.request import HttpHeaders, split_domain_port
 from django.test import RequestFactory, SimpleTestCase, override_settings
 from django.test.client import FakePayload
 
@@ -830,3 +830,85 @@ class BuildAbsoluteURITests(SimpleTestCase):
         for location, expected_url in tests:
             with self.subTest(location=location):
                 self.assertEqual(request.build_absolute_uri(location=location), expected_url)
+
+
+class RequestHeadersTests(SimpleTestCase):
+    ENVIRON = {
+        # Non-headers are ignored.
+        'PATH_INFO': '/somepath/',
+        'REQUEST_METHOD': 'get',
+        'wsgi.input': BytesIO(b''),
+        'SERVER_NAME': 'internal.com',
+        'SERVER_PORT': 80,
+        # These non-HTTP prefixed headers are included.
+        'CONTENT_TYPE': 'text/html',
+        'CONTENT_LENGTH': '100',
+        # All HTTP-prefixed headers are included.
+        'HTTP_ACCEPT': '*',
+        'HTTP_HOST': 'example.com',
+        'HTTP_USER_AGENT': 'python-requests/1.2.0',
+    }
+
+    def test_base_request_headers(self):
+        request = HttpRequest()
+        request.META = self.ENVIRON
+        self.assertEqual(dict(request.headers), {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Accept': '*',
+            'Host': 'example.com',
+            'User-Agent': 'python-requests/1.2.0',
+        })
+
+    def test_wsgi_request_headers(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(dict(request.headers), {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Accept': '*',
+            'Host': 'example.com',
+            'User-Agent': 'python-requests/1.2.0',
+        })
+
+    def test_wsgi_request_headers_getitem(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(request.headers['User-Agent'], 'python-requests/1.2.0')
+        self.assertEqual(request.headers['user-agent'], 'python-requests/1.2.0')
+        self.assertEqual(request.headers['Content-Type'], 'text/html')
+        self.assertEqual(request.headers['Content-Length'], '100')
+
+    def test_wsgi_request_headers_get(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(request.headers.get('User-Agent'), 'python-requests/1.2.0')
+        self.assertEqual(request.headers.get('user-agent'), 'python-requests/1.2.0')
+        self.assertEqual(request.headers.get('Content-Type'), 'text/html')
+        self.assertEqual(request.headers.get('Content-Length'), '100')
+
+
+class HttpHeadersTests(SimpleTestCase):
+    def test_basic(self):
+        environ = {
+            'CONTENT_TYPE': 'text/html',
+            'CONTENT_LENGTH': '100',
+            'HTTP_HOST': 'example.com',
+        }
+        headers = HttpHeaders(environ)
+        self.assertEqual(sorted(headers), ['Content-Length', 'Content-Type', 'Host'])
+        self.assertEqual(headers, {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Host': 'example.com',
+        })
+
+    def test_parse_header_name(self):
+        tests = (
+            ('PATH_INFO', None),
+            ('HTTP_ACCEPT', 'Accept'),
+            ('HTTP_USER_AGENT', 'User-Agent'),
+            ('HTTP_X_FORWARDED_PROTO', 'X-Forwarded-Proto'),
+            ('CONTENT_TYPE', 'Content-Type'),
+            ('CONTENT_LENGTH', 'Content-Length'),
+        )
+        for header, expected in tests:
+            with self.subTest(header=header):
+                self.assertEqual(HttpHeaders.parse_header_name(header), expected)

+ 81 - 2
tests/utils_tests/test_datastructures.py

@@ -6,8 +6,8 @@ import copy
 
 from django.test import SimpleTestCase
 from django.utils.datastructures import (
-    DictWrapper, ImmutableList, MultiValueDict, MultiValueDictKeyError,
-    OrderedSet,
+    CaseInsensitiveMapping, DictWrapper, ImmutableList, MultiValueDict,
+    MultiValueDictKeyError, OrderedSet,
 )
 
 
@@ -148,3 +148,82 @@ class DictWrapperTests(SimpleTestCase):
             "Normal: %(a)s. Modified: %(xx_a)s" % d,
             'Normal: a. Modified: *a'
         )
+
+
+class CaseInsensitiveMappingTests(SimpleTestCase):
+    def setUp(self):
+        self.dict1 = CaseInsensitiveMapping({
+            'Accept': 'application/json',
+            'content-type': 'text/html',
+        })
+
+    def test_create_with_invalid_values(self):
+        msg = 'dictionary update sequence element #1 has length 4; 2 is required'
+        with self.assertRaisesMessage(ValueError, msg):
+            CaseInsensitiveMapping([('Key1', 'Val1'), 'Key2'])
+
+    def test_create_with_invalid_key(self):
+        msg = 'Element key 1 invalid, only strings are allowed'
+        with self.assertRaisesMessage(ValueError, msg):
+            CaseInsensitiveMapping([(1, '2')])
+
+    def test_list(self):
+        self.assertEqual(sorted(list(self.dict1)), sorted(['Accept', 'content-type']))
+
+    def test_dict(self):
+        self.assertEqual(dict(self.dict1), {'Accept': 'application/json', 'content-type': 'text/html'})
+
+    def test_repr(self):
+        dict1 = CaseInsensitiveMapping({'Accept': 'application/json'})
+        dict2 = CaseInsensitiveMapping({'content-type': 'text/html'})
+        self.assertEqual(repr(dict1), repr({'Accept': 'application/json'}))
+        self.assertEqual(repr(dict2), repr({'content-type': 'text/html'}))
+
+    def test_str(self):
+        dict1 = CaseInsensitiveMapping({'Accept': 'application/json'})
+        dict2 = CaseInsensitiveMapping({'content-type': 'text/html'})
+        self.assertEqual(str(dict1), str({'Accept': 'application/json'}))
+        self.assertEqual(str(dict2), str({'content-type': 'text/html'}))
+
+    def test_equal(self):
+        self.assertEqual(self.dict1, {'Accept': 'application/json', 'content-type': 'text/html'})
+        self.assertNotEqual(self.dict1, {'accept': 'application/jso', 'Content-Type': 'text/html'})
+        self.assertNotEqual(self.dict1, 'string')
+
+    def test_items(self):
+        other = {'Accept': 'application/json', 'content-type': 'text/html'}
+        self.assertEqual(sorted(self.dict1.items()), sorted(other.items()))
+
+    def test_copy(self):
+        copy = self.dict1.copy()
+        self.assertIs(copy, self.dict1)
+        self.assertEqual(copy, self.dict1)
+
+    def test_getitem(self):
+        self.assertEqual(self.dict1['Accept'], 'application/json')
+        self.assertEqual(self.dict1['accept'], 'application/json')
+        self.assertEqual(self.dict1['aCCept'], 'application/json')
+        self.assertEqual(self.dict1['content-type'], 'text/html')
+        self.assertEqual(self.dict1['Content-Type'], 'text/html')
+        self.assertEqual(self.dict1['Content-type'], 'text/html')
+
+    def test_in(self):
+        self.assertIn('Accept', self.dict1)
+        self.assertIn('accept', self.dict1)
+        self.assertIn('aCCept', self.dict1)
+        self.assertIn('content-type', self.dict1)
+        self.assertIn('Content-Type', self.dict1)
+
+    def test_del(self):
+        self.assertIn('Accept', self.dict1)
+        msg = "'CaseInsensitiveMapping' object does not support item deletion"
+        with self.assertRaisesMessage(TypeError, msg):
+            del self.dict1['Accept']
+        self.assertIn('Accept', self.dict1)
+
+    def test_set(self):
+        self.assertEqual(len(self.dict1), 2)
+        msg = "'CaseInsensitiveMapping' object does not support item assignment"
+        with self.assertRaisesMessage(TypeError, msg):
+            self.dict1['New Key'] = 1
+        self.assertEqual(len(self.dict1), 2)