Browse Source

Fixed #21134 -- Prevented queries in broken transactions.

Squashed commit of the following:

commit 63ddb271a44df389b2c302e421fc17b7f0529755
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 29 22:51:00 2013 +0200

    Clarified interactions between atomic and exceptions.

commit 2899ec299228217c876ba3aa4024e523a41c8504
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:45:32 2013 +0200

    Fixed TransactionManagementError in tests.

    Previous commit introduced an additional check to prevent running
    queries in transactions that will be rolled back, which triggered a few
    failures in the tests. In practice using transaction.atomic instead of
    the low-level savepoint APIs was enough to fix the problems.

commit 4a639b059ea80aeb78f7f160a7d4b9f609b9c238
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Tue Sep 24 22:24:17 2013 +0200

    Allowed nesting constraint_checks_disabled inside atomic.

    Since MySQL handles transactions loosely, this isn't a problem.

commit 2a4ab1cb6e83391ff7e25d08479e230ca564bfef
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sat Sep 21 18:43:12 2013 +0200

    Prevented running queries in transactions that will be rolled back.

    This avoids a counter-intuitive behavior in an edge case on databases
    with non-atomic transaction semantics.

    It prevents using savepoint_rollback() inside an atomic block without
    calling set_rollback(False) first, which is backwards-incompatible in
    tests.

    Refs #21134.

commit 8e3db393853c7ac64a445b66e57f3620a3fde7b0
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:14:17 2013 +0200

    Replaced manual savepoints by atomic blocks.

    This ensures the rollback flag is handled consistently in internal APIs.
Aymeric Augustin 11 years ago
parent
commit
728548e483

+ 2 - 3
django/contrib/sessions/backends/db.py

@@ -58,12 +58,11 @@ class SessionStore(SessionBase):
             expire_date=self.get_expiry_date()
         )
         using = router.db_for_write(Session, instance=obj)
-        sid = transaction.savepoint(using=using)
         try:
-            obj.save(force_insert=must_create, using=using)
+            with transaction.atomic(using=using):
+                obj.save(force_insert=must_create, using=using)
         except IntegrityError:
             if must_create:
-                transaction.savepoint_rollback(sid, using=using)
                 raise CreateError
             raise
 

+ 9 - 0
django/db/backends/__init__.py

@@ -361,6 +361,12 @@ class BaseDatabaseWrapper(object):
             raise TransactionManagementError(
                 "This is forbidden when an 'atomic' block is active.")
 
+    def validate_no_broken_transaction(self):
+        if self.needs_rollback:
+            raise TransactionManagementError(
+                "An error occurred in the current transaction. You can't "
+                "execute queries until the end of the 'atomic' block.")
+
     def abort(self):
         """
         Roll back any ongoing transaction and clean the transaction state
@@ -638,6 +644,9 @@ class BaseDatabaseFeatures(object):
     # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965
     autocommits_when_autocommit_is_off = False
 
+    # Does the backend prevent running SQL queries in broken transactions?
+    atomic_transactions = True
+
     # Can we roll back DDL in a transaction?
     can_rollback_ddl = False
 

+ 8 - 1
django/db/backends/mysql/base.py

@@ -172,6 +172,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     requires_explicit_null_ordering_when_grouping = True
     allows_primary_key_0 = False
     uses_savepoints = True
+    atomic_transactions = False
     supports_check_constraints = False
 
     def __init__(self, connection):
@@ -484,7 +485,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         """
         Re-enable foreign key checks after they have been disabled.
         """
-        self.cursor().execute('SET foreign_key_checks=1')
+        # Override needs_rollback in case constraint_checks_disabled is
+        # nested inside transaction.atomic.
+        self.needs_rollback, needs_rollback = False, self.needs_rollback
+        try:
+            self.cursor().execute('SET foreign_key_checks=1')
+        finally:
+            self.needs_rollback = needs_rollback
 
     def check_constraints(self, table_names=None):
         """

+ 1 - 0
django/db/backends/oracle/base.py

