Browse Source

Fixed #33161 -- Enabled durability check for nested atomic blocks in TestCase.

Co-Authored-By: Adam Johnson <me@adamj.eu>
Krzysztof Jagiello 3 years ago
parent
commit
8d9827c06c

+ 1 - 0
AUTHORS

@@ -544,6 +544,7 @@ answer newbie questions, and generally made Django that much better:
     Kowito Charoenratchatabhan <kowito@felspar.com>
     Krišjānis Vaiders <krisjanisvaiders@gmail.com>
     krzysiek.pawlik@silvermedia.pl
+    Krzysztof Jagiello <me@kjagiello.com>
     Krzysztof Jurewicz <krzysztof.jurewicz@gmail.com>
     Krzysztof Kulewski <kulewski@gmail.com>
     kurtiss@meetro.com

+ 3 - 0
django/db/backends/base/base.py

@@ -79,6 +79,8 @@ class BaseDatabaseWrapper:
         self.savepoint_state = 0
         # List of savepoints created by 'atomic'.
         self.savepoint_ids = []
+        # Stack of active 'atomic' blocks.
+        self.atomic_blocks = []
         # Tracks if the outermost 'atomic' block should commit on exit,
         # ie. if autocommit was active on entry.
         self.commit_on_exit = True
@@ -200,6 +202,7 @@ class BaseDatabaseWrapper:
         # In case the previous connection was closed while in an atomic block
         self.in_atomic_block = False
         self.savepoint_ids = []
+        self.atomic_blocks = []
         self.needs_rollback = False
         # Reset parameters defining when to close the connection
         max_age = self.settings_dict['CONN_MAX_AGE']

+ 12 - 4
django/db/transaction.py

