Browse Source

Fixed #21134 -- Prevented queries in broken transactions.

Squashed commit of the following:

commit 63ddb271a44df389b2c302e421fc17b7f0529755
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 29 22:51:00 2013 +0200

    Clarified interactions between atomic and exceptions.

commit 2899ec299228217c876ba3aa4024e523a41c8504
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:45:32 2013 +0200

    Fixed TransactionManagementError in tests.

    Previous commit introduced an additional check to prevent running
    queries in transactions that will be rolled back, which triggered a few
    failures in the tests. In practice using transaction.atomic instead of
    the low-level savepoint APIs was enough to fix the problems.

commit 4a639b059ea80aeb78f7f160a7d4b9f609b9c238
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Tue Sep 24 22:24:17 2013 +0200

    Allowed nesting constraint_checks_disabled inside atomic.

    Since MySQL handles transactions loosely, this isn't a problem.

commit 2a4ab1cb6e83391ff7e25d08479e230ca564bfef
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sat Sep 21 18:43:12 2013 +0200

    Prevented running queries in transactions that will be rolled back.

    This avoids a counter-intuitive behavior in an edge case on databases
    with non-atomic transaction semantics.

    It prevents using savepoint_rollback() inside an atomic block without
    calling set_rollback(False) first, which is backwards-incompatible in
    tests.

    Refs #21134.

commit 8e3db393853c7ac64a445b66e57f3620a3fde7b0
Author: Aymeric Augustin <aymeric.augustin@m4x.org>
Date:   Sun Sep 22 22:14:17 2013 +0200

    Replaced manual savepoints by atomic blocks.

    This ensures the rollback flag is handled consistently in internal APIs.
Aymeric Augustin 11 years ago
parent
commit
728548e483

+ 2 - 3
django/contrib/sessions/backends/db.py

@@ -58,12 +58,11 @@ class SessionStore(SessionBase):
             expire_date=self.get_expiry_date()
             expire_date=self.get_expiry_date()
         )
         )
         using = router.db_for_write(Session, instance=obj)
         using = router.db_for_write(Session, instance=obj)
-        sid = transaction.savepoint(using=using)
         try:
         try:
-            obj.save(force_insert=must_create, using=using)
+            with transaction.atomic(using=using):
+                obj.save(force_insert=must_create, using=using)
         except IntegrityError:
         except IntegrityError:
             if must_create:
             if must_create:
-                transaction.savepoint_rollback(sid, using=using)
                 raise CreateError
                 raise CreateError
             raise
             raise
 
 

+ 9 - 0
django/db/backends/__init__.py

@@ -361,6 +361,12 @@ class BaseDatabaseWrapper(object):
             raise TransactionManagementError(
             raise TransactionManagementError(
                 "This is forbidden when an 'atomic' block is active.")
                 "This is forbidden when an 'atomic' block is active.")
 
 
+    def validate_no_broken_transaction(self):
+        if self.needs_rollback:
+            raise TransactionManagementError(
+                "An error occurred in the current transaction. You can't "
+                "execute queries until the end of the 'atomic' block.")
+
     def abort(self):
     def abort(self):
         """
         """
         Roll back any ongoing transaction and clean the transaction state
         Roll back any ongoing transaction and clean the transaction state
@@ -638,6 +644,9 @@ class BaseDatabaseFeatures(object):
     # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965
     # when autocommit is disabled? http://bugs.python.org/issue8145#msg109965
     autocommits_when_autocommit_is_off = False
     autocommits_when_autocommit_is_off = False
 
 
+    # Does the backend prevent running SQL queries in broken transactions?
+    atomic_transactions = True
+
     # Can we roll back DDL in a transaction?
     # Can we roll back DDL in a transaction?
     can_rollback_ddl = False
     can_rollback_ddl = False
 
 

+ 8 - 1
django/db/backends/mysql/base.py

@@ -172,6 +172,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     requires_explicit_null_ordering_when_grouping = True
     requires_explicit_null_ordering_when_grouping = True
     allows_primary_key_0 = False
     allows_primary_key_0 = False
     uses_savepoints = True
     uses_savepoints = True
+    atomic_transactions = False
     supports_check_constraints = False
     supports_check_constraints = False
 
 
     def __init__(self, connection):
     def __init__(self, connection):
@@ -484,7 +485,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         """
         """
         Re-enable foreign key checks after they have been disabled.
         Re-enable foreign key checks after they have been disabled.
         """
         """