@@ -96,6 +96,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     has_bulk_insert = True
     supports_tablespaces = True
     supports_sequence_reset = False
+    atomic_transactions = False
     supports_combined_alters = False
     max_index_name_length = 30
     nulls_order_largest = True

+ 1 - 0
django/db/backends/sqlite3/base.py

@@ -105,6 +105,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_foreign_keys = False
     supports_check_constraints = False
     autocommits_when_autocommit_is_off = True
+    atomic_transactions = False
     supports_paramstyle_pyformat = False
     supports_sequence_reset = False
 

+ 32 - 15
django/db/backends/utils.py

@@ -19,14 +19,9 @@ class CursorWrapper(object):
         self.cursor = cursor
         self.db = db
 
-    SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc'])
-    WRAP_ERROR_ATTRS = frozenset([
-        'callproc', 'close', 'execute', 'executemany',
-        'fetchone', 'fetchmany', 'fetchall', 'nextset'])
+    WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
 
     def __getattr__(self, attr):
-        if attr in CursorWrapper.SET_DIRTY_ATTRS:
-            self.db.set_dirty()
         cursor_attr = getattr(self.cursor, attr)
         if attr in CursorWrapper.WRAP_ERROR_ATTRS:
             return self.db.wrap_database_errors(cursor_attr)
@@ -44,18 +39,42 @@ class CursorWrapper(object):
         # specific behavior.
         self.close()
 
+    # The following methods cannot be implemented in __getattr__, because the
+    # code must run when the method is invoked, not just when it is accessed.
 
-class CursorDebugWrapper(CursorWrapper):
+    def callproc(self, procname, params=None):
+        self.db.validate_no_broken_transaction()
+        self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            if params is None:
+                return self.cursor.callproc(procname)
+            else:
+                return self.cursor.callproc(procname, params)
 
     def execute(self, sql, params=None):
+        self.db.validate_no_broken_transaction()
         self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            if params is None:
+                return self.cursor.execute(sql)
+            else:
+                return self.cursor.execute(sql, params)
+
+    def executemany(self, sql, param_list):
+        self.db.validate_no_broken_transaction()
+        self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            return self.cursor.executemany(sql, param_list)
+
+
+class CursorDebugWrapper(CursorWrapper):
+
+    # XXX callproc isn't instrumented at this time.
+
+    def execute(self, sql, params=None):
         start = time()
         try:
-            with self.db.wrap_database_errors:
-                if params is None:
-                    # params default might be backend specific
-                    return self.cursor.execute(sql)
-                return self.cursor.execute(sql, params)
+            return super(CursorDebugWrapper, self).execute(sql, params)
         finally:
             stop = time()
             duration = stop - start
@@ -69,11 +88,9 @@ class CursorDebugWrapper(CursorWrapper):
             )
 
     def executemany(self, sql, param_list):
-        self.db.set_dirty()
         start = time()
         try:
-            with self.db.wrap_database_errors:
-                return self.cursor.executemany(sql, param_list)
+            return super(CursorDebugWrapper, self).executemany(sql, param_list)
         finally:
             stop = time()
             duration = stop - start

+ 9 - 17
django/db/models/query.py

@@ -436,14 +436,9 @@ class QuerySet(object):
         for k, v in six.iteritems(defaults):
             setattr(obj, k, v)
 
-        sid = transaction.savepoint(using=self.db)
-        try:
+        with transaction.atomic(using=self.db):
             obj.save(using=self.db)
-            transaction.savepoint_commit(sid, using=self.db)
-            return obj, False
-        except DatabaseError:
-            transaction.savepoint_rollback(sid, using=self.db)
-            six.reraise(*sys.exc_info())
+        return obj, False
 
     def _create_object_from_params(self, lookup, params):
         """
@@ -451,19 +446,16 @@ class QuerySet(object):
         Used by get_or_create and update_or_create
         """
         obj = self.model(**params)
-        sid = transaction.savepoint(using=self.db)
         try:
