|
@@ -2,6 +2,7 @@ import json
|
|
|
|
|
|
from django.conf import settings
|
|
|
from django.contrib.messages.storage.base import BaseStorage, Message
|
|
|
+from django.core import signing
|
|
|
from django.http import SimpleCookie
|
|
|
from django.utils.crypto import constant_time_compare, salted_hmac
|
|
|
from django.utils.safestring import SafeData, mark_safe
|
|
@@ -58,6 +59,10 @@ class CookieStorage(BaseStorage):
|
|
|
not_finished = '__messagesnotfinished__'
|
|
|
key_salt = 'django.contrib.messages'
|
|
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ self.signer = signing.get_cookie_signer(salt=self.key_salt)
|
|
|
+
|
|
|
def _get(self, *args, **kwargs):
|
|
|
"""
|
|
|
Retrieve a list of messages from the messages cookie. If the
|
|
@@ -118,8 +123,9 @@ class CookieStorage(BaseStorage):
|
|
|
self._update_cookie(encoded_data, response)
|
|
|
return unstored_messages
|
|
|
|
|
|
- def _hash(self, value):
|
|
|
+ def _legacy_hash(self, value):
|
|
|
"""
|
|
|
+
|
|
|
Create an HMAC/SHA1 hash based on the value and the project setting's
|
|
|
SECRET_KEY, modified to make it unique for the present purpose.
|
|
|
"""
|
|
@@ -136,7 +142,7 @@ class CookieStorage(BaseStorage):
|
|
|
if messages or encode_empty:
|
|
|
encoder = MessageEncoder(separators=(',', ':'))
|
|
|
value = encoder.encode(messages)
|
|
|
- return '%s$%s' % (self._hash(value), value)
|
|
|
+ return self.signer.sign(value)
|
|
|
|
|
|
def _decode(self, data):
|
|
|
"""
|
|
@@ -147,17 +153,28 @@ class CookieStorage(BaseStorage):
|
|
|
"""
|
|
|
if not data:
|
|
|
return None
|
|
|
- bits = data.split('$', 1)
|
|
|
- if len(bits) == 2:
|
|
|
- hash, value = bits
|
|
|
- if constant_time_compare(hash, self._hash(value)):
|
|
|
- try:
|
|
|
-
|
|
|
-
|
|
|
- return json.loads(value, cls=MessageDecoder)
|
|
|
- except json.JSONDecodeError:
|
|
|
- pass
|
|
|
+ try:
|
|
|
+ decoded = self.signer.unsign(data)
|
|
|
+ except signing.BadSignature:
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ decoded = self._legacy_decode(data)
|
|
|
+ if decoded:
|
|
|
+ try:
|
|
|
+ return json.loads(decoded, cls=MessageDecoder)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
|
|
|
|
|
|
self.used = True
|
|
|
return None
|
|
|
+
|
|
|
+ def _legacy_decode(self, data):
|
|
|
+
|
|
|
+ bits = data.split('$', 1)
|
|
|
+ if len(bits) == 2:
|
|
|
+ hash_, value = bits
|
|
|
+ if constant_time_compare(hash_, self._legacy_hash(value)):
|
|
|
+ return value
|
|
|
+ return None
|