Browse Source

Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().

Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope,
and Mariusz Felisiak for reviews.
sean_c_hsu 4 years ago
parent
commit
0f6946495a

+ 4 - 0
django/db/backends/base/features.py

@@ -271,6 +271,10 @@ class BaseDatabaseFeatures:
     # Does the backend support ignoring constraint or uniqueness errors during
     # INSERT?
     supports_ignore_conflicts = True
+    # Does the backend support updating rows on constraint or uniqueness errors
+    # during INSERT?
+    supports_update_conflicts = False
+    supports_update_conflicts_with_target = False
 
     # Does this backend require casting the results of CASE expressions used
     # in UPDATE statements to ensure the expression has the correct type?

+ 2 - 2
django/db/backends/base/operations.py

@@ -717,8 +717,8 @@ class BaseDatabaseOperations:
             raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
         return self.explain_prefix
 
-    def insert_statement(self, ignore_conflicts=False):
+    def insert_statement(self, on_conflict=None):
         return 'INSERT INTO'
 
-    def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
         return ''

+ 1 - 0
django/db/backends/mysql/features.py

@@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_select_difference = False
     supports_slicing_ordering_in_compound = True
     supports_index_on_text_field = False
+    supports_update_conflicts = True
     create_test_procedure_without_params_sql = """
         CREATE PROCEDURE test_procedure ()
         BEGIN

+ 29 - 2
django/db/backends/mysql/operations.py

@@ -4,6 +4,7 @@ from django.conf import settings
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import split_tzname_delta
 from django.db.models import Exists, ExpressionWrapper, Lookup
+from django.db.models.constants import OnConflict
 from django.utils import timezone
 from django.utils.encoding import force_str
 
@@ -365,8 +366,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         match_option = 'c' if lookup_type == 'regex' else 'i'
         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
 
-    def insert_statement(self, ignore_conflicts=False):
-        return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
+    def insert_statement(self, on_conflict=None):
+        if on_conflict == OnConflict.IGNORE:
+            return 'INSERT IGNORE INTO'
+        return super().insert_statement(on_conflict=on_conflict)
 
     def lookup_cast(self, lookup_type, internal_type=None):
         lookup = '%s'
@@ -388,3 +391,27 @@ class DatabaseOperations(BaseDatabaseOperations):
         if getattr(expression, 'conditional', False):
             return False
         return super().conditional_expression_supported_in_where_clause(expression)
+
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
+        if on_conflict == OnConflict.UPDATE:
+            conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s'
+            field_sql = '%(field)s = VALUES(%(field)s)'
+            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
+            # aliases for the new row and its columns available in MySQL
+            # 8.0.19+.
+            if not self.connection.mysql_is_mariadb:
+                if self.connection.mysql_version >= (8, 0, 19):
+                    conflict_suffix_sql = f'AS new {conflict_suffix_sql}'
+                    field_sql = '%(field)s = new.%(field)s'
+            # VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
+            elif self.connection.mysql_version >= (10, 3, 3):
+                field_sql = '%(field)s = VALUE(%(field)s)'
+
+            fields = ', '.join([
+                field_sql % {'field': field}
+                for field in map(self.quote_name, update_fields)
+            ])
+            return conflict_suffix_sql % {'fields': fields}
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 2 - 0
django/db/backends/postgresql/features.py

@@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_deferrable_unique_constraints = True
     has_json_operators = True
     json_key_contains_list_matching_requires_list = True
+    supports_update_conflicts = True
+    supports_update_conflicts_with_target = True
     test_collations = {
         'non_default': 'sv-x-icu',
         'swedish_ci': 'sv-x-icu',

+ 15 - 2
django/db/backends/postgresql/operations.py

@@ -3,6 +3,7 @@ from psycopg2.extras import Inet
 from django.conf import settings
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import split_tzname_delta
+from django.db.models.constants import OnConflict
 
 
 class DatabaseOperations(BaseDatabaseOperations):
@@ -272,5 +273,17 @@ class DatabaseOperations(BaseDatabaseOperations):
             prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
         return prefix
 
-    def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
-        return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
+        if on_conflict == OnConflict.IGNORE:
+            return 'ON CONFLICT DO NOTHING'
+        if on_conflict == OnConflict.UPDATE:
+            return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
+                ', '.join(map(self.quote_name, unique_fields)),
+                ', '.join([
+                    f'{field} = EXCLUDED.{field}'
+                    for field in map(self.quote_name, update_fields)
+                ]),
+            )
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 2 - 0
django/db/backends/sqlite3/features.py

@@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
     order_by_nulls_first = True
     supports_json_field_contains = False
+    supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
+    supports_update_conflicts_with_target = supports_update_conflicts
     test_collations = {
         'ci': 'nocase',
         'cs': 'binary',

+ 21 - 2
django/db/backends/sqlite3/operations.py

@@ -8,6 +8,7 @@ from django.conf import settings
 from django.core.exceptions import FieldError
 from django.db import DatabaseError, NotSupportedError, models
 from django.db.backends.base.operations import BaseDatabaseOperations
+from django.db.models.constants import OnConflict
 from django.db.models.expressions import Col
 from django.utils import timezone
 from django.utils.dateparse import parse_date, parse_datetime, parse_time
@@ -370,8 +371,10 @@ class DatabaseOperations(BaseDatabaseOperations):
             return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
         return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
 
-    def insert_statement(self, ignore_conflicts=False):
-        return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
+    def insert_statement(self, on_conflict=None):
+        if on_conflict == OnConflict.IGNORE:
+            return 'INSERT OR IGNORE INTO'
+        return super().insert_statement(on_conflict=on_conflict)
 
     def return_insert_columns(self, fields):
         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
@@ -384,3 +387,19 @@ class DatabaseOperations(BaseDatabaseOperations):
             ) for field in fields
         ]
         return 'RETURNING %s' % ', '.join(columns), ()
+
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
+        if (
+            on_conflict == OnConflict.UPDATE and
+            self.connection.features.supports_update_conflicts_with_target
+        ):
+            return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
+                ', '.join(map(self.quote_name, unique_fields)),
+                ', '.join([
+                    f'{field} = EXCLUDED.{field}'
+                    for field in map(self.quote_name, update_fields)
+                ]),
+            )
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 6 - 0
django/db/models/constants.py

@@ -1,6 +1,12 @@
 """
 Constants used across the ORM in general.
 """
+from enum import Enum
 
 # Separator used to split filter strings apart.
 LOOKUP_SEP = '__'
+
+
+class OnConflict(Enum):
+    IGNORE = 'ignore'
+    UPDATE = 'update'

+ 106 - 13
django/db/models/query.py

@@ -15,7 +15,7 @@ from django.db import (
     router, transaction,
 )
 from django.db.models import AutoField, DateField, DateTimeField, sql
-from django.db.models.constants import LOOKUP_SEP
+from django.db.models.constants import LOOKUP_SEP, OnConflict
 from django.db.models.deletion import Collector
 from django.db.models.expressions import Case, Expression, F, Ref, Value, When
 from django.db.models.functions import Cast, Trunc
@@ -466,7 +466,69 @@ class QuerySet:
                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
             obj._prepare_related_fields_for_save(operation_name='bulk_create')
 
-    def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
+    def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
+        if ignore_conflicts and update_conflicts:
+            raise ValueError(
+                'ignore_conflicts and update_conflicts are mutually exclusive.'
+            )
+        db_features = connections[self.db].features
+        if ignore_conflicts:
+            if not db_features.supports_ignore_conflicts:
+                raise NotSupportedError(
+                    'This database backend does not support ignoring conflicts.'
+                )
+            return OnConflict.IGNORE
+        elif update_conflicts:
+            if not db_features.supports_update_conflicts:
+                raise NotSupportedError(
+                    'This database backend does not support updating conflicts.'
+                )
+            if not update_fields:
+                raise ValueError(
+                    'Fields that will be updated when a row insertion fails '
+                    'on conflicts must be provided.'
+                )
+            if unique_fields and not db_features.supports_update_conflicts_with_target:
+                raise NotSupportedError(
+                    'This database backend does not support updating '
+                    'conflicts with specifying unique fields that can trigger '
+                    'the upsert.'
+                )
+            if not unique_fields and db_features.supports_update_conflicts_with_target:
+                raise ValueError(
+                    'Unique fields that can trigger the upsert must be '
+                    'provided.'
+                )
+            # Updating primary keys and non-concrete fields is forbidden.
+            update_fields = [self.model._meta.get_field(name) for name in update_fields]
+            if any(not f.concrete or f.many_to_many for f in update_fields):
+                raise ValueError(
+                    'bulk_create() can only be used with concrete fields in '
+                    'update_fields.'
+                )
+            if any(f.primary_key for f in update_fields):
+                raise ValueError(
+                    'bulk_create() cannot be used with primary keys in '
+                    'update_fields.'
+                )
+            if unique_fields:
+                # Primary key is allowed in unique_fields.
+                unique_fields = [
+                    self.model._meta.get_field(name)
+                    for name in unique_fields if name != 'pk'
+                ]
+                if any(not f.concrete or f.many_to_many for f in unique_fields):
+                    raise ValueError(
+                        'bulk_create() can only be used with concrete fields '
+                        'in unique_fields.'
+                    )
+            return OnConflict.UPDATE
+        return None
+
+    def bulk_create(
+        self, objs, batch_size=None, ignore_conflicts=False,
+        update_conflicts=False, update_fields=None, unique_fields=None,
+    ):
         """
         Insert each of the instances into the database. Do *not* call
         save() on each of the instances, do not send any pre/post_save