@@ -165,19 +165,21 @@ class Atomic(ContextDecorator):
 
     This is a private API.
     """
-    # This private flag is provided only to disable the durability checks in
-    # TestCase.
-    _ensure_durability = True
 
     def __init__(self, using, savepoint, durable):
         self.using = using
         self.savepoint = savepoint
         self.durable = durable
+        self._from_testcase = False
 
     def __enter__(self):
         connection = get_connection(self.using)
 
-        if self.durable and self._ensure_durability and connection.in_atomic_block:
+        if (
+            self.durable and
+            connection.atomic_blocks and
+            not connection.atomic_blocks[-1]._from_testcase
+        ):
             raise RuntimeError(
                 'A durable atomic block cannot be nested within another '
                 'atomic block.'
@@ -207,9 +209,15 @@ class Atomic(ContextDecorator):
             connection.set_autocommit(False, force_begin_transaction_with_broken_autocommit=True)
             connection.in_atomic_block = True
 
+        if connection.in_atomic_block:
+            connection.atomic_blocks.append(self)
+
     def __exit__(self, exc_type, exc_value, traceback):
         connection = get_connection(self.using)
 
+        if connection.in_atomic_block:
+            connection.atomic_blocks.pop()
+
         if connection.savepoint_ids:
             sid = connection.savepoint_ids.pop()
         else:

+ 19 - 25
django/test/testcases.py

@@ -1146,8 +1146,10 @@ class TestCase(TransactionTestCase):
         """Open atomic blocks for multiple databases."""
         atomics = {}
         for db_name in cls._databases_names():
-            atomics[db_name] = transaction.atomic(using=db_name)
-            atomics[db_name].__enter__()
+            atomic = transaction.atomic(using=db_name)
+            atomic._from_testcase = True
+            atomic.__enter__()
+            atomics[db_name] = atomic
         return atomics
 
     @classmethod
@@ -1166,35 +1168,27 @@ class TestCase(TransactionTestCase):
         super().setUpClass()
         if not cls._databases_support_transactions():
             return
-        # Disable the durability check to allow testing durable atomic blocks
-        # in a transaction for performance reasons.
-        transaction.Atomic._ensure_durability = False
+        cls.cls_atomics = cls._enter_atomics()
+
+        if cls.fixtures:
+            for db_name in cls._databases_names(include_mirrors=False):
+                try:
+                    call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
+                except Exception:
+                    cls._rollback_atomics(cls.cls_atomics)
+                    raise
+        pre_attrs = cls.__dict__.copy()
         try:
-            cls.cls_atomics = cls._enter_atomics()
-
-            if cls.fixtures:
-                for db_name in cls._databases_names(include_mirrors=False):
-                    try:
-                        call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
-                    except Exception:
-                        cls._rollback_atomics(cls.cls_atomics)
-                        raise
-            pre_attrs = cls.__dict__.copy()
-            try:
-                cls.setUpTestData()
-            except Exception:
-                cls._rollback_atomics(cls.cls_atomics)
-                raise
-            for name, value in cls.__dict__.items():
-                if value is not pre_attrs.get(name):
-                    setattr(cls, name, TestData(name, value))
+            cls.setUpTestData()
         except Exception:
-            transaction.Atomic._ensure_durability = True
+            cls._rollback_atomics(cls.cls_atomics)
             raise
+        for name, value in cls.__dict__.items():
+            if value is not pre_attrs.get(name):
+                setattr(cls, name, TestData(name, value))
 
     @classmethod
     def tearDownClass(cls):
-        transaction.Atomic._ensure_durability = True
         if cls._databases_support_transactions():
             cls._rollback_atomics(cls.cls_atomics)
             for conn in connections.all():

+ 2 - 1
docs/releases/4.1.txt

@@ -215,7 +215,8 @@ Templates
 Tests
 ~~~~~
 
-* ...
+* A nested atomic block marked as durable in :class:`django.test.TestCase` now
+  raises a ``RuntimeError``, the same as outside of tests.
 
 URLs
 ~~~~

+ 3 - 4
docs/topics/db/transactions.txt

@@ -238,11 +238,10 @@ Django provides a single API to control database transactions.
     is especially important if you're using :func:`atomic` in long-running
     processes, outside of Django's request / response cycle.
 
-.. warning::
+.. versionchanged:: 4.1
 
-    :class:`django.test.TestCase` disables the durability check to allow
-    testing durable atomic blocks in a transaction for performance reasons. Use
-    :class:`django.test.TransactionTestCase` for testing durability.
+    In older versions, the durability check was disabled in
+    :class:`django.test.TestCase`.
 
 Autocommit
 ==========

+ 12 - 36
tests/transactions/tests.py

@@ -501,7 +501,7 @@ class NonAutocommitTests(TransactionTestCase):
         Reporter.objects.create(first_name="Tintin")
 
 
-class DurableTests(TransactionTestCase):
+class DurableTestsBase:
     available_apps = ['transactions']
 
     def test_commit(self):
@@ -533,42 +533,18 @@ class DurableTests(TransactionTestCase):
                 with transaction.atomic(durable=True):
                     pass
 
-
-class DisableDurabiltityCheckTests(TestCase):
-    """
-    TestCase runs all tests in a transaction by default. Code using
-    durable=True would always fail when run from TestCase. This would mean
-    these tests would be forced to use the slower TransactionTestCase even when
-    not testing durability. For this reason, TestCase disables the durability
-    check.
-    """
-    available_apps = ['transactions']
-
-    def test_commit(self):
+    def test_sequence_of_durables(self):
         with transaction.atomic(durable=True):
-            reporter = Reporter.objects.create(first_name='Tintin')
-        self.assertEqual(Reporter.objects.get(), reporter)
-
-    def test_nested_outer_durable(self):
+            reporter = Reporter.objects.create(first_name='Tintin 1')
+        self.assertEqual(Reporter.objects.get(first_name='Tintin 1'), reporter)
         with transaction.atomic(durable=True):
-            reporter1 = Reporter.objects.create(first_name='Tintin')
-            with transaction.atomic():
-                reporter2 = Reporter.objects.create(
-                    first_name='Archibald',
-                    last_name='Haddock',
-                )
-        self.assertSequenceEqual(Reporter.objects.all(), [reporter2, reporter1])
+            reporter = Reporter.objects.create(first_name='Tintin 2')
+        self.assertEqual(Reporter.objects.get(first_name='Tintin 2'), reporter)
 
-    def test_nested_both_durable(self):
-        with transaction.atomic(durable=True):
-            # Error is not raised.
-            with transaction.atomic(durable=True):
-                reporter = Reporter.objects.create(first_name='Tintin')
-        self.assertEqual(Reporter.objects.get(), reporter)
 
-    def test_nested_inner_durable(self):
-        with transaction.atomic():
-            # Error is not raised.
-            with transaction.atomic(durable=True):
-                reporter = Reporter.objects.create(first_name='Tintin')
-        self.assertEqual(Reporter.objects.get(), reporter)
+class DurableTransactionTests(DurableTestsBase, TransactionTestCase):
+    pass
+
+
+class DurableTests(DurableTestsBase, TestCase):
+    pass