-            obj.save(force_insert=True, using=self.db)
-            transaction.savepoint_commit(sid, using=self.db)
+            with transaction.atomic(using=self.db):
+                obj.save(force_insert=True, using=self.db)
             return obj, True
-        except DatabaseError as e:
-            transaction.savepoint_rollback(sid, using=self.db)
+        except IntegrityError:
             exc_info = sys.exc_info()
-            if isinstance(e, IntegrityError):
-                try:
-                    return self.get(**lookup), False
-                except self.model.DoesNotExist:
-                    pass
+            try:
+                return self.get(**lookup), False
+            except self.model.DoesNotExist:
+                pass
             six.reraise(*exc_info)
 
     def _extract_model_params(self, defaults, **kwargs):

+ 5 - 4
django/db/transaction.py

@@ -16,14 +16,15 @@ import warnings
 
 from functools import wraps
 
-from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS
+from django.db import (
+        connections, DEFAULT_DB_ALIAS,
+        DatabaseError, ProgrammingError)
 from django.utils.decorators import available_attrs
 
 
-class TransactionManagementError(Exception):
+class TransactionManagementError(ProgrammingError):
     """
-    This exception is thrown when something bad happens with transaction
-    management.
+    This exception is thrown when transaction management is used improperly.
     """
     pass
 

+ 20 - 9
docs/topics/db/transactions.txt

@@ -163,20 +163,31 @@ Django provides a single API to control database transactions.
     called, so the exception handler can also operate on the database if
     necessary.
 
-    .. admonition:: Don't catch database exceptions inside ``atomic``!
-
-        If you catch :exc:`~django.db.DatabaseError` or a subclass such as
-        :exc:`~django.db.IntegrityError` inside an ``atomic`` block, you will
-        hide from Django the fact that an error has occurred and that the
-        transaction is broken. At this point, Django's behavior is unspecified
-        and database-dependent. It will usually result in a rollback, which
-        may break your expectations, since you caught the exception.
+    .. admonition:: Avoid catching exceptions inside ``atomic``!
+
+        When exiting an ``atomic`` block, Django looks at whether it's exited
+        normally or with an exception to determine whether to commit or roll
+        back. If you catch and handle exceptions inside an ``atomic`` block,
+        you may hide from Django the fact that a problem has happened. This
+        can result in unexpected behavior.
+
+        This is mostly a concern for :exc:`~django.db.DatabaseError` and its
+        subclasses such as :exc:`~django.db.IntegrityError`. After such an
+        error, the transaction is broken and Django will perform a rollback at
+        the end of the ``atomic`` block. If you attempt to run database
+        queries before the rollback happens, Django will raise a
+        :class:`~django.db.transaction.TransactionManagementError`. You may
+        also encounter this behavior when an ORM-related signal handler raises
+        an exception.
 
         The correct way to catch database errors is around an ``atomic`` block
         as shown above. If necessary, add an extra ``atomic`` block for this
-        purpose -- it's cheap! This pattern is useful to delimit explicitly
+        purpose. This pattern has another advantage: it delimits explicitly
         which operations will be rolled back if an exception occurs.
 
+        If you catch exceptions raised by raw SQL queries, Django's behavior
+        is unspecified and database-dependent.
+
     In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting
     to commit, roll back, or change the autocommit state of the database
     connection within an ``atomic`` block will raise an exception.

+ 6 - 10
tests/custom_pk/tests.py

@@ -149,11 +149,9 @@ class CustomPKTests(TestCase):
         Employee.objects.create(
             employee_code=123, first_name="Frank", last_name="Jones"
         )
-        sid = transaction.savepoint()
-        self.assertRaises(IntegrityError,
-            Employee.objects.create, employee_code=123, first_name="Fred", last_name="Jones"
-        )
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones")
 
     def test_custom_field_pk(self):
         # Regression for #10785 -- Custom fields can be used for primary keys.
@@ -175,8 +173,6 @@ class CustomPKTests(TestCase):
     def test_required_pk(self):
         # The primary key must be specified, so an error is raised if you
         # try to create an object without it.
