Browse Source

Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().

Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope,
and Mariusz Felisiak for reviews.
sean_c_hsu 4 years ago
parent
commit
0f6946495a

+ 4 - 0
django/db/backends/base/features.py

@@ -271,6 +271,10 @@ class BaseDatabaseFeatures:
     # Does the backend support ignoring constraint or uniqueness errors during
     # Does the backend support ignoring constraint or uniqueness errors during
     # INSERT?
     # INSERT?
     supports_ignore_conflicts = True
     supports_ignore_conflicts = True
+    # Does the backend support updating rows on constraint or uniqueness errors
+    # during INSERT?
+    supports_update_conflicts = False
+    supports_update_conflicts_with_target = False
 
 
     # Does this backend require casting the results of CASE expressions used
     # Does this backend require casting the results of CASE expressions used
     # in UPDATE statements to ensure the expression has the correct type?
     # in UPDATE statements to ensure the expression has the correct type?

+ 2 - 2
django/db/backends/base/operations.py

@@ -717,8 +717,8 @@ class BaseDatabaseOperations:
             raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
             raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
         return self.explain_prefix
         return self.explain_prefix
 
 
-    def insert_statement(self, ignore_conflicts=False):
+    def insert_statement(self, on_conflict=None):
         return 'INSERT INTO'
         return 'INSERT INTO'
 
 
-    def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
         return ''
         return ''

+ 1 - 0
django/db/backends/mysql/features.py

@@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_select_difference = False
     supports_select_difference = False
     supports_slicing_ordering_in_compound = True
     supports_slicing_ordering_in_compound = True
     supports_index_on_text_field = False
     supports_index_on_text_field = False
+    supports_update_conflicts = True
     create_test_procedure_without_params_sql = """
     create_test_procedure_without_params_sql = """
         CREATE PROCEDURE test_procedure ()
         CREATE PROCEDURE test_procedure ()
         BEGIN
         BEGIN

+ 29 - 2
django/db/backends/mysql/operations.py

@@ -4,6 +4,7 @@ from django.conf import settings
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import split_tzname_delta
 from django.db.backends.utils import split_tzname_delta
 from django.db.models import Exists, ExpressionWrapper, Lookup
 from django.db.models import Exists, ExpressionWrapper, Lookup
+from django.db.models.constants import OnConflict
 from django.utils import timezone
 from django.utils import timezone
 from django.utils.encoding import force_str
 from django.utils.encoding import force_str
 
 
@@ -365,8 +366,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         match_option = 'c' if lookup_type == 'regex' else 'i'
         match_option = 'c' if lookup_type == 'regex' else 'i'
         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
         return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
 
 
-    def insert_statement(self, ignore_conflicts=False):
+    def insert_statement(self, on_conflict=None):
-        return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
+        if on_conflict == OnConflict.IGNORE:
+            return 'INSERT IGNORE INTO'
+        return super().insert_statement(on_conflict=on_conflict)
 
 
     def lookup_cast(self, lookup_type, internal_type=None):
     def lookup_cast(self, lookup_type, internal_type=None):
         lookup = '%s'
         lookup = '%s'
@@ -388,3 +391,27 @@ class DatabaseOperations(BaseDatabaseOperations):
         if getattr(expression, 'conditional', False):
         if getattr(expression, 'conditional', False):
             return False
             return False
         return super().conditional_expression_supported_in_where_clause(expression)
         return super().conditional_expression_supported_in_where_clause(expression)
+
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
+        if on_conflict == OnConflict.UPDATE:
+            conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s'
+            field_sql = '%(field)s = VALUES(%(field)s)'
+            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
+            # aliases for the new row and its columns available in MySQL
+            # 8.0.19+.
+            if not self.connection.mysql_is_mariadb:
+                if self.connection.mysql_version >= (8, 0, 19):
+                    conflict_suffix_sql = f'AS new {conflict_suffix_sql}'
+                    field_sql = '%(field)s = new.%(field)s'
+            # VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
+            elif self.connection.mysql_version >= (10, 3, 3):
+                field_sql = '%(field)s = VALUE(%(field)s)'
+
+            fields = ', '.join([
+                field_sql % {'field': field}
+                for field in map(self.quote_name, update_fields)
+            ])
+            return conflict_suffix_sql % {'fields': fields}
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 2 - 0
django/db/backends/postgresql/features.py

@@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_deferrable_unique_constraints = True
     supports_deferrable_unique_constraints = True
     has_json_operators = True
     has_json_operators = True
     json_key_contains_list_matching_requires_list = True
     json_key_contains_list_matching_requires_list = True