@@ -497,6 +559,12 @@ class QuerySet:
                 raise ValueError("Can't bulk create a multi-table inherited model")
         if not objs:
             return objs
+        on_conflict = self._check_bulk_create_options(
+            ignore_conflicts,
+            update_conflicts,
+            update_fields,
+            unique_fields,
+        )
         self._for_write = True
         opts = self.model._meta
         fields = opts.concrete_fields
@@ -506,7 +574,12 @@ class QuerySet:
             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
             if objs_with_pk:
                 returned_columns = self._batched_insert(
-                    objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
+                    objs_with_pk,
+                    fields,
+                    batch_size,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
                 )
                 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
                     for result, field in zip(results, opts.db_returning_fields):
@@ -518,10 +591,15 @@ class QuerySet:
             if objs_without_pk:
                 fields = [f for f in fields if not isinstance(f, AutoField)]
                 returned_columns = self._batched_insert(
-                    objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
+                    objs_without_pk,
+                    fields,
+                    batch_size,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
                 )
                 connection = connections[self.db]
-                if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
+                if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
                     assert len(returned_columns) == len(objs_without_pk)
                 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
                     for result, field in zip(results, opts.db_returning_fields):
@@ -1293,7 +1371,10 @@ class QuerySet:
     # PRIVATE METHODS #
     ###################
 
