Browse Source

Fixed #30457 -- Added TestCase.captureOnCommitCallbacks().

Adam Johnson 4 years ago
parent
commit
e906ff6fca

+ 15 - 0
django/test/testcases.py

@@ -1229,6 +1229,21 @@ class TestCase(TransactionTestCase):
             not connection.needs_rollback and connection.is_usable()
             not connection.needs_rollback and connection.is_usable()
         )
         )
 
 
+    @classmethod
+    @contextmanager
+    def captureOnCommitCallbacks(cls, *, using=DEFAULT_DB_ALIAS, execute=False):
+        """Context manager to capture transaction.on_commit() callbacks."""
+        callbacks = []
+        start_count = len(connections[using].run_on_commit)
+        try:
+            yield callbacks
+        finally:
+            run_on_commit = connections[using].run_on_commit[start_count:]
+            callbacks[:] = [func for sids, func in run_on_commit]
+            if execute:
+                for callback in callbacks:
+                    callback()
+
 
 
 class CheckCondition:
 class CheckCondition:
     """Descriptor class for deferred condition checking."""
     """Descriptor class for deferred condition checking."""

+ 5 - 0
docs/releases/3.2.txt

@@ -276,6 +276,11 @@ Tests
 * :class:`~django.test.Client` now preserves the request query string when
 * :class:`~django.test.Client` now preserves the request query string when
   following 307 and 308 redirects.
   following 307 and 308 redirects.
 
 
+* The new :meth:`.TestCase.captureOnCommitCallbacks` method captures callback
+  functions passed to :func:`transaction.on_commit()
+  <django.db.transaction.on_commit>` in a list. This allows you to test such
+  callbacks without using the slower :class:`.TransactionTestCase`.
+
 URLs
 URLs
 ~~~~
 ~~~~
 
 

+ 13 - 3
docs/topics/db/transactions.txt

@@ -394,9 +394,19 @@ Use in tests
 Django's :class:`~django.test.TestCase` class wraps each test in a transaction
 Django's :class:`~django.test.TestCase` class wraps each test in a transaction
 and rolls back that transaction after each test, in order to provide test
 and rolls back that transaction after each test, in order to provide test
 isolation. This means that no transaction is ever actually committed, thus your
 isolation. This means that no transaction is ever actually committed, thus your
-:func:`on_commit` callbacks will never be run. If you need to test the results
+:func:`on_commit` callbacks will never be run.
-of an :func:`on_commit` callback, use a
+
-:class:`~django.test.TransactionTestCase` instead.
+You can overcome this limitation by using
+:meth:`.TestCase.captureOnCommitCallbacks`. This captures your
+:func:`on_commit` callbacks in a list, allowing you to make assertions on them,
+or emulate the transaction committing by calling them.
+
+Another way to overcome the limitation is to use
+:class:`~django.test.TransactionTestCase` instead of
+:class:`~django.test.TestCase`. This will mean your transactions are committed,
+and the callbacks will run. However
+:class:`~django.test.TransactionTestCase` flushes the database between tests,
+which is significantly slower than :class:`~django.test.TestCase`\'s isolation.
 
 
 Why no rollback hook?
 Why no rollback hook?
 ---------------------
 ---------------------

+ 36 - 0
docs/topics/testing/tools.txt

@@ -881,6 +881,42 @@ It also provides an additional method:
         previous versions of Django these objects were reused and changes made
         previous versions of Django these objects were reused and changes made
         to them were persisted between test methods.
         to them were persisted between test methods.
 
 