+    supports_update_conflicts = True
+    supports_update_conflicts_with_target = True
     test_collations = {
     test_collations = {
         'non_default': 'sv-x-icu',
         'non_default': 'sv-x-icu',
         'swedish_ci': 'sv-x-icu',
         'swedish_ci': 'sv-x-icu',

+ 15 - 2
django/db/backends/postgresql/operations.py

@@ -3,6 +3,7 @@ from psycopg2.extras import Inet
 from django.conf import settings
 from django.conf import settings
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import split_tzname_delta
 from django.db.backends.utils import split_tzname_delta
+from django.db.models.constants import OnConflict
 
 
 
 
 class DatabaseOperations(BaseDatabaseOperations):
 class DatabaseOperations(BaseDatabaseOperations):
@@ -272,5 +273,17 @@ class DatabaseOperations(BaseDatabaseOperations):
             prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
             prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
         return prefix
         return prefix
 
 
-    def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
-        return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
+        if on_conflict == OnConflict.IGNORE:
+            return 'ON CONFLICT DO NOTHING'
+        if on_conflict == OnConflict.UPDATE:
+            return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
+                ', '.join(map(self.quote_name, unique_fields)),
+                ', '.join([
+                    f'{field} = EXCLUDED.{field}'
+                    for field in map(self.quote_name, update_fields)
+                ]),
+            )
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 2 - 0
django/db/backends/sqlite3/features.py

@@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
     order_by_nulls_first = True
     order_by_nulls_first = True
     supports_json_field_contains = False
     supports_json_field_contains = False
+    supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
+    supports_update_conflicts_with_target = supports_update_conflicts
     test_collations = {
     test_collations = {
         'ci': 'nocase',
         'ci': 'nocase',
         'cs': 'binary',
         'cs': 'binary',

+ 21 - 2
django/db/backends/sqlite3/operations.py

@@ -8,6 +8,7 @@ from django.conf import settings
 from django.core.exceptions import FieldError
 from django.core.exceptions import FieldError
 from django.db import DatabaseError, NotSupportedError, models
 from django.db import DatabaseError, NotSupportedError, models
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.base.operations import BaseDatabaseOperations
+from django.db.models.constants import OnConflict
 from django.db.models.expressions import Col
 from django.db.models.expressions import Col
 from django.utils import timezone
 from django.utils import timezone
 from django.utils.dateparse import parse_date, parse_datetime, parse_time
 from django.utils.dateparse import parse_date, parse_datetime, parse_time
@@ -370,8 +371,10 @@ class DatabaseOperations(BaseDatabaseOperations):
             return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
             return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
         return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
         return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
 
 
-    def insert_statement(self, ignore_conflicts=False):
+    def insert_statement(self, on_conflict=None):
-        return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
+        if on_conflict == OnConflict.IGNORE:
+            return 'INSERT OR IGNORE INTO'
+        return super().insert_statement(on_conflict=on_conflict)
 
 
     def return_insert_columns(self, fields):
     def return_insert_columns(self, fields):
         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
         # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
@@ -384,3 +387,19 @@ class DatabaseOperations(BaseDatabaseOperations):
             ) for field in fields
             ) for field in fields
         ]
         ]
         return 'RETURNING %s' % ', '.join(columns), ()
         return 'RETURNING %s' % ', '.join(columns), ()
+
+    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
+        if (
+            on_conflict == OnConflict.UPDATE and
+            self.connection.features.supports_update_conflicts_with_target
+        ):
+            return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
+                ', '.join(map(self.quote_name, unique_fields)),
+                ', '.join([
+                    f'{field} = EXCLUDED.{field}'
+                    for field in map(self.quote_name, update_fields)
+                ]),
+            )
+        return super().on_conflict_suffix_sql(
+            fields, on_conflict, update_fields, unique_fields,
+        )

+ 6 - 0
django/db/models/constants.py

@@ -1,6 +1,12 @@
 """
 """
 Constants used across the ORM in general.
 Constants used across the ORM in general.
 """
 """
+from enum import Enum
 
 
 # Separator used to split filter strings apart.
 # Separator used to split filter strings apart.
 LOOKUP_SEP = '__'
 LOOKUP_SEP = '__'
+
+
+class OnConflict(Enum):
+    IGNORE = 'ignore'
+    UPDATE = 'update'

+ 106 - 13
django/db/models/query.py

@@ -15,7 +15,7 @@ from django.db import (
     router, transaction,
     router, transaction,
 )
 )
 from django.db.models import AutoField, DateField, DateTimeField, sql
 from django.db.models import AutoField, DateField, DateTimeField, sql
-from django.db.models.constants import LOOKUP_SEP
+from django.db.models.constants import LOOKUP_SEP, OnConflict
 from django.db.models.deletion import Collector
 from django.db.models.deletion import Collector
 from django.db.models.expressions import Case, Expression, F, Ref, Value, When
 from django.db.models.expressions import Case, Expression, F, Ref, Value, When
 from django.db.models.functions import Cast, Trunc
 from django.db.models.functions import Cast, Trunc
