فهرست منبع

Fixed #36118 -- Accounted for multiple primary keys in bulk_update max_batch_size.

Co-authored-by: Simon Charette <charette.s@gmail.com>
Sarah Boyce 2 ماه پیش
والد
کامیت
5a2c1bc07d

+ 14 - 1
django/db/backends/oracle/operations.py

@@ -1,12 +1,19 @@
 import datetime
 import uuid
 from functools import lru_cache
+from itertools import chain
 
 from django.conf import settings
 from django.db import NotSupportedError
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
-from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
+from django.db.models import (
+    AutoField,
+    CompositePrimaryKey,
+    Exists,
+    ExpressionWrapper,
+    Lookup,
+)
 from django.db.models.expressions import RawSQL
 from django.db.models.sql.where import WhereNode
 from django.utils import timezone
@@ -699,6 +706,12 @@ END;
 
     def bulk_batch_size(self, fields, objs):
         """Oracle restricts the number of parameters in a query."""
+        fields = list(
+            chain.from_iterable(
+                field.fields if isinstance(field, CompositePrimaryKey) else [field]
+                for field in fields
+            )
+        )
         if fields:
             return self.connection.features.max_query_params // len(fields)
         return len(objs)

+ 10 - 0
django/db/backends/sqlite3/operations.py

@@ -36,6 +36,16 @@ class DatabaseOperations(BaseDatabaseOperations):
         If there's only a single field to insert, the limit is 500
         (SQLITE_MAX_COMPOUND_SELECT).
         """
+        fields = list(
+            chain.from_iterable(
+                (
+                    field.fields
+                    if isinstance(field, models.CompositePrimaryKey)
+                    else [field]
+                )
+                for field in fields
+            )
+        )
         if len(fields) == 1:
             return 500
         elif len(fields) > 1:

+ 1 - 2
django/db/models/deletion.py

@@ -230,9 +230,8 @@ class Collector:
         """
         Return the objs in suitably sized batches for the used connection.
         """
-        field_names = [field.name for field in fields]
         conn_batch_size = max(
-            connections[self.using].ops.bulk_batch_size(field_names, objs), 1
+            connections[self.using].ops.bulk_batch_size(fields, objs), 1
         )
         if len(objs) > conn_batch_size:
             return [

+ 7 - 4
django/db/models/query.py

@@ -874,11 +874,12 @@ class QuerySet(AltersData):
         objs = tuple(objs)
         if not all(obj._is_pk_set() for obj in objs):
             raise ValueError("All bulk_update() objects must have a primary key set.")
-        fields = [self.model._meta.get_field(name) for name in fields]
+        opts = self.model._meta
+        fields = [opts.get_field(name) for name in fields]
         if any(not f.concrete or f.many_to_many for f in fields):
             raise ValueError("bulk_update() can only be used with concrete fields.")
-        all_pk_fields = set(self.model._meta.pk_fields)
-        for parent in self.model._meta.all_parents:
+        all_pk_fields = set(opts.pk_fields)
+        for parent in opts.all_parents:
             all_pk_fields.update(parent._meta.pk_fields)
         if any(f in all_pk_fields for f in fields):
             raise ValueError("bulk_update() cannot be used with primary key fields.")
@@ -892,7 +893,9 @@ class QuerySet(AltersData):
         # and once in the WHEN. Each field will also have one CAST.
         self._for_write = True
         connection = connections[self.db]
-        max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
+        max_batch_size = connection.ops.bulk_batch_size(
+            [opts.pk, opts.pk] + fields, objs
+        )
         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
         requires_casting = connection.features.requires_casted_case_in_updates
         batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))

+ 14 - 3
tests/backends/oracle/test_operations.py

@@ -1,7 +1,7 @@
 import unittest
 
 from django.core.management.color import no_style
-from django.db import connection
+from django.db import connection, models
 from django.test import TransactionTestCase
 
 from ..models import Person, Tag
@@ -22,14 +22,25 @@ class OperationsTests(TransactionTestCase):
         objects = range(2**16)
         self.assertEqual(connection.ops.bulk_batch_size([], objects), len(objects))
         # Each field is a parameter for each object.
+        first_name_field = Person._meta.get_field("first_name")
+        last_name_field = Person._meta.get_field("last_name")
         self.assertEqual(
-            connection.ops.bulk_batch_size(["id"], objects),
+            connection.ops.bulk_batch_size([first_name_field], objects),
             connection.features.max_query_params,
         )
         self.assertEqual(
-            connection.ops.bulk_batch_size(["id", "other"], objects),
+            connection.ops.bulk_batch_size(
+                [first_name_field, last_name_field],
+                objects,
+            ),
             connection.features.max_query_params // 2,
         )
+        composite_pk = models.CompositePrimaryKey("first_name", "last_name")
+        composite_pk.fields = [first_name_field, last_name_field]
+        self.assertEqual(
+            connection.ops.bulk_batch_size([composite_pk, first_name_field], objects),
+            connection.features.max_query_params // 3,
+        )
 
     def test_sql_flush(self):
         statements = connection.ops.sql_flush(

+ 23 - 1
tests/backends/sqlite/test_operations.py

@@ -1,7 +1,7 @@
 import unittest
 
 from django.core.management.color import no_style
-from django.db import connection
+from django.db import connection, models
 from django.test import TestCase
 
 from ..models import Person, Tag
@@ -86,3 +86,25 @@ class SQLiteOperationsTests(TestCase):
             "zzz'",
             statements[-1],
         )
+
+    def test_bulk_batch_size(self):
+        self.assertEqual(connection.ops.bulk_batch_size([], [Person()]), 1)
+        first_name_field = Person._meta.get_field("first_name")
+        last_name_field = Person._meta.get_field("last_name")
+        self.assertEqual(
+            connection.ops.bulk_batch_size([first_name_field], [Person()]), 500
+        )
+        self.assertEqual(
+            connection.ops.bulk_batch_size(
+                [first_name_field, last_name_field], [Person()]
+            ),
+            connection.features.max_query_params // 2,
+        )
+        composite_pk = models.CompositePrimaryKey("first_name", "last_name")
+        composite_pk.fields = [first_name_field, last_name_field]
+        self.assertEqual(
+            connection.ops.bulk_batch_size(
+                [composite_pk, first_name_field], [Person()]
+            ),
+            connection.features.max_query_params // 3,
+        )