-        self.cursor().execute('SET foreign_key_checks=1')
+        # Override needs_rollback in case constraint_checks_disabled is
+        # nested inside transaction.atomic.
+        self.needs_rollback, needs_rollback = False, self.needs_rollback
+        try:
+            self.cursor().execute('SET foreign_key_checks=1')
+        finally:
+            self.needs_rollback = needs_rollback
 
 
     def check_constraints(self, table_names=None):
     def check_constraints(self, table_names=None):
         """
         """

+ 1 - 0
django/db/backends/oracle/base.py

@@ -96,6 +96,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     has_bulk_insert = True
     has_bulk_insert = True
     supports_tablespaces = True
     supports_tablespaces = True
     supports_sequence_reset = False
     supports_sequence_reset = False
+    atomic_transactions = False
     supports_combined_alters = False
     supports_combined_alters = False
     max_index_name_length = 30
     max_index_name_length = 30
     nulls_order_largest = True
     nulls_order_largest = True

+ 1 - 0
django/db/backends/sqlite3/base.py

@@ -105,6 +105,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_foreign_keys = False
     supports_foreign_keys = False
     supports_check_constraints = False
     supports_check_constraints = False
     autocommits_when_autocommit_is_off = True
     autocommits_when_autocommit_is_off = True
+    atomic_transactions = False
     supports_paramstyle_pyformat = False
     supports_paramstyle_pyformat = False
     supports_sequence_reset = False
     supports_sequence_reset = False
 
 

+ 32 - 15
django/db/backends/utils.py

@@ -19,14 +19,9 @@ class CursorWrapper(object):
         self.cursor = cursor
         self.cursor = cursor
         self.db = db
         self.db = db
 
 
-    SET_DIRTY_ATTRS = frozenset(['execute', 'executemany', 'callproc'])
-    WRAP_ERROR_ATTRS = frozenset([
-        'callproc', 'close', 'execute', 'executemany',
-        'fetchone', 'fetchmany', 'fetchall', 'nextset'])
+    WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
 
 
     def __getattr__(self, attr):
     def __getattr__(self, attr):
-        if attr in CursorWrapper.SET_DIRTY_ATTRS:
-            self.db.set_dirty()
         cursor_attr = getattr(self.cursor, attr)
         cursor_attr = getattr(self.cursor, attr)
         if attr in CursorWrapper.WRAP_ERROR_ATTRS:
         if attr in CursorWrapper.WRAP_ERROR_ATTRS:
             return self.db.wrap_database_errors(cursor_attr)
             return self.db.wrap_database_errors(cursor_attr)
@@ -44,18 +39,42 @@ class CursorWrapper(object):
         # specific behavior.
         # specific behavior.
         self.close()
         self.close()
 
 
+    # The following methods cannot be implemented in __getattr__, because the
+    # code must run when the method is invoked, not just when it is accessed.
 
 
-class CursorDebugWrapper(CursorWrapper):
+    def callproc(self, procname, params=None):
+        self.db.validate_no_broken_transaction()
+        self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            if params is None:
+                return self.cursor.callproc(procname)
+            else:
+                return self.cursor.callproc(procname, params)
 
 
     def execute(self, sql, params=None):
     def execute(self, sql, params=None):
+        self.db.validate_no_broken_transaction()
         self.db.set_dirty()
         self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            if params is None:
+                return self.cursor.execute(sql)
+            else:
+                return self.cursor.execute(sql, params)
+
+    def executemany(self, sql, param_list):
+        self.db.validate_no_broken_transaction()
+        self.db.set_dirty()
+        with self.db.wrap_database_errors:
+            return self.cursor.executemany(sql, param_list)
+
+
+class CursorDebugWrapper(CursorWrapper):
+
+    # XXX callproc isn't instrumented at this time.
+
+    def execute(self, sql, params=None):
         start = time()
         start = time()
         try:
         try:
-            with self.db.wrap_database_errors:
-                if params is None:
-                    # params default might be backend specific
-                    return self.cursor.execute(sql)
-                return self.cursor.execute(sql, params)
+            return super(CursorDebugWrapper, self).execute(sql, params)
         finally:
         finally:
             stop = time()
             stop = time()
             duration = stop - start
             duration = stop - start
@@ -69,11 +88,9 @@ class CursorDebugWrapper(CursorWrapper):
             )
             )
 
 
     def executemany(self, sql, param_list):
     def executemany(self, sql, param_list):