@@ -466,7 +466,69 @@ class QuerySet:
                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
                 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
             obj._prepare_related_fields_for_save(operation_name='bulk_create')
             obj._prepare_related_fields_for_save(operation_name='bulk_create')
 
 
-    def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
+    def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
+        if ignore_conflicts and update_conflicts:
+            raise ValueError(
+                'ignore_conflicts and update_conflicts are mutually exclusive.'
+            )
+        db_features = connections[self.db].features
+        if ignore_conflicts:
+            if not db_features.supports_ignore_conflicts:
+                raise NotSupportedError(
+                    'This database backend does not support ignoring conflicts.'
+                )
+            return OnConflict.IGNORE
+        elif update_conflicts:
+            if not db_features.supports_update_conflicts:
+                raise NotSupportedError(
+                    'This database backend does not support updating conflicts.'
+                )
+            if not update_fields:
+                raise ValueError(
+                    'Fields that will be updated when a row insertion fails '
+                    'on conflicts must be provided.'
+                )
+            if unique_fields and not db_features.supports_update_conflicts_with_target:
+                raise NotSupportedError(
+                    'This database backend does not support updating '
+                    'conflicts with specifying unique fields that can trigger '
+                    'the upsert.'
+                )
+            if not unique_fields and db_features.supports_update_conflicts_with_target:
+                raise ValueError(
+                    'Unique fields that can trigger the upsert must be '
+                    'provided.'
+                )
+            # Updating primary keys and non-concrete fields is forbidden.
+            update_fields = [self.model._meta.get_field(name) for name in update_fields]
+            if any(not f.concrete or f.many_to_many for f in update_fields):
+                raise ValueError(
+                    'bulk_create() can only be used with concrete fields in '
+                    'update_fields.'
+                )
+            if any(f.primary_key for f in update_fields):
+                raise ValueError(
+                    'bulk_create() cannot be used with primary keys in '
+                    'update_fields.'
+                )
+            if unique_fields:
+                # Primary key is allowed in unique_fields.
+                unique_fields = [
+                    self.model._meta.get_field(name)
+                    for name in unique_fields if name != 'pk'
+                ]
+                if any(not f.concrete or f.many_to_many for f in unique_fields):
+                    raise ValueError(
+                        'bulk_create() can only be used with concrete fields '
+                        'in unique_fields.'
+                    )
+            return OnConflict.UPDATE
+        return None
+
+    def bulk_create(
+        self, objs, batch_size=None, ignore_conflicts=False,
+        update_conflicts=False, update_fields=None, unique_fields=None,
+    ):
         """
         """
         Insert each of the instances into the database. Do *not* call
         Insert each of the instances into the database. Do *not* call
         save() on each of the instances, do not send any pre/post_save
         save() on each of the instances, do not send any pre/post_save
@@ -497,6 +559,12 @@ class QuerySet:
                 raise ValueError("Can't bulk create a multi-table inherited model")
                 raise ValueError("Can't bulk create a multi-table inherited model")
         if not objs:
         if not objs:
             return objs
             return objs
+        on_conflict = self._check_bulk_create_options(
+            ignore_conflicts,
+            update_conflicts,
+            update_fields,
+            unique_fields,
+        )
         self._for_write = True
         self._for_write = True
         opts = self.model._meta
         opts = self.model._meta
         fields = opts.concrete_fields
         fields = opts.concrete_fields
@@ -506,7 +574,12 @@ class QuerySet:
             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
             objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
             if objs_with_pk:
             if objs_with_pk:
                 returned_columns = self._batched_insert(
                 returned_columns = self._batched_insert(
-                    objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
+                    objs_with_pk,
+                    fields,
+                    batch_size,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
                 )
                 )
                 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
                 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
                     for result, field in zip(results, opts.db_returning_fields):
                     for result, field in zip(results, opts.db_returning_fields):
@@ -518,10 +591,15 @@ class QuerySet:
             if objs_without_pk:
             if objs_without_pk:
                 fields = [f for f in fields if not isinstance(f, AutoField)]
                 fields = [f for f in fields if not isinstance(f, AutoField)]
                 returned_columns = self._batched_insert(
                 returned_columns = self._batched_insert(
-                    objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
+                    objs_without_pk,
+                    fields,
+                    batch_size,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
                 )
                 )
                 connection = connections[self.db]
                 connection = connections[self.db]
