Browse Source

Fixed #20147 -- Added HttpRequest.headers.

Santiago Basulto 6 years ago
parent
commit
4fc35a9c3e

+ 29 - 1
django/http/request.py

@@ -12,7 +12,9 @@ from django.core.exceptions import (
 )
 )
 from django.core.files import uploadhandler
 from django.core.files import uploadhandler
 from django.http.multipartparser import MultiPartParser, MultiPartParserError
 from django.http.multipartparser import MultiPartParser, MultiPartParserError
-from django.utils.datastructures import ImmutableList, MultiValueDict
+from django.utils.datastructures import (
+    CaseInsensitiveMapping, ImmutableList, MultiValueDict,
+)
 from django.utils.deprecation import RemovedInDjango30Warning
 from django.utils.deprecation import RemovedInDjango30Warning
 from django.utils.encoding import escape_uri_path, iri_to_uri
 from django.utils.encoding import escape_uri_path, iri_to_uri
 from django.utils.functional import cached_property
 from django.utils.functional import cached_property
@@ -65,6 +67,10 @@ class HttpRequest:
             return '<%s>' % self.__class__.__name__
             return '<%s>' % self.__class__.__name__
         return '<%s: %s %r>' % (self.__class__.__name__, self.method, self.get_full_path())
         return '<%s: %s %r>' % (self.__class__.__name__, self.method, self.get_full_path())
 
 
+    @cached_property
+    def headers(self):
+        return HttpHeaders(self.META)
+
     def _get_raw_host(self):
     def _get_raw_host(self):
         """
         """
         Return the HTTP host using the environment or request headers. Skip
         Return the HTTP host using the environment or request headers. Skip
@@ -359,6 +365,28 @@ class HttpRequest:
         return list(self)
         return list(self)
 
 
 
 
+class HttpHeaders(CaseInsensitiveMapping):
+    HTTP_PREFIX = 'HTTP_'
+    # PEP 333 gives two headers which aren't prepended with HTTP_.
+    UNPREFIXED_HEADERS = {'CONTENT_TYPE', 'CONTENT_LENGTH'}
+
+    def __init__(self, environ):
+        headers = {}
+        for header, value in environ.items():
+            name = self.parse_header_name(header)
+            if name:
+                headers[name] = value
+        super().__init__(headers)
+
+    @classmethod
+    def parse_header_name(cls, header):
+        if header.startswith(cls.HTTP_PREFIX):
+            header = header[len(cls.HTTP_PREFIX):]
+        elif header not in cls.UNPREFIXED_HEADERS:
+            return None
+        return header.replace('_', '-').title()
+
+
 class QueryDict(MultiValueDict):
 class QueryDict(MultiValueDict):
     """
     """
     A specialized MultiValueDict which represents a query string.
     A specialized MultiValueDict which represents a query string.

+ 59 - 0
django/utils/datastructures.py

@@ -1,5 +1,6 @@
 import copy
 import copy
 from collections import OrderedDict
 from collections import OrderedDict
+from collections.abc import Mapping
 
 
 
 
 class OrderedSet:
 class OrderedSet:
@@ -280,3 +281,61 @@ class DictWrapper(dict):
         if use_func:
         if use_func:
             return self.func(value)
             return self.func(value)
         return value
         return value