-        self.db.set_dirty()
         start = time()
         start = time()
         try:
         try:
-            with self.db.wrap_database_errors:
-                return self.cursor.executemany(sql, param_list)
+            return super(CursorDebugWrapper, self).executemany(sql, param_list)
         finally:
         finally:
             stop = time()
             stop = time()
             duration = stop - start
             duration = stop - start

+ 9 - 17
django/db/models/query.py

@@ -436,14 +436,9 @@ class QuerySet(object):
         for k, v in six.iteritems(defaults):
         for k, v in six.iteritems(defaults):
             setattr(obj, k, v)
             setattr(obj, k, v)
 
 
-        sid = transaction.savepoint(using=self.db)
-        try:
+        with transaction.atomic(using=self.db):
             obj.save(using=self.db)
             obj.save(using=self.db)
-            transaction.savepoint_commit(sid, using=self.db)
-            return obj, False
-        except DatabaseError:
-            transaction.savepoint_rollback(sid, using=self.db)
-            six.reraise(*sys.exc_info())
+        return obj, False
 
 
     def _create_object_from_params(self, lookup, params):
     def _create_object_from_params(self, lookup, params):
         """
         """
@@ -451,19 +446,16 @@ class QuerySet(object):
         Used by get_or_create and update_or_create
         Used by get_or_create and update_or_create
         """
         """
         obj = self.model(**params)
         obj = self.model(**params)
-        sid = transaction.savepoint(using=self.db)
         try:
         try:
-            obj.save(force_insert=True, using=self.db)
-            transaction.savepoint_commit(sid, using=self.db)
+            with transaction.atomic(using=self.db):
+                obj.save(force_insert=True, using=self.db)
             return obj, True
             return obj, True
-        except DatabaseError as e:
-            transaction.savepoint_rollback(sid, using=self.db)
+        except IntegrityError:
             exc_info = sys.exc_info()
             exc_info = sys.exc_info()
-            if isinstance(e, IntegrityError):
-                try:
-                    return self.get(**lookup), False
-                except self.model.DoesNotExist:
-                    pass
+            try:
+                return self.get(**lookup), False
+            except self.model.DoesNotExist:
+                pass
             six.reraise(*exc_info)
             six.reraise(*exc_info)
 
 
     def _extract_model_params(self, defaults, **kwargs):
     def _extract_model_params(self, defaults, **kwargs):

+ 5 - 4
django/db/transaction.py

@@ -16,14 +16,15 @@ import warnings
 
 
 from functools import wraps
 from functools import wraps
 
 
-from django.db import connections, DatabaseError, DEFAULT_DB_ALIAS
+from django.db import (
+        connections, DEFAULT_DB_ALIAS,
+        DatabaseError, ProgrammingError)
 from django.utils.decorators import available_attrs
 from django.utils.decorators import available_attrs
 
 
 
 
-class TransactionManagementError(Exception):
+class TransactionManagementError(ProgrammingError):
     """
     """
-    This exception is thrown when something bad happens with transaction
-    management.
+    This exception is thrown when transaction management is used improperly.
     """
     """
     pass
     pass
 
 

+ 20 - 9
docs/topics/db/transactions.txt

@@ -163,20 +163,31 @@ Django provides a single API to control database transactions.
     called, so the exception handler can also operate on the database if
     called, so the exception handler can also operate on the database if
     necessary.
     necessary.
 
 
-    .. admonition:: Don't catch database exceptions inside ``atomic``!
-
-        If you catch :exc:`~django.db.DatabaseError` or a subclass such as
-        :exc:`~django.db.IntegrityError` inside an ``atomic`` block, you will
-        hide from Django the fact that an error has occurred and that the
-        transaction is broken. At this point, Django's behavior is unspecified
-        and database-dependent. It will usually result in a rollback, which
-        may break your expectations, since you caught the exception.
+    .. admonition:: Avoid catching exceptions inside ``atomic``!
+
+        When exiting an ``atomic`` block, Django looks at whether it's exited
+        normally or with an exception to determine whether to commit or roll
+        back. If you catch and handle exceptions inside an ``atomic`` block,
+        you may hide from Django the fact that a problem has happened. This
+        can result in unexpected behavior.
+
+        This is mostly a concern for :exc:`~django.db.DatabaseError` and its
+        subclasses such as :exc:`~django.db.IntegrityError`. After such an
+        error, the transaction is broken and Django will perform a rollback at
+        the end of the ``atomic`` block. If you attempt to run database
+        queries before the rollback happens, Django will raise a
+        :class:`~django.db.transaction.TransactionManagementError`. You may
+        also encounter this behavior when an ORM-related signal handler raises
+        an exception.
 
 
         The correct way to catch database errors is around an ``atomic`` block
         The correct way to catch database errors is around an ``atomic`` block
         as shown above. If necessary, add an extra ``atomic`` block for this
         as shown above. If necessary, add an extra ``atomic`` block for this