-                if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
+                if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
                     assert len(returned_columns) == len(objs_without_pk)
                     assert len(returned_columns) == len(objs_without_pk)
                 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
                 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
                     for result, field in zip(results, opts.db_returning_fields):
                     for result, field in zip(results, opts.db_returning_fields):
@@ -1293,7 +1371,10 @@ class QuerySet:
     # PRIVATE METHODS #
     # PRIVATE METHODS #
     ###################
     ###################
 
 
-    def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
+    def _insert(
+        self, objs, fields, returning_fields=None, raw=False, using=None,
+        on_conflict=None, update_fields=None, unique_fields=None,
+    ):
         """
         """
         Insert a new record for the given model. This provides an interface to
         Insert a new record for the given model. This provides an interface to
         the InsertQuery class and is how Model.save() is implemented.
         the InsertQuery class and is how Model.save() is implemented.
@@ -1301,33 +1382,45 @@ class QuerySet:
         self._for_write = True
         self._for_write = True
         if using is None:
         if using is None:
             using = self.db
             using = self.db
-        query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
+        query = sql.InsertQuery(
+            self.model,
+            on_conflict=on_conflict,
+            update_fields=update_fields,
+            unique_fields=unique_fields,
+        )
         query.insert_values(fields, objs, raw=raw)
         query.insert_values(fields, objs, raw=raw)
         return query.get_compiler(using=using).execute_sql(returning_fields)
         return query.get_compiler(using=using).execute_sql(returning_fields)
     _insert.alters_data = True
     _insert.alters_data = True
     _insert.queryset_only = False
     _insert.queryset_only = False
 
 
-    def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
+    def _batched_insert(
+        self, objs, fields, batch_size, on_conflict=None, update_fields=None,
+        unique_fields=None,
+    ):
         """
         """
         Helper method for bulk_create() to insert objs one batch at a time.
         Helper method for bulk_create() to insert objs one batch at a time.
         """
         """
         connection = connections[self.db]
         connection = connections[self.db]
-        if ignore_conflicts and not connection.features.supports_ignore_conflicts:
-            raise NotSupportedError('This database backend does not support ignoring conflicts.')
         ops = connection.ops
         ops = connection.ops
         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
         max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
         batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
         inserted_rows = []
         inserted_rows = []
         bulk_return = connection.features.can_return_rows_from_bulk_insert
         bulk_return = connection.features.can_return_rows_from_bulk_insert
         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
         for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
-            if bulk_return and not ignore_conflicts:
+            if bulk_return and on_conflict is None:
                 inserted_rows.extend(self._insert(
                 inserted_rows.extend(self._insert(
                     item, fields=fields, using=self.db,
                     item, fields=fields, using=self.db,
                     returning_fields=self.model._meta.db_returning_fields,
                     returning_fields=self.model._meta.db_returning_fields,
-                    ignore_conflicts=ignore_conflicts,
                 ))
                 ))
             else:
             else:
-                self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
+                self._insert(
+                    item,
+                    fields=fields,
+                    using=self.db,
+                    on_conflict=on_conflict,
+                    update_fields=update_fields,
+                    unique_fields=unique_fields,
+                )
         return inserted_rows
         return inserted_rows
 
 
     def _chain(self):
     def _chain(self):

+ 14 - 9
django/db/models/sql/compiler.py

@@ -1387,7 +1387,9 @@ class SQLInsertCompiler(SQLCompiler):
         # going to be column names (so we can avoid the extra overhead).
         # going to be column names (so we can avoid the extra overhead).
         qn = self.connection.ops.quote_name
         qn = self.connection.ops.quote_name
         opts = self.query.get_meta()
         opts = self.query.get_meta()
-        insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts)
+        insert_statement = self.connection.ops.insert_statement(
+            on_conflict=self.query.on_conflict,
+        )
         result = ['%s %s' % (insert_statement, qn(opts.db_table))]
         result = ['%s %s' % (insert_statement, qn(opts.db_table))]
         fields = self.query.fields or [opts.pk]
         fields = self.query.fields or [opts.pk]
         result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
         result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
@@ -1410,8 +1412,11 @@ class SQLInsertCompiler(SQLCompiler):
 
 
         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
         placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
 
 
-        ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
+        on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql(
-            ignore_conflicts=self.query.ignore_conflicts
+            fields,
+            self.query.on_conflict,
+            self.query.update_fields,
+            self.query.unique_fields,
         )
         )
         if self.returning_fields and self.connection.features.can_return_columns_from_insert:
         if self.returning_fields and self.connection.features.can_return_columns_from_insert:
             if self.connection.features.can_return_rows_from_bulk_insert:
             if self.connection.features.can_return_rows_from_bulk_insert:
@@ -1420,8 +1425,8 @@ class SQLInsertCompiler(SQLCompiler):
             else:
             else:
                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
                 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
                 params = [param_rows[0]]
                 params = [param_rows[0]]
