Browse Source

Fixed #32076 -- Added async methods to BaseCache.

This also makes DummyCache async-compatible.
Andrew-Chen-Wang 4 years ago
parent
commit
301a85a12f
4 changed files with 312 additions and 1 deletions
  1. 87 0
      django/core/cache/backends/base.py
  2. 7 1
      docs/releases/4.0.txt
  3. 31 0
      docs/topics/cache.txt
  4. 187 0
      tests/cache/tests_async.py

+ 87 - 0
django/core/cache/backends/base.py

@@ -2,6 +2,8 @@
 import time
 import warnings
 
+from asgiref.sync import sync_to_async
+
 from django.core.exceptions import ImproperlyConfigured
 from django.utils.module_loading import import_string
 
@@ -130,6 +132,9 @@ class BaseCache:
         """
         raise NotImplementedError('subclasses of BaseCache must provide an add() method')
 
+    async def aadd(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
+        return await sync_to_async(self.add, thread_sensitive=True)(key, value, timeout, version)
+
     def get(self, key, default=None, version=None):
         """
         Fetch a given key from the cache. If the key does not exist, return
@@ -137,6 +142,9 @@ class BaseCache:
         """
         raise NotImplementedError('subclasses of BaseCache must provide a get() method')
 
+    async def aget(self, key, default=None, version=None):
+        return await sync_to_async(self.get, thread_sensitive=True)(key, default, version)
+
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
         """
         Set a value in the cache. If timeout is given, use that timeout for the
@@ -144,6 +152,9 @@ class BaseCache:
         """
         raise NotImplementedError('subclasses of BaseCache must provide a set() method')
 
+    async def aset(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
+        return await sync_to_async(self.set, thread_sensitive=True)(key, value, timeout, version)
+
     def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
         """
         Update the key's expiry time using timeout. Return True if successful
@@ -151,6 +162,9 @@ class BaseCache:
         """
         raise NotImplementedError('subclasses of BaseCache must provide a touch() method')
 
+    async def atouch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
+        return await sync_to_async(self.touch, thread_sensitive=True)(key, timeout, version)
+
     def delete(self, key, version=None):
         """
         Delete a key from the cache and return whether it succeeded, failing