-        sid = transaction.savepoint()
-        self.assertRaises(IntegrityError,
-            Employee.objects.create, first_name="Tom", last_name="Smith"
-        )
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                Employee.objects.create(first_name="Tom", last_name="Smith")

+ 6 - 5
tests/expressions/tests.py

@@ -2,6 +2,7 @@ from __future__ import unicode_literals
 
 from django.core.exceptions import FieldError
 from django.db.models import F
+from django.db import transaction
 from django.test import TestCase
 from django.utils import six
 
@@ -185,11 +186,11 @@ class ExpressionsTests(TestCase):
             "foo",
         )
 
-        self.assertRaises(FieldError,
-            lambda: Company.objects.exclude(
-                ceo__firstname=F('point_of_contact__firstname')
-            ).update(name=F('point_of_contact__lastname'))
-        )
+        with transaction.atomic():
+            with self.assertRaises(FieldError):
+                Company.objects.exclude(
+                    ceo__firstname=F('point_of_contact__firstname')
+                ).update(name=F('point_of_contact__lastname'))
 
         # F expressions can be used to update attributes on single objects
         test_gmbh = Company.objects.get(name="Test GmbH")

+ 11 - 6
tests/force_insert_update/tests.py

@@ -21,24 +21,29 @@ class ForceTests(TestCase):
         # Won't work because force_update and force_insert are mutually
         # exclusive
         c.value = 4
-        self.assertRaises(ValueError, c.save, force_insert=True, force_update=True)
+        with self.assertRaises(ValueError):
+            c.save(force_insert=True, force_update=True)
 
         # Try to update something that doesn't have a primary key in the first
         # place.
         c1 = Counter(name="two", value=2)
-        self.assertRaises(ValueError, c1.save, force_update=True)
+        with self.assertRaises(ValueError):
+            with transaction.atomic():
+                c1.save(force_update=True)
         c1.save(force_insert=True)
 
         # Won't work because we can't insert a pk of the same value.
-        sid = transaction.savepoint()
         c.value = 5
-        self.assertRaises(IntegrityError, c.save, force_insert=True)
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                c.save(force_insert=True)
 
         # Trying to update should still fail, even with manual primary keys, if
         # the data isn't in the database already.
         obj = WithCustomPK(name=1, value=1)
-        self.assertRaises(DatabaseError, obj.save, force_update=True)
+        with self.assertRaises(DatabaseError):
+            with transaction.atomic():
+                obj.save(force_update=True)
 
 
 class InheritanceTests(TestCase):

+ 3 - 3
tests/one_to_one/tests.py

@@ -118,7 +118,7 @@ class OneToOneTests(TestCase):
         self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>')
         # This will fail because each one-to-one field must be unique (and
         # link2=o1 was used for x1, above).
-        sid = transaction.savepoint()
         mm = MultiModel(link1=self.p2, link2=o1, name="x1")
-        self.assertRaises(IntegrityError, mm.save)
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                mm.save()

+ 44 - 27
tests/transactions/tests.py

@@ -4,7 +4,7 @@ import sys
 from unittest import skipIf, skipUnless
 
 from django.db import connection, transaction, DatabaseError, IntegrityError
-from django.test import TransactionTestCase, skipUnlessDBFeature
+from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import IgnoreDeprecationWarningsMixin
 from django.utils import six
 
@@ -204,10 +204,10 @@ class AtomicTests(TransactionTestCase):
                 with transaction.atomic(savepoint=False):
                     connection.cursor().execute(
                             "SELECT no_such_col FROM transactions_reporter")
-            transaction.savepoint_rollback(sid)
-            # atomic block should rollback, but prevent it, as we just did it.
+            # prevent atomic from rolling back since we're recovering manually
             self.assertTrue(transaction.get_rollback())
             transaction.set_rollback(False)
+            transaction.savepoint_rollback(sid)
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
 
 
@@ -267,11 +267,19 @@ class AtomicMergeTests(TransactionTestCase):
                     with transaction.atomic(savepoint=False):
                         Reporter.objects.create(first_name="Calculus")
                         raise Exception("Oops, that's his last name")