+
+
+def _destruct_iterable_mapping_values(data):
+    for i, elem in enumerate(data):
+        if len(elem) != 2:
+            raise ValueError(
+                'dictionary update sequence element #{} has '
+                'length {}; 2 is required.'.format(i, len(elem))
+            )
+        if not isinstance(elem[0], str):
+            raise ValueError('Element key %r invalid, only strings are allowed' % elem[0])
+        yield tuple(elem)
+
+
+class CaseInsensitiveMapping(Mapping):
+    """
+    Mapping allowing case-insensitive key lookups. Original case of keys is
+    preserved for iteration and string representation.
+
+    Example::
+
+        >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
+        >>> ci_map['Name']
+        Jane
+        >>> ci_map['NAME']
+        Jane
+        >>> ci_map['name']
+        Jane
+        >>> ci_map  # original case preserved
+        {'name': 'Jane'}
+    """
+
+    def __init__(self, data):
+        if not isinstance(data, Mapping):
+            data = {k: v for k, v in _destruct_iterable_mapping_values(data)}
+        self._store = {k.lower(): (k, v) for k, v in data.items()}
+
+    def __getitem__(self, key):
+        return self._store[key.lower()][1]
+
+    def __len__(self):
+        return len(self._store)
+
+    def __eq__(self, other):
+        return isinstance(other, Mapping) and {
+            k.lower(): v for k, v in self.items()
+        } == {
+            k.lower(): v for k, v in other.items()
+        }
+
+    def __iter__(self):
+        return (original_key for original_key, value in self._store.values())
+
+    def __repr__(self):
+        return repr({key: value for key, value in self._store.values()})
+
+    def copy(self):
+        return self

+ 32 - 0
docs/ref/request-response.txt

@@ -167,6 +167,38 @@ All attributes should be considered read-only, unless stated otherwise.
     underscores in WSGI environment variables. It matches the behavior of
     underscores in WSGI environment variables. It matches the behavior of
     Web servers like Nginx and Apache 2.4+.
     Web servers like Nginx and Apache 2.4+.
 
 
+    :attr:`HttpRequest.headers` is a simpler way to access all HTTP-prefixd
+    headers, plus ``CONTENT_LENGTH`` and ``CONTENT_TYPE``.
+
+.. attribute:: HttpRequest.headers
+
+    .. versionadded:: 2.2
+
+    A case insensitive, dict-like object that provides access to all
+    HTTP-prefixed headers (plus ``Content-Length`` and ``Content-Type``) from
+    the request.
+
+    The name of each header is stylized with title-casing (e.g. ``User-Agent``)
+    when it's displayed. You can access headers case-insensitively::
+
+        >>> request.headers
+        {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6', ...}
+
+        >>> 'User-Agent' in request.headers
+        True
+        >>> 'user-agent' in request.headers
+        True
+
+        >>> request.headers['User-Agent']
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+        >>> request.headers['user-agent']
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+
+        >>> request.headers.get('User-Agent')
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+        >>> request.headers.get('user-agent')
+        Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)
+
 .. attribute:: HttpRequest.resolver_match
 .. attribute:: HttpRequest.resolver_match
 
 
     An instance of :class:`~django.urls.ResolverMatch` representing the
     An instance of :class:`~django.urls.ResolverMatch` representing the

+ 2 - 1
docs/releases/2.2.txt

@@ -266,7 +266,8 @@ Models
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 
-* ...
+* Added :attr:`.HttpRequest.headers` to allow simple access to a request's
+  headers.
 
 
 Serialization
 Serialization
 ~~~~~~~~~~~~~
 ~~~~~~~~~~~~~

+ 83 - 1
tests/requests/tests.py

@@ -5,7 +5,7 @@ from urllib.parse import urlencode
 from django.core.exceptions import DisallowedHost
 from django.core.exceptions import DisallowedHost
 from django.core.handlers.wsgi import LimitedStream, WSGIRequest
 from django.core.handlers.wsgi import LimitedStream, WSGIRequest
 from django.http import HttpRequest, RawPostDataException, UnreadablePostError
 from django.http import HttpRequest, RawPostDataException, UnreadablePostError
-from django.http.request import split_domain_port
+from django.http.request import HttpHeaders, split_domain_port
 from django.test import RequestFactory, SimpleTestCase, override_settings
 from django.test import RequestFactory, SimpleTestCase, override_settings
 from django.test.client import FakePayload
 from django.test.client import FakePayload
 
 
@@ -830,3 +830,85 @@ class BuildAbsoluteURITests(SimpleTestCase):
         for location, expected_url in tests:
         for location, expected_url in tests:
             with self.subTest(location=location):
             with self.subTest(location=location):
                 self.assertEqual(request.build_absolute_uri(location=location), expected_url)
                 self.assertEqual(request.build_absolute_uri(location=location), expected_url)