-    def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
+    def _insert(
+        self, objs, fields, returning_fields=None, raw=False, using=None,
+        on_conflict=None, update_fields=None, unique_fields=None,
+    ):
         """
         Insert a new record for the given model. This provides an interface to
         the InsertQuery class and is how Model.save() is implemented.
@@ -1301,33 +1382,45 @@ class QuerySet:
         self._for_write = True
         if using is None:
             using = self.db
-        query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
+        query = sql.InsertQuery(
+            self.model,
+            on_conflict=on_conflict,
+            update_fields=update_fields,
+            unique_fields=unique_fields,
+        )
         query.insert_values(fields, objs, raw=raw)
         return query.get_compiler(using=using).execute_sql(returning_fields)
     _insert.alters_data = True
     _insert.queryset_only = False
 
-    def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
+    def _batched_insert(
+        self, objs, fields, batch_size, on_conflict=None, update_fields=None,
+        unique_fields=None,
+    ):
         """
         Helper method for bulk_create() to insert objs one batch at a time.
         """
         connection = connections[self.db]
-        if ignore_conflicts and not connection.features.supports_ignore_conflicts:
-            raise NotSupportedError('This database backend does not support ignoring conflicts.')
         ops = connection.ops
         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
         inserted_rows = []
         bulk_return = connection.features.can_return_rows_from_bulk_insert
         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
-            if bulk_return and not ignore_conflicts:
+            if bulk_return and on_conflict is None:
                 inserted_rows.extend(self._insert(
                     item, fields=fields, using=self.db,
                     returning_fields=self.model._meta.db_returning_fields,
-                    ignore_conflicts=ignore_conflicts,
                 ))
             else:
-                self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
+                self._insert(
+                    item,
+                    fields=fields,
+                    using=self.db,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
+                )
         return inserted_rows
 
     def _chain(self):

+ 14 - 9
django/db/models/sql/compiler.py

@@ -1387,7 +1387,9 @@ class SQLInsertCompiler(SQLCompiler):
         # going to be column names (so we can avoid the extra overhead).
         qn = self.connection.ops.quote_name
         opts = self.query.get_meta()
-        insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts)
+        insert_statement = self.connection.ops.insert_statement(
+            on_conflict=self.query.on_conflict,
+        )
         result = ['%s %s' % (insert_statement, qn(opts.db_table))]
         fields = self.query.fields or [opts.pk]
         result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
@@ -1410,8 +1412,11 @@ class SQLInsertCompiler(SQLCompiler):
 
         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
 
-        ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
-            ignore_conflicts=self.query.ignore_conflicts
+        on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql(
+            fields,
+            self.query.on_conflict,
+            self.query.update_fields,
+            self.query.unique_fields,
         )
         if self.returning_fields and self.connection.features.can_return_columns_from_insert:
             if self.connection.features.can_return_rows_from_bulk_insert:
@@ -1420,8 +1425,8 @@ class SQLInsertCompiler(SQLCompiler):
             else:
                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
                 params = [param_rows[0]]
-            if ignore_conflicts_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+            if on_conflict_suffix_sql:
+                result.append(on_conflict_suffix_sql)
             # Skip empty r_sql to allow subclasses to customize behavior for
             # 3rd party backends. Refs #19096.
             r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
@@ -1432,12 +1437,12 @@ class SQLInsertCompiler(SQLCompiler):
 
         if can_bulk:
             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
-            if ignore_conflicts_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+            if on_conflict_suffix_sql:
+                result.append(on_conflict_suffix_sql)
             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
         else:
-            if ignore_conflicts_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+            if on_conflict_suffix_sql:
+                result.append(on_conflict_suffix_sql)
             return [
                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
                 for p, vals in zip(placeholder_rows, param_rows)

+ 4 - 2
django/db/models/sql/subqueries.py

@@ -138,11 +138,13 @@ class UpdateQuery(Query):
 class InsertQuery(Query):
     compiler = 'SQLInsertCompiler'
 
-    def __init__(self, *args, ignore_conflicts=False, **kwargs):
+    def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs):
         super().__init__(*args, **kwargs)
         self.fields = []
         self.objs = []
-        self.ignore_conflicts = ignore_conflicts
+        self.on_conflict = on_conflict
+        self.update_fields = update_fields or []
+        self.unique_fields = unique_fields or []
 
     def insert_values(self, fields, objs, raw=False):
         self.fields = fields

+ 18 - 4
docs/ref/models/querysets.txt

@@ -2155,7 +2155,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
 ``bulk_create()``
 ~~~~~~~~~~~~~~~~~
 
-.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False)
+.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None)
 
 This method inserts the provided list of objects into the database in an
 efficient manner (generally only 1 query, no matter how many objects there
@@ -2198,9 +2198,17 @@ where the default is such that at most 999 variables per query are used.
 
 On databases that support it (all but Oracle), setting the ``ignore_conflicts``
 parameter to ``True`` tells the database to ignore failure to insert any rows
-that fail constraints such as duplicate unique values. Enabling this parameter
-disables setting the primary key on each model instance (if the database
-normally supports it).
+that fail constraints such as duplicate unique values.
+
+On databases that support it (all except Oracle and SQLite < 3.24), setting the
+``update_conflicts`` parameter to ``True``, tells the database to update
+``update_fields`` when a row insertion fails on conflicts. On PostgreSQL and
+SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
+be in conflict must be provided.
+
+Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable
+setting the primary key on each model instance (if the database normally
+support it).
 
 .. warning::
 
@@ -2217,6 +2225,12 @@ normally supports it).
 
     Support for the fetching primary key attributes on SQLite 3.35+ was added.
 
+.. versionchanged:: 4.1
+
+    The ``update_conflicts``, ``update_fields``, and ``unique_fields``
+    parameters were added to support updating fields when a row insertion fails
+    on conflict.
+
 ``bulk_update()``
 ~~~~~~~~~~~~~~~~~
 

+ 12 - 0
docs/releases/4.1.txt

@@ -232,6 +232,10 @@ Models
   in order to reduce the number of failed requests, e.g. after database server
   restart.
 
+* :meth:`.QuerySet.bulk_create` now supports updating fields when a row
+  insertion fails uniqueness constraints. This is supported on MariaDB, MySQL,
+  PostgreSQL, and SQLite 3.24+.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
@@ -298,6 +302,14 @@ backends.
 * ``DatabaseIntrospection.get_key_columns()`` is removed. Use
   ``DatabaseIntrospection.get_relations()`` instead.
 
+* ``DatabaseOperations.ignore_conflicts_suffix_sql()`` method is replaced by
+  ``DatabaseOperations.on_conflict_suffix_sql()`` that accepts the ``fields``,
+  ``on_conflict``, ``update_fields``, and ``unique_fields`` arguments.
+
+* The ``ignore_conflicts`` argument of the
+  ``DatabaseOperations.insert_statement()`` method is replaced by
+  ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
+
 Dropped support for MariaDB 10.2
 --------------------------------
 

+ 21 - 0
tests/bulk_create/models.py

@@ -16,6 +16,14 @@ class Country(models.Model):
     iso_two_letter = models.CharField(max_length=2)
     description = models.TextField()
 
+    class Meta:
+        constraints = [
+            models.UniqueConstraint(
+                fields=['iso_two_letter', 'name'],
+                name='country_name_iso_unique',
+            ),
+        ]
+
 
 class ProxyCountry(Country):
     class Meta:
@@ -58,6 +66,13 @@ class State(models.Model):
 class TwoFields(models.Model):
     f1 = models.IntegerField(unique=True)
     f2 = models.IntegerField(unique=True)
+    name = models.CharField(max_length=15, null=True)
+
+
+class UpsertConflict(models.Model):
+    number = models.IntegerField(unique=True)
+    rank = models.IntegerField()
+    name = models.CharField(max_length=15)
 
 
 class NoFields(models.Model):
@@ -103,3 +118,9 @@ class NullableFields(models.Model):
     text_field = models.TextField(null=True, default='text')
     url_field = models.URLField(null=True, default='/')
     uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
+
+
+class RelatedModel(models.Model):
+    name = models.CharField(max_length=15, null=True)
+    country = models.OneToOneField(Country, models.CASCADE, primary_key=True)
+    big_auto_fields = models.ManyToManyField(BigAutoFieldModel)

+ 285 - 7
tests/bulk_create/tests.py

@@ -1,7 +1,11 @@
 from math import ceil
 from operator import attrgetter
 
-from django.db import IntegrityError, NotSupportedError, connection
+from django.core.exceptions import FieldDoesNotExist
+from django.db import (
+    IntegrityError, NotSupportedError, OperationalError, ProgrammingError,
+    connection,
+)
 from django.db.models import FileField, Value
 from django.db.models.functions import Lower
 from django.test import (
@@ -11,7 +15,8 @@ from django.test import (
 from .models import (
     BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria,
     ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry,
-    Restaurant, SmallAutoFieldModel, State, TwoFields,
+    RelatedModel, Restaurant, SmallAutoFieldModel, State, TwoFields,
+    UpsertConflict,
 )
 
 
@@ -53,10 +58,10 @@ class BulkCreateTests(TestCase):
     @skipUnlessDBFeature('has_bulk_insert')
     def test_long_and_short_text(self):
         Country.objects.bulk_create([
-            Country(description='a' * 4001),
-            Country(description='a'),
-            Country(description='Ж' * 2001),
-            Country(description='Ж'),
+            Country(description='a' * 4001, iso_two_letter='A'),
+            Country(description='a', iso_two_letter='B'),
+            Country(description='Ж' * 2001, iso_two_letter='C'),
+            Country(description='Ж', iso_two_letter='D'),
         ])
         self.assertEqual(Country.objects.count(), 4)
 
@@ -218,7 +223,7 @@ class BulkCreateTests(TestCase):
 
     @skipUnlessDBFeature('has_bulk_insert')
     def test_explicit_batch_size_respects_max_batch_size(self):
-        objs = [Country() for i in range(1000)]
+        objs = [Country(name=f'Country {i}') for i in range(1000)]
         fields = ['name', 'iso_two_letter', 'description']
         max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1)
         with self.assertNumQueries(ceil(len(objs) / max_batch_size)):
@@ -352,3 +357,276 @@ class BulkCreateTests(TestCase):
         msg = 'Batch size must be a positive integer.'
         with self.assertRaisesMessage(ValueError, msg):
             Country.objects.bulk_create([], batch_size=-1)
+
+    @skipIfDBFeature('supports_update_conflicts')
+    def test_update_conflicts_unsupported(self):
+        msg = 'This database backend does not support updating conflicts.'
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            Country.objects.bulk_create(self.data, update_conflicts=True)
+
+    @skipUnlessDBFeature('supports_ignore_conflicts', 'supports_update_conflicts')
+    def test_ignore_update_conflicts_exclusive(self):
+        msg = 'ignore_conflicts and update_conflicts are mutually exclusive'
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                ignore_conflicts=True,
+                update_conflicts=True,
+            )
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    def test_update_conflicts_no_update_fields(self):
+        msg = (
+            'Fields that will be updated when a row insertion fails on '
+            'conflicts must be provided.'
+        )
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(self.data, update_conflicts=True)
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_field_unsupported(self):
+        msg = (
+            'This database backend does not support updating conflicts with '
+            'specifying unique fields that can trigger the upsert.'
+        )
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['f2'],
+                unique_fields=['f1'],
+            )
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    def test_update_conflicts_nonexistent_update_fields(self):
+        unique_fields = None
+        if connection.features.supports_update_conflicts_with_target:
+            unique_fields = ['f1']
+        msg = "TwoFields has no field named 'nonexistent'"
+        with self.assertRaisesMessage(FieldDoesNotExist, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['nonexistent'],
+                unique_fields=unique_fields,
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_unique_fields_required(self):
+        msg = 'Unique fields that can trigger the upsert must be provided.'
+        with self.assertRaisesMessage(ValueError, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['f1'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_invalid_update_fields(self):
+        msg = (
+            'bulk_create() can only be used with concrete fields in '
+            'update_fields.'
+        )
+        # Reverse one-to-one relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                update_conflicts=True,
+                update_fields=['relatedmodel'],
+                unique_fields=['pk'],
+            )
+        # Many-to-many relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            RelatedModel.objects.bulk_create(
+                [RelatedModel(country=self.data[0])],
+                update_conflicts=True,
+                update_fields=['big_auto_fields'],
+                unique_fields=['country'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_pk_in_update_fields(self):
+        msg = 'bulk_create() cannot be used with primary keys in update_fields.'
+        with self.assertRaisesMessage(ValueError, msg):
+            BigAutoFieldModel.objects.bulk_create(
+                [BigAutoFieldModel()],
+                update_conflicts=True,
+                update_fields=['id'],
+                unique_fields=['id'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_invalid_unique_fields(self):
+        msg = (
+            'bulk_create() can only be used with concrete fields in '
+            'unique_fields.'
+        )
+        # Reverse one-to-one relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                update_conflicts=True,
+                update_fields=['name'],
+                unique_fields=['relatedmodel'],
+            )
+        # Many-to-many relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            RelatedModel.objects.bulk_create(
+                [RelatedModel(country=self.data[0])],
+                update_conflicts=True,
+                update_fields=['name'],
+                unique_fields=['big_auto_fields'],
+            )
+
+    def _test_update_conflicts_two_fields(self, unique_fields):
+        TwoFields.objects.bulk_create([
+            TwoFields(f1=1, f2=1, name='a'),
+            TwoFields(f1=2, f2=2, name='b'),
+        ])
+        self.assertEqual(TwoFields.objects.count(), 2)
+
+        conflicting_objects = [
+            TwoFields(f1=1, f2=1, name='c'),
+            TwoFields(f1=2, f2=2, name='d'),
+        ]
+        TwoFields.objects.bulk_create(
+            conflicting_objects,
+            update_conflicts=True,
+            unique_fields=unique_fields,
+            update_fields=['name'],
+        )
+        self.assertEqual(TwoFields.objects.count(), 2)
+        self.assertCountEqual(TwoFields.objects.values('f1', 'f2', 'name'), [
+            {'f1': 1, 'f2': 1, 'name': 'c'},
+            {'f1': 2, 'f2': 2, 'name': 'd'},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_first(self):
+        self._test_update_conflicts_two_fields(['f1'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_second(self):
+        self._test_update_conflicts_two_fields(['f2'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_both(self):
+        with self.assertRaises((OperationalError, ProgrammingError)):
+            self._test_update_conflicts_two_fields(['f1', 'f2'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_no_unique_fields(self):
+        self._test_update_conflicts_two_fields([])
+
+    def _test_update_conflicts_unique_two_fields(self, unique_fields):
+        Country.objects.bulk_create(self.data)
+        self.assertEqual(Country.objects.count(), 4)
+
+        new_data = [
+            # Conflicting countries.
+            Country(name='Germany', iso_two_letter='DE', description=(
+                'Germany is a country in Central Europe.'
+            )),
+            Country(name='Czech Republic', iso_two_letter='CZ', description=(
+                'The Czech Republic is a landlocked country in Central Europe.'
+            )),
+            # New countries.
+            Country(name='Australia', iso_two_letter='AU'),
+            Country(name='Japan', iso_two_letter='JP', description=(
+                'Japan is an island country in East Asia.'
+            )),
+        ]
+        Country.objects.bulk_create(
+            new_data,
+            update_conflicts=True,
+            update_fields=['description'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(Country.objects.count(), 6)
+        self.assertCountEqual(Country.objects.values('iso_two_letter', 'description'), [
+            {'iso_two_letter': 'US', 'description': ''},
+            {'iso_two_letter': 'NL', 'description': ''},
+            {'iso_two_letter': 'DE', 'description': (
+                'Germany is a country in Central Europe.'
+            )},
+            {'iso_two_letter': 'CZ', 'description': (
+                'The Czech Republic is a landlocked country in Central Europe.'
+            )},
+            {'iso_two_letter': 'AU', 'description': ''},
+            {'iso_two_letter': 'JP', 'description': (
+                'Japan is an island country in East Asia.'
+            )},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_fields_both(self):
+        self._test_update_conflicts_unique_two_fields(['iso_two_letter', 'name'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_fields_one(self):
+        with self.assertRaises((OperationalError, ProgrammingError)):
+            self._test_update_conflicts_unique_two_fields(['iso_two_letter'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self):
+        self._test_update_conflicts_unique_two_fields([])
+
+    def _test_update_conflicts(self, unique_fields):
+        UpsertConflict.objects.bulk_create([
+            UpsertConflict(number=1, rank=1, name='John'),
+            UpsertConflict(number=2, rank=2, name='Mary'),
+            UpsertConflict(number=3, rank=3, name='Hannah'),
+        ])
+        self.assertEqual(UpsertConflict.objects.count(), 3)
+
+        conflicting_objects = [
+            UpsertConflict(number=1, rank=4, name='Steve'),
+            UpsertConflict(number=2, rank=2, name='Olivia'),
+            UpsertConflict(number=3, rank=1, name='Hannah'),
+        ]
+        UpsertConflict.objects.bulk_create(
+            conflicting_objects,
+            update_conflicts=True,
+            update_fields=['name', 'rank'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(UpsertConflict.objects.count(), 3)
+        self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
+            {'number': 1, 'rank': 4, 'name': 'Steve'},
+            {'number': 2, 'rank': 2, 'name': 'Olivia'},
+            {'number': 3, 'rank': 1, 'name': 'Hannah'},
+        ])
+
+        UpsertConflict.objects.bulk_create(
+            conflicting_objects + [UpsertConflict(number=4, rank=4, name='Mark')],
+            update_conflicts=True,
+            update_fields=['name', 'rank'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(UpsertConflict.objects.count(), 4)
+        self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
+            {'number': 1, 'rank': 4, 'name': 'Steve'},
+            {'number': 2, 'rank': 2, 'name': 'Olivia'},
+            {'number': 3, 'rank': 1, 'name': 'Hannah'},
+            {'number': 4, 'rank': 4, 'name': 'Mark'},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_fields(self):
+        self._test_update_conflicts(unique_fields=['number'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_no_unique_fields(self):
+        self._test_update_conflicts([])