瀏覽代碼

Fixed #33616 -- Allowed registering callbacks that can fail in transaction.on_commit().

Thanks David Wobrock and Mariusz Felisiak for reviews.
SirAbhi13 2 年之前
父節點
當前提交
4a1150b41d

+ 32 - 6
django/db/backends/base/base.py

@@ -1,6 +1,7 @@
 import _thread
 import copy
 import datetime
+import logging
 import threading
 import time
 import warnings
@@ -26,6 +27,8 @@ from django.utils.functional import cached_property
 NO_DB_ALIAS = "__no_db__"
 RAN_DB_VERSION_CHECK = set()
 
+logger = logging.getLogger("django.db.backends.base")
+
 
 # RemovedInDjango50Warning
 def timezone_constructor(tzname):
@@ -417,7 +420,9 @@ class BaseDatabaseWrapper:
 
         # Remove any callbacks registered while this savepoint was active.
         self.run_on_commit = [
-            (sids, func) for (sids, func) in self.run_on_commit if sid not in sids
+            (sids, func, robust)
+            for (sids, func, robust) in self.run_on_commit
+            if sid not in sids
         ]
 
     @async_unsafe
@@ -723,12 +728,12 @@ class BaseDatabaseWrapper:
             )
         return self.SchemaEditorClass(self, *args, **kwargs)
 
-    def on_commit(self, func):
+    def on_commit(self, func, robust=False):
         if not callable(func):
             raise TypeError("on_commit()'s callback must be a callable.")
         if self.in_atomic_block:
             # Transaction in progress; save for execution on commit.
