|
@@ -9,7 +9,9 @@ from django.contrib.staticfiles.finders import get_finder, get_finders
|
|
|
from django.contrib.staticfiles.storage import staticfiles_storage
|
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
|
from django.core.files.storage import default_storage
|
|
|
-from django.db import connection, connections, models, router
|
|
|
+from django.db import (
|
|
|
+ IntegrityError, connection, connections, models, router, transaction,
|
|
|
+)
|
|
|
from django.forms import EmailField, IntegerField
|
|
|
from django.http import HttpResponse
|
|
|
from django.template.loader import render_to_string
|
|
@@ -1273,6 +1275,71 @@ class TestBadSetUpTestData(TestCase):
|
|
|
self.assertFalse(self._in_atomic_block)
|
|
|
|
|
|
|
|
|
+class CaptureOnCommitCallbacksTests(TestCase):
|
|
|
+ databases = {'default', 'other'}
|
|
|
+ callback_called = False
|
|
|
+
|
|
|
+ def enqueue_callback(self, using='default'):
|
|
|
+ def hook():
|
|
|
+ self.callback_called = True
|
|
|
+
|
|
|
+ transaction.on_commit(hook, using=using)
|
|
|
+
|
|
|
+ def test_no_arguments(self):
|
|
|
+ with self.captureOnCommitCallbacks() as callbacks:
|
|
|
+ self.enqueue_callback()
|
|
|
+
|
|
|
+ self.assertEqual(len(callbacks), 1)
|
|
|
+ self.assertIs(self.callback_called, False)
|
|
|
+ callbacks[0]()
|
|
|
+ self.assertIs(self.callback_called, True)
|
|
|
+
|
|
|
+ def test_using(self):
|
|
|
+ with self.captureOnCommitCallbacks(using='other') as callbacks:
|
|
|
+ self.enqueue_callback(using='other')
|
|
|
+
|
|
|
+ self.assertEqual(len(callbacks), 1)
|
|
|
+ self.assertIs(self.callback_called, False)
|
|
|
+ callbacks[0]()
|
|
|
+ self.assertIs(self.callback_called, True)
|
|
|
+
|
|
|
+ def test_different_using(self):
|
|
|
+ with self.captureOnCommitCallbacks(using='default') as callbacks:
|
|
|
+ self.enqueue_callback(using='other')
|
|
|
+
|
|
|
+ self.assertEqual(callbacks, [])
|
|
|
+
|
|
|
+ def test_execute(self):
|
|
|
+ with self.captureOnCommitCallbacks(execute=True) as callbacks:
|
|
|
+ self.enqueue_callback()
|
|
|
+
|
|
|
+ self.assertEqual(len(callbacks), 1)
|
|
|
+ self.assertIs(self.callback_called, True)
|
|
|
+
|
|
|
+ def test_pre_callback(self):
|
|
|
+ def pre_hook():
|
|
|
+ pass
|
|
|
+
|
|
|
+ transaction.on_commit(pre_hook, using='default')
|
|
|
+ with self.captureOnCommitCallbacks() as callbacks:
|
|
|
+ self.enqueue_callback()
|
|
|
+
|
|
|
+ self.assertEqual(len(callbacks), 1)
|
|
|
+ self.assertNotEqual(callbacks[0], pre_hook)
|
|
|
+
|
|
|
+ def test_with_rolled_back_savepoint(self):
|
|
|
+ with self.captureOnCommitCallbacks() as callbacks:
|
|
|
+ try:
|
|
|
+ with transaction.atomic():
|
|
|
+ self.enqueue_callback()
|
|
|
+ raise IntegrityError
|
|
|
+ except IntegrityError:
|
|
|
+ # Inner transaction.atomic() has been rolled back.
|
|
|
+ pass
|
|
|
+
|
|
|
+ self.assertEqual(callbacks, [])
|
|
|
+
|
|
|
+
|
|
|
class DisallowedDatabaseQueriesTests(SimpleTestCase):
|
|
|
def test_disallowed_database_connections(self):
|
|
|
expected_message = (
|