Browse Source

Fixed #32233 -- Cleaned-up duplicate connection functionality.

Florian Apolloner 7 years ago
parent
commit
98e05ccde4

+ 9 - 56
django/core/cache/__init__.py

@@ -12,13 +12,11 @@ object.
 
 See docs/topics/cache.txt for information on the public API.
 """
-from asgiref.local import Local
-
-from django.conf import settings
 from django.core import signals
 from django.core.cache.backends.base import (
     BaseCache, CacheKeyWarning, InvalidCacheBackendError, InvalidCacheKey,
 )
+from django.utils.connection import BaseConnectionHandler, ConnectionProxy
 from django.utils.module_loading import import_string
 
 __all__ = [
@@ -29,28 +27,12 @@ __all__ = [
 DEFAULT_CACHE_ALIAS = 'default'
 
 
-class CacheHandler:
-    """
-    A Cache Handler to manage access to Cache instances.
-
-    Ensure only one instance of each alias exists per thread.
-    """
-    def __init__(self):
-        self._caches = Local()
-
-    def __getitem__(self, alias):
-        try:
-            return self._caches.caches[alias]
-        except AttributeError:
-            self._caches.caches = {}
-        except KeyError:
-            pass
+class CacheHandler(BaseConnectionHandler):
+    settings_name = 'CACHES'
+    exception_class = InvalidCacheBackendError
 
-        if alias not in settings.CACHES:
-            raise InvalidCacheBackendError(
-                "Could not find config for '%s' in settings.CACHES" % alias
-            )
-        params = settings.CACHES[alias].copy()
+    def create_connection(self, alias):
+        params = self.settings[alias].copy()
         backend = params.pop('BACKEND')
         location = params.pop('LOCATION', '')
         try:
@@ -58,42 +40,13 @@ class CacheHandler:
         except ImportError as e:
             raise InvalidCacheBackendError(
                 "Could not find backend '%s': %s" % (backend, e)
-            )
-        cache = backend_cls(location, params)
-        self._caches.caches[alias] = cache
-        return cache
-
-    def all(self):
-        return getattr(self._caches, 'caches', {}).values()
+            ) from e
+        return backend_cls(location, params)
 
 
 caches = CacheHandler()
 
-
-class DefaultCacheProxy:
-    """
-    Proxy access to the default Cache object's attributes.
-
-    This allows the legacy `cache` object to be thread-safe using the new
-    ``caches`` API.
-    """
-    def __getattr__(self, name):
-        return getattr(caches[DEFAULT_CACHE_ALIAS], name)
-
-    def __setattr__(self, name, value):
-        return setattr(caches[DEFAULT_CACHE_ALIAS], name, value)
-
-    def __delattr__(self, name):
-        return delattr(caches[DEFAULT_CACHE_ALIAS], name)
-
-    def __contains__(self, key):
-        return key in caches[DEFAULT_CACHE_ALIAS]
-
-    def __eq__(self, other):
-        return caches[DEFAULT_CACHE_ALIAS] == other
-
-
-cache = DefaultCacheProxy()
+cache = ConnectionProxy(caches, DEFAULT_CACHE_ALIAS)
 
 
 def close_caches(**kwargs):

+ 2 - 21
django/db/__init__.py

@@ -5,6 +5,7 @@ from django.db.utils import (
     InterfaceError, InternalError, NotSupportedError, OperationalError,
     ProgrammingError,
 )
+from django.utils.connection import ConnectionProxy
 
 __all__ = [
     'connection', 'connections', 'router', 'DatabaseError', 'IntegrityError',
@@ -17,28 +18,8 @@ connections = ConnectionHandler()
 
 router = ConnectionRouter()
 
-
-class DefaultConnectionProxy:
-    """
-    Proxy for accessing the default DatabaseWrapper object's attributes. If you
-    need to access the DatabaseWrapper object itself, use
-    connections[DEFAULT_DB_ALIAS] instead.
-    """
-    def __getattr__(self, item):
-        return getattr(connections[DEFAULT_DB_ALIAS], item)
-
-    def __setattr__(self, name, value):
-        return setattr(connections[DEFAULT_DB_ALIAS], name, value)
-
-    def __delattr__(self, name):
-        return delattr(connections[DEFAULT_DB_ALIAS], name)
-
-    def __eq__(self, other):
-        return connections[DEFAULT_DB_ALIAS] == other
-
-
 # For backwards compatibility. Prefer connections['default'] instead.
-connection = DefaultConnectionProxy()
+connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
 
 
 # Register an event to reset saved queries when a Django request is started.

+ 29 - 54
django/db/utils.py

@@ -2,10 +2,11 @@ import pkgutil
 from importlib import import_module
 from pathlib import Path
 
-from asgiref.local import Local
-
 from django.conf import settings
 from django.core.exceptions import ImproperlyConfigured
+# For backwards compatibility with Django < 3.2
+from django.utils.connection import ConnectionDoesNotExist  # NOQA: F401
+from django.utils.connection import BaseConnectionHandler
 from django.utils.functional import cached_property
 from django.utils.module_loading import import_string
 
@@ -131,39 +132,30 @@ def load_backend(backend_name):
             raise
 
 
-class ConnectionDoesNotExist(Exception):
-    pass
-
-
-class ConnectionHandler:
-    def __init__(self, databases=None):
-        """
-        databases is an optional dictionary of database definitions (structured
-        like settings.DATABASES).
-        """
-        self._databases = databases
-        # Connections needs to still be an actual thread local, as it's truly
-        # thread-critical. Database backends should use @async_unsafe to protect
-        # their code from async contexts, but this will give those contexts
-        # separate connections in case it's needed as well. There's no cleanup
-        # after async contexts, though, so we don't allow that if we can help it.
-        self._connections = Local(thread_critical=True)
+class ConnectionHandler(BaseConnectionHandler):
+    settings_name = 'DATABASES'
+    # Connections needs to still be an actual thread local, as it's truly
+    # thread-critical. Database backends should use @async_unsafe to protect
+    # their code from async contexts, but this will give those contexts
+    # separate connections in case it's needed as well. There's no cleanup
+    # after async contexts, though, so we don't allow that if we can help it.
+    thread_critical = True
+
+    def configure_settings(self, databases):
+        databases = super().configure_settings(databases)
+        if databases == {}:
+            databases[DEFAULT_DB_ALIAS] = {'ENGINE': 'django.db.backends.dummy'}
+        elif DEFAULT_DB_ALIAS not in databases:
+            raise ImproperlyConfigured(
+                f"You must define a '{DEFAULT_DB_ALIAS}' database."
+            )
+        elif databases[DEFAULT_DB_ALIAS] == {}:
+            databases[DEFAULT_DB_ALIAS]['ENGINE'] = 'django.db.backends.dummy'
+        return databases
 
-    @cached_property
+    @property
     def databases(self):
-        if self._databases is None:
-            self._databases = settings.DATABASES
-        if self._databases == {}:
-            self._databases = {
-                DEFAULT_DB_ALIAS: {
-                    'ENGINE': 'django.db.backends.dummy',
-                },
-            }
-        if DEFAULT_DB_ALIAS not in self._databases:
-            raise ImproperlyConfigured("You must define a '%s' database." % DEFAULT_DB_ALIAS)
-        if self._databases[DEFAULT_DB_ALIAS] == {}:
-            self._databases[DEFAULT_DB_ALIAS]['ENGINE'] = 'django.db.backends.dummy'
-        return self._databases
+        return self.settings
 
     def ensure_defaults(self, alias):
         """
