Quellcode durchsuchen

Fixed #23646 -- Added QuerySet.bulk_update() to efficiently update many models.

Tom Forbes vor 6 Jahren
Ursprung
Commit
9cbdb44014

+ 4 - 0
django/db/backends/base/features.py

@@ -265,6 +265,10 @@ class BaseDatabaseFeatures:
     # INSERT?
     supports_ignore_conflicts = True
 
+    # Does this backend require casting the results of CASE expressions used
+    # in UPDATE statements to ensure the expression has the correct type?
+    requires_casted_case_in_updates = False
+
     def __init__(self, connection):
         self.connection = connection
 

+ 1 - 0
django/db/backends/postgresql/features.py

@@ -48,6 +48,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
             V_I := P_I;
         END;
     $$ LANGUAGE plpgsql;"""
+    requires_casted_case_in_updates = True
     supports_over_clause = True
     supports_aggregate_filter_clause = True
     supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}

+ 46 - 2
django/db/models/query.py

@@ -18,9 +18,9 @@ from django.db import (
 from django.db.models import DateField, DateTimeField, sql
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.deletion import Collector
-from django.db.models.expressions import F
+from django.db.models.expressions import Case, Expression, F, Value, When
 from django.db.models.fields import AutoField
-from django.db.models.functions import Trunc
+from django.db.models.functions import Cast, Trunc
 from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
 from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
 from django.db.utils import NotSupportedError
@@ -473,6 +473,50 @@ class QuerySet:
 
         return objs
 
+    def bulk_update(self, objs, fields, batch_size=None):
+        """
+        Update the given fields in each of the given objects in the database.
+        """
+        if batch_size is not None and batch_size < 0:
+            raise ValueError('Batch size must be a positive integer.')
+        if not fields:
+            raise ValueError('Field names must be given to bulk_update().')
+        objs = tuple(objs)
+        if not all(obj.pk for obj in objs):
+            raise ValueError('All bulk_update() objects must have a primary key set.')
+        fields = [self.model._meta.get_field(name) for name in fields]
+        if any(not f.concrete or f.many_to_many for f in fields):
+            raise ValueError('bulk_update() can only be used with concrete fields.')
+        if any(f.primary_key for f in fields):
+            raise ValueError('bulk_update() cannot be used with primary key fields.')
+        if not objs:
+            return
+        # PK is used twice in the resulting update query, once in the filter
+        # and once in the WHEN. Each field will also have one CAST.
+        max_batch_size = connections[self.db].ops.bulk_batch_size(['pk', 'pk'] + fields, objs)
+        batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
+        requires_casting = connections[self.db].features.requires_casted_case_in_updates
+        batches = (objs[i:i + batch_size] for i in range(0, len(objs), batch_size))
+        updates = []
+        for batch_objs in batches:
+            update_kwargs = {}
+            for field in fields:
+                when_statements = []
+                for obj in batch_objs:
+                    attr = getattr(obj, field.attname)
+                    if not isinstance(attr, Expression):
+                        attr = Value(attr, output_field=field)
+                    when_statements.append(When(pk=obj.pk, then=attr))
+                case_statement = Case(*when_statements, output_field=field)
+                if requires_casting:
+                    case_statement = Cast(case_statement, output_field=field)
+                update_kwargs[field.attname] = case_statement
+            updates.append(([obj.pk for obj in batch_objs], update_kwargs))
+        with transaction.atomic(using=self.db, savepoint=False):
+            for pks, update_kwargs in updates:
+                self.filter(pk__in=pks).update(**update_kwargs)
+    bulk_update.alters_data = True
+
     def get_or_create(self, defaults=None, **kwargs):
         """
         Look up an object with the given kwargs, creating one if necessary.

+ 36 - 0
docs/ref/models/querysets.txt

@@ -2089,6 +2089,42 @@ instance (if the database normally supports it).
 
     The ``ignore_conflicts`` parameter was added.
 
