Browse Source

Fixed #34698 -- Made QuerySet.bulk_create() retrieve primary keys when updating conflicts.

Thomas Chaumeny 1 year ago
parent
commit
89c7454dbd
4 changed files with 47 additions and 11 deletions
  1. 6 1
      django/db/models/query.py
  2. 7 3
      docs/ref/models/querysets.txt
  3. 4 0
      docs/releases/5.0.txt
  4. 30 7
      tests/bulk_create/tests.py

+ 6 - 1
django/db/models/query.py

@@ -1837,12 +1837,17 @@ class QuerySet(AltersData):
         inserted_rows = []
         bulk_return = connection.features.can_return_rows_from_bulk_insert
         for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
-            if bulk_return and on_conflict is None:
+            if bulk_return and (
+                on_conflict is None or on_conflict == OnConflict.UPDATE
+            ):
                 inserted_rows.extend(
                     self._insert(
                         item,
                         fields=fields,
                         using=self.db,
+                        on_conflict=on_conflict,
+                        update_fields=update_fields,
+                        unique_fields=unique_fields,
                         returning_fields=self.model._meta.db_returning_fields,
                     )
                 )

+ 7 - 3
docs/ref/models/querysets.txt

@@ -2411,9 +2411,13 @@ On databases that support it (all except Oracle and SQLite < 3.24), setting the
 SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
 be in conflict must be provided.
 
-Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable
-setting the primary key on each model instance (if the database normally
-support it).
+Enabling the ``ignore_conflicts`` parameter disables setting the primary key on
+each model instance (if the database normally supports it).
+
+.. versionchanged:: 5.0
+
+    In older versions, enabling the ``update_conflicts`` parameter prevented
+    setting the primary key on each model instance.
 
 .. warning::
 

+ 4 - 0
docs/releases/5.0.txt

@@ -357,6 +357,10 @@ Models
   :meth:`.Model.save` now allows specifying a tuple of parent classes that must
   be forced to be inserted.
 
+* :meth:`.QuerySet.bulk_create` and :meth:`.QuerySet.abulk_create` methods now
+  set the primary key on each model instance when the ``update_conflicts``
+  parameter is enabled (if the database supports it).
+
 Pagination
 ~~~~~~~~~~
 

+ 30 - 7
tests/bulk_create/tests.py

@@ -582,12 +582,16 @@ class BulkCreateTests(TestCase):
             TwoFields(f1=1, f2=1, name="c"),
             TwoFields(f1=2, f2=2, name="d"),
         ]
-        TwoFields.objects.bulk_create(
+        results = TwoFields.objects.bulk_create(
             conflicting_objects,
             update_conflicts=True,
             unique_fields=unique_fields,
             update_fields=["name"],
         )
+        self.assertEqual(len(results), len(conflicting_objects))
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(TwoFields.objects.count(), 2)
         self.assertCountEqual(
             TwoFields.objects.values("f1", "f2", "name"),
@@ -619,7 +623,6 @@ class BulkCreateTests(TestCase):
                 TwoFields(f1=2, f2=2, name="b"),
             ]
         )
-        self.assertEqual(TwoFields.objects.count(), 2)
 
         obj1 = TwoFields.objects.get(f1=1)
         obj2 = TwoFields.objects.get(f1=2)
@@ -627,12 +630,16 @@ class BulkCreateTests(TestCase):
             TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"),
             TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"),
         ]
-        TwoFields.objects.bulk_create(
+        results = TwoFields.objects.bulk_create(
             conflicting_objects,
             update_conflicts=True,
             unique_fields=["pk"],
             update_fields=["name"],
         )
+        self.assertEqual(len(results), len(conflicting_objects))
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(TwoFields.objects.count(), 2)
         self.assertCountEqual(
             TwoFields.objects.values("f1", "f2", "name"),
@@ -680,12 +687,16 @@ class BulkCreateTests(TestCase):
                 description=("Japan is an island country in East Asia."),
             ),
         ]
-        Country.objects.bulk_create(
+        results = Country.objects.bulk_create(
             new_data,
             update_conflicts=True,
             update_fields=["description"],
             unique_fields=unique_fields,
         )
+        self.assertEqual(len(results), len(new_data))
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(Country.objects.count(), 6)
         self.assertCountEqual(
             Country.objects.values("iso_two_letter", "description"),
@@ -743,12 +754,16 @@ class BulkCreateTests(TestCase):
             UpsertConflict(number=2, rank=2, name="Olivia"),
             UpsertConflict(number=3, rank=1, name="Hannah"),
         ]
-        UpsertConflict.objects.bulk_create(
+        results = UpsertConflict.objects.bulk_create(
             conflicting_objects,
             update_conflicts=True,
             update_fields=["name", "rank"],
             unique_fields=unique_fields,
         )
+        self.assertEqual(len(results), len(conflicting_objects))
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(UpsertConflict.objects.count(), 3)
         self.assertCountEqual(
             UpsertConflict.objects.values("number", "rank", "name"),
@@ -759,12 +774,16 @@ class BulkCreateTests(TestCase):
             ],
         )
 
-        UpsertConflict.objects.bulk_create(
+        results = UpsertConflict.objects.bulk_create(
             conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")],
             update_conflicts=True,
             update_fields=["name", "rank"],
             unique_fields=unique_fields,
         )
+        self.assertEqual(len(results), 4)
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(UpsertConflict.objects.count(), 4)
         self.assertCountEqual(
             UpsertConflict.objects.values("number", "rank", "name"),
@@ -803,12 +822,16 @@ class BulkCreateTests(TestCase):
             FieldsWithDbColumns(rank=1, name="c"),
             FieldsWithDbColumns(rank=2, name="d"),
         ]
-        FieldsWithDbColumns.objects.bulk_create(
+        results = FieldsWithDbColumns.objects.bulk_create(
             conflicting_objects,
             update_conflicts=True,
             unique_fields=["rank"],
             update_fields=["name"],
         )
+        self.assertEqual(len(results), len(conflicting_objects))
+        if connection.features.can_return_rows_from_bulk_insert:
+            for instance in results:
+                self.assertIsNotNone(instance.pk)
         self.assertEqual(FieldsWithDbColumns.objects.count(), 2)
         self.assertCountEqual(
             FieldsWithDbColumns.objects.values("rank", "name"),