-            if ignore_conflicts_suffix_sql:
+            if on_conflict_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+                result.append(on_conflict_suffix_sql)
             # Skip empty r_sql to allow subclasses to customize behavior for
             # Skip empty r_sql to allow subclasses to customize behavior for
             # 3rd party backends. Refs #19096.
             # 3rd party backends. Refs #19096.
             r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
             r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
@@ -1432,12 +1437,12 @@ class SQLInsertCompiler(SQLCompiler):
 
 
         if can_bulk:
         if can_bulk:
             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
             result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
-            if ignore_conflicts_suffix_sql:
+            if on_conflict_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+                result.append(on_conflict_suffix_sql)
             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
             return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
         else:
         else:
-            if ignore_conflicts_suffix_sql:
+            if on_conflict_suffix_sql:
-                result.append(ignore_conflicts_suffix_sql)
+                result.append(on_conflict_suffix_sql)
             return [
             return [
                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
                 for p, vals in zip(placeholder_rows, param_rows)
                 for p, vals in zip(placeholder_rows, param_rows)

+ 4 - 2
django/db/models/sql/subqueries.py

@@ -138,11 +138,13 @@ class UpdateQuery(Query):
 class InsertQuery(Query):
 class InsertQuery(Query):
     compiler = 'SQLInsertCompiler'
     compiler = 'SQLInsertCompiler'
 
 
-    def __init__(self, *args, ignore_conflicts=False, **kwargs):
+    def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self.fields = []
         self.fields = []
         self.objs = []
         self.objs = []
-        self.ignore_conflicts = ignore_conflicts
+        self.on_conflict = on_conflict
+        self.update_fields = update_fields or []
+        self.unique_fields = unique_fields or []
 
 
     def insert_values(self, fields, objs, raw=False):
     def insert_values(self, fields, objs, raw=False):
         self.fields = fields
         self.fields = fields

+ 18 - 4
docs/ref/models/querysets.txt

@@ -2155,7 +2155,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
 ``bulk_create()``
 ``bulk_create()``
 ~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~
 
 
-.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False)
+.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None)
 
 
 This method inserts the provided list of objects into the database in an
 This method inserts the provided list of objects into the database in an
 efficient manner (generally only 1 query, no matter how many objects there
 efficient manner (generally only 1 query, no matter how many objects there
@@ -2198,9 +2198,17 @@ where the default is such that at most 999 variables per query are used.
 
 
 On databases that support it (all but Oracle), setting the ``ignore_conflicts``
 On databases that support it (all but Oracle), setting the ``ignore_conflicts``
 parameter to ``True`` tells the database to ignore failure to insert any rows
 parameter to ``True`` tells the database to ignore failure to insert any rows
-that fail constraints such as duplicate unique values. Enabling this parameter
+that fail constraints such as duplicate unique values.
-disables setting the primary key on each model instance (if the database
+
-normally supports it).
+On databases that support it (all except Oracle and SQLite < 3.24), setting the
+``update_conflicts`` parameter to ``True``, tells the database to update
+``update_fields`` when a row insertion fails on conflicts. On PostgreSQL and
+SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
+be in conflict must be provided.
+
+Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable
+setting the primary key on each model instance (if the database normally
+support it).
 
 
 .. warning::
 .. warning::
 
 
@@ -2217,6 +2225,12 @@ normally supports it).
 
 
     Support for the fetching primary key attributes on SQLite 3.35+ was added.
     Support for the fetching primary key attributes on SQLite 3.35+ was added.
 
 
+.. versionchanged:: 4.1
+
+    The ``update_conflicts``, ``update_fields``, and ``unique_fields``
+    parameters were added to support updating fields when a row insertion fails
+    on conflict.
+
 ``bulk_update()``
 ``bulk_update()``
 ~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~
 
 

+ 12 - 0
docs/releases/4.1.txt

@@ -232,6 +232,10 @@ Models
   in order to reduce the number of failed requests, e.g. after database server
   in order to reduce the number of failed requests, e.g. after database server
   restart.
   restart.
 
 
+* :meth:`.QuerySet.bulk_create` now supports updating fields when a row
+  insertion fails uniqueness constraints. This is supported on MariaDB, MySQL,
+  PostgreSQL, and SQLite 3.24+.
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 
@@ -298,6 +302,14 @@ backends.
 * ``DatabaseIntrospection.get_key_columns()`` is removed. Use
 * ``DatabaseIntrospection.get_key_columns()`` is removed. Use
   ``DatabaseIntrospection.get_relations()`` instead.
   ``DatabaseIntrospection.get_relations()`` instead.
 
 
+* ``DatabaseOperations.ignore_conflicts_suffix_sql()`` method is replaced by
+  ``DatabaseOperations.on_conflict_suffix_sql()`` that accepts the ``fields``,
+  ``on_conflict``, ``update_fields``, and ``unique_fields`` arguments.
+
+* The ``ignore_conflicts`` argument of the
+  ``DatabaseOperations.insert_statement()`` method is replaced by
+  ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
+
 Dropped support for MariaDB 10.2
 Dropped support for MariaDB 10.2
 --------------------------------
 --------------------------------
 
 

+ 21 - 0
tests/bulk_create/models.py

@@ -16,6 +16,14 @@ class Country(models.Model):
     iso_two_letter = models.CharField(max_length=2)
     iso_two_letter = models.CharField(max_length=2)
     description = models.TextField()
     description = models.TextField()
 
 
+    class Meta:
+        constraints = [
+            models.UniqueConstraint(
+                fields=['iso_two_letter', 'name'],
+                name='country_name_iso_unique',
+            ),
+        ]
+
 
 
 class ProxyCountry(Country):
 class ProxyCountry(Country):
     class Meta:
     class Meta:
@@ -58,6 +66,13 @@ class State(models.Model):
 class TwoFields(models.Model):
 class TwoFields(models.Model):
     f1 = models.IntegerField(unique=True)
     f1 = models.IntegerField(unique=True)
     f2 = models.IntegerField(unique=True)
     f2 = models.IntegerField(unique=True)
+    name = models.CharField(max_length=15, null=True)
+
+
+class UpsertConflict(models.Model):
+    number = models.IntegerField(unique=True)
+    rank = models.IntegerField()
+    name = models.CharField(max_length=15)
 
 
 
 
 class NoFields(models.Model):
 class NoFields(models.Model):
@@ -103,3 +118,9 @@ class NullableFields(models.Model):
     text_field = models.TextField(null=True, default='text')
     text_field = models.TextField(null=True, default='text')
     url_field = models.URLField(null=True, default='/')
     url_field = models.URLField(null=True, default='/')
     uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
     uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
+
+
+class RelatedModel(models.Model):
+    name = models.CharField(max_length=15, null=True)
+    country = models.OneToOneField(Country, models.CASCADE, primary_key=True)
+    big_auto_fields = models.ManyToManyField(BigAutoFieldModel)

+ 285 - 7
tests/bulk_create/tests.py

@@ -1,7 +1,11 @@
 from math import ceil
 from math import ceil
 from operator import attrgetter
 from operator import attrgetter
 
 
-from django.db import IntegrityError, NotSupportedError, connection
+from django.core.exceptions import FieldDoesNotExist
+from django.db import (
+    IntegrityError, NotSupportedError, OperationalError, ProgrammingError,
+    connection,
+)
 from django.db.models import FileField, Value
 from django.db.models import FileField, Value
 from django.db.models.functions import Lower
 from django.db.models.functions import Lower
 from django.test import (
 from django.test import (
@@ -11,7 +15,8 @@ from django.test import (
 from .models import (
 from .models import (
     BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria,
     BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria,
     ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry,
     ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry,
-    Restaurant, SmallAutoFieldModel, State, TwoFields,
+    RelatedModel, Restaurant, SmallAutoFieldModel, State, TwoFields,
+    UpsertConflict,
 )
 )
 
 
 
 
@@ -53,10 +58,10 @@ class BulkCreateTests(TestCase):
     @skipUnlessDBFeature('has_bulk_insert')
     @skipUnlessDBFeature('has_bulk_insert')
     def test_long_and_short_text(self):
     def test_long_and_short_text(self):
         Country.objects.bulk_create([
         Country.objects.bulk_create([
-            Country(description='a' * 4001),
+            Country(description='a' * 4001, iso_two_letter='A'),
-            Country(description='a'),
+            Country(description='a', iso_two_letter='B'),
-            Country(description='Ж' * 2001),
+            Country(description='Ж' * 2001, iso_two_letter='C'),
-            Country(description='Ж'),
+            Country(description='Ж', iso_two_letter='D'),
         ])
         ])
         self.assertEqual(Country.objects.count(), 4)
         self.assertEqual(Country.objects.count(), 4)
 
 
@@ -218,7 +223,7 @@ class BulkCreateTests(TestCase):
 
 
     @skipUnlessDBFeature('has_bulk_insert')
     @skipUnlessDBFeature('has_bulk_insert')
     def test_explicit_batch_size_respects_max_batch_size(self):
     def test_explicit_batch_size_respects_max_batch_size(self):
-        objs = [Country() for i in range(1000)]
+        objs = [Country(name=f'Country {i}') for i in range(1000)]
         fields = ['name', 'iso_two_letter', 'description']
         fields = ['name', 'iso_two_letter', 'description']
         max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1)
         max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1)
         with self.assertNumQueries(ceil(len(objs) / max_batch_size)):
         with self.assertNumQueries(ceil(len(objs) / max_batch_size)):
@@ -352,3 +357,276 @@ class BulkCreateTests(TestCase):
         msg = 'Batch size must be a positive integer.'
         msg = 'Batch size must be a positive integer.'
         with self.assertRaisesMessage(ValueError, msg):
         with self.assertRaisesMessage(ValueError, msg):
             Country.objects.bulk_create([], batch_size=-1)
             Country.objects.bulk_create([], batch_size=-1)
+
+    @skipIfDBFeature('supports_update_conflicts')
+    def test_update_conflicts_unsupported(self):
+        msg = 'This database backend does not support updating conflicts.'
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            Country.objects.bulk_create(self.data, update_conflicts=True)
+
+    @skipUnlessDBFeature('supports_ignore_conflicts', 'supports_update_conflicts')
+    def test_ignore_update_conflicts_exclusive(self):
+        msg = 'ignore_conflicts and update_conflicts are mutually exclusive'
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                ignore_conflicts=True,
+                update_conflicts=True,
+            )
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    def test_update_conflicts_no_update_fields(self):
+        msg = (
+            'Fields that will be updated when a row insertion fails on '
+            'conflicts must be provided.'
+        )
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(self.data, update_conflicts=True)
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_field_unsupported(self):
+        msg = (
+            'This database backend does not support updating conflicts with '
+            'specifying unique fields that can trigger the upsert.'
+        )
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['f2'],
+                unique_fields=['f1'],
+            )
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    def test_update_conflicts_nonexistent_update_fields(self):
+        unique_fields = None
+        if connection.features.supports_update_conflicts_with_target:
+            unique_fields = ['f1']
+        msg = "TwoFields has no field named 'nonexistent'"
+        with self.assertRaisesMessage(FieldDoesNotExist, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['nonexistent'],
+                unique_fields=unique_fields,
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_unique_fields_required(self):
+        msg = 'Unique fields that can trigger the upsert must be provided.'
+        with self.assertRaisesMessage(ValueError, msg):
+            TwoFields.objects.bulk_create(
+                [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
+                update_conflicts=True,
+                update_fields=['f1'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_invalid_update_fields(self):
+        msg = (
+            'bulk_create() can only be used with concrete fields in '
+            'update_fields.'
+        )
+        # Reverse one-to-one relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                update_conflicts=True,
+                update_fields=['relatedmodel'],
+                unique_fields=['pk'],
+            )
+        # Many-to-many relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            RelatedModel.objects.bulk_create(
+                [RelatedModel(country=self.data[0])],
+                update_conflicts=True,
+                update_fields=['big_auto_fields'],
+                unique_fields=['country'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_pk_in_update_fields(self):
+        msg = 'bulk_create() cannot be used with primary keys in update_fields.'
+        with self.assertRaisesMessage(ValueError, msg):
+            BigAutoFieldModel.objects.bulk_create(
+                [BigAutoFieldModel()],
+                update_conflicts=True,
+                update_fields=['id'],
+                unique_fields=['id'],
+            )
+
+    @skipUnlessDBFeature(
+        'supports_update_conflicts', 'supports_update_conflicts_with_target',
+    )
+    def test_update_conflicts_invalid_unique_fields(self):
+        msg = (
+            'bulk_create() can only be used with concrete fields in '
+            'unique_fields.'
+        )
+        # Reverse one-to-one relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            Country.objects.bulk_create(
+                self.data,
+                update_conflicts=True,
+                update_fields=['name'],
+                unique_fields=['relatedmodel'],
+            )
+        # Many-to-many relationship.
+        with self.assertRaisesMessage(ValueError, msg):
+            RelatedModel.objects.bulk_create(
+                [RelatedModel(country=self.data[0])],
+                update_conflicts=True,
+                update_fields=['name'],
+                unique_fields=['big_auto_fields'],
+            )
+
+    def _test_update_conflicts_two_fields(self, unique_fields):
+        TwoFields.objects.bulk_create([
+            TwoFields(f1=1, f2=1, name='a'),
+            TwoFields(f1=2, f2=2, name='b'),
+        ])
+        self.assertEqual(TwoFields.objects.count(), 2)
+
+        conflicting_objects = [
+            TwoFields(f1=1, f2=1, name='c'),
+            TwoFields(f1=2, f2=2, name='d'),
+        ]
+        TwoFields.objects.bulk_create(
+            conflicting_objects,
+            update_conflicts=True,
+            unique_fields=unique_fields,
+            update_fields=['name'],
+        )
+        self.assertEqual(TwoFields.objects.count(), 2)
+        self.assertCountEqual(TwoFields.objects.values('f1', 'f2', 'name'), [
+            {'f1': 1, 'f2': 1, 'name': 'c'},
+            {'f1': 2, 'f2': 2, 'name': 'd'},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_first(self):
+        self._test_update_conflicts_two_fields(['f1'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_second(self):
+        self._test_update_conflicts_two_fields(['f2'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_unique_fields_both(self):
+        with self.assertRaises((OperationalError, ProgrammingError)):
+            self._test_update_conflicts_two_fields(['f1', 'f2'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_two_fields_no_unique_fields(self):
+        self._test_update_conflicts_two_fields([])
+
+    def _test_update_conflicts_unique_two_fields(self, unique_fields):
+        Country.objects.bulk_create(self.data)
+        self.assertEqual(Country.objects.count(), 4)
+
+        new_data = [
+            # Conflicting countries.
+            Country(name='Germany', iso_two_letter='DE', description=(
+                'Germany is a country in Central Europe.'
+            )),
+            Country(name='Czech Republic', iso_two_letter='CZ', description=(
+                'The Czech Republic is a landlocked country in Central Europe.'
+            )),
+            # New countries.
+            Country(name='Australia', iso_two_letter='AU'),
+            Country(name='Japan', iso_two_letter='JP', description=(
+                'Japan is an island country in East Asia.'
+            )),
+        ]
+        Country.objects.bulk_create(
+            new_data,
+            update_conflicts=True,
+            update_fields=['description'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(Country.objects.count(), 6)
+        self.assertCountEqual(Country.objects.values('iso_two_letter', 'description'), [
+            {'iso_two_letter': 'US', 'description': ''},
+            {'iso_two_letter': 'NL', 'description': ''},
+            {'iso_two_letter': 'DE', 'description': (
+                'Germany is a country in Central Europe.'
+            )},
+            {'iso_two_letter': 'CZ', 'description': (
+                'The Czech Republic is a landlocked country in Central Europe.'
+            )},
+            {'iso_two_letter': 'AU', 'description': ''},
+            {'iso_two_letter': 'JP', 'description': (
+                'Japan is an island country in East Asia.'
+            )},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_fields_both(self):
+        self._test_update_conflicts_unique_two_fields(['iso_two_letter', 'name'])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_fields_one(self):
+        with self.assertRaises((OperationalError, ProgrammingError)):
+            self._test_update_conflicts_unique_two_fields(['iso_two_letter'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self):
+        self._test_update_conflicts_unique_two_fields([])
+
+    def _test_update_conflicts(self, unique_fields):
+        UpsertConflict.objects.bulk_create([
+            UpsertConflict(number=1, rank=1, name='John'),
+            UpsertConflict(number=2, rank=2, name='Mary'),
+            UpsertConflict(number=3, rank=3, name='Hannah'),
+        ])
+        self.assertEqual(UpsertConflict.objects.count(), 3)
+
+        conflicting_objects = [
+            UpsertConflict(number=1, rank=4, name='Steve'),
+            UpsertConflict(number=2, rank=2, name='Olivia'),
+            UpsertConflict(number=3, rank=1, name='Hannah'),
+        ]
+        UpsertConflict.objects.bulk_create(
+            conflicting_objects,
+            update_conflicts=True,
+            update_fields=['name', 'rank'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(UpsertConflict.objects.count(), 3)
+        self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
+            {'number': 1, 'rank': 4, 'name': 'Steve'},
+            {'number': 2, 'rank': 2, 'name': 'Olivia'},
+            {'number': 3, 'rank': 1, 'name': 'Hannah'},
+        ])
+
+        UpsertConflict.objects.bulk_create(
+            conflicting_objects + [UpsertConflict(number=4, rank=4, name='Mark')],
+            update_conflicts=True,
+            update_fields=['name', 'rank'],
+            unique_fields=unique_fields,
+        )
+        self.assertEqual(UpsertConflict.objects.count(), 4)
+        self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
+            {'number': 1, 'rank': 4, 'name': 'Steve'},
+            {'number': 2, 'rank': 2, 'name': 'Olivia'},
+            {'number': 3, 'rank': 1, 'name': 'Hannah'},
+            {'number': 4, 'rank': 4, 'name': 'Mark'},
+        ])
+
+    @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
+    def test_update_conflicts_unique_fields(self):
+        self._test_update_conflicts(unique_fields=['number'])
+
+    @skipUnlessDBFeature('supports_update_conflicts')
+    @skipIfDBFeature('supports_update_conflicts_with_target')
+    def test_update_conflicts_no_unique_fields(self):
+        self._test_update_conflicts([])