-        purpose -- it's cheap! This pattern is useful to delimit explicitly
+        purpose. This pattern has another advantage: it delimits explicitly
         which operations will be rolled back if an exception occurs.
         which operations will be rolled back if an exception occurs.
 
 
+        If you catch exceptions raised by raw SQL queries, Django's behavior
+        is unspecified and database-dependent.
+
     In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting
     In order to guarantee atomicity, ``atomic`` disables some APIs. Attempting
     to commit, roll back, or change the autocommit state of the database
     to commit, roll back, or change the autocommit state of the database
     connection within an ``atomic`` block will raise an exception.
     connection within an ``atomic`` block will raise an exception.

+ 6 - 10
tests/custom_pk/tests.py

@@ -149,11 +149,9 @@ class CustomPKTests(TestCase):
         Employee.objects.create(
         Employee.objects.create(
             employee_code=123, first_name="Frank", last_name="Jones"
             employee_code=123, first_name="Frank", last_name="Jones"
         )
         )
-        sid = transaction.savepoint()
-        self.assertRaises(IntegrityError,
-            Employee.objects.create, employee_code=123, first_name="Fred", last_name="Jones"
-        )
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones")
 
 
     def test_custom_field_pk(self):
     def test_custom_field_pk(self):
         # Regression for #10785 -- Custom fields can be used for primary keys.
         # Regression for #10785 -- Custom fields can be used for primary keys.
@@ -175,8 +173,6 @@ class CustomPKTests(TestCase):
     def test_required_pk(self):
     def test_required_pk(self):
         # The primary key must be specified, so an error is raised if you
         # The primary key must be specified, so an error is raised if you
         # try to create an object without it.
         # try to create an object without it.
-        sid = transaction.savepoint()
-        self.assertRaises(IntegrityError,
-            Employee.objects.create, first_name="Tom", last_name="Smith"
-        )
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                Employee.objects.create(first_name="Tom", last_name="Smith")

+ 6 - 5
tests/expressions/tests.py

@@ -2,6 +2,7 @@ from __future__ import unicode_literals
 
 
 from django.core.exceptions import FieldError
 from django.core.exceptions import FieldError
 from django.db.models import F
 from django.db.models import F
+from django.db import transaction
 from django.test import TestCase
 from django.test import TestCase
 from django.utils import six
 from django.utils import six
 
 
@@ -185,11 +186,11 @@ class ExpressionsTests(TestCase):
             "foo",
             "foo",
         )
         )
 
 
-        self.assertRaises(FieldError,
-            lambda: Company.objects.exclude(
-                ceo__firstname=F('point_of_contact__firstname')
-            ).update(name=F('point_of_contact__lastname'))
-        )
+        with transaction.atomic():
+            with self.assertRaises(FieldError):
+                Company.objects.exclude(
+                    ceo__firstname=F('point_of_contact__firstname')
+                ).update(name=F('point_of_contact__lastname'))
 
 
         # F expressions can be used to update attributes on single objects
         # F expressions can be used to update attributes on single objects
         test_gmbh = Company.objects.get(name="Test GmbH")
         test_gmbh = Company.objects.get(name="Test GmbH")

+ 11 - 6
tests/force_insert_update/tests.py

@@ -21,24 +21,29 @@ class ForceTests(TestCase):
         # Won't work because force_update and force_insert are mutually
         # Won't work because force_update and force_insert are mutually
         # exclusive
         # exclusive
         c.value = 4
         c.value = 4
-        self.assertRaises(ValueError, c.save, force_insert=True, force_update=True)
+        with self.assertRaises(ValueError):
+            c.save(force_insert=True, force_update=True)
 
 
         # Try to update something that doesn't have a primary key in the first
         # Try to update something that doesn't have a primary key in the first
         # place.
         # place.
         c1 = Counter(name="two", value=2)
         c1 = Counter(name="two", value=2)