-                # It wasn't possible to roll back
+                # The third insert couldn't be roll back. Temporarily mark the
+                # connection as not needing rollback to check it.
+                self.assertTrue(transaction.get_rollback())
+                transaction.set_rollback(False)
                 self.assertEqual(Reporter.objects.count(), 3)
-            # It wasn't possible to roll back
+                transaction.set_rollback(True)
+            # The second insert couldn't be roll back. Temporarily mark the
+            # connection as not needing rollback to check it.
+            self.assertTrue(transaction.get_rollback())
+            transaction.set_rollback(False)
             self.assertEqual(Reporter.objects.count(), 3)
-        # The outer block must roll back
+            transaction.set_rollback(True)
+        # The first block has a savepoint and must roll back.
         self.assertQuerysetEqual(Reporter.objects.all(), [])
 
     def test_merged_inner_savepoint_rollback(self):
@@ -283,36 +291,22 @@ class AtomicMergeTests(TransactionTestCase):
                     with transaction.atomic(savepoint=False):
                         Reporter.objects.create(first_name="Calculus")
                         raise Exception("Oops, that's his last name")
-                # It wasn't possible to roll back
+                # The third insert couldn't be roll back. Temporarily mark the
+                # connection as not needing rollback to check it.
+                self.assertTrue(transaction.get_rollback())
+                transaction.set_rollback(False)
                 self.assertEqual(Reporter.objects.count(), 3)
-            # The first block with a savepoint must roll back
+                transaction.set_rollback(True)
+            # The second block has a savepoint and must roll back.
             self.assertEqual(Reporter.objects.count(), 1)
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
 
-    def test_merged_outer_rollback_after_inner_failure_and_inner_success(self):
-        with transaction.atomic():
-            Reporter.objects.create(first_name="Tintin")
-            # Inner block without a savepoint fails
-            with six.assertRaisesRegex(self, Exception, "Oops"):
-                with transaction.atomic(savepoint=False):
-                    Reporter.objects.create(first_name="Haddock")
-                    raise Exception("Oops, that's his last name")
-            # It wasn't possible to roll back
-            self.assertEqual(Reporter.objects.count(), 2)
-            # Inner block with a savepoint succeeds
-            with transaction.atomic(savepoint=False):
-                Reporter.objects.create(first_name="Archibald", last_name="Haddock")
-            # It still wasn't possible to roll back
-            self.assertEqual(Reporter.objects.count(), 3)
-        # The outer block must rollback
-        self.assertQuerysetEqual(Reporter.objects.all(), [])
-
 
 @skipUnless(connection.features.uses_savepoints,
         "'atomic' requires transactions and savepoints.")
 class AtomicErrorsTests(TransactionTestCase):
 
-    available_apps = []
+    available_apps = ['transactions']
 
     def test_atomic_prevents_setting_autocommit(self):
         autocommit = transaction.get_autocommit()
@@ -336,6 +330,29 @@ class AtomicErrorsTests(TransactionTestCase):
             with self.assertRaises(transaction.TransactionManagementError):
                 transaction.leave_transaction_management()
 
+    def test_atomic_prevents_queries_in_broken_transaction(self):
+        r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
+        with transaction.atomic():
+            r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
+            with self.assertRaises(IntegrityError):
+                r2.save(force_insert=True)
+            # The transaction is marked as needing rollback.
+            with self.assertRaises(transaction.TransactionManagementError):
+                r2.save(force_update=True)
+        self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock")
+
+    @skipIfDBFeature('atomic_transactions')
+    def test_atomic_allows_queries_after_fixing_transaction(self):
+        r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
+        with transaction.atomic():
+            r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
+            with self.assertRaises(IntegrityError):
+                r2.save(force_insert=True)
+            # Mark the transaction as no longer needing rollback.
+            transaction.set_rollback(False)
+            r2.save(force_update=True)
+        self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus")
+
 
 class AtomicMiscTests(TransactionTestCase):