test_cookie.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import json
  2. import random
  3. from unittest import TestCase
  4. from django.conf import settings
  5. from django.contrib.messages import constants
  6. from django.contrib.messages.storage.base import Message
  7. from django.contrib.messages.storage.cookie import (
  8. CookieStorage,
  9. MessageDecoder,
  10. MessageEncoder,
  11. bisect_keep_left,
  12. bisect_keep_right,
  13. )
  14. from django.test import SimpleTestCase, override_settings
  15. from django.utils.crypto import get_random_string
  16. from django.utils.safestring import SafeData, mark_safe
  17. from .base import BaseTests
  18. def set_cookie_data(storage, messages, invalid=False, encode_empty=False):
  19. """
  20. Set ``request.COOKIES`` with the encoded data and remove the storage
  21. backend's loaded data cache.
  22. """
  23. encoded_data = storage._encode(messages, encode_empty=encode_empty)
  24. if invalid:
  25. # Truncate the first character so that the hash is invalid.
  26. encoded_data = encoded_data[1:]
  27. storage.request.COOKIES = {CookieStorage.cookie_name: encoded_data}
  28. if hasattr(storage, "_loaded_data"):
  29. del storage._loaded_data
  30. def stored_cookie_messages_count(storage, response):
  31. """
  32. Return an integer containing the number of messages stored.
  33. """
  34. # Get a list of cookies, excluding ones with a max-age of 0 (because
  35. # they have been marked for deletion).
  36. cookie = response.cookies.get(storage.cookie_name)
  37. if not cookie or cookie["max-age"] == 0:
  38. return 0
  39. data = storage._decode(cookie.value)
  40. if not data:
  41. return 0
  42. if data[-1] == CookieStorage.not_finished:
  43. data.pop()
  44. return len(data)
  45. @override_settings(
  46. SESSION_COOKIE_DOMAIN=".example.com",
  47. SESSION_COOKIE_SECURE=True,
  48. SESSION_COOKIE_HTTPONLY=True,
  49. )
  50. class CookieTests(BaseTests, SimpleTestCase):
  51. storage_class = CookieStorage
  52. def stored_messages_count(self, storage, response):
  53. return stored_cookie_messages_count(storage, response)
  54. def encode_decode(self, *args, **kwargs):
  55. storage = self.get_storage()
  56. message = [Message(constants.DEBUG, *args, **kwargs)]
  57. encoded = storage._encode(message)
  58. return storage._decode(encoded)[0]
  59. def test_get(self):
  60. storage = self.storage_class(self.get_request())
  61. # Set initial data.
  62. example_messages = ["test", "me"]
  63. set_cookie_data(storage, example_messages)
  64. # The message contains what's expected.
  65. self.assertEqual(list(storage), example_messages)
  66. @override_settings(SESSION_COOKIE_SAMESITE="Strict")
  67. def test_cookie_settings(self):
  68. """
  69. CookieStorage honors SESSION_COOKIE_DOMAIN, SESSION_COOKIE_SECURE, and
  70. SESSION_COOKIE_HTTPONLY (#15618, #20972).
  71. """
  72. # Test before the messages have been consumed
  73. storage = self.get_storage()
  74. response = self.get_response()
  75. storage.add(constants.INFO, "test")
  76. storage.update(response)
  77. messages = storage._decode(response.cookies["messages"].value)
  78. self.assertEqual(len(messages), 1)
  79. self.assertEqual(messages[0].message, "test")
  80. self.assertEqual(response.cookies["messages"]["domain"], ".example.com")
  81. self.assertEqual(response.cookies["messages"]["expires"], "")
  82. self.assertIs(response.cookies["messages"]["secure"], True)
  83. self.assertIs(response.cookies["messages"]["httponly"], True)
  84. self.assertEqual(response.cookies["messages"]["samesite"], "Strict")
  85. # Deletion of the cookie (storing with an empty value) after the
  86. # messages have been consumed.
  87. storage = self.get_storage()
  88. response = self.get_response()
  89. storage.add(constants.INFO, "test")
  90. for m in storage:
  91. pass # Iterate through the storage to simulate consumption of messages.
  92. storage.update(response)
  93. self.assertEqual(response.cookies["messages"].value, "")
  94. self.assertEqual(response.cookies["messages"]["domain"], ".example.com")
  95. self.assertEqual(
  96. response.cookies["messages"]["expires"], "Thu, 01 Jan 1970 00:00:00 GMT"
  97. )
  98. self.assertEqual(
  99. response.cookies["messages"]["samesite"],
  100. settings.SESSION_COOKIE_SAMESITE,
  101. )
  102. def test_get_bad_cookie(self):
  103. request = self.get_request()
  104. storage = self.storage_class(request)
  105. # Set initial (invalid) data.
  106. example_messages = ["test", "me"]
  107. set_cookie_data(storage, example_messages, invalid=True)
  108. # The message actually contains what we expect.
  109. self.assertEqual(list(storage), [])
  110. def test_max_cookie_length(self):
  111. """
  112. If the data exceeds what is allowed in a cookie, older messages are
  113. removed before saving (and returned by the ``update`` method).
  114. """
  115. storage = self.get_storage()
  116. response = self.get_response()
  117. # When storing as a cookie, the cookie has constant overhead of approx
  118. # 54 chars, and each message has a constant overhead of about 37 chars
  119. # and a variable overhead of zero in the best case. We aim for a message
  120. # size which will fit 4 messages into the cookie, but not 5.
  121. # See also FallbackTest.test_session_fallback
  122. msg_size = int((CookieStorage.max_cookie_size - 54) / 4.5 - 37)
  123. first_msg = None
  124. # Generate the same (tested) content every time that does not get run
  125. # through zlib compression.
  126. random.seed(42)
  127. for i in range(5):
  128. msg = get_random_string(msg_size)
  129. storage.add(constants.INFO, msg)
  130. if i == 0:
  131. first_msg = msg
  132. unstored_messages = storage.update(response)
  133. cookie_storing = self.stored_messages_count(storage, response)
  134. self.assertEqual(cookie_storing, 4)
  135. self.assertEqual(len(unstored_messages), 1)
  136. self.assertEqual(unstored_messages[0].message, first_msg)
  137. def test_message_rfc6265(self):
  138. non_compliant_chars = ["\\", ",", ";", '"']
  139. messages = ["\\te,st", ';m"e', "\u2019", '123"NOTRECEIVED"']
  140. storage = self.get_storage()
  141. encoded = storage._encode(messages)
  142. for illegal in non_compliant_chars:
  143. self.assertEqual(encoded.find(illegal), -1)
  144. def test_json_encoder_decoder(self):
  145. """
  146. A complex nested data structure containing Message
  147. instances is properly encoded/decoded by the custom JSON
  148. encoder/decoder classes.
  149. """
  150. messages = [
  151. {
  152. "message": Message(constants.INFO, "Test message"),
  153. "message_list": [
  154. Message(constants.INFO, "message %s") for x in range(5)
  155. ]
  156. + [{"another-message": Message(constants.ERROR, "error")}],
  157. },
  158. Message(constants.INFO, "message %s"),
  159. ]
  160. encoder = MessageEncoder()
  161. value = encoder.encode(messages)
  162. decoded_messages = json.loads(value, cls=MessageDecoder)
  163. self.assertEqual(messages, decoded_messages)
  164. def test_safedata(self):
  165. """
  166. A message containing SafeData is keeping its safe status when
  167. retrieved from the message storage.
  168. """
  169. self.assertIsInstance(
  170. self.encode_decode(mark_safe("<b>Hello Django!</b>")).message,
  171. SafeData,
  172. )
  173. self.assertNotIsInstance(
  174. self.encode_decode("<b>Hello Django!</b>").message,
  175. SafeData,
  176. )
  177. def test_extra_tags(self):
  178. """
  179. A message's extra_tags attribute is correctly preserved when retrieved
  180. from the message storage.
  181. """
  182. for extra_tags in ["", None, "some tags"]:
  183. with self.subTest(extra_tags=extra_tags):
  184. self.assertEqual(
  185. self.encode_decode("message", extra_tags=extra_tags).extra_tags,
  186. extra_tags,
  187. )
  188. class BisectTests(TestCase):
  189. def test_bisect_keep_left(self):
  190. self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 2), 2)
  191. self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 0), 0)
  192. self.assertEqual(bisect_keep_left([], fn=lambda arr: sum(arr) != 0), 0)
  193. def test_bisect_keep_right(self):
  194. self.assertEqual(bisect_keep_right([1, 1, 1], fn=lambda arr: sum(arr) != 2), 1)
  195. self.assertEqual(
  196. bisect_keep_right([1, 1, 1, 1], fn=lambda arr: sum(arr) != 2), 2
  197. )
  198. self.assertEqual(
  199. bisect_keep_right([1, 1, 1, 1, 1], fn=lambda arr: sum(arr) != 1), 4
  200. )
  201. self.assertEqual(bisect_keep_right([], fn=lambda arr: sum(arr) != 0), 0)