test_cookie.py 8.5 KB


  1. import json
  2. import random
  3. from unittest import TestCase
  4. from django.conf import settings
  5. from django.contrib.messages import Message, constants
  6. from django.contrib.messages.storage.cookie import (
  7. CookieStorage,
  8. MessageDecoder,
  9. MessageEncoder,
  10. bisect_keep_left,
  11. bisect_keep_right,
  12. )
  13. from django.test import SimpleTestCase, override_settings
  14. from django.utils.crypto import get_random_string
  15. from django.utils.safestring import SafeData, mark_safe
  16. from .base import BaseTests
  17. def set_cookie_data(storage, messages, invalid=False, encode_empty=False):
  18. """
  19. Set ``request.COOKIES`` with the encoded data and remove the storage
  20. backend's loaded data cache.
  21. """
  22. encoded_data = storage._encode(messages, encode_empty=encode_empty)
  23. if invalid:
  24. # Truncate the first character so that the hash is invalid.
  25. encoded_data = encoded_data[1:]
  26. storage.request.COOKIES = {CookieStorage.cookie_name: encoded_data}
  27. if hasattr(storage, "_loaded_data"):
  28. del storage._loaded_data
  29. def stored_cookie_messages_count(storage, response):
  30. """
  31. Return an integer containing the number of messages stored.
  32. """
  33. # Get a list of cookies, excluding ones with a max-age of 0 (because
  34. # they have been marked for deletion).
  35. cookie = response.cookies.get(storage.cookie_name)
  36. if not cookie or cookie["max-age"] == 0:
  37. return 0
  38. data = storage._decode(cookie.value)
  39. if not data:
  40. return 0
  41. if data[-1] == CookieStorage.not_finished:
  42. data.pop()
  43. return len(data)
  44. @override_settings(
  45. SESSION_COOKIE_DOMAIN=".example.com",
  46. SESSION_COOKIE_SECURE=True,
  47. SESSION_COOKIE_HTTPONLY=True,
  48. )
  49. class CookieTests(BaseTests, SimpleTestCase):
  50. storage_class = CookieStorage
  51. def stored_messages_count(self, storage, response):
  52. return stored_cookie_messages_count(storage, response)
  53. def encode_decode(self, *args, **kwargs):
  54. storage = self.get_storage()
  55. message = [Message(constants.DEBUG, *args, **kwargs)]
  56. encoded = storage._encode(message)
  57. return storage._decode(encoded)[0]
  58. def test_get(self):
  59. storage = self.storage_class(self.get_request())
  60. # Set initial data.
  61. example_messages = ["test", "me"]
  62. set_cookie_data(storage, example_messages)
  63. # The message contains what's expected.
  64. self.assertEqual(list(storage), example_messages)
  65. @override_settings(SESSION_COOKIE_SAMESITE="Strict")
  66. def test_cookie_settings(self):
  67. """
  68. CookieStorage honors SESSION_COOKIE_DOMAIN, SESSION_COOKIE_SECURE, and
  69. SESSION_COOKIE_HTTPONLY (#15618, #20972).
  70. """
  71. # Test before the messages have been consumed
  72. storage = self.get_storage()
  73. response = self.get_response()
  74. storage.add(constants.INFO, "test")
  75. storage.update(response)
  76. messages = storage._decode(response.cookies["messages"].value)
  77. self.assertEqual(len(messages), 1)
  78. self.assertEqual(messages[0].message, "test")
  79. self.assertEqual(response.cookies["messages"]["domain"], ".example.com")
  80. self.assertEqual(response.cookies["messages"]["expires"], "")
  81. self.assertIs(response.cookies["messages"]["secure"], True)
  82. self.assertIs(response.cookies["messages"]["httponly"], True)
  83. self.assertEqual(response.cookies["messages"]["samesite"], "Strict")
  84. # Deletion of the cookie (storing with an empty value) after the
  85. # messages have been consumed.
  86. storage = self.get_storage()
  87. response = self.get_response()
  88. storage.add(constants.INFO, "test")
  89. for m in storage:
  90. pass # Iterate through the storage to simulate consumption of messages.
  91. storage.update(response)
  92. self.assertEqual(response.cookies["messages"].value, "")
  93. self.assertEqual(response.cookies["messages"]["domain"], ".example.com")
  94. self.assertEqual(
  95. response.cookies["messages"]["expires"], "Thu, 01 Jan 1970 00:00:00 GMT"
  96. )
  97. self.assertEqual(
  98. response.cookies["messages"]["samesite"],
  99. settings.SESSION_COOKIE_SAMESITE,
  100. )
  101. def test_get_bad_cookie(self):
  102. request = self.get_request()
  103. storage = self.storage_class(request)
  104. # Set initial (invalid) data.
  105. example_messages = ["test", "me"]
  106. set_cookie_data(storage, example_messages, invalid=True)
  107. # The message actually contains what we expect.
  108. self.assertEqual(list(storage), [])
  109. def test_max_cookie_length(self):
  110. """
  111. If the data exceeds what is allowed in a cookie, older messages are
  112. removed before saving (and returned by the ``update`` method).
  113. """
  114. storage = self.get_storage()
  115. response = self.get_response()
  116. # When storing as a cookie, the cookie has constant overhead of approx
  117. # 54 chars, and each message has a constant overhead of about 37 chars
  118. # and a variable overhead of zero in the best case. We aim for a message
  119. # size which will fit 4 messages into the cookie, but not 5.
  120. # See also FallbackTest.test_session_fallback
  121. msg_size = int((CookieStorage.max_cookie_size - 54) / 4.5 - 37)
  122. first_msg = None
  123. # Generate the same (tested) content every time that does not get run
  124. # through zlib compression.
  125. random.seed(42)
  126. for i in range(5):
  127. msg = get_random_string(msg_size)
  128. storage.add(constants.INFO, msg)
  129. if i == 0:
  130. first_msg = msg
  131. unstored_messages = storage.update(response)
  132. cookie_storing = self.stored_messages_count(storage, response)
  133. self.assertEqual(cookie_storing, 4)
  134. self.assertEqual(len(unstored_messages), 1)
  135. self.assertEqual(unstored_messages[0].message, first_msg)
  136. def test_message_rfc6265(self):
  137. non_compliant_chars = ["\\", ",", ";", '"']
  138. messages = ["\\te,st", ';m"e', "\u2019", '123"NOTRECEIVED"']
  139. storage = self.get_storage()
  140. encoded = storage._encode(messages)
  141. for illegal in non_compliant_chars:
  142. self.assertEqual(encoded.find(illegal), -1)
  143. def test_json_encoder_decoder(self):
  144. """
  145. A complex nested data structure containing Message
  146. instances is properly encoded/decoded by the custom JSON
  147. encoder/decoder classes.
  148. """
  149. messages = [
  150. {
  151. "message": Message(constants.INFO, "Test message"),
  152. "message_list": [
  153. Message(constants.INFO, "message %s") for x in range(5)
  154. ]
  155. + [{"another-message": Message(constants.ERROR, "error")}],
  156. },
  157. Message(constants.INFO, "message %s"),
  158. ]
  159. encoder = MessageEncoder()
  160. value = encoder.encode(messages)
  161. decoded_messages = json.loads(value, cls=MessageDecoder)
  162. self.assertEqual(messages, decoded_messages)
  163. def test_safedata(self):
  164. """
  165. A message containing SafeData is keeping its safe status when
  166. retrieved from the message storage.
  167. """
  168. self.assertIsInstance(
  169. self.encode_decode(mark_safe("<b>Hello Django!</b>")).message,
  170. SafeData,
  171. )
  172. self.assertNotIsInstance(
  173. self.encode_decode("<b>Hello Django!</b>").message,
  174. SafeData,
  175. )
  176. def test_extra_tags(self):
  177. """
  178. A message's extra_tags attribute is correctly preserved when retrieved
  179. from the message storage.
  180. """
  181. for extra_tags in ["", None, "some tags"]:
  182. with self.subTest(extra_tags=extra_tags):
  183. self.assertEqual(
  184. self.encode_decode("message", extra_tags=extra_tags).extra_tags,
  185. extra_tags,
  186. )
  187. class BisectTests(TestCase):
  188. def test_bisect_keep_left(self):
  189. self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 2), 2)
  190. self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 0), 0)
  191. self.assertEqual(bisect_keep_left([], fn=lambda arr: sum(arr) != 0), 0)
  192. def test_bisect_keep_right(self):
  193. self.assertEqual(bisect_keep_right([1, 1, 1], fn=lambda arr: sum(arr) != 2), 1)
  194. self.assertEqual(
  195. bisect_keep_right([1, 1, 1, 1], fn=lambda arr: sum(arr) != 2), 2
  196. )
  197. self.assertEqual(
  198. bisect_keep_right([1, 1, 1, 1, 1], fn=lambda arr: sum(arr) != 1), 4
  199. )
  200. self.assertEqual(bisect_keep_right([], fn=lambda arr: sum(arr) != 0), 0)