+``bulk_update()``
+~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 2.2
+
+.. method:: bulk_update(objs, fields, batch_size=None)
+
+This method efficiently updates the given fields on the provided model
+instances, generally with one query::
+
+    >>> objs = [
+    ...    Entry.objects.create(headline='Entry 1'),
+    ...    Entry.objects.create(headline='Entry 2'),
+    ... ]
+    >>> objs[0].headline = 'This is entry 1'
+    >>> objs[1].headline = 'This is entry 2'
+    >>> Entry.objects.bulk_update(objs, ['headline'])
+
+:meth:`.QuerySet.update` is used to save the changes, so this is more efficient
+than iterating through the list of models and calling ``save()`` on each of
+them, but it has a few caveats:
+
+* You cannot update the model's primary key.
+* Each model's ``save()`` method isn't called, and the
+  :attr:`~django.db.models.signals.pre_save` and
+  :attr:`~django.db.models.signals.post_save` signals aren't sent.
+* If updating a large number of columns in a large number of rows, the SQL
+  generated can be very large. Avoid this by specifying a suitable
+  ``batch_size``.
+* Updating fields defined on multi-table inheritance ancestors will incur an
+  extra query per ancestor.
+
+The ``batch_size`` parameter controls how many objects are saved in a single
+query. The default is to create all objects in one batch, except for SQLite
+and Oracle which have restrictions on the number of variables used in a query.
+
 ``count()``
 ~~~~~~~~~~~
 

+ 3 - 0
docs/releases/2.2.txt

@@ -199,6 +199,9 @@ Models
   :class:`~django.db.models.DateTimeField`, and the new :lookup:`iso_year`
   lookup allows querying by an ISO-8601 week-numbering year.
 
+* The new :meth:`.QuerySet.bulk_update` method allows efficiently updating
+  specific fields on multiple model instances.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 1 - 0
tests/basic/tests.py

@@ -532,6 +532,7 @@ class ManagerTest(SimpleTestCase):
         'update_or_create',
         'create',
         'bulk_create',
+        'bulk_update',
         'filter',
         'aggregate',
         'annotate',

+ 3 - 3
tests/postgres_tests/migrations/0002_create_test_models.py

@@ -56,9 +56,9 @@ class Migration(migrations.Migration):
             name='OtherTypesArrayModel',
             fields=[
                 ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
-                ('ips', ArrayField(models.GenericIPAddressField(), size=None)),
-                ('uuids', ArrayField(models.UUIDField(), size=None)),
-                ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
+                ('ips', ArrayField(models.GenericIPAddressField(), size=None, default=list)),
+                ('uuids', ArrayField(models.UUIDField(), size=None, default=list)),
+                ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None, default=list)),
                 ('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
                 ('json', ArrayField(JSONField(default={}), default=[])),
                 ('int_ranges', ArrayField(IntegerRangeField(), null=True, blank=True)),

+ 3 - 3
tests/postgres_tests/models.py

@@ -63,9 +63,9 @@ class NestedIntegerArrayModel(PostgreSQLModel):
 
 
 class OtherTypesArrayModel(PostgreSQLModel):
-    ips = ArrayField(models.GenericIPAddressField())
-    uuids = ArrayField(models.UUIDField())
-    decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
+    ips = ArrayField(models.GenericIPAddressField(), default=list)
+    uuids = ArrayField(models.UUIDField(), default=list)
+    decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2), default=list)
     tags = ArrayField(TagField(), blank=True, null=True)
     json = ArrayField(JSONField(default=dict), default=list)
     int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)

+ 34 - 0
tests/postgres_tests/test_bulk_update.py

@@ -0,0 +1,34 @@
+from datetime import date
+
+from . import PostgreSQLTestCase
+from .models import (
+    HStoreModel, IntegerArrayModel, JSONModel, NestedIntegerArrayModel,
+    NullableIntegerArrayModel, OtherTypesArrayModel, RangesModel,
+)
+
+try:
+    from psycopg2.extras import NumericRange, DateRange
+except ImportError:
+    pass  # psycopg2 isn't installed.
+
+
+class BulkSaveTests(PostgreSQLTestCase):
+    def test_bulk_update(self):
+        test_data = [
+            (IntegerArrayModel, 'field', [], [1, 2, 3]),
+            (NullableIntegerArrayModel, 'field', [1, 2, 3], None),
+            (JSONModel, 'field', {'a': 'b'}, {'c': 'd'}),
+            (NestedIntegerArrayModel, 'field', [], [[1, 2, 3]]),
+            (HStoreModel, 'field', {}, {1: 2}),
+            (RangesModel, 'ints', None, NumericRange(lower=1, upper=10)),
+            (RangesModel, 'dates', None, DateRange(lower=date.today(), upper=date.today())),
+            (OtherTypesArrayModel, 'ips', [], ['1.2.3.4']),
+            (OtherTypesArrayModel, 'json', [], [{'a': 'b'}])
+        ]
+        for Model, field, initial, new in test_data:
+            with self.subTest(model=Model, field=field):
+                instances = Model.objects.bulk_create(Model(**{field: initial}) for _ in range(20))
+                for instance in instances:
+                    setattr(instance, field, new)
+                Model.objects.bulk_update(instances, [field])
+                self.assertSequenceEqual(Model.objects.filter(**{field: new}), instances)

