Browse Source

Refs #28948 -- Removed superfluous messages from cookie through bisect.

David Wobrock 2 years ago
parent
commit
21757bbdcd
2 changed files with 76 additions and 11 deletions
  1. 56 11
      django/contrib/messages/storage/cookie.py
  2. 20 0
      tests/messages_tests/test_cookie.py

+ 56 - 11
django/contrib/messages/storage/cookie.py

@@ -144,20 +144,31 @@ class CookieStorage(BaseStorage):
             # adds its own overhead, which we must account for.
             cookie = SimpleCookie()  # create outside the loop
 
-            def stored_length(val):
-                return len(cookie.value_encode(val)[1])
+            def is_too_large_for_cookie(data):
+                return data and len(cookie.value_encode(data)[1]) > self.max_cookie_size
 
-            while encoded_data and stored_length(encoded_data) > self.max_cookie_size:
+            def compute_msg(some_serialized_msg):
+                return self._encode_parts(
+                    some_serialized_msg + [self.not_finished_json],
+                    encode_empty=True,
+                )
+
+            if is_too_large_for_cookie(encoded_data):
                 if remove_oldest:
-                    unstored_messages.append(messages.pop(0))
-                    serialized_messages.pop(0)
+                    idx = bisect_keep_right(
+                        serialized_messages,
+                        fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
+                    )
+                    unstored_messages = messages[:idx]
+                    encoded_data = compute_msg(serialized_messages[idx:])
                 else:
-                    unstored_messages.insert(0, messages.pop())
-                    serialized_messages.pop()
-                encoded_data = self._encode_parts(
-                    serialized_messages + [self.not_finished_json],
-                    encode_empty=bool(unstored_messages),
-                )
+                    idx = bisect_keep_left(
+                        serialized_messages,
+                        fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
+                    )
+                    unstored_messages = messages[idx:]
+                    encoded_data = compute_msg(serialized_messages[:idx])
+
         self._update_cookie(encoded_data, response)
         return unstored_messages
 
@@ -201,3 +212,37 @@ class CookieStorage(BaseStorage):
         # with the data.
         self.used = True
         return None
+
+
+def bisect_keep_left(a, fn):
+    """
+    Find the index of the first element from the start of the array that
+    verifies the given condition.
+    The function is applied from the start of the array to the pivot.
+    """
+    lo = 0
+    hi = len(a)
+    while lo < hi:
+        mid = (lo + hi) // 2
+        if fn(a[: mid + 1]):
+            hi = mid
+        else:
+            lo = mid + 1
+    return lo
+
+
+def bisect_keep_right(a, fn):
+    """
+    Find the index of the first element from the end of the array that verifies
+    the given condition.
+    The function is applied from the pivot to the end of array.
+    """
+    lo = 0
+    hi = len(a)
+    while lo < hi:
+        mid = (lo + hi) // 2
+        if fn(a[mid:]):
+            lo = mid + 1
+        else:
+            hi = mid
+    return lo

+ 20 - 0
tests/messages_tests/test_cookie.py

@@ -1,5 +1,6 @@
 import json
 import random
+from unittest import TestCase
 
 from django.conf import settings
 from django.contrib.messages import constants
@@ -8,6 +9,8 @@ from django.contrib.messages.storage.cookie import (
     CookieStorage,
     MessageDecoder,
     MessageEncoder,
+    bisect_keep_left,
+    bisect_keep_right,
 )
 from django.test import SimpleTestCase, override_settings
 from django.utils.crypto import get_random_string
@@ -204,3 +207,20 @@ class CookieTests(BaseTests, SimpleTestCase):
                     self.encode_decode("message", extra_tags=extra_tags).extra_tags,
                     extra_tags,
                 )
+
+
+class BisectTests(TestCase):
+    def test_bisect_keep_left(self):
+        self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 2), 2)
+        self.assertEqual(bisect_keep_left([1, 1, 1], fn=lambda arr: sum(arr) != 0), 0)
+        self.assertEqual(bisect_keep_left([], fn=lambda arr: sum(arr) != 0), 0)
+
+    def test_bisect_keep_right(self):
+        self.assertEqual(bisect_keep_right([1, 1, 1], fn=lambda arr: sum(arr) != 2), 1)
+        self.assertEqual(
+            bisect_keep_right([1, 1, 1, 1], fn=lambda arr: sum(arr) != 2), 2
+        )
+        self.assertEqual(
+            bisect_keep_right([1, 1, 1, 1, 1], fn=lambda arr: sum(arr) != 1), 4
+        )
+        self.assertEqual(bisect_keep_right([], fn=lambda arr: sum(arr) != 0), 0)