Răsfoiți Sursa

Fixed #32381 -- Made QuerySet.bulk_update() return the number of objects updated.

Co-authored-by: Diego Lima <diego.lima@lais.huol.ufrn.br>
abhiabhi94 3 ani în urmă
părinte
comite
cd124295d8

+ 4 - 2
django/db/models/query.py

@@ -541,7 +541,7 @@ class QuerySet:
         if any(f.primary_key for f in fields):
             raise ValueError('bulk_update() cannot be used with primary key fields.')
         if not objs:
-            return
+            return 0
         # PK is used twice in the resulting update query, once in the filter
         # and once in the WHEN. Each field will also have one CAST.
         max_batch_size = connections[self.db].ops.bulk_batch_size(['pk', 'pk'] + fields, objs)
@@ -563,9 +563,11 @@ class QuerySet:
                     case_statement = Cast(case_statement, output_field=field)
                 update_kwargs[field.attname] = case_statement
             updates.append(([obj.pk for obj in batch_objs], update_kwargs))
+        rows_updated = 0
         with transaction.atomic(using=self.db, savepoint=False):
             for pks, update_kwargs in updates:
-                self.filter(pk__in=pks).update(**update_kwargs)
+                rows_updated += self.filter(pk__in=pks).update(**update_kwargs)
+        return rows_updated
     bulk_update.alters_data = True
 
     def get_or_create(self, defaults=None, **kwargs):

+ 11 - 1
docs/ref/models/querysets.txt

@@ -2221,7 +2221,8 @@ normally supports it).
 .. method:: bulk_update(objs, fields, batch_size=None)
 
 This method efficiently updates the given fields on the provided model
-instances, generally with one query::
+instances, generally with one query, and returns the number of objects
+updated::
 
     >>> objs = [
     ...    Entry.objects.create(headline='Entry 1'),
@@ -2230,6 +2231,11 @@ instances, generally with one query::
     >>> objs[0].headline = 'This is entry 1'
     >>> objs[1].headline = 'This is entry 2'
     >>> Entry.objects.bulk_update(objs, ['headline'])
+    2
+
+.. versionchanged:: 4.0
+
+    The return value of the number of objects updated was added.
 
 :meth:`.QuerySet.update` is used to save the changes, so this is more efficient
 than iterating through the list of models and calling ``save()`` on each of
@@ -2246,6 +2252,10 @@ them, but it has a few caveats:
   extra query per ancestor.
 * When an individual batch contains duplicates, only the first instance in that
   batch will result in an update.
+* The number of objects updated returned by the function may be fewer than the
+  number of objects passed in. This can be due to duplicate objects passed in
+  which are updated in the same batch or race conditions such that objects are
+  no longer present in the database.
 
 The ``batch_size`` parameter controls how many objects are saved in a single
 query. The default is to update all objects in one batch, except for SQLite

+ 2 - 0
docs/releases/4.0.txt

@@ -263,6 +263,8 @@ Models
 * :class:`~django.db.models.DurationField` now supports multiplying and
   dividing by scalar values on SQLite.
 
+* :meth:`.QuerySet.bulk_update` now returns the number of objects updated.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 12 - 2
tests/queries/test_bulk_update.py

@@ -125,7 +125,8 @@ class BulkUpdateTests(TestCase):
 
     def test_empty_objects(self):
         with self.assertNumQueries(0):
-            Note.objects.bulk_update([], ['note'])
+            rows_updated = Note.objects.bulk_update([], ['note'])
+        self.assertEqual(rows_updated, 0)
 
     def test_large_batch(self):
         Note.objects.bulk_create([
@@ -133,7 +134,16 @@ class BulkUpdateTests(TestCase):
             for i in range(0, 2000)
         ])
         notes = list(Note.objects.all())
-        Note.objects.bulk_update(notes, ['note'])
+        rows_updated = Note.objects.bulk_update(notes, ['note'])
+        self.assertEqual(rows_updated, 2000)
+
+    def test_updated_rows_when_passing_duplicates(self):
+        note = Note.objects.create(note='test-note', misc='test')
+        rows_updated = Note.objects.bulk_update([note, note], ['note'])
+        self.assertEqual(rows_updated, 1)
+        # Duplicates in different batches.
+        rows_updated = Note.objects.bulk_update([note, note], ['note'], batch_size=1)
+        self.assertEqual(rows_updated, 2)
 
     def test_only_concrete_fields_allowed(self):
         obj = Valid.objects.create(valid='test')