浏览代码

Fixed #28344 -- Allowed customizing queryset in Model.refresh_from_db()/arefresh_from_db().

The from_queryset parameter can be used to:
- use a custom Manager
- lock the row until the end of transaction
- select additional related objects
Aivars Kalvans 1 年之前
父节点
当前提交
f92641a636
共有 5 个文件被更改,包括 111 次插入14 次删除
  1. 18 10
      django/db/models/base.py
  2. 23 2
      docs/ref/models/instances.txt
  3. 5 0
      docs/releases/5.1.txt
  4. 11 0
      tests/async/test_async_model_methods.py
  5. 54 2
      tests/basic/tests.py

+ 18 - 10
django/db/models/base.py

@@ -673,7 +673,7 @@ class Model(AltersData, metaclass=ModelBase):
             if f.attname not in self.__dict__
         }
 
-    def refresh_from_db(self, using=None, fields=None):
+    def refresh_from_db(self, using=None, fields=None, from_queryset=None):
         """
         Reload field values from the database.
 
@@ -705,10 +705,13 @@ class Model(AltersData, metaclass=ModelBase):
                     "are not allowed in fields." % LOOKUP_SEP
                 )
 
-        hints = {"instance": self}
-        db_instance_qs = self.__class__._base_manager.db_manager(
-            using, hints=hints
-        ).filter(pk=self.pk)
+        if from_queryset is None:
+            hints = {"instance": self}
+            from_queryset = self.__class__._base_manager.db_manager(using, hints=hints)
+        elif using is not None:
+            from_queryset = from_queryset.using(using)
+
+        db_instance_qs = from_queryset.filter(pk=self.pk)
 
         # Use provided fields, if not set then reload all non-deferred fields.
         deferred_fields = self.get_deferred_fields()
@@ -729,9 +732,12 @@ class Model(AltersData, metaclass=ModelBase):
                 # This field wasn't refreshed - skip ahead.
                 continue
             setattr(self, field.attname, getattr(db_instance, field.attname))
-            # Clear cached foreign keys.
-            if field.is_relation and field.is_cached(self):
-                field.delete_cached_value(self)
+            # Clear or copy cached foreign keys.
+            if field.is_relation:
+                if field.is_cached(db_instance):
+                    field.set_cached_value(self, field.get_cached_value(db_instance))
+                elif field.is_cached(self):
+                    field.delete_cached_value(self)
 
         # Clear cached relations.
         for field in self._meta.related_objects:
@@ -745,8 +751,10 @@ class Model(AltersData, metaclass=ModelBase):
 
         self._state.db = db_instance._state.db
 
-    async def arefresh_from_db(self, using=None, fields=None):
-        return await sync_to_async(self.refresh_from_db)(using=using, fields=fields)
+    async def arefresh_from_db(self, using=None, fields=None, from_queryset=None):
+        return await sync_to_async(self.refresh_from_db)(
+            using=using, fields=fields, from_queryset=from_queryset
+        )
 
     def serializable_value(self, field_name):
         """

+ 23 - 2
docs/ref/models/instances.txt

@@ -142,8 +142,8 @@ value from the database:
     >>> del obj.field
     >>> obj.field  # Loads the field from the database
 
-.. method:: Model.refresh_from_db(using=None, fields=None)
-.. method:: Model.arefresh_from_db(using=None, fields=None)
+.. method:: Model.refresh_from_db(using=None, fields=None, from_queryset=None)
+.. method:: Model.arefresh_from_db(using=None, fields=None, from_queryset=None)
 
 *Asynchronous version*: ``arefresh_from_db()``
 
@@ -197,6 +197,27 @@ all of the instance's fields when a deferred field is reloaded::
                     fields = fields.union(deferred_fields)
             super().refresh_from_db(using, fields, **kwargs)
 
+The ``from_queryset`` argument allows using a different queryset than the one
+created from :attr:`~django.db.models.Model._base_manager`. It gives you more
+control over how the model is reloaded. For example, when your model uses soft
+deletion you can make ``refresh_from_db()`` to take this into account::
+
+    obj.refresh_from_db(from_queryset=MyModel.active_objects.all())
+
+You can cache related objects that otherwise would be cleared from the reloaded
+instance::
+
+    obj.refresh_from_db(from_queryset=MyModel.objects.select_related("related_field"))
+
+You can lock the row until the end of transaction before reloading a model's
+values::
+
+    obj.refresh_from_db(from_queryset=MyModel.objects.select_for_update())
+
+.. versionchanged:: 5.1
+
+    The ``from_queryset`` argument was added.
+
 .. method:: Model.get_deferred_fields()
 
 A helper method that returns a set containing the attribute names of all those

