123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- from django.conf import settings
- from django.contrib.messages import constants
- from django.contrib.messages.storage.base import BaseStorage, Message
- from django.http import SimpleCookie
- from django.utils import simplejson as json
- from django.utils.crypto import salted_hmac, constant_time_compare
- class MessageEncoder(json.JSONEncoder):
- """
- Compactly serializes instances of the ``Message`` class as JSON.
- """
- message_key = '__json_message'
- def default(self, obj):
- if isinstance(obj, Message):
- message = [self.message_key, obj.level, obj.message]
- if obj.extra_tags:
- message.append(obj.extra_tags)
- return message
- return super(MessageEncoder, self).default(obj)
- class MessageDecoder(json.JSONDecoder):
- """
- Decodes JSON that includes serialized ``Message`` instances.
- """
- def process_messages(self, obj):
- if isinstance(obj, list) and obj:
- if obj[0] == MessageEncoder.message_key:
- return Message(*obj[1:])
- return [self.process_messages(item) for item in obj]
- if isinstance(obj, dict):
- return dict([(key, self.process_messages(value))
- for key, value in obj.iteritems()])
- return obj
- def decode(self, s, **kwargs):
- decoded = super(MessageDecoder, self).decode(s, **kwargs)
- return self.process_messages(decoded)
- class CookieStorage(BaseStorage):
- """
- Stores messages in a cookie.
- """
- cookie_name = 'messages'
- # We should be able to store 4K in a cookie, but Internet Explorer
- # imposes 4K as the *total* limit for a domain. To allow other
- # cookies, we go for 3/4 of 4K.
- max_cookie_size = 3072
- not_finished = '__messagesnotfinished__'
- def _get(self, *args, **kwargs):
- """
- Retrieves a list of messages from the messages cookie. If the
- not_finished sentinel value is found at the end of the message list,
- remove it and return a result indicating that not all messages were
- retrieved by this storage.
- """
- data = self.request.COOKIES.get(self.cookie_name)
- messages = self._decode(data)
- all_retrieved = not (messages and messages[-1] == self.not_finished)
- if messages and not all_retrieved:
- # remove the sentinel value
- messages.pop()
- return messages, all_retrieved
- def _update_cookie(self, encoded_data, response):
- """
- Either sets the cookie with the encoded data if there is any data to
- store, or deletes the cookie.
- """
- if encoded_data:
- response.set_cookie(self.cookie_name, encoded_data,
- domain=settings.SESSION_COOKIE_DOMAIN)
- else:
- response.delete_cookie(self.cookie_name)
- def _store(self, messages, response, remove_oldest=True, *args, **kwargs):
- """
- Stores the messages to a cookie, returning a list of any messages which
- could not be stored.
- If the encoded data is larger than ``max_cookie_size``, removes
- messages until the data fits (these are the messages which are
- returned), and add the not_finished sentinel value to indicate as much.
- """
- unstored_messages = []
- encoded_data = self._encode(messages)
- if self.max_cookie_size:
- # data is going to be stored eventually by SimpleCookie, which
- # adds it's own overhead, which we must account for.
- cookie = SimpleCookie() # create outside the loop
- def stored_length(val):
- return len(cookie.value_encode(val)[1])
- while encoded_data and stored_length(encoded_data) > self.max_cookie_size:
- if remove_oldest:
- unstored_messages.append(messages.pop(0))
- else:
- unstored_messages.insert(0, messages.pop())
- encoded_data = self._encode(messages + [self.not_finished],
- encode_empty=unstored_messages)
- self._update_cookie(encoded_data, response)
- return unstored_messages
- def _hash(self, value):
- """
- Creates an HMAC/SHA1 hash based on the value and the project setting's
- SECRET_KEY, modified to make it unique for the present purpose.
- """
- key_salt = 'django.contrib.messages'
- return salted_hmac(key_salt, value).hexdigest()
- def _encode(self, messages, encode_empty=False):
- """
- Returns an encoded version of the messages list which can be stored as
- plain text.
- Since the data will be retrieved from the client-side, the encoded data
- also contains a hash to ensure that the data was not tampered with.
- """
- if messages or encode_empty:
- encoder = MessageEncoder(separators=(',', ':'))
- value = encoder.encode(messages)
- return '%s$%s' % (self._hash(value), value)
- def _decode(self, data):
- """
- Safely decodes a encoded text stream back into a list of messages.
- If the encoded text stream contained an invalid hash or was in an
- invalid format, ``None`` is returned.
- """
- if not data:
- return None
- bits = data.split('$', 1)
- if len(bits) == 2:
- hash, value = bits
- if constant_time_compare(hash, self._hash(value)):
- try:
- # If we get here (and the JSON decode works), everything is
- # good. In any other case, drop back and return None.
- return json.loads(value, cls=MessageDecoder)
- except ValueError:
- pass
- # Mark the data as used (so it gets removed) since something was wrong
- # with the data.
- self.used = True
- return None
|