+
+
+class RequestHeadersTests(SimpleTestCase):
+    ENVIRON = {
+        # Non-headers are ignored.
+        'PATH_INFO': '/somepath/',
+        'REQUEST_METHOD': 'get',
+        'wsgi.input': BytesIO(b''),
+        'SERVER_NAME': 'internal.com',
+        'SERVER_PORT': 80,
+        # These non-HTTP prefixed headers are included.
+        'CONTENT_TYPE': 'text/html',
+        'CONTENT_LENGTH': '100',
+        # All HTTP-prefixed headers are included.
+        'HTTP_ACCEPT': '*',
+        'HTTP_HOST': 'example.com',
+        'HTTP_USER_AGENT': 'python-requests/1.2.0',
+    }
+
+    def test_base_request_headers(self):
+        request = HttpRequest()
+        request.META = self.ENVIRON
+        self.assertEqual(dict(request.headers), {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Accept': '*',
+            'Host': 'example.com',
+            'User-Agent': 'python-requests/1.2.0',
+        })
+
+    def test_wsgi_request_headers(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(dict(request.headers), {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Accept': '*',
+            'Host': 'example.com',
+            'User-Agent': 'python-requests/1.2.0',
+        })
+
+    def test_wsgi_request_headers_getitem(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(request.headers['User-Agent'], 'python-requests/1.2.0')
+        self.assertEqual(request.headers['user-agent'], 'python-requests/1.2.0')
+        self.assertEqual(request.headers['Content-Type'], 'text/html')
+        self.assertEqual(request.headers['Content-Length'], '100')
+
+    def test_wsgi_request_headers_get(self):
+        request = WSGIRequest(self.ENVIRON)
+        self.assertEqual(request.headers.get('User-Agent'), 'python-requests/1.2.0')
+        self.assertEqual(request.headers.get('user-agent'), 'python-requests/1.2.0')
+        self.assertEqual(request.headers.get('Content-Type'), 'text/html')
+        self.assertEqual(request.headers.get('Content-Length'), '100')
+
+
+class HttpHeadersTests(SimpleTestCase):
+    def test_basic(self):
+        environ = {
+            'CONTENT_TYPE': 'text/html',
+            'CONTENT_LENGTH': '100',
+            'HTTP_HOST': 'example.com',
+        }
+        headers = HttpHeaders(environ)
+        self.assertEqual(sorted(headers), ['Content-Length', 'Content-Type', 'Host'])
+        self.assertEqual(headers, {
+            'Content-Type': 'text/html',
+            'Content-Length': '100',
+            'Host': 'example.com',
+        })
+
+    def test_parse_header_name(self):
+        tests = (
+            ('PATH_INFO', None),
+            ('HTTP_ACCEPT', 'Accept'),
+            ('HTTP_USER_AGENT', 'User-Agent'),
+            ('HTTP_X_FORWARDED_PROTO', 'X-Forwarded-Proto'),
+            ('CONTENT_TYPE', 'Content-Type'),
+            ('CONTENT_LENGTH', 'Content-Length'),
+        )
+        for header, expected in tests:
+            with self.subTest(header=header):
+                self.assertEqual(HttpHeaders.parse_header_name(header), expected)

+ 81 - 2
tests/utils_tests/test_datastructures.py

@@ -6,8 +6,8 @@ import copy
 
 
 from django.test import SimpleTestCase
 from django.test import SimpleTestCase
 from django.utils.datastructures import (
 from django.utils.datastructures import (
-    DictWrapper, ImmutableList, MultiValueDict, MultiValueDictKeyError,
-    OrderedSet,
+    CaseInsensitiveMapping, DictWrapper, ImmutableList, MultiValueDict,
+    MultiValueDictKeyError, OrderedSet,
 )
 )
 
 
 
 
@@ -148,3 +148,82 @@ class DictWrapperTests(SimpleTestCase):
             "Normal: %(a)s. Modified: %(xx_a)s" % d,
             "Normal: %(a)s. Modified: %(xx_a)s" % d,
             'Normal: a. Modified: *a'
             'Normal: a. Modified: *a'
         )
         )