@@ -158,6 +172,9 @@ class BaseCache:
         """
         raise NotImplementedError('subclasses of BaseCache must provide a delete() method')
 
+    async def adelete(self, key, version=None):
+        return await sync_to_async(self.delete, thread_sensitive=True)(key, version)
+
     def get_many(self, keys, version=None):
         """
         Fetch a bunch of keys from the cache. For certain backends (memcached,
@@ -173,6 +190,15 @@ class BaseCache:
                 d[k] = val
         return d
 
+    async def aget_many(self, keys, version=None):
+        """See get_many()."""
+        d = {}
+        for k in keys:
+            val = await self.aget(k, self._missing_key, version=version)
+            if val is not self._missing_key:
+                d[k] = val
+        return d
+
     def get_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
         """
         Fetch a given key from the cache. If the key does not exist,
@@ -192,12 +218,30 @@ class BaseCache:
             return self.get(key, default, version=version)
         return val
 
+    async def aget_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
+        """See get_or_set()."""
+        val = await self.aget(key, self._missing_key, version=version)
+        if val is self._missing_key:
+            if callable(default):
+                default = default()
+            await self.aadd(key, default, timeout=timeout, version=version)
+            # Fetch the value again to avoid a race condition if another caller
+            # added a value between the first aget() and the aadd() above.
+            return await self.aget(key, default, version=version)
+        return val
+
     def has_key(self, key, version=None):
         """
         Return True if the key is in the cache and has not expired.
         """
         return self.get(key, self._missing_key, version=version) is not self._missing_key
 
+    async def ahas_key(self, key, version=None):
+        return (
+            await self.aget(key, self._missing_key, version=version)
+            is not self._missing_key
+        )
+
     def incr(self, key, delta=1, version=None):
         """
         Add delta to value in the cache. If the key does not exist, raise a
@@ -210,6 +254,15 @@ class BaseCache:
         self.set(key, new_value, version=version)
         return new_value
 
+    async def aincr(self, key, delta=1, version=None):
+        """See incr()."""
+        value = await self.aget(key, self._missing_key, version=version)
+        if value is self._missing_key:
+            raise ValueError("Key '%s' not found" % key)
+        new_value = value + delta
+        await self.aset(key, new_value, version=version)
+        return new_value
+
     def decr(self, key, delta=1, version=None):
         """
         Subtract delta from value in the cache. If the key does not exist, raise
@@ -217,6 +270,9 @@ class BaseCache:
         """
         return self.incr(key, -delta, version=version)
 
+    async def adecr(self, key, delta=1, version=None):
+        return await self.aincr(key, -delta, version=version)
+
     def __contains__(self, key):
         """
         Return True if the key is in the cache and has not expired.
@@ -242,6 +298,11 @@ class BaseCache:
             self.set(key, value, timeout=timeout, version=version)
         return []
 
+    async def aset_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
+        for key, value in data.items():
+            await self.aset(key, value, timeout=timeout, version=version)
+        return []
+
     def delete_many(self, keys, version=None):
         """
         Delete a bunch of values in the cache at once. For certain backends
@@ -251,10 +312,17 @@ class BaseCache:
         for key in keys:
             self.delete(key, version=version)
 
+    async def adelete_many(self, keys, version=None):
+        for key in keys:
+            await self.adelete(key, version=version)
+
     def clear(self):
         """Remove *all* values from the cache at once."""
         raise NotImplementedError('subclasses of BaseCache must provide a clear() method')
 
+    async def aclear(self):
+        return await sync_to_async(self.clear, thread_sensitive=True)()
+
     def incr_version(self, key, delta=1, version=None):
         """
         Add delta to the cache version for the supplied key. Return the new
@@ -271,6 +339,19 @@ class BaseCache:
         self.delete(key, version=version)
         return version + delta
 
+    async def aincr_version(self, key, delta=1, version=None):
+        """See incr_version()."""
+        if version is None:
+            version = self.version
+
+        value = await self.aget(key, self._missing_key, version=version)
+        if value is self._missing_key:
+            raise ValueError("Key '%s' not found" % key)
+
+        await self.aset(key, value, version=version + delta)
+        await self.adelete(key, version=version)
+        return version + delta
+
     def decr_version(self, key, delta=1, version=None):
         """
         Subtract delta from the cache version for the supplied key. Return the
@@ -278,10 +359,16 @@ class BaseCache:
         """
         return self.incr_version(key, -delta, version)
 
+    async def adecr_version(self, key, delta=1, version=None):
+        return await self.aincr_version(key, -delta, version)
+
     def close(self, **kwargs):
         """Close the cache connection"""
         pass
 
+    async def aclose(self, **kwargs):
+        pass
+
 
 def memcache_key_warnings(key):
     if len(key) > MEMCACHE_MAX_KEY_LENGTH:

+ 7 - 1
docs/releases/4.0.txt

@@ -187,7 +187,13 @@ Minor features
 Cache
 ~~~~~
 
-* ...
+* The new async API for ``django.core.cache.backends.base.BaseCache`` begins
+  the process of making cache backends async-compatible. The new async methods
+  all have ``a`` prefixed names, e.g. ``aadd()``, ``aget()``, ``aset()``,
+  ``aget_or_set()``, or ``adelete_many()``.
+
+  Going forward, the ``a`` prefix will be used for async variants of methods
+  generally.
 
 CSRF
 ~~~~

+ 31 - 0
docs/topics/cache.txt

@@ -808,6 +808,8 @@ Accessing the cache
 
     This object is equivalent to ``caches['default']``.
 
+.. _cache-basic-interface:
+
 Basic usage
 -----------
 
@@ -997,6 +999,16 @@ the cache backend.
 
     For caches that don't implement ``close`` methods it is a no-op.
 
+.. note::
+
+    The async variants of base methods are prefixed with ``a``, e.g.
+    ``cache.aadd()`` or ``cache.adelete_many()``. See `Asynchronous support`_
+    for more details.
+
+.. versionchanged:: 4.0
+
+    The async variants of methods were added to the ``BaseCache``.
+
 .. _cache_key_prefixing:
 
 Cache key prefixing
@@ -1123,6 +1135,25 @@ instance, to do this for the ``locmem`` backend, put this code in a module::
 ...and use the dotted Python path to this class in the
 :setting:`BACKEND <CACHES-BACKEND>` portion of your :setting:`CACHES` setting.
 
+.. _asynchronous_support:
+
+Asynchronous support
+====================
+
+.. versionadded:: 4.0
+
+Django has developing support for asynchronous cache backends, but does not
+yet support asynchronous caching. It will be coming in a future release.
+
+``django.core.cache.backends.base.BaseCache`` has async variants of :ref:`all
+base methods <cache-basic-interface>`. By convention, the asynchronous versions
+of all methods are prefixed with ``a``. By default, the arguments for both
+variants are the same::
+
+    >>> await cache.aset('num', 1)
+    >>> await cache.ahas_key('num')
+    True
+
 .. _downstream-caches:
 
 Downstream caches

+ 187 - 0
tests/cache/tests_async.py

@@ -0,0 +1,187 @@
+import asyncio
+
+from django.core.cache import CacheKeyWarning, cache
+from django.test import SimpleTestCase, override_settings
+
+from .tests import KEY_ERRORS_WITH_MEMCACHED_MSG
+
+
+@override_settings(CACHES={
+    'default': {
+        'BACKEND': 'django.core.cache.backends.dummy.DummyCache',
+    }
+})
+class AsyncDummyCacheTests(SimpleTestCase):
+    async def test_simple(self):
+        """Dummy cache backend ignores cache set calls."""
+        await cache.aset('key', 'value')
+        self.assertIsNone(await cache.aget('key'))
+
+    async def test_aadd(self):
+        """Add doesn't do anything in dummy cache backend."""
+        self.assertIs(await cache.aadd('key', 'value'), True)
+        self.assertIs(await cache.aadd('key', 'new_value'), True)
+        self.assertIsNone(await cache.aget('key'))
+
+    async def test_non_existent(self):
+        """Nonexistent keys aren't found in the dummy cache backend."""
+        self.assertIsNone(await cache.aget('does_not_exist'))
+        self.assertEqual(await cache.aget('does_not_exist', 'default'), 'default')
+
+    async def test_aget_many(self):
+        """aget_many() returns nothing for the dummy cache backend."""
+        await cache.aset_many({'a': 'a', 'b': 'b', 'c': 'c', 'd': 'd'})
+        self.assertEqual(await cache.aget_many(['a', 'c', 'd']), {})
+        self.assertEqual(await cache.aget_many(['a', 'b', 'e']), {})
+
+    async def test_aget_many_invalid_key(self):
+        msg = KEY_ERRORS_WITH_MEMCACHED_MSG % ':1:key with spaces'
+        with self.assertWarnsMessage(CacheKeyWarning, msg):
+            await cache.aget_many(['key with spaces'])
+
+    async def test_adelete(self):
+        """
+        Cache deletion is transparently ignored on the dummy cache backend.
+        """
+        await cache.aset_many({'key1': 'spam', 'key2': 'eggs'})
+        self.assertIsNone(await cache.aget('key1'))
+        self.assertIs(await cache.adelete('key1'), False)
+        self.assertIsNone(await cache.aget('key1'))
+        self.assertIsNone(await cache.aget('key2'))
+
+    async def test_ahas_key(self):
+        """ahas_key() doesn't ever return True for the dummy cache backend."""
+        await cache.aset('hello1', 'goodbye1')
+        self.assertIs(await cache.ahas_key('hello1'), False)
+        self.assertIs(await cache.ahas_key('goodbye1'), False)
+
+    async def test_aincr(self):
+        """Dummy cache values can't be incremented."""
+        await cache.aset('answer', 42)
+        with self.assertRaises(ValueError):
+            await cache.aincr('answer')
+        with self.assertRaises(ValueError):
+            await cache.aincr('does_not_exist')
+        with self.assertRaises(ValueError):
+            await cache.aincr('does_not_exist', -1)
+
+    async def test_adecr(self):
+        """Dummy cache values can't be decremented."""
+        await cache.aset('answer', 42)
+        with self.assertRaises(ValueError):
+            await cache.adecr('answer')
+        with self.assertRaises(ValueError):
+            await cache.adecr('does_not_exist')
+        with self.assertRaises(ValueError):
+            await cache.adecr('does_not_exist', -1)
+
+    async def test_atouch(self):
+        self.assertIs(await cache.atouch('key'), False)
+
+    async def test_data_types(self):
+        """All data types are ignored equally by the dummy cache."""
+        def f():
+            return 42
+
+        class C:
+            def m(n):
+                return 24
+
+        data = {
+            'string': 'this is a string',
+            'int': 42,
+            'list': [1, 2, 3, 4],
+            'tuple': (1, 2, 3, 4),
+            'dict': {'A': 1, 'B': 2},
+            'function': f,
+            'class': C,
+        }
+        await cache.aset('data', data)
+        self.assertIsNone(await cache.aget('data'))
+
+    async def test_expiration(self):
+        """Expiration has no effect on the dummy cache."""
+        await cache.aset('expire1', 'very quickly', 1)
+        await cache.aset('expire2', 'very quickly', 1)
+        await cache.aset('expire3', 'very quickly', 1)
+
+        await asyncio.sleep(2)
+        self.assertIsNone(await cache.aget('expire1'))
+
+        self.assertIs(await cache.aadd('expire2', 'new_value'), True)
+        self.assertIsNone(await cache.aget('expire2'))
+        self.assertIs(await cache.ahas_key('expire3'), False)
+
+    async def test_unicode(self):
+        """Unicode values are ignored by the dummy cache."""
+        tests = {
+            'ascii': 'ascii_value',
+            'unicode_ascii': 'Iñtërnâtiônàlizætiøn1',
+            'Iñtërnâtiônàlizætiøn': 'Iñtërnâtiônàlizætiøn2',
+            'ascii2': {'x': 1},
+        }
+        for key, value in tests.items():
+            with self.subTest(key=key):
+                await cache.aset(key, value)
+                self.assertIsNone(await cache.aget(key))
+
+    async def test_aset_many(self):
+        """aset_many() does nothing for the dummy cache backend."""
+        self.assertEqual(await cache.aset_many({'a': 1, 'b': 2}), [])
+        self.assertEqual(
+            await cache.aset_many({'a': 1, 'b': 2}, timeout=2, version='1'),
+            [],
+        )
+
+    async def test_aset_many_invalid_key(self):
+        msg = KEY_ERRORS_WITH_MEMCACHED_MSG % ':1:key with spaces'
+        with self.assertWarnsMessage(CacheKeyWarning, msg):
+            await cache.aset_many({'key with spaces': 'foo'})
+
+    async def test_adelete_many(self):
+        """adelete_many() does nothing for the dummy cache backend."""
+        await cache.adelete_many(['a', 'b'])
+
+    async def test_adelete_many_invalid_key(self):
+        msg = KEY_ERRORS_WITH_MEMCACHED_MSG % ':1:key with spaces'
+        with self.assertWarnsMessage(CacheKeyWarning, msg):
+            await cache.adelete_many({'key with spaces': 'foo'})
+
+    async def test_aclear(self):
+        """aclear() does nothing for the dummy cache backend."""
+        await cache.aclear()
+
+    async def test_aclose(self):
+        """aclose() does nothing for the dummy cache backend."""
+        await cache.aclose()
+
+    async def test_aincr_version(self):
+        """Dummy cache versions can't be incremented."""
+        await cache.aset('answer', 42)
+        with self.assertRaises(ValueError):
+            await cache.aincr_version('answer')
+        with self.assertRaises(ValueError):
+            await cache.aincr_version('answer', version=2)
+        with self.assertRaises(ValueError):
+            await cache.aincr_version('does_not_exist')
+
+    async def test_adecr_version(self):
+        """Dummy cache versions can't be decremented."""
+        await cache.aset('answer', 42)
+        with self.assertRaises(ValueError):
+            await cache.adecr_version('answer')
+        with self.assertRaises(ValueError):
+            await cache.adecr_version('answer', version=2)
+        with self.assertRaises(ValueError):
+            await cache.adecr_version('does_not_exist')
+
+    async def test_aget_or_set(self):
+        self.assertEqual(await cache.aget_or_set('key', 'default'), 'default')
+        self.assertIsNone(await cache.aget_or_set('key', None))
+
+    async def test_aget_or_set_callable(self):
+        def my_callable():
+            return 'default'
+
+        self.assertEqual(await cache.aget_or_set('key', my_callable), 'default')
+        self.assertEqual(await cache.aget_or_set('key', my_callable()), 'default')