Browse Source

Fixed #30382 -- Allowed specifying parent classes in force_insert of Model.save().

Akash Kumar Sen 1 year ago
parent
commit
a40b0103bc

+ 30 - 3
django/db/models/base.py

@@ -832,6 +832,26 @@ class Model(AltersData, metaclass=ModelBase):
 
     asave.alters_data = True
 
+    @classmethod
+    def _validate_force_insert(cls, force_insert):
+        if force_insert is False:
+            return ()
+        if force_insert is True:
+            return (cls,)
+        if not isinstance(force_insert, tuple):
+            raise TypeError("force_insert must be a bool or tuple.")
+        for member in force_insert:
+            if not isinstance(member, ModelBase):
+                raise TypeError(
+                    f"Invalid force_insert member. {member!r} must be a model subclass."
+                )
+            if not issubclass(cls, member):
+                raise TypeError(
+                    f"Invalid force_insert member. {member.__qualname__} must be a "
+                    f"base of {cls.__qualname__}."
+                )
+        return force_insert
+
     def save_base(
         self,
         raw=False,
@@ -873,7 +893,11 @@ class Model(AltersData, metaclass=ModelBase):
         with context_manager:
             parent_inserted = False
             if not raw:
-                parent_inserted = self._save_parents(cls, using, update_fields)
+                # Validate force insert only when parents are inserted.
+                force_insert = self._validate_force_insert(force_insert)
+                parent_inserted = self._save_parents(
+                    cls, using, update_fields, force_insert
+                )
             updated = self._save_table(
                 raw,
                 cls,
@@ -900,7 +924,9 @@ class Model(AltersData, metaclass=ModelBase):
 
     save_base.alters_data = True
 
-    def _save_parents(self, cls, using, update_fields, updated_parents=None):
+    def _save_parents(
+        self, cls, using, update_fields, force_insert, updated_parents=None
+    ):
         """Save all the parents of cls using values from self."""
         meta = cls._meta
         inserted = False
@@ -919,13 +945,14 @@ class Model(AltersData, metaclass=ModelBase):
                     cls=parent,
                     using=using,
                     update_fields=update_fields,
+                    force_insert=force_insert,
                     updated_parents=updated_parents,
                 )
                 updated = self._save_table(
                     cls=parent,
                     using=using,
                     update_fields=update_fields,
-                    force_insert=parent_inserted,
+                    force_insert=parent_inserted or issubclass(parent, force_insert),
                 )
                 if not updated:
                     inserted = True

+ 17 - 0
docs/ref/models/instances.txt

@@ -589,6 +589,18 @@ row. In these cases you can pass the ``force_insert=True`` or
 Passing both parameters is an error: you cannot both insert *and* update at the
 same time!
 
+When using :ref:`multi-table inheritance <multi-table-inheritance>`, it's also
+possible to provide a tuple of parent classes to ``force_insert`` in order to
+force ``INSERT`` statements for each base. For example::
+
+    Restaurant(pk=1, name="Bob's Cafe").save(force_insert=(Place,))
+
+    Restaurant(pk=1, name="Bob's Cafe", rating=4).save(force_insert=(Place, Rating))
+
+You can pass ``force_insert=(models.Model,)`` to force an ``INSERT`` statement
+for all parents. By default, ``force_insert=True`` only forces the insertion of
+a new row for the current model.
+
 It should be very rare that you'll need to use these parameters. Django will
 almost always do the right thing and trying to override that will lead to
 errors that are difficult to track down. This feature is for advanced use
@@ -596,6 +608,11 @@ only.
 
 Using ``update_fields`` will force an update similarly to ``force_update``.
 
+.. versionchanged:: 5.0
+
+    Support for passing a tuple of parent classes to ``force_insert`` was
+    added.
+
 .. _ref-models-field-updates-using-f-expressions:
 
 Updating attributes based on existing fields

+ 4 - 0
docs/releases/5.0.txt

@@ -335,6 +335,10 @@ Models
   :ref:`Choices classes <field-choices-enum-types>` directly instead of
   requiring expansion with the ``choices`` attribute.
 
+* The :ref:`force_insert <ref-models-force-insert>` argument of
+  :meth:`.Model.save` now allows specifying a tuple of parent classes that must
+  be forced to be inserted.
+
 Pagination
 ~~~~~~~~~~
 

+ 1 - 1
tests/extra_regress/models.py

@@ -10,7 +10,7 @@ class RevisionableModel(models.Model):
     title = models.CharField(blank=True, max_length=255)
     when = models.DateTimeField(default=datetime.datetime.now)
 
-    def save(self, *args, force_insert=None, force_update=None, **kwargs):
+    def save(self, *args, force_insert=False, force_update=False, **kwargs):
         super().save(
             *args, force_insert=force_insert, force_update=force_update, **kwargs
         )

+ 10 - 0
tests/force_insert_update/models.py

@@ -30,3 +30,13 @@ class SubSubCounter(SubCounter):
 class WithCustomPK(models.Model):
     name = models.IntegerField(primary_key=True)
     value = models.IntegerField()
+
+
+class OtherSubCounter(Counter):
+    other_counter_ptr = models.OneToOneField(
+        Counter, primary_key=True, parent_link=True, on_delete=models.CASCADE
+    )
+
+
+class DiamondSubSubCounter(SubCounter, OtherSubCounter):
+    pass

+ 76 - 1
tests/force_insert_update/tests.py

@@ -1,9 +1,11 @@
-from django.db import DatabaseError, IntegrityError, transaction
+from django.db import DatabaseError, IntegrityError, models, transaction
 from django.test import TestCase
 
 from .models import (
     Counter,
+    DiamondSubSubCounter,
     InheritedCounter,
+    OtherSubCounter,
     ProxyCounter,
     SubCounter,
     SubSubCounter,
@@ -76,6 +78,29 @@ class InheritanceTests(TestCase):
 
 
 class ForceInsertInheritanceTests(TestCase):
+    def test_force_insert_not_bool_or_tuple(self):
+        msg = "force_insert must be a bool or tuple."
+        with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
+            Counter().save(force_insert=1)
+        with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
+            Counter().save(force_insert="test")
+        with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
+            Counter().save(force_insert=[])
+
+    def test_force_insert_not_model(self):
+        msg = f"Invalid force_insert member. {object!r} must be a model subclass."
+        with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
+            Counter().save(force_insert=(object,))
+        instance = Counter()
+        msg = f"Invalid force_insert member. {instance!r} must be a model subclass."
+        with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
+            Counter().save(force_insert=(instance,))
+
+    def test_force_insert_not_base(self):
+        msg = "Invalid force_insert member. SubCounter must be a base of Counter."
+        with self.assertRaisesMessage(TypeError, msg):
+            Counter().save(force_insert=(SubCounter,))
+
     def test_force_insert_false(self):
         with self.assertNumQueries(3):
             obj = SubCounter.objects.create(pk=1, value=0)
@@ -87,6 +112,10 @@ class ForceInsertInheritanceTests(TestCase):
             SubCounter(pk=obj.pk, value=2).save(force_insert=False)
         obj.refresh_from_db()
         self.assertEqual(obj.value, 2)
+        with self.assertNumQueries(2):
+            SubCounter(pk=obj.pk, value=3).save(force_insert=())
+        obj.refresh_from_db()
+        self.assertEqual(obj.value, 3)
 
     def test_force_insert_false_with_existing_parent(self):
         parent = Counter.objects.create(pk=1, value=1)
@@ -96,13 +125,59 @@ class ForceInsertInheritanceTests(TestCase):
     def test_force_insert_parent(self):
         with self.assertNumQueries(3):
             SubCounter(pk=1, value=1).save(force_insert=True)
+        # Force insert a new parent and don't UPDATE first.
+        with self.assertNumQueries(2):
+            SubCounter(pk=2, value=1).save(force_insert=(Counter,))
+        with self.assertNumQueries(2):
+            SubCounter(pk=3, value=1).save(force_insert=(models.Model,))
 
     def test_force_insert_with_grandparent(self):
         with self.assertNumQueries(4):
             SubSubCounter(pk=1, value=1).save(force_insert=True)
+        # Force insert parents on all levels and don't UPDATE first.
+        with self.assertNumQueries(3):
+            SubSubCounter(pk=2, value=1).save(force_insert=(models.Model,))
+        with self.assertNumQueries(3):
+            SubSubCounter(pk=3, value=1).save(force_insert=(Counter,))
+        # Force insert only the last parent.
+        with self.assertNumQueries(4):
+            SubSubCounter(pk=4, value=1).save(force_insert=(SubCounter,))
 
     def test_force_insert_with_existing_grandparent(self):
         # Force insert only the last child.
         grandparent = Counter.objects.create(pk=1, value=1)
         with self.assertNumQueries(4):
             SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=True)
+        # Force insert a parent, and don't force insert a grandparent.
+        grandparent = Counter.objects.create(pk=2, value=1)
+        with self.assertNumQueries(3):
+            SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(SubCounter,))
+        # Force insert parents on all levels, grandparent conflicts.
+        grandparent = Counter.objects.create(pk=3, value=1)
+        with self.assertRaises(IntegrityError), transaction.atomic():
+            SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(Counter,))
+
+    def test_force_insert_diamond_mti(self):
+        # Force insert all parents.
+        with self.assertNumQueries(4):
+            DiamondSubSubCounter(pk=1, value=1).save(
+                force_insert=(Counter, SubCounter, OtherSubCounter)
+            )
+        with self.assertNumQueries(4):
+            DiamondSubSubCounter(pk=2, value=1).save(force_insert=(models.Model,))
+        # Force insert parents, and don't force insert a common grandparent.
+        with self.assertNumQueries(5):
+            DiamondSubSubCounter(pk=3, value=1).save(
+                force_insert=(SubCounter, OtherSubCounter)
+            )
+        grandparent = Counter.objects.create(pk=4, value=1)
+        with self.assertNumQueries(4):
+            DiamondSubSubCounter(pk=grandparent.pk, value=1).save(
+                force_insert=(SubCounter, OtherSubCounter),
+            )
+        # Force insert all parents, grandparent conflicts.
+        grandparent = Counter.objects.create(pk=5, value=1)
+        with self.assertRaises(IntegrityError), transaction.atomic():
+            DiamondSubSubCounter(pk=grandparent.pk, value=1).save(
+                force_insert=(models.Model,)
+            )