+ 5 - 0
tests/queries/models.py

@@ -718,3 +718,8 @@ class RelatedIndividual(models.Model):
 
     class Meta:
         db_table = 'RelatedIndividual'
+
+
+class CustomDbColumn(models.Model):
+    custom_column = models.IntegerField(db_column='custom_name', null=True)
+    ip_address = models.GenericIPAddressField(null=True)

+ 223 - 0
tests/queries/test_bulk_update.py

@@ -0,0 +1,223 @@
+import datetime
+
+from django.core.exceptions import FieldDoesNotExist
+from django.db.models import F
+from django.db.models.functions import Lower
+from django.test import TestCase
+
+from .models import (
+    Article, CustomDbColumn, CustomPk, Detail, Individual, Member, Note,
+    Number, Paragraph, SpecialCategory, Tag, Valid,
+)
+
+
+class BulkUpdateNoteTests(TestCase):
+    def setUp(self):
+        self.notes = [
+            Note.objects.create(note=str(i), misc=str(i))
+            for i in range(10)
+        ]
+
+    def create_tags(self):
+        self.tags = [
+            Tag.objects.create(name=str(i))
+            for i in range(10)
+        ]
+
+    def test_simple(self):
+        for note in self.notes:
+            note.note = 'test-%s' % note.id
+        with self.assertNumQueries(1):
+            Note.objects.bulk_update(self.notes, ['note'])
+        self.assertCountEqual(
+            Note.objects.values_list('note', flat=True),
+            [cat.note for cat in self.notes]
+        )
+
+    def test_multiple_fields(self):
+        for note in self.notes:
+            note.note = 'test-%s' % note.id
+            note.misc = 'misc-%s' % note.id
+        with self.assertNumQueries(1):
+            Note.objects.bulk_update(self.notes, ['note', 'misc'])
+        self.assertCountEqual(
+            Note.objects.values_list('note', flat=True),
+            [cat.note for cat in self.notes]
+        )
+        self.assertCountEqual(
+            Note.objects.values_list('misc', flat=True),
+            [cat.misc for cat in self.notes]
+        )
+
+    def test_batch_size(self):
+        with self.assertNumQueries(len(self.notes)):
+            Note.objects.bulk_update(self.notes, fields=['note'], batch_size=1)
+
+    def test_unsaved_models(self):
+        objs = self.notes + [Note(note='test', misc='test')]
+        msg = 'All bulk_update() objects must have a primary key set.'
+        with self.assertRaisesMessage(ValueError, msg):
+            Note.objects.bulk_update(objs, fields=['note'])
+
+    def test_foreign_keys_do_not_lookup(self):
+        self.create_tags()
+        for note, tag in zip(self.notes, self.tags):
+            note.tag = tag
+        with self.assertNumQueries(1):
+            Note.objects.bulk_update(self.notes, ['tag'])
+        self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes)
+
+    def test_set_field_to_null(self):
+        self.create_tags()
+        Note.objects.update(tag=self.tags[0])
+        for note in self.notes:
+            note.tag = None
+        Note.objects.bulk_update(self.notes, ['tag'])
+        self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes)
+
+    def test_set_mixed_fields_to_null(self):
+        self.create_tags()
+        midpoint = len(self.notes) // 2
+        top, bottom = self.notes[:midpoint], self.notes[midpoint:]
+        for note in top:
+            note.tag = None
+        for note in bottom:
+            note.tag = self.tags[0]
+        Note.objects.bulk_update(self.notes, ['tag'])
+        self.assertCountEqual(Note.objects.filter(tag__isnull=True), top)
+        self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom)
+
+    def test_functions(self):
+        Note.objects.update(note='TEST')
+        for note in self.notes:
+            note.note = Lower('note')
+        Note.objects.bulk_update(self.notes, ['note'])
+        self.assertEqual(set(Note.objects.values_list('note', flat=True)), {'test'})
+
+    # Tests that use self.notes go here, otherwise put them in another class.
+
+
+class BulkUpdateTests(TestCase):
+    def test_no_fields(self):
+        msg = 'Field names must be given to bulk_update().'
+        with self.assertRaisesMessage(ValueError, msg):
+            Note.objects.bulk_update([], fields=[])
+
+    def test_invalid_batch_size(self):
+        msg = 'Batch size must be a positive integer.'
+        with self.assertRaisesMessage(ValueError, msg):
+            Note.objects.bulk_update([], fields=['note'], batch_size=-1)
+
+    def test_nonexistent_field(self):
+        with self.assertRaisesMessage(FieldDoesNotExist, "Note has no field named 'nonexistent'"):
+            Note.objects.bulk_update([], ['nonexistent'])
+
+    pk_fields_error = 'bulk_update() cannot be used with primary key fields.'
+
+    def test_update_primary_key(self):
+        with self.assertRaisesMessage(ValueError, self.pk_fields_error):
+            Note.objects.bulk_update([], ['id'])
+
+    def test_update_custom_primary_key(self):
+        with self.assertRaisesMessage(ValueError, self.pk_fields_error):
+            CustomPk.objects.bulk_update([], ['name'])
+
+    def test_empty_objects(self):
+        with self.assertNumQueries(0):
+            Note.objects.bulk_update([], ['note'])
+
+    def test_large_batch(self):
+        Note.objects.bulk_create([
+            Note(note=str(i), misc=str(i))
+            for i in range(0, 2000)
+        ])
+        notes = list(Note.objects.all())
+        Note.objects.bulk_update(notes, ['note'])
+
+    def test_only_concrete_fields_allowed(self):
+        obj = Valid.objects.create(valid='test')
+        detail = Detail.objects.create(data='test')
+        paragraph = Paragraph.objects.create(text='test')
+        Member.objects.create(name='test', details=detail)
+        msg = 'bulk_update() can only be used with concrete fields.'
+        with self.assertRaisesMessage(ValueError, msg):
+            Detail.objects.bulk_update([detail], fields=['member'])
+        with self.assertRaisesMessage(ValueError, msg):
+            Paragraph.objects.bulk_update([paragraph], fields=['page'])
+        with self.assertRaisesMessage(ValueError, msg):
+            Valid.objects.bulk_update([obj], fields=['parent'])
+
+    def test_custom_db_columns(self):
+        model = CustomDbColumn.objects.create(custom_column=1)
+        model.custom_column = 2
+        CustomDbColumn.objects.bulk_update([model], fields=['custom_column'])
+        model.refresh_from_db()
+        self.assertEqual(model.custom_column, 2)
+
+    def test_custom_pk(self):
+        custom_pks = [
+            CustomPk.objects.create(name='pk-%s' % i, extra='')
+            for i in range(10)
+        ]
+        for model in custom_pks:
+            model.extra = 'extra-%s' % model.pk
+        CustomPk.objects.bulk_update(custom_pks, ['extra'])
+        self.assertCountEqual(
+            CustomPk.objects.values_list('extra', flat=True),
+            [cat.extra for cat in custom_pks]
+        )
+
+    def test_inherited_fields(self):
+        special_categories = [
+            SpecialCategory.objects.create(name=str(i), special_name=str(i))
+            for i in range(10)
+        ]
+        for category in special_categories:
+            category.name = 'test-%s' % category.id
+            category.special_name = 'special-test-%s' % category.special_name
+        SpecialCategory.objects.bulk_update(special_categories, ['name', 'special_name'])
+        self.assertCountEqual(
+            SpecialCategory.objects.values_list('name', flat=True),
+            [cat.name for cat in special_categories]
+        )
+        self.assertCountEqual(
+            SpecialCategory.objects.values_list('special_name', flat=True),
+            [cat.special_name for cat in special_categories]
+        )
+
+    def test_field_references(self):
+        numbers = [Number.objects.create(num=0) for _ in range(10)]
+        for number in numbers:
+            number.num = F('num') + 1
+        Number.objects.bulk_update(numbers, ['num'])
+        self.assertCountEqual(Number.objects.filter(num=1), numbers)
+
+    def test_booleanfield(self):
+        individuals = [Individual.objects.create(alive=False) for _ in range(10)]
+        for individual in individuals:
+            individual.alive = True
+        Individual.objects.bulk_update(individuals, ['alive'])
+        self.assertCountEqual(Individual.objects.filter(alive=True), individuals)
+
+    def test_ipaddressfield(self):
+        for ip in ('2001::1', '1.2.3.4'):
+            with self.subTest(ip=ip):
+                models = [
+                    CustomDbColumn.objects.create(ip_address='0.0.0.0')
+                    for _ in range(10)
+                ]
+                for model in models:
+                    model.ip_address = ip
+                CustomDbColumn.objects.bulk_update(models, ['ip_address'])
+                self.assertCountEqual(CustomDbColumn.objects.filter(ip_address=ip), models)
+
+    def test_datetime_field(self):
+        articles = [
+            Article.objects.create(name=str(i), created=datetime.datetime.today())
+            for i in range(10)
+        ]
+        point_in_time = datetime.datetime(1991, 10, 31)
+        for article in articles:
+            article.created = point_in_time
+        Article.objects.bulk_update(articles, ['created'])
+        self.assertCountEqual(Article.objects.filter(created=point_in_time), articles)