浏览代码

Fixed #30171 -- Fixed DatabaseError in servers tests.

Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.

The error appeared after 8c775391b78b2a4a2b57c5e89ed4888f36aada4b.
Jon Dufresne 6 年之前
父节点
当前提交
76990cbbda

+ 24 - 14
django/db/backends/base/base.py

@@ -1,4 +1,5 @@
 import copy
+import threading
 import time
 import warnings
 from collections import deque
@@ -43,8 +44,7 @@ class BaseDatabaseWrapper:
 
     queries_limit = 9000
 
-    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
-                 allow_thread_sharing=False):
+    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
         # Connection related attributes.
         # The underlying database connection.
         self.connection = None
@@ -80,7 +80,8 @@ class BaseDatabaseWrapper:
         self.errors_occurred = False
 
         # Thread-safety related attributes.
-        self.allow_thread_sharing = allow_thread_sharing
+        self._thread_sharing_lock = threading.Lock()
+        self._thread_sharing_count = 0
         self._thread_ident = _thread.get_ident()
 
         # A list of no-argument functions to run when the transaction commits.
@@ -515,12 +516,27 @@ class BaseDatabaseWrapper:
 
     # ##### Thread safety handling #####
 
+    @property
+    def allow_thread_sharing(self):
+        with self._thread_sharing_lock:
+            return self._thread_sharing_count > 0
+
+    def inc_thread_sharing(self):
+        with self._thread_sharing_lock:
+            self._thread_sharing_count += 1
+
+    def dec_thread_sharing(self):
+        with self._thread_sharing_lock:
+            if self._thread_sharing_count <= 0:
+                raise RuntimeError('Cannot decrement the thread sharing count below zero.')
+            self._thread_sharing_count -= 1
+
     def validate_thread_sharing(self):
         """
         Validate that the connection isn't accessed by another thread than the
         one which originally created it, unless the connection was explicitly
-        authorized to be shared between threads (via the `allow_thread_sharing`
-        property). Raise an exception if the validation fails.
+        authorized to be shared between threads (via the `inc_thread_sharing()`
+        method). Raise an exception if the validation fails.
         """
         if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
             raise DatabaseError(
@@ -589,11 +605,7 @@ class BaseDatabaseWrapper:
         potential child threads while (or after) the test database is destroyed.
         Refs #10868, #17786, #16969.
         """
-        return self.__class__(
-            {**self.settings_dict, 'NAME': None},
-            alias=NO_DB_ALIAS,
-            allow_thread_sharing=False,
-        )
+        return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
 
     def schema_editor(self, *args, **kwargs):
         """
@@ -635,7 +647,7 @@ class BaseDatabaseWrapper:
         finally:
             self.execute_wrappers.pop()
 
-    def copy(self, alias=None, allow_thread_sharing=None):
+    def copy(self, alias=None):
         """
         Return a copy of this connection.
 
@@ -644,6 +656,4 @@ class BaseDatabaseWrapper:
         settings_dict = copy.deepcopy(self.settings_dict)
         if alias is None:
             alias = self.alias
-        if allow_thread_sharing is None:
-            allow_thread_sharing = self.allow_thread_sharing
-        return type(self)(settings_dict, alias, allow_thread_sharing)
+        return type(self)(settings_dict, alias)

+ 0 - 1
django/db/backends/postgresql/base.py

@@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                     return self.__class__(
                         {**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
                         alias=self.alias,
-                        allow_thread_sharing=False,
                     )
         return nodb_connection
 

+ 4 - 5
django/test/testcases.py

@@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase):
             # the server thread.
             if conn.vendor == 'sqlite' and conn.is_in_memory_db():
                 # Explicitly enable thread-shareability for this connection
-                conn.allow_thread_sharing = True
+                conn.inc_thread_sharing()
                 connections_override[conn.alias] = conn
 
         cls._live_server_modified_settings = modify_settings(
@@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase):
             # Terminate the live server's thread
             cls.server_thread.terminate()
 
-        # Restore sqlite in-memory database connections' non-shareability
-        for conn in connections.all():
-            if conn.vendor == 'sqlite' and conn.is_in_memory_db():
-                conn.allow_thread_sharing = False
+            # Restore sqlite in-memory database connections' non-shareability.
+            for conn in cls.server_thread.connections_override.values():
+                conn.dec_thread_sharing()
 
     @classmethod
     def tearDownClass(cls):

+ 3 - 0
docs/releases/2.2.txt

@@ -286,6 +286,9 @@ backends.
   * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``)
   * ``_create_check_sql()`` and ``_delete_check_sql()``
 
+* The third argument of ``DatabaseWrapper.__init__()``,
+  ``allow_thread_sharing``, is removed.
+
 Admin actions are no longer collected from base ``ModelAdmin`` classes
 ----------------------------------------------------------------------
 

+ 64 - 43
tests/backends/tests.py

@@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
             connection = connections[DEFAULT_DB_ALIAS]
             # Allow thread sharing so the connection can be closed by the
             # main thread.
-            connection.allow_thread_sharing = True
+            connection.inc_thread_sharing()
             connection.cursor()
             connections_dict[id(connection)] = connection
-        for x in range(2):
-            t = threading.Thread(target=runner)
-            t.start()
-            t.join()
-        # Each created connection got different inner connection.
-        self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
-        # Finish by closing the connections opened by the other threads (the
-        # connection opened in the main thread will automatically be closed on
-        # teardown).
-        for conn in connections_dict.values():
-            if conn is not connection:
-                conn.close()
+        try:
+            for x in range(2):
+                t = threading.Thread(target=runner)
+                t.start()
+                t.join()
+            # Each created connection got different inner connection.
+            self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
+        finally:
+            # Finish by closing the connections opened by the other threads
+            # (the connection opened in the main thread will automatically be
+            # closed on teardown).
+            for conn in connections_dict.values():
+                if conn is not connection:
+                    if conn.allow_thread_sharing:
+                        conn.close()
+                        conn.dec_thread_sharing()
 
     def test_connections_thread_local(self):
         """
@@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
             for conn in connections.all():
                 # Allow thread sharing so the connection can be closed by the
                 # main thread.
-                conn.allow_thread_sharing = True
+                conn.inc_thread_sharing()
                 connections_dict[id(conn)] = conn
-        for x in range(2):
-            t = threading.Thread(target=runner)
-            t.start()
-            t.join()
-        self.assertEqual(len(connections_dict), 6)
-        # Finish by closing the connections opened by the other threads (the
-        # connection opened in the main thread will automatically be closed on
-        # teardown).
-        for conn in connections_dict.values():
-            if conn is not connection:
-                conn.close()
+        try:
+            for x in range(2):
+                t = threading.Thread(target=runner)
+                t.start()
+                t.join()
+            self.assertEqual(len(connections_dict), 6)
+        finally:
+            # Finish by closing the connections opened by the other threads
+            # (the connection opened in the main thread will automatically be
+            # closed on teardown).
+            for conn in connections_dict.values():
+                if conn is not connection:
+                    if conn.allow_thread_sharing:
+                        conn.close()
+                        conn.dec_thread_sharing()
 
     def test_pass_connection_between_threads(self):
         """
@@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
             t.start()
             t.join()
 
-        # Without touching allow_thread_sharing, which should be False by default.
-        exceptions = []
-        do_thread()
-        # Forbidden!
-        self.assertIsInstance(exceptions[0], DatabaseError)
-
-        # If explicitly setting allow_thread_sharing to False
-        connections['default'].allow_thread_sharing = False
+        # Without touching thread sharing, which should be False by default.
         exceptions = []
         do_thread()
         # Forbidden!
         self.assertIsInstance(exceptions[0], DatabaseError)
 
-        # If explicitly setting allow_thread_sharing to True
-        connections['default'].allow_thread_sharing = True
-        exceptions = []
-        do_thread()
-        # All good
-        self.assertEqual(exceptions, [])
+        # After calling inc_thread_sharing() on the connection.
+        connections['default'].inc_thread_sharing()
+        try:
+            exceptions = []
+            do_thread()
+            # All good
+            self.assertEqual(exceptions, [])
+        finally:
+            connections['default'].dec_thread_sharing()
 
     def test_closing_non_shared_connections(self):
         """
@@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
                 except DatabaseError as e:
                     exceptions.add(e)
             # Enable thread sharing
-            connections['default'].allow_thread_sharing = True
-            t2 = threading.Thread(target=runner2, args=[connections['default']])
-            t2.start()
-            t2.join()
+            connections['default'].inc_thread_sharing()
+            try:
+                t2 = threading.Thread(target=runner2, args=[connections['default']])
+                t2.start()
+                t2.join()
+            finally:
+                connections['default'].dec_thread_sharing()
         t1 = threading.Thread(target=runner1)
         t1.start()
         t1.join()
         # No exception was raised
         self.assertEqual(len(exceptions), 0)
 
+    def test_thread_sharing_count(self):
+        self.assertIs(connection.allow_thread_sharing, False)
+        connection.inc_thread_sharing()
+        self.assertIs(connection.allow_thread_sharing, True)
+        connection.inc_thread_sharing()
+        self.assertIs(connection.allow_thread_sharing, True)
+        connection.dec_thread_sharing()
+        self.assertIs(connection.allow_thread_sharing, True)
+        connection.dec_thread_sharing()
+        self.assertIs(connection.allow_thread_sharing, False)
+        msg = 'Cannot decrement the thread sharing count below zero.'
+        with self.assertRaisesMessage(RuntimeError, msg):
+            connection.dec_thread_sharing()
+
 
 class MySQLPKZeroTests(TestCase):
     """

+ 2 - 3
tests/servers/test_liveserverthread.py

@@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase):
         # Pass a connection to the thread to check they are being closed.
         connections_override = {DEFAULT_DB_ALIAS: conn}
 
-        saved_sharing = conn.allow_thread_sharing
+        conn.inc_thread_sharing()
         try:
-            conn.allow_thread_sharing = True
             self.assertTrue(conn.is_usable())
             self.run_live_server_thread(connections_override)
             self.assertFalse(conn.is_usable())
         finally:
-            conn.allow_thread_sharing = saved_sharing
+            conn.dec_thread_sharing()

+ 3 - 0
tests/staticfiles_tests/test_liveserver.py

@@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase):
             # app without having set the required STATIC_URL setting.")
             pass
         finally:
+            # Use del to avoid decrementing the database thread sharing count a
+            # second time.
+            del cls.server_thread
             super().tearDownClass()
 
     def test_test_test(self):