+ 5 - 0
docs/releases/5.1.txt

@@ -208,6 +208,11 @@ Models
   :class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced
   <slicing-using-f>`.
 
+* The new ``from_queryset`` argument of :meth:`.Model.refresh_from_db` and
+  :meth:`.Model.arefresh_from_db`  allows customizing the queryset used to
+  reload a model's value. This can be used to lock the row before reloading or
+  to select related objects.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 11 - 0
tests/async/test_async_model_methods.py

@@ -23,3 +23,14 @@ class AsyncModelOperationTest(TestCase):
         await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
         await self.s1.arefresh_from_db()
         self.assertEqual(self.s1.field, 20)
+
+    async def test_arefresh_from_db_from_queryset(self):
+        await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
+        with self.assertRaises(SimpleModel.DoesNotExist):
+            await self.s1.arefresh_from_db(
+                from_queryset=SimpleModel.objects.filter(field=0)
+            )
+        await self.s1.arefresh_from_db(
+            from_queryset=SimpleModel.objects.filter(field__gt=0)
+        )
+        self.assertEqual(self.s1.field, 20)

+ 54 - 2
tests/basic/tests.py

@@ -4,7 +4,14 @@ from datetime import datetime, timedelta
 from unittest import mock
 
 from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
-from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
+from django.db import (
+    DEFAULT_DB_ALIAS,
+    DatabaseError,
+    connection,
+    connections,
+    models,
+    transaction,
+)
 from django.db.models.manager import BaseManager
 from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
 from django.test import (
@@ -13,7 +20,8 @@ from django.test import (
     TransactionTestCase,
     skipUnlessDBFeature,
 )
-from django.test.utils import ignore_warnings
+from django.test.utils import CaptureQueriesContext, ignore_warnings
+from django.utils.connection import ConnectionDoesNotExist
 from django.utils.deprecation import RemovedInDjango60Warning
 from django.utils.translation import gettext_lazy
 
@@ -1003,3 +1011,47 @@ class ModelRefreshTests(TestCase):
         # Cache was cleared and new results are available.
         self.assertCountEqual(a2_prefetched.selfref_set.all(), [s])
         self.assertCountEqual(a2_prefetched.cited.all(), [s])
+
+    @skipUnlessDBFeature("has_select_for_update")
+    def test_refresh_for_update(self):
+        a = Article.objects.create(pub_date=datetime.now())
+        for_update_sql = connection.ops.for_update_sql()
+
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            a.refresh_from_db(from_queryset=Article.objects.select_for_update())
+        self.assertTrue(
+            any(for_update_sql in query["sql"] for query in ctx.captured_queries)
+        )
+
+    def test_refresh_with_related(self):
+        a = Article.objects.create(pub_date=datetime.now())
+        fa = FeaturedArticle.objects.create(article=a)
+
+        from_queryset = FeaturedArticle.objects.select_related("article")
+        with self.assertNumQueries(1):
+            fa.refresh_from_db(from_queryset=from_queryset)
+            self.assertEqual(fa.article.pub_date, a.pub_date)
+        with self.assertNumQueries(2):
+            fa.refresh_from_db()
+            self.assertEqual(fa.article.pub_date, a.pub_date)
+
+    def test_refresh_overwrites_queryset_using(self):
+        a = Article.objects.create(pub_date=datetime.now())
+
+        from_queryset = Article.objects.using("nonexistent")
+        with self.assertRaises(ConnectionDoesNotExist):
+            a.refresh_from_db(from_queryset=from_queryset)
+        a.refresh_from_db(using="default", from_queryset=from_queryset)
+
+    def test_refresh_overwrites_queryset_fields(self):
+        a = Article.objects.create(pub_date=datetime.now())
+        headline = "headline"
+        Article.objects.filter(pk=a.pk).update(headline=headline)
+
+        from_queryset = Article.objects.only("pub_date")
+        with self.assertNumQueries(1):
+            a.refresh_from_db(from_queryset=from_queryset)
+            self.assertNotEqual(a.headline, headline)
+        with self.assertNumQueries(1):
+            a.refresh_from_db(fields=["headline"], from_queryset=from_queryset)
+            self.assertEqual(a.headline, headline)