Browse Source

Fixed #30457 -- Added TestCase.captureOnCommitCallbacks().

Adam Johnson 4 years ago
parent
commit
e906ff6fca

+ 15 - 0
django/test/testcases.py

@@ -1229,6 +1229,21 @@ class TestCase(TransactionTestCase):
             not connection.needs_rollback and connection.is_usable()
         )
 
+    @classmethod
+    @contextmanager
+    def captureOnCommitCallbacks(cls, *, using=DEFAULT_DB_ALIAS, execute=False):
+        """Context manager to capture transaction.on_commit() callbacks."""
+        callbacks = []
+        start_count = len(connections[using].run_on_commit)
+        try:
+            yield callbacks
+        finally:
+            run_on_commit = connections[using].run_on_commit[start_count:]
+            callbacks[:] = [func for sids, func in run_on_commit]
+            if execute:
+                for callback in callbacks:
+                    callback()
+
 
 class CheckCondition:
     """Descriptor class for deferred condition checking."""

+ 5 - 0
docs/releases/3.2.txt

@@ -276,6 +276,11 @@ Tests
 * :class:`~django.test.Client` now preserves the request query string when
   following 307 and 308 redirects.
 
+* The new :meth:`.TestCase.captureOnCommitCallbacks` method captures callback
+  functions passed to :func:`transaction.on_commit()
+  <django.db.transaction.on_commit>` in a list. This allows you to test such
+  callbacks without using the slower :class:`.TransactionTestCase`.
+
 URLs
 ~~~~
 

+ 13 - 3
docs/topics/db/transactions.txt

@@ -394,9 +394,19 @@ Use in tests
 Django's :class:`~django.test.TestCase` class wraps each test in a transaction
 and rolls back that transaction after each test, in order to provide test
 isolation. This means that no transaction is ever actually committed, thus your
-:func:`on_commit` callbacks will never be run. If you need to test the results
-of an :func:`on_commit` callback, use a
-:class:`~django.test.TransactionTestCase` instead.
+:func:`on_commit` callbacks will never be run.
+
+You can overcome this limitation by using
+:meth:`.TestCase.captureOnCommitCallbacks`. This captures your
+:func:`on_commit` callbacks in a list, allowing you to make assertions on them,
+or emulate the transaction committing by calling them.
+
+Another way to overcome the limitation is to use
+:class:`~django.test.TransactionTestCase` instead of
+:class:`~django.test.TestCase`. This will mean your transactions are committed,
+and the callbacks will run. However
+:class:`~django.test.TransactionTestCase` flushes the database between tests,
+which is significantly slower than :class:`~django.test.TestCase`\'s isolation.
 
 Why no rollback hook?
 ---------------------

+ 36 - 0
docs/topics/testing/tools.txt

@@ -881,6 +881,42 @@ It also provides an additional method:
         previous versions of Django these objects were reused and changes made
         to them were persisted between test methods.
 
+.. classmethod:: TestCase.captureOnCommitCallbacks(using=DEFAULT_DB_ALIAS, execute=False)
+
+    .. versionadded:: 3.2
+
+    Returns a context manager that captures :func:`transaction.on_commit()
+    <django.db.transaction.on_commit>` callbacks for the given database
+    connection. It returns a list that contains, on exit of the context, the
+    captured callback functions. From this list you can make assertions on the
+    callbacks or call them to invoke their side effects, emulating a commit.
+
+    ``using`` is the alias of the database connection to capture callbacks for.
+
+    If ``execute`` is ``True``, all the callbacks will be called as the context
+    manager exits, if no exception occurred. This emulates a commit after the
+    wrapped block of code.
+
+    For example::
+
+        from django.core import mail
+        from django.test import TestCase
+
+
+        class ContactTests(TestCase):
+            def test_post(self):
+                with self.captureOnCommitCallbacks(execute=True) as callbacks:
+                    response = self.client.post(
+                        '/contact/',
+                        {'message': 'I like your site'},
+                    )
+
+                self.assertEqual(response.status_code, 200)
+                self.assertEqual(len(callbacks), 1)
+                self.assertEqual(len(mail.outbox), 1)
+                self.assertEqual(mail.outbox[0].subject, 'Contact Form')
+                self.assertEqual(mail.outbox[0].body, 'I like your site')
+
 .. _live-test-server:
 
 ``LiveServerTestCase``

+ 68 - 1
tests/test_utils/tests.py

@@ -9,7 +9,9 @@ from django.contrib.staticfiles.finders import get_finder, get_finders
 from django.contrib.staticfiles.storage import staticfiles_storage
 from django.core.exceptions import ImproperlyConfigured
 from django.core.files.storage import default_storage
-from django.db import connection, connections, models, router
+from django.db import (
+    IntegrityError, connection, connections, models, router, transaction,
+)
 from django.forms import EmailField, IntegerField
 from django.http import HttpResponse
 from django.template.loader import render_to_string
@@ -1273,6 +1275,71 @@ class TestBadSetUpTestData(TestCase):
         self.assertFalse(self._in_atomic_block)
 
 
+class CaptureOnCommitCallbacksTests(TestCase):
+    databases = {'default', 'other'}
+    callback_called = False
+
+    def enqueue_callback(self, using='default'):
+        def hook():
+            self.callback_called = True
+
+        transaction.on_commit(hook, using=using)
+
+    def test_no_arguments(self):
+        with self.captureOnCommitCallbacks() as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, False)
+        callbacks[0]()
+        self.assertIs(self.callback_called, True)
+
+    def test_using(self):
+        with self.captureOnCommitCallbacks(using='other') as callbacks:
+            self.enqueue_callback(using='other')
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, False)
+        callbacks[0]()
+        self.assertIs(self.callback_called, True)
+
+    def test_different_using(self):
+        with self.captureOnCommitCallbacks(using='default') as callbacks:
+            self.enqueue_callback(using='other')
+
+        self.assertEqual(callbacks, [])
+
+    def test_execute(self):
+        with self.captureOnCommitCallbacks(execute=True) as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, True)
+
+    def test_pre_callback(self):
+        def pre_hook():
+            pass
+
+        transaction.on_commit(pre_hook, using='default')
+        with self.captureOnCommitCallbacks() as callbacks:
+            self.enqueue_callback()
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertNotEqual(callbacks[0], pre_hook)
+
+    def test_with_rolled_back_savepoint(self):
+        with self.captureOnCommitCallbacks() as callbacks:
+            try:
+                with transaction.atomic():
+                    self.enqueue_callback()
+                    raise IntegrityError
+            except IntegrityError:
+                # Inner transaction.atomic() has been rolled back.
+                pass
+
+        self.assertEqual(callbacks, [])
+
+
 class DisallowedDatabaseQueriesTests(SimpleTestCase):
     def test_disallowed_database_connections(self):
         expected_message = (