@@ -173,7 +165,7 @@ class ConnectionHandler:
         try:
             conn = self.databases[alias]
         except KeyError:
-            raise ConnectionDoesNotExist("The connection %s doesn't exist" % alias)
+            raise self.exception_class(f"The connection '{alias}' doesn't exist.")
 
         conn.setdefault('ATOMIC_REQUESTS', False)
         conn.setdefault('AUTOCOMMIT', True)
@@ -193,7 +185,7 @@ class ConnectionHandler:
         try:
             conn = self.databases[alias]
         except KeyError:
-            raise ConnectionDoesNotExist("The connection %s doesn't exist" % alias)
+            raise self.exception_class(f"The connection '{alias}' doesn't exist.")
 
         test_settings = conn.setdefault('TEST', {})
         default_test_settings = [
@@ -206,29 +198,12 @@ class ConnectionHandler:
         for key, value in default_test_settings:
             test_settings.setdefault(key, value)
 
-    def __getitem__(self, alias):
-        if hasattr(self._connections, alias):
-            return getattr(self._connections, alias)
-
+    def create_connection(self, alias):
         self.ensure_defaults(alias)
         self.prepare_test_settings(alias)
         db = self.databases[alias]
         backend = load_backend(db['ENGINE'])
-        conn = backend.DatabaseWrapper(db, alias)
-        setattr(self._connections, alias, conn)
-        return conn
-
-    def __setitem__(self, key, value):
-        setattr(self._connections, key, value)
-
-    def __delitem__(self, key):
-        delattr(self._connections, key)
-
-    def __iter__(self):
-        return iter(self.databases)
-
-    def all(self):
-        return [self[alias] for alias in self]
+        return backend.DatabaseWrapper(db, alias)
 
     def close_all(self):
         for alias in self:

+ 2 - 1
django/test/signals.py

@@ -28,7 +28,8 @@ def clear_cache_handlers(**kwargs):
     if kwargs['setting'] == 'CACHES':
         from django.core.cache import caches, close_caches
         close_caches()
-        caches._caches = Local()
+        caches._settings = caches.settings = caches.configure_settings(None)
+        caches._connections = Local()
 
 
 @receiver(setting_changed)

+ 76 - 0
django/utils/connection.py

@@ -0,0 +1,76 @@
+from asgiref.local import Local
+
+from django.conf import settings as django_settings
+from django.utils.functional import cached_property
+
+
+class ConnectionProxy:
+    """Proxy for accessing a connection object's attributes."""
+
+    def __init__(self, connections, alias):
+        self.__dict__['_connections'] = connections
+        self.__dict__['_alias'] = alias
+
+    def __getattr__(self, item):
+        return getattr(self._connections[self._alias], item)
+
+    def __setattr__(self, name, value):
+        return setattr(self._connections[self._alias], name, value)
+
+    def __delattr__(self, name):
+        return delattr(self._connections[self._alias], name)
+
+    def __contains__(self, key):
+        return key in self._connections[self._alias]
+
+    def __eq__(self, other):
+        return self._connections[self._alias] == other
+
+
+class ConnectionDoesNotExist(Exception):
+    pass
+
+
+class BaseConnectionHandler:
+    settings_name = None
+    exception_class = ConnectionDoesNotExist
+    thread_critical = False
+
+    def __init__(self, settings=None):
+        self._settings = settings
+        self._connections = Local(self.thread_critical)
+
+    @cached_property
+    def settings(self):
+        self._settings = self.configure_settings(self._settings)
+        return self._settings
+
+    def configure_settings(self, settings):
+        if settings is None:
+            settings = getattr(django_settings, self.settings_name)
+        return settings
+
+    def create_connection(self, alias):
+        raise NotImplementedError('Subclasses must implement create_connection().')
+
+    def __getitem__(self, alias):
+        try:
+            return getattr(self._connections, alias)
+        except AttributeError:
+            if alias not in self.settings:
+                raise self.exception_class(f"The connection '{alias}' doesn't exist.")
+        conn = self.create_connection(alias)
+        setattr(self._connections, alias, conn)
+        return conn
+
+    def __setitem__(self, key, value):
+        setattr(self._connections, key, value)
+
+    def __delitem__(self, key):
+        delattr(self._connections, key)
+
+    def __iter__(self):
+        return iter(self.settings)
+
+    def all(self):
+        return [self[alias] for alias in self]

+ 1 - 1
docs/topics/db/multi-db.txt

@@ -74,7 +74,7 @@ example ``settings.py`` snippet defining two non-default databases, with the
 
 If you attempt to access a database that you haven't defined in your
 :setting:`DATABASES` setting, Django will raise a
-``django.db.utils.ConnectionDoesNotExist`` exception.
+``django.utils.connection.ConnectionDoesNotExist`` exception.
 
 .. _synchronizing_multiple_databases:
 

+ 10 - 9
tests/cache/tests.py

@@ -17,7 +17,8 @@ from unittest import mock, skipIf
 from django.conf import settings
 from django.core import management, signals
 from django.core.cache import (
-    DEFAULT_CACHE_ALIAS, CacheKeyWarning, InvalidCacheKey, cache, caches,
+    DEFAULT_CACHE_ALIAS, CacheHandler, CacheKeyWarning, InvalidCacheKey, cache,
+    caches,
 )
 from django.core.cache.backends.base import InvalidCacheBackendError
 from django.core.cache.utils import make_template_fragment_key
@@ -2501,19 +2502,19 @@ class CacheHandlerTest(SimpleTestCase):
         self.assertIsNot(c[0], c[1])
 
     def test_nonexistent_alias(self):
-        msg = "Could not find config for 'nonexistent' in settings.CACHES"
+        msg = "The connection 'nonexistent' doesn't exist."
         with self.assertRaisesMessage(InvalidCacheBackendError, msg):
             caches['nonexistent']
 
     def test_nonexistent_backend(self):
+        test_caches = CacheHandler({
+            'invalid_backend': {
+                'BACKEND': 'django.nonexistent.NonexistentBackend',
+            },
+        })
         msg = (
             "Could not find backend 'django.nonexistent.NonexistentBackend': "
             "No module named 'django.nonexistent'"
         )
-        with self.settings(CACHES={
-            'invalid_backend': {
-                'BACKEND': 'django.nonexistent.NonexistentBackend',
-            },
-        }):
-            with self.assertRaisesMessage(InvalidCacheBackendError, msg):
-                caches['invalid_backend']
+        with self.assertRaisesMessage(InvalidCacheBackendError, msg):
+            test_caches['invalid_backend']

+ 5 - 6
tests/db_utils/tests.py

@@ -3,10 +3,9 @@ import unittest
 
 from django.core.exceptions import ImproperlyConfigured
 from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection
-from django.db.utils import (
-    ConnectionDoesNotExist, ConnectionHandler, load_backend,
-)
+from django.db.utils import ConnectionHandler, load_backend
 from django.test import SimpleTestCase, TestCase
+from django.utils.connection import ConnectionDoesNotExist
 
 
 class ConnectionHandlerTests(SimpleTestCase):
@@ -41,7 +40,7 @@ class ConnectionHandlerTests(SimpleTestCase):
             conns['other'].ensure_connection()
 
     def test_nonexistent_alias(self):
-        msg = "The connection nonexistent doesn't exist"
+        msg = "The connection 'nonexistent' doesn't exist."
         conns = ConnectionHandler({
             DEFAULT_DB_ALIAS: {'ENGINE': 'django.db.backends.dummy'},
         })
@@ -49,7 +48,7 @@ class ConnectionHandlerTests(SimpleTestCase):
             conns['nonexistent']
 
     def test_ensure_defaults_nonexistent_alias(self):
-        msg = "The connection nonexistent doesn't exist"
+        msg = "The connection 'nonexistent' doesn't exist."
         conns = ConnectionHandler({
             DEFAULT_DB_ALIAS: {'ENGINE': 'django.db.backends.dummy'},
         })
@@ -57,7 +56,7 @@ class ConnectionHandlerTests(SimpleTestCase):
             conns.ensure_defaults('nonexistent')
 
     def test_prepare_test_settings_nonexistent_alias(self):
-        msg = "The connection nonexistent doesn't exist"
+        msg = "The connection 'nonexistent' doesn't exist."
         conns = ConnectionHandler({
             DEFAULT_DB_ALIAS: {'ENGINE': 'django.db.backends.dummy'},
         })

+ 10 - 0
tests/utils_tests/test_connection.py

@@ -0,0 +1,10 @@
+from django.test import SimpleTestCase
+from django.utils.connection import BaseConnectionHandler
+
+
+class BaseConnectionHandlerTests(SimpleTestCase):
+    def test_create_connection(self):
+        handler = BaseConnectionHandler()
+        msg = 'Subclasses must implement create_connection().'
+        with self.assertRaisesMessage(NotImplementedError, msg):
+            handler.create_connection(None)