+.. classmethod:: TestCase.captureOnCommitCallbacks(using=DEFAULT_DB_ALIAS, execute=False)
+
+    .. versionadded:: 3.2
+
+    Returns a context manager that captures :func:`transaction.on_commit()
+    <django.db.transaction.on_commit>` callbacks for the given database
+    connection. It returns a list that contains, on exit of the context, the
+    captured callback functions. From this list you can make assertions on the
+    callbacks or call them to invoke their side effects, emulating a commit.
+
+    ``using`` is the alias of the database connection to capture callbacks for.
+
+    If ``execute`` is ``True``, all the callbacks will be called as the context
+    manager exits, if no exception occurred. This emulates a commit after the
+    wrapped block of code.
+
+    For example::
+
+        from django.core import mail
+        from django.test import TestCase
+
+
+        class ContactTests(TestCase):
+            def test_post(self):
+                with self.captureOnCommitCallbacks(execute=True) as callbacks:
+                    response = self.client.post(
+                        '/contact/',
+                        {'message': 'I like your site'},
+                    )
+
+                self.assertEqual(response.status_code, 200)
+                self.assertEqual(len(callbacks), 1)
+                self.assertEqual(len(mail.outbox), 1)
+                self.assertEqual(mail.outbox[0].subject, 'Contact Form')
+                self.assertEqual(mail.outbox[0].body, 'I like your site')
+
 .. _live-test-server:
 .. _live-test-server:
 
 
 ``LiveServerTestCase``
 ``LiveServerTestCase``

+ 68 - 1
tests/test_utils/tests.py

@@ -9,7 +9,9 @@ from django.contrib.staticfiles.finders import get_finder, get_finders
 from django.contrib.staticfiles.storage import staticfiles_storage
 from django.contrib.staticfiles.storage import staticfiles_storage
 from django.core.exceptions import ImproperlyConfigured
 from django.core.exceptions import ImproperlyConfigured
 from django.core.files.storage import default_storage
 from django.core.files.storage import default_storage
-from django.db import connection, connections, models, router
+from django.db import (
+    IntegrityError, connection, connections, models, router, transaction,
+)
 from django.forms import EmailField, IntegerField
 from django.forms import EmailField, IntegerField
 from django.http import HttpResponse
 from django.http import HttpResponse
 from django.template.loader import render_to_string
 from django.template.loader import render_to_string
@@ -1273,6 +1275,71 @@ class TestBadSetUpTestData(TestCase):
         self.assertFalse(self._in_atomic_block)
         self.assertFalse(self._in_atomic_block)
 
 
 
 
+class CaptureOnCommitCallbacksTests(TestCase):
+    databases = {'default', 'other'}
+    callback_called = False
+
+    def enqueue_callback(self, using='default'):
+        def hook():
+            self.callback_called = True
+
+        transaction.on_commit(hook, using=using)
+
+    def test_no_arguments(self):
+        with self.captureOnCommitCallbacks() as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, False)
+        callbacks[0]()
+        self.assertIs(self.callback_called, True)
+
+    def test_using(self):
+        with self.captureOnCommitCallbacks(using='other') as callbacks:
+            self.enqueue_callback(using='other')
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, False)
+        callbacks[0]()
+        self.assertIs(self.callback_called, True)
+
+    def test_different_using(self):
+        with self.captureOnCommitCallbacks(using='default') as callbacks:
+            self.enqueue_callback(using='other')
+
+        self.assertEqual(callbacks, [])
+
+    def test_execute(self):
+        with self.captureOnCommitCallbacks(execute=True) as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, True)
+
+    def test_pre_callback(self):
+        def pre_hook():
+            pass
+
+        transaction.on_commit(pre_hook, using='default')
+        with self.captureOnCommitCallbacks() as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertNotEqual(callbacks[0], pre_hook)
+
+    def test_with_rolled_back_savepoint(self):
+        with self.captureOnCommitCallbacks() as callbacks:
+            try:
+                with transaction.atomic():
+                    self.enqueue_callback()
+                    raise IntegrityError
+            except IntegrityError:
+                # Inner transaction.atomic() has been rolled back.
+                pass
+
+        self.assertEqual(callbacks, [])
+
+
 class DisallowedDatabaseQueriesTests(SimpleTestCase):
 class DisallowedDatabaseQueriesTests(SimpleTestCase):
     def test_disallowed_database_connections(self):
     def test_disallowed_database_connections(self):
         expected_message = (
         expected_message = (