-        self.assertRaises(ValueError, c1.save, force_update=True)
+        with self.assertRaises(ValueError):
+            with transaction.atomic():
+                c1.save(force_update=True)
         c1.save(force_insert=True)
         c1.save(force_insert=True)
 
 
         # Won't work because we can't insert a pk of the same value.
         # Won't work because we can't insert a pk of the same value.
-        sid = transaction.savepoint()
         c.value = 5
         c.value = 5
-        self.assertRaises(IntegrityError, c.save, force_insert=True)
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                c.save(force_insert=True)
 
 
         # Trying to update should still fail, even with manual primary keys, if
         # Trying to update should still fail, even with manual primary keys, if
         # the data isn't in the database already.
         # the data isn't in the database already.
         obj = WithCustomPK(name=1, value=1)
         obj = WithCustomPK(name=1, value=1)
-        self.assertRaises(DatabaseError, obj.save, force_update=True)
+        with self.assertRaises(DatabaseError):
+            with transaction.atomic():
+                obj.save(force_update=True)
 
 
 
 
 class InheritanceTests(TestCase):
 class InheritanceTests(TestCase):

+ 3 - 3
tests/one_to_one/tests.py

@@ -118,7 +118,7 @@ class OneToOneTests(TestCase):
         self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>')
         self.assertEqual(repr(o1.multimodel), '<MultiModel: Multimodel x1>')
         # This will fail because each one-to-one field must be unique (and
         # This will fail because each one-to-one field must be unique (and
         # link2=o1 was used for x1, above).
         # link2=o1 was used for x1, above).
-        sid = transaction.savepoint()
         mm = MultiModel(link1=self.p2, link2=o1, name="x1")
         mm = MultiModel(link1=self.p2, link2=o1, name="x1")
-        self.assertRaises(IntegrityError, mm.save)
-        transaction.savepoint_rollback(sid)
+        with self.assertRaises(IntegrityError):
+            with transaction.atomic():
+                mm.save()

+ 44 - 27
tests/transactions/tests.py

@@ -4,7 +4,7 @@ import sys
 from unittest import skipIf, skipUnless
 from unittest import skipIf, skipUnless
 
 
 from django.db import connection, transaction, DatabaseError, IntegrityError
 from django.db import connection, transaction, DatabaseError, IntegrityError
-from django.test import TransactionTestCase, skipUnlessDBFeature
+from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import IgnoreDeprecationWarningsMixin
 from django.test.utils import IgnoreDeprecationWarningsMixin
 from django.utils import six
 from django.utils import six
 
 
@@ -204,10 +204,10 @@ class AtomicTests(TransactionTestCase):
                 with transaction.atomic(savepoint=False):
                 with transaction.atomic(savepoint=False):
                     connection.cursor().execute(
                     connection.cursor().execute(
                             "SELECT no_such_col FROM transactions_reporter")
                             "SELECT no_such_col FROM transactions_reporter")
-            transaction.savepoint_rollback(sid)
-            # atomic block should rollback, but prevent it, as we just did it.
+            # prevent atomic from rolling back since we're recovering manually
             self.assertTrue(transaction.get_rollback())
             self.assertTrue(transaction.get_rollback())
             transaction.set_rollback(False)
             transaction.set_rollback(False)
+            transaction.savepoint_rollback(sid)
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
 
 
 
 
@@ -267,11 +267,19 @@ class AtomicMergeTests(TransactionTestCase):
                     with transaction.atomic(savepoint=False):
                     with transaction.atomic(savepoint=False):
                         Reporter.objects.create(first_name="Calculus")
                         Reporter.objects.create(first_name="Calculus")
                         raise Exception("Oops, that's his last name")
                         raise Exception("Oops, that's his last name")
-                # It wasn't possible to roll back
+                # The third insert couldn't be roll back. Temporarily mark the
+                # connection as not needing rollback to check it.
+                self.assertTrue(transaction.get_rollback())
+                transaction.set_rollback(False)
                 self.assertEqual(Reporter.objects.count(), 3)
                 self.assertEqual(Reporter.objects.count(), 3)
-            # It wasn't possible to roll back
+                transaction.set_rollback(True)
+            # The second insert couldn't be roll back. Temporarily mark the
+            # connection as not needing rollback to check it.
+            self.assertTrue(transaction.get_rollback())
+            transaction.set_rollback(False)
             self.assertEqual(Reporter.objects.count(), 3)
             self.assertEqual(Reporter.objects.count(), 3)