+
+
+class CaseInsensitiveMappingTests(SimpleTestCase):
+    def setUp(self):
+        self.dict1 = CaseInsensitiveMapping({
+            'Accept': 'application/json',
+            'content-type': 'text/html',
+        })
+
+    def test_create_with_invalid_values(self):
+        msg = 'dictionary update sequence element #1 has length 4; 2 is required'
+        with self.assertRaisesMessage(ValueError, msg):
+            CaseInsensitiveMapping([('Key1', 'Val1'), 'Key2'])
+
+    def test_create_with_invalid_key(self):
+        msg = 'Element key 1 invalid, only strings are allowed'
+        with self.assertRaisesMessage(ValueError, msg):
+            CaseInsensitiveMapping([(1, '2')])
+
+    def test_list(self):
+        self.assertEqual(sorted(list(self.dict1)), sorted(['Accept', 'content-type']))
+
+    def test_dict(self):
+        self.assertEqual(dict(self.dict1), {'Accept': 'application/json', 'content-type': 'text/html'})
+
+    def test_repr(self):
+        dict1 = CaseInsensitiveMapping({'Accept': 'application/json'})
+        dict2 = CaseInsensitiveMapping({'content-type': 'text/html'})
+        self.assertEqual(repr(dict1), repr({'Accept': 'application/json'}))
+        self.assertEqual(repr(dict2), repr({'content-type': 'text/html'}))
+
+    def test_str(self):
+        dict1 = CaseInsensitiveMapping({'Accept': 'application/json'})
+        dict2 = CaseInsensitiveMapping({'content-type': 'text/html'})
+        self.assertEqual(str(dict1), str({'Accept': 'application/json'}))
+        self.assertEqual(str(dict2), str({'content-type': 'text/html'}))
+
+    def test_equal(self):
+        self.assertEqual(self.dict1, {'Accept': 'application/json', 'content-type': 'text/html'})
+        self.assertNotEqual(self.dict1, {'accept': 'application/jso', 'Content-Type': 'text/html'})
+        self.assertNotEqual(self.dict1, 'string')
+
+    def test_items(self):
+        other = {'Accept': 'application/json', 'content-type': 'text/html'}
+        self.assertEqual(sorted(self.dict1.items()), sorted(other.items()))
+
+    def test_copy(self):
+        copy = self.dict1.copy()
+        self.assertIs(copy, self.dict1)
+        self.assertEqual(copy, self.dict1)
+
+    def test_getitem(self):
+        self.assertEqual(self.dict1['Accept'], 'application/json')
+        self.assertEqual(self.dict1['accept'], 'application/json')
+        self.assertEqual(self.dict1['aCCept'], 'application/json')
+        self.assertEqual(self.dict1['content-type'], 'text/html')
+        self.assertEqual(self.dict1['Content-Type'], 'text/html')
+        self.assertEqual(self.dict1['Content-type'], 'text/html')
+
+    def test_in(self):
+        self.assertIn('Accept', self.dict1)
+        self.assertIn('accept', self.dict1)
+        self.assertIn('aCCept', self.dict1)
+        self.assertIn('content-type', self.dict1)
+        self.assertIn('Content-Type', self.dict1)
+
+    def test_del(self):
+        self.assertIn('Accept', self.dict1)
+        msg = "'CaseInsensitiveMapping' object does not support item deletion"
+        with self.assertRaisesMessage(TypeError, msg):
+            del self.dict1['Accept']
+        self.assertIn('Accept', self.dict1)
+
+    def test_set(self):
+        self.assertEqual(len(self.dict1), 2)
+        msg = "'CaseInsensitiveMapping' object does not support item assignment"
+        with self.assertRaisesMessage(TypeError, msg):
+            self.dict1['New Key'] = 1
+        self.assertEqual(len(self.dict1), 2)