Browse Source

Fixed #33060 -- Added BaseCache.make_and_validate_key() hook.

This helper function reduces the amount of duplicated code and makes it
easier to ensure that we always validate the keys.
Nick Pope 3 years ago
parent
commit
42dfa97e19

+ 15 - 9
django/core/cache/backends/base.py

@@ -105,6 +105,21 @@ class BaseCache:
 
         return self.key_func(key, self.key_prefix, version)
 
+    def validate_key(self, key):
+        """
+        Warn about keys that would not be portable to the memcached
+        backend. This encourages (but does not force) writing backend-portable
+        cache code.
+        """
+        for warning in memcache_key_warnings(key):
+            warnings.warn(warning, CacheKeyWarning)
+
+    def make_and_validate_key(self, key, version=None):
+        """Helper to make and validate keys."""
+        key = self.make_key(key, version=version)
+        self.validate_key(key)
+        return key
+
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
         """
         Set a value in the cache if the key does not already exist. If
@@ -240,15 +255,6 @@ class BaseCache:
         """Remove *all* values from the cache at once."""
         raise NotImplementedError('subclasses of BaseCache must provide a clear() method')
 
-    def validate_key(self, key):
-        """
-        Warn about keys that would not be portable to the memcached
-        backend. This encourages (but does not force) writing backend-portable
-        cache code.
-        """
-        for warning in memcache_key_warnings(key):
-            warnings.warn(warning, CacheKeyWarning)
-
     def incr_version(self, key, delta=1, version=None):
         """
         Add delta to the cache version for the supplied key. Return the new

+ 8 - 21
django/core/cache/backends/db.py

@@ -54,11 +54,7 @@ class DatabaseCache(BaseDatabaseCache):
         if not keys:
             return {}
 
-        key_map = {}
-        for key in keys:
-            new_key = self.make_key(key, version)
-            self.validate_key(new_key)
-            key_map[new_key] = key
+        key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
 
         db = router.db_for_read(self.cache_model_class)
         connection = connections[db]
@@ -96,18 +92,15 @@ class DatabaseCache(BaseDatabaseCache):
         return result
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         self._base_set('set', key, value, timeout)
 
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return self._base_set('add', key, value, timeout)
 
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return self._base_set('touch', key, None, timeout)
 
     def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
@@ -197,17 +190,12 @@ class DatabaseCache(BaseDatabaseCache):
                 return True
 
     def delete(self, key, version=None):