-            self.run_on_commit.append((set(self.savepoint_ids), func))
+            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
         elif not self.get_autocommit():
             raise TransactionManagementError(
                 "on_commit() cannot be used in manual transaction management"
@@ -736,15 +741,36 @@ class BaseDatabaseWrapper:
         else:
             # No transaction in progress and in autocommit mode; execute
             # immediately.
-            func()
+            if robust:
+                try:
+                    func()
+                except Exception as e:
+                    logger.error(
+                        f"Error calling {func.__qualname__} in on_commit() (%s).",
+                        e,
+                        exc_info=True,
+                    )
+            else:
+                func()
 
     def run_and_clear_commit_hooks(self):
         self.validate_no_atomic_block()
         current_run_on_commit = self.run_on_commit
         self.run_on_commit = []
         while current_run_on_commit:
-            sids, func = current_run_on_commit.pop(0)
-            func()
+            _, func, robust = current_run_on_commit.pop(0)
+            if robust:
+                try:
+                    func()
+                except Exception as e:
+                    logger.error(
+                        f"Error calling {func.__qualname__} in on_commit() during "
+                        f"transaction (%s).",
+                        e,
+                        exc_info=True,
+                    )
+            else:
+                func()
 
     @contextmanager
     def execute_wrapper(self, wrapper):

+ 2 - 2
django/db/transaction.py

@@ -125,12 +125,12 @@ def mark_for_rollback_on_error(using=None):
         raise
 
 
-def on_commit(func, using=None):
+def on_commit(func, using=None, robust=False):
     """
     Register `func` to be called when the current transaction is committed.
     If the current transaction is rolled back, `func` will not be called.
     """
-    get_connection(using).on_commit(func)
+    get_connection(using).on_commit(func, robust)
 
 
 #################################

+ 17 - 2
django/test/testcases.py

@@ -59,6 +59,8 @@ from django.utils.functional import classproperty
 from django.utils.version import PY310
 from django.views.static import serve
 
+logger = logging.getLogger("django.test")
+
 __all__ = (
     "TestCase",
     "TransactionTestCase",
@@ -1510,10 +1512,23 @@ class TestCase(TransactionTestCase):
         finally:
             while True:
                 callback_count = len(connections[using].run_on_commit)
-                for _, callback in connections[using].run_on_commit[start_count:]:
+                for _, callback, robust in connections[using].run_on_commit[
+                    start_count:
+                ]:
                     callbacks.append(callback)
                     if execute:
-                        callback()
+                        if robust:
+                            try:
+                                callback()
+                            except Exception as e:
+                                logger.error(
+                                    f"Error calling {callback.__qualname__} in "
+                                    f"on_commit() (%s).",
+                                    e,
+                                    exc_info=True,
+                                )
+                        else:
+                            callback()
 
                 if callback_count == len(connections[using].run_on_commit):
                     break

+ 4 - 0
docs/releases/4.2.txt

@@ -212,6 +212,10 @@ Models
 * :ref:`Registering lookups <lookup-registration-api>` on
   :class:`~django.db.models.Field` instances is now supported.
 
+* The new ``robust`` argument for :func:`~django.db.transaction.on_commit`
+  allows performing actions that can fail after a database transaction is
+  successfully committed.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 18 - 5
docs/topics/db/transactions.txt

@@ -297,7 +297,7 @@ include a `Celery`_ task, an email notification, or a cache invalidation.
 Django provides the :func:`on_commit` function to register callback functions
 that should be executed after a transaction is successfully committed:
 
-.. function:: on_commit(func, using=None)
+.. function:: on_commit(func, using=None, robust=False)
 
 Pass any function (that takes no arguments) to :func:`on_commit`::
 
@@ -325,6 +325,15 @@ If that hypothetical database write is instead rolled back (typically when an
 unhandled exception is raised in an :func:`atomic` block), your function will
 be discarded and never called.
 
+It's sometimes useful to register callback functions that can fail. Passing
+``robust=True`` allows the next functions to be executed even if the current
+function throws an exception. All errors derived from Python's ``Exception``
+class are caught and logged to the ``django.db.backends.base`` logger.
+
+.. versionchanged:: 4.2
+
+    The ``robust`` argument was added.
+
 Savepoints
 ----------
 
@@ -366,10 +375,14 @@ registered.
 Exception handling
 ------------------
 
-If one on-commit function within a given transaction raises an uncaught
-exception, no later registered functions in that same transaction will run.
-This is the same behavior as if you'd executed the functions sequentially
-yourself without :func:`on_commit`.
+If one on-commit function registered with ``robust=False`` within a given
+transaction raises an uncaught exception, no later registered functions in that
+same transaction will run. This is the same behavior as if you'd executed the
+functions sequentially yourself without :func:`on_commit`.
+
+.. versionchanged:: 4.2
+
+    The ``robust`` argument was added.
 
 Timing of execution
 -------------------

+ 26 - 0
tests/test_utils/tests.py

@@ -2285,6 +2285,32 @@ class CaptureOnCommitCallbacksTests(TestCase):
 
         self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
 
+    def test_execute_robust(self):
+        class MyException(Exception):
+            pass
+
+        def hook():
+            self.callback_called = True
+            raise MyException("robust callback")
+
+        with self.assertLogs("django.test", "ERROR") as cm:
+            with self.captureOnCommitCallbacks(execute=True) as callbacks:
+                transaction.on_commit(hook, robust=True)
+
+        self.assertEqual(len(callbacks), 1)
+        self.assertIs(self.callback_called, True)
+
+        log_record = cm.records[0]
+        self.assertEqual(
+            log_record.getMessage(),
+            "Error calling CaptureOnCommitCallbacksTests.test_execute_robust.<locals>."
+            "hook in on_commit() (robust callback).",
+        )
+        self.assertIsNotNone(log_record.exc_info)
+        raised_exception = log_record.exc_info[1]
+        self.assertIsInstance(raised_exception, MyException)
+        self.assertEqual(str(raised_exception), "robust callback")
+
 
 class DisallowedDatabaseQueriesTests(SimpleTestCase):
     def test_disallowed_database_connections(self):

+ 41 - 0
tests/transaction_hooks/tests.py

@@ -43,6 +43,47 @@ class TestConnectionOnCommit(TransactionTestCase):
         self.do(1)
         self.assertDone([1])
 
+    def test_robust_if_no_transaction(self):
+        def robust_callback():
+            raise ForcedError("robust callback")
+
+        with self.assertLogs("django.db.backends.base", "ERROR") as cm:
+            transaction.on_commit(robust_callback, robust=True)
+            self.do(1)
+
+        self.assertDone([1])
+        log_record = cm.records[0]
+        self.assertEqual(
+            log_record.getMessage(),
+            "Error calling TestConnectionOnCommit.test_robust_if_no_transaction."
+            "<locals>.robust_callback in on_commit() (robust callback).",
+        )
+        self.assertIsNotNone(log_record.exc_info)
+        raised_exception = log_record.exc_info[1]
+        self.assertIsInstance(raised_exception, ForcedError)
+        self.assertEqual(str(raised_exception), "robust callback")
+
+    def test_robust_transaction(self):
+        def robust_callback():
+            raise ForcedError("robust callback")
+
+        with self.assertLogs("django.db.backends", "ERROR") as cm:
+            with transaction.atomic():
+                transaction.on_commit(robust_callback, robust=True)
+                self.do(1)
+
+        self.assertDone([1])
+        log_record = cm.records[0]
+        self.assertEqual(
+            log_record.getMessage(),
+            "Error calling TestConnectionOnCommit.test_robust_transaction.<locals>."
+            "robust_callback in on_commit() during transaction (robust callback).",
+        )
+        self.assertIsNotNone(log_record.exc_info)
+        raised_exception = log_record.exc_info[1]
+        self.assertIsInstance(raised_exception, ForcedError)
+        self.assertEqual(str(raised_exception), "robust callback")
+
     def test_delays_execution_until_after_transaction_commit(self):
         with transaction.atomic():
             self.do(1)