-        # The outer block must roll back
+            transaction.set_rollback(True)
+        # The first block has a savepoint and must roll back.
         self.assertQuerysetEqual(Reporter.objects.all(), [])
         self.assertQuerysetEqual(Reporter.objects.all(), [])
 
 
     def test_merged_inner_savepoint_rollback(self):
     def test_merged_inner_savepoint_rollback(self):
@@ -283,36 +291,22 @@ class AtomicMergeTests(TransactionTestCase):
                     with transaction.atomic(savepoint=False):
                     with transaction.atomic(savepoint=False):
                         Reporter.objects.create(first_name="Calculus")
                         Reporter.objects.create(first_name="Calculus")
                         raise Exception("Oops, that's his last name")
                         raise Exception("Oops, that's his last name")
-                # It wasn't possible to roll back
+                # The third insert couldn't be roll back. Temporarily mark the
+                # connection as not needing rollback to check it.
+                self.assertTrue(transaction.get_rollback())
+                transaction.set_rollback(False)
                 self.assertEqual(Reporter.objects.count(), 3)
                 self.assertEqual(Reporter.objects.count(), 3)
-            # The first block with a savepoint must roll back
+                transaction.set_rollback(True)
+            # The second block has a savepoint and must roll back.
             self.assertEqual(Reporter.objects.count(), 1)
             self.assertEqual(Reporter.objects.count(), 1)
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
         self.assertQuerysetEqual(Reporter.objects.all(), ['<Reporter: Tintin>'])
 
 
-    def test_merged_outer_rollback_after_inner_failure_and_inner_success(self):
-        with transaction.atomic():
-            Reporter.objects.create(first_name="Tintin")
-            # Inner block without a savepoint fails
-            with six.assertRaisesRegex(self, Exception, "Oops"):
-                with transaction.atomic(savepoint=False):
-                    Reporter.objects.create(first_name="Haddock")
-                    raise Exception("Oops, that's his last name")
-            # It wasn't possible to roll back
-            self.assertEqual(Reporter.objects.count(), 2)
-            # Inner block with a savepoint succeeds
-            with transaction.atomic(savepoint=False):
-                Reporter.objects.create(first_name="Archibald", last_name="Haddock")
-            # It still wasn't possible to roll back
-            self.assertEqual(Reporter.objects.count(), 3)
-        # The outer block must rollback
-        self.assertQuerysetEqual(Reporter.objects.all(), [])
-
 
 
 @skipUnless(connection.features.uses_savepoints,
 @skipUnless(connection.features.uses_savepoints,
         "'atomic' requires transactions and savepoints.")
         "'atomic' requires transactions and savepoints.")
 class AtomicErrorsTests(TransactionTestCase):
 class AtomicErrorsTests(TransactionTestCase):
 
 
-    available_apps = []
+    available_apps = ['transactions']
 
 
     def test_atomic_prevents_setting_autocommit(self):
     def test_atomic_prevents_setting_autocommit(self):
         autocommit = transaction.get_autocommit()
         autocommit = transaction.get_autocommit()
@@ -336,6 +330,29 @@ class AtomicErrorsTests(TransactionTestCase):
             with self.assertRaises(transaction.TransactionManagementError):
             with self.assertRaises(transaction.TransactionManagementError):
                 transaction.leave_transaction_management()
                 transaction.leave_transaction_management()
 
 
+    def test_atomic_prevents_queries_in_broken_transaction(self):
+        r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
+        with transaction.atomic():
+            r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
+            with self.assertRaises(IntegrityError):
+                r2.save(force_insert=True)
+            # The transaction is marked as needing rollback.
+            with self.assertRaises(transaction.TransactionManagementError):
+                r2.save(force_update=True)
+        self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock")
+
+    @skipIfDBFeature('atomic_transactions')
+    def test_atomic_allows_queries_after_fixing_transaction(self):
+        r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock")
+        with transaction.atomic():
+            r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id)
+            with self.assertRaises(IntegrityError):
+                r2.save(force_insert=True)
+            # Mark the transaction as no longer needing rollback.
+            transaction.set_rollback(False)
+            r2.save(force_update=True)
+        self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus")
+
 
 
 class AtomicMiscTests(TransactionTestCase):
 class AtomicMiscTests(TransactionTestCase):