-        key = self.make_key(key, version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return self._base_delete_many([key])
 
     def delete_many(self, keys, version=None):
-        key_list = []
-        for key in keys:
-            key = self.make_key(key, version)
-            self.validate_key(key)
-            key_list.append(key)
-        self._base_delete_many(key_list)
+        keys = [self.make_and_validate_key(key, version=version) for key in keys]
+        self._base_delete_many(keys)
 
     def _base_delete_many(self, keys):
         if not keys:
@@ -230,8 +218,7 @@ class DatabaseCache(BaseDatabaseCache):
             return bool(cursor.rowcount)
 
     def has_key(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
 
         db = router.db_for_read(self.cache_model_class)
         connection = connections[db]

+ 6 - 12
django/core/cache/backends/dummy.py

@@ -8,32 +8,26 @@ class DummyCache(BaseCache):
         super().__init__(*args, **kwargs)
 
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
         return True
 
     def get(self, key, default=None, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
         return default
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
 
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
         return False
 
     def delete(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
         return False
 
     def has_key(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        self.make_and_validate_key(key, version=version)
         return False
 
     def clear(self):

+ 1 - 2
django/core/cache/backends/filebased.py

@@ -127,8 +127,7 @@ class FileBasedCache(BaseCache):
         Convert a key into a cache file path. Basically this is the
         root cache path joined with the md5sum of the key and a suffix.
         """
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return os.path.join(self._dir, ''.join(
             [hashlib.md5(key.encode()).hexdigest(), self.cache_suffix]))
 

+ 7 - 14
django/core/cache/backends/locmem.py

@@ -23,8 +23,7 @@ class LocMemCache(BaseCache):
         self._lock = _locks.setdefault(name, Lock())
 
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         pickled = pickle.dumps(value, self.pickle_protocol)
         with self._lock:
             if self._has_expired(key):
@@ -33,8 +32,7 @@ class LocMemCache(BaseCache):
             return False
 
     def get(self, key, default=None, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         with self._lock:
             if self._has_expired(key):
                 self._delete(key)
@@ -51,15 +49,13 @@ class LocMemCache(BaseCache):
         self._expire_info[key] = self.get_backend_timeout(timeout)
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         pickled = pickle.dumps(value, self.pickle_protocol)
         with self._lock:
             self._set(key, pickled, timeout)
 
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         with self._lock:
             if self._has_expired(key):
                 return False
@@ -67,8 +63,7 @@ class LocMemCache(BaseCache):
             return True
 
     def incr(self, key, delta=1, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         with self._lock:
             if self._has_expired(key):
                 self._delete(key)
@@ -82,8 +77,7 @@ class LocMemCache(BaseCache):
         return new_value
 
     def has_key(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         with self._lock:
             if self._has_expired(key):
                 self._delete(key)
@@ -113,8 +107,7 @@ class LocMemCache(BaseCache):
         return True
 
     def delete(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         with self._lock:
             return self._delete(key)
 

+ 12 - 26
django/core/cache/backends/memcached.py

@@ -67,36 +67,29 @@ class BaseMemcachedCache(BaseCache):
         return int(timeout)
 
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return self._cache.add(key, value, self.get_backend_timeout(timeout))
 
     def get(self, key, default=None, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return self._cache.get(key, default)
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         if not self._cache.set(key, value, self.get_backend_timeout(timeout)):
             # make sure the key doesn't keep its old value in case of failure to set (memcached's 1MB limit)
             self._cache.delete(key)
 
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return bool(self._cache.touch(key, self.get_backend_timeout(timeout)))
 
     def delete(self, key, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return bool(self._cache.delete(key))
 
     def get_many(self, keys, version=None):
-        key_map = {self.make_key(key, version=version): key for key in keys}
-        for key in key_map:
-            self.validate_key(key)
+        key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
         ret = self._cache.get_multi(key_map.keys())
         return {key_map[k]: v for k, v in ret.items()}
 
@@ -105,8 +98,7 @@ class BaseMemcachedCache(BaseCache):
         self._cache.disconnect_all()
 
     def incr(self, key, delta=1, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         try:
             # Memcached doesn't support negative delta.
             if delta < 0:
@@ -126,17 +118,14 @@ class BaseMemcachedCache(BaseCache):
         safe_data = {}
         original_keys = {}
         for key, value in data.items():
-            safe_key = self.make_key(key, version=version)
-            self.validate_key(safe_key)
+            safe_key = self.make_and_validate_key(key, version=version)
             safe_data[safe_key] = value
             original_keys[safe_key] = key
         failed_keys = self._cache.set_multi(safe_data, self.get_backend_timeout(timeout))
         return [original_keys[k] for k in failed_keys]
 
     def delete_many(self, keys, version=None):
-        keys = [self.make_key(key, version=version) for key in keys]
-        for key in keys:
-            self.validate_key(key)
+        keys = [self.make_and_validate_key(key, version=version) for key in keys]
         self._cache.delete_multi(keys)
 
     def clear(self):
@@ -167,8 +156,7 @@ class MemcachedCache(BaseMemcachedCache):
         self._options = {'pickleProtocol': pickle.HIGHEST_PROTOCOL, **self._options}
 
     def get(self, key, default=None, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         val = self._cache.get(key)
         # python-memcached doesn't support default values in get().
         # https://github.com/linsomniac/python-memcached/issues/159
@@ -181,8 +169,7 @@ class MemcachedCache(BaseMemcachedCache):
         # python-memcached's delete() returns True when key doesn't exist.
         # https://github.com/linsomniac/python-memcached/issues/170
         # Call _deletetouch() without the NOT_FOUND in expected results.
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         return bool(self._cache._deletetouch([b'DELETED'], 'delete', key))
 
 
@@ -200,8 +187,7 @@ class PyLibMCCache(BaseMemcachedCache):
         return output
 
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
-        key = self.make_key(key, version=version)
-        self.validate_key(key)
+        key = self.make_and_validate_key(key, version=version)
         if timeout == 0:
             return self._cache.delete(key)
         return self._cache.touch(key, self.get_backend_timeout(timeout))