Browse Source

Fixed #34701 -- Added support for NULLS [NOT] DISTINCT on PostgreSQL 15+.

Simon Charette 1 year ago
parent
commit
595a2abb58

+ 5 - 0
django/db/backends/base/features.py

@@ -27,6 +27,11 @@ class BaseDatabaseFeatures:
     # Does the backend allow inserting duplicate rows when a unique_together
     # constraint exists and some fields are nullable but not all of them?
     supports_partially_nullable_unique_constraints = True
+
+    # Does the backend supports specifying whether NULL values should be
+    # considered distinct in unique constraints?
+    supports_nulls_distinct_unique_constraints = False
+
     # Does the backend support initially deferrable unique constraints?
     supports_deferrable_unique_constraints = False
 

+ 31 - 4
django/db/backends/base/schema.py

@@ -129,7 +129,7 @@ class BaseDatabaseSchemaEditor:
     )
     sql_create_unique_index = (
         "CREATE UNIQUE INDEX %(name)s ON %(table)s "
-        "(%(columns)s)%(include)s%(condition)s"
+        "(%(columns)s)%(include)s%(condition)s%(nulls_distinct)s"
     )
     sql_rename_index = "ALTER INDEX %(old_name)s RENAME TO %(new_name)s"
     sql_delete_index = "DROP INDEX %(name)s"
@@ -1675,12 +1675,20 @@ class BaseDatabaseSchemaEditor:
         if deferrable == Deferrable.IMMEDIATE:
             return " DEFERRABLE INITIALLY IMMEDIATE"
 
+    def _unique_index_nulls_distinct_sql(self, nulls_distinct):
+        if nulls_distinct is False:
+            return " NULLS NOT DISTINCT"
+        elif nulls_distinct is True:
+            return " NULLS DISTINCT"
+        return ""
+
     def _unique_supported(
         self,
         condition=None,
         deferrable=None,
         include=None,
         expressions=None,
+        nulls_distinct=None,
     ):
         return (
             (not condition or self.connection.features.supports_partial_indexes)
@@ -1692,6 +1700,10 @@ class BaseDatabaseSchemaEditor:
             and (
                 not expressions or self.connection.features.supports_expression_indexes
             )
+            and (
+                nulls_distinct is None
+                or self.connection.features.supports_nulls_distinct_unique_constraints
+            )
         )
 
     def _unique_sql(
@@ -1704,17 +1716,26 @@ class BaseDatabaseSchemaEditor:
         include=None,
         opclasses=None,
         expressions=None,
+        nulls_distinct=None,
     ):
         if not self._unique_supported(
             condition=condition,
             deferrable=deferrable,
             include=include,
             expressions=expressions,
+            nulls_distinct=nulls_distinct,
         ):
             return None
-        if condition or include or opclasses or expressions:
-            # Databases support conditional, covering, and functional unique
-            # constraints via a unique index.
+
+        if (
+            condition
+            or include
+            or opclasses
+            or expressions
+            or nulls_distinct is not None
+        ):
+            # Databases support conditional, covering, functional unique,
+            # and nulls distinct constraints via a unique index.
             sql = self._create_unique_sql(
                 model,
                 fields,
@@ -1723,6 +1744,7 @@ class BaseDatabaseSchemaEditor:
                 include=include,
                 opclasses=opclasses,
                 expressions=expressions,
+                nulls_distinct=nulls_distinct,
             )
             if sql:
                 self.deferred_sql.append(sql)
@@ -1746,12 +1768,14 @@ class BaseDatabaseSchemaEditor:
         include=None,
         opclasses=None,
         expressions=None,
+        nulls_distinct=None,
     ):
         if not self._unique_supported(
             condition=condition,
             deferrable=deferrable,
             include=include,
             expressions=expressions,
+            nulls_distinct=nulls_distinct,
         ):
             return None
 
@@ -1782,6 +1806,7 @@ class BaseDatabaseSchemaEditor:
             condition=self._index_condition_sql(condition),
             deferrable=self._deferrable_constraint_sql(deferrable),
             include=self._index_include_sql(model, include),
+            nulls_distinct=self._unique_index_nulls_distinct_sql(nulls_distinct),
         )
 
     def _unique_constraint_name(self, table, columns, quote=True):
@@ -1804,12 +1829,14 @@ class BaseDatabaseSchemaEditor:
         include=None,
         opclasses=None,
         expressions=None,
+        nulls_distinct=None,
     ):
         if not self._unique_supported(
             condition=condition,
             deferrable=deferrable,
             include=include,
             expressions=expressions,
+            nulls_distinct=nulls_distinct,
         ):
             return None
         if condition or include or opclasses or expressions:

+ 7 - 0
django/db/backends/postgresql/features.py

@@ -132,6 +132,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     def is_postgresql_14(self):
         return self.connection.pg_version >= 140000
 
+    @cached_property
+    def is_postgresql_15(self):
+        return self.connection.pg_version >= 150000
+
     has_bit_xor = property(operator.attrgetter("is_postgresql_14"))
     supports_covering_spgist_indexes = property(operator.attrgetter("is_postgresql_14"))
     supports_unlimited_charfield = True
+    supports_nulls_distinct_unique_constraints = property(
+        operator.attrgetter("is_postgresql_15")
+    )

+ 23 - 0
django/db/models/base.py

@@ -2442,6 +2442,29 @@ class Model(AltersData, metaclass=ModelBase):
                         id="models.W044",
                     )
                 )
+            if not (
+                connection.features.supports_nulls_distinct_unique_constraints
+                or (
+                    "supports_nulls_distinct_unique_constraints"
+                    in cls._meta.required_db_features
+                )
+            ) and any(
+                isinstance(constraint, UniqueConstraint)
+                and constraint.nulls_distinct is not None
+                for constraint in cls._meta.constraints
+            ):
+                errors.append(
+                    checks.Warning(
+                        "%s does not support unique constraints with "
+                        "nulls distinct." % connection.display_name,
+                        hint=(
+                            "A constraint won't be created. Silence this "
+                            "warning if you don't care about it."
+                        ),
+                        obj=cls,
+                        id="models.W047",
+                    )
+                )
             fields = set(
                 chain.from_iterable(
                     (*constraint.fields, *constraint.include)

+ 25 - 4
django/db/models/constraints.py

@@ -186,6 +186,7 @@ class UniqueConstraint(BaseConstraint):
         deferrable=None,
         include=None,
         opclasses=(),
+        nulls_distinct=None,
         violation_error_code=None,
         violation_error_message=None,
     ):
@@ -223,6 +224,8 @@ class UniqueConstraint(BaseConstraint):
             raise ValueError("UniqueConstraint.include must be a list or tuple.")
         if not isinstance(opclasses, (list, tuple)):
             raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
+        if not isinstance(nulls_distinct, (NoneType, bool)):
+            raise ValueError("UniqueConstraint.nulls_distinct must be a bool.")
         if opclasses and len(fields) != len(opclasses):
             raise ValueError(
                 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
@@ -233,6 +236,7 @@ class UniqueConstraint(BaseConstraint):
         self.deferrable = deferrable
         self.include = tuple(include) if include else ()
         self.opclasses = opclasses
+        self.nulls_distinct = nulls_distinct
         self.expressions = tuple(
             F(expression) if isinstance(expression, str) else expression
             for expression in expressions
@@ -284,6 +288,7 @@ class UniqueConstraint(BaseConstraint):
             include=include,
             opclasses=self.opclasses,
             expressions=expressions,
+            nulls_distinct=self.nulls_distinct,
         )
 
     def create_sql(self, model, schema_editor):
@@ -302,6 +307,7 @@ class UniqueConstraint(BaseConstraint):
             include=include,
             opclasses=self.opclasses,
             expressions=expressions,
+            nulls_distinct=self.nulls_distinct,
         )
 
     def remove_sql(self, model, schema_editor):
@@ -318,10 +324,11 @@ class UniqueConstraint(BaseConstraint):
             include=include,
             opclasses=self.opclasses,
             expressions=expressions,
+            nulls_distinct=self.nulls_distinct,
         )
 
     def __repr__(self):
-        return "<%s:%s%s%s%s%s%s%s%s%s>" % (
+        return "<%s:%s%s%s%s%s%s%s%s%s%s>" % (
             self.__class__.__qualname__,
             "" if not self.fields else " fields=%s" % repr(self.fields),
             "" if not self.expressions else " expressions=%s" % repr(self.expressions),
@@ -330,6 +337,11 @@ class UniqueConstraint(BaseConstraint):
             "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
             "" if not self.include else " include=%s" % repr(self.include),
             "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
+            (
+                ""
+                if self.nulls_distinct is None
+                else " nulls_distinct=%r" % self.nulls_distinct
+            ),
             (
                 ""
                 if self.violation_error_code is None
@@ -353,6 +365,7 @@ class UniqueConstraint(BaseConstraint):
                 and self.include == other.include
                 and self.opclasses == other.opclasses
                 and self.expressions == other.expressions
+                and self.nulls_distinct is other.nulls_distinct
                 and self.violation_error_code == other.violation_error_code
                 and self.violation_error_message == other.violation_error_message
             )
@@ -370,6 +383,8 @@ class UniqueConstraint(BaseConstraint):
             kwargs["include"] = self.include
         if self.opclasses:
             kwargs["opclasses"] = self.opclasses
+        if self.nulls_distinct is not None:
+            kwargs["nulls_distinct"] = self.nulls_distinct
         return path, self.expressions, kwargs
 
     def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
@@ -381,9 +396,15 @@ class UniqueConstraint(BaseConstraint):
                     return
                 field = model._meta.get_field(field_name)
                 lookup_value = getattr(instance, field.attname)
-                if lookup_value is None or (
-                    lookup_value == ""
-                    and connections[using].features.interprets_empty_strings_as_nulls
+                if (
+                    self.nulls_distinct is not False
+                    and lookup_value is None
+                    or (
+                        lookup_value == ""
+                        and connections[
+                            using
+                        ].features.interprets_empty_strings_as_nulls
+                    )
                 ):
                     # A composite constraint containing NULL value cannot cause
                     # a violation since NULL != NULL in SQL.

+ 2 - 0
docs/ref/checks.txt

@@ -408,6 +408,8 @@ Models
   expression and won't be validated during the model ``full_clean()``.
 * **models.W046**: ``<database>`` does not support comments on tables
   (``db_table_comment``).
+* **models.W047**: ``<database>`` does not support unique constraints with
+  nulls distinct.
 
 Security
 --------

+ 21 - 1
docs/ref/models/constraints.txt

@@ -131,7 +131,7 @@ ensures the age field is never less than 18.
 ``UniqueConstraint``
 ====================
 
-.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_code=None, violation_error_message=None)
+.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), nulls_distinct=None, violation_error_code=None, violation_error_message=None)
 
     Creates a unique constraint in the database.
 
@@ -254,6 +254,26 @@ creates a unique index on ``username`` using ``varchar_pattern_ops``.
 
 ``opclasses`` are ignored for databases besides PostgreSQL.
 
+``nulls_distinct``
+------------------
+
+.. versionadded:: 5.0
+
+.. attribute:: UniqueConstraint.nulls_distinct
+
+Whether rows containing ``NULL`` values covered by the unique constraint should
+be considered distinct from each other. The default value is ``None`` which
+uses the database default which is ``True`` on most backends.
+
+For example::
+
+    UniqueConstraint(name="ordering", fields=["ordering"], nulls_distinct=False)
+
+creates a unique constraint that only allows one row to store a ``NULL`` value
+in the ``ordering`` column.
+
+``nulls_distinct`` is ignored for databases besides PostgreSQL 15+.
+
 ``violation_error_code``
 ------------------------
 

+ 3 - 0
docs/releases/5.0.txt

@@ -361,6 +361,9 @@ Models
   set the primary key on each model instance when the ``update_conflicts``
   parameter is enabled (if the database supports it).
 
+* The new :attr:`.UniqueConstraint.nulls_distinct` attribute allows customizing
+  the treatment of ``NULL`` values on PostgreSQL 15+.
+
 Pagination
 ~~~~~~~~~~
 

+ 58 - 0
tests/constraints/tests.py

@@ -503,6 +503,27 @@ class UniqueConstraintTests(TestCase):
         self.assertEqual(constraint, mock.ANY)
         self.assertNotEqual(constraint, another_constraint)
 
+    def test_eq_with_nulls_distinct(self):
+        constraint_1 = models.UniqueConstraint(
+            Lower("title"),
+            nulls_distinct=False,
+            name="book_func_nulls_distinct_uq",
+        )
+        constraint_2 = models.UniqueConstraint(
+            Lower("title"),
+            nulls_distinct=True,
+            name="book_func_nulls_distinct_uq",
+        )
+        constraint_3 = models.UniqueConstraint(
+            Lower("title"),
+            name="book_func_nulls_distinct_uq",
+        )
+        self.assertEqual(constraint_1, constraint_1)
+        self.assertEqual(constraint_1, mock.ANY)
+        self.assertNotEqual(constraint_1, constraint_2)
+        self.assertNotEqual(constraint_1, constraint_3)
+        self.assertNotEqual(constraint_2, constraint_3)
+
     def test_repr(self):
         fields = ["foo", "bar"]
         name = "unique_fields"
@@ -560,6 +581,18 @@ class UniqueConstraintTests(TestCase):
             "opclasses=['text_pattern_ops', 'varchar_pattern_ops']>",
         )
 
+    def test_repr_with_nulls_distinct(self):
+        constraint = models.UniqueConstraint(
+            fields=["foo", "bar"],
+            name="nulls_distinct_fields",
+            nulls_distinct=False,
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<UniqueConstraint: fields=('foo', 'bar') name='nulls_distinct_fields' "
+            "nulls_distinct=False>",
+        )
+
     def test_repr_with_expressions(self):
         constraint = models.UniqueConstraint(
             Lower("title"),
@@ -679,6 +712,24 @@ class UniqueConstraintTests(TestCase):
             },
         )
 
+    def test_deconstruction_with_nulls_distinct(self):
+        fields = ["foo", "bar"]
+        name = "unique_fields"
+        constraint = models.UniqueConstraint(
+            fields=fields, name=name, nulls_distinct=True
+        )
+        path, args, kwargs = constraint.deconstruct()
+        self.assertEqual(path, "django.db.models.UniqueConstraint")
+        self.assertEqual(args, ())
+        self.assertEqual(
+            kwargs,
+            {
+                "fields": tuple(fields),
+                "name": name,
+                "nulls_distinct": True,
+            },
+        )
+
     def test_deconstruction_with_expressions(self):
         name = "unique_fields"
         constraint = models.UniqueConstraint(Lower("title"), name=name)
@@ -1029,6 +1080,13 @@ class UniqueConstraintTests(TestCase):
                 opclasses="jsonb_path_ops",
             )
 
+    def test_invalid_nulls_distinct_argument(self):
+        msg = "UniqueConstraint.nulls_distinct must be a bool."
+        with self.assertRaisesMessage(ValueError, msg):
+            models.UniqueConstraint(
+                name="uniq_opclasses", fields=["field"], nulls_distinct="NULLS DISTINCT"
+            )
+
     def test_opclasses_and_fields_same_length(self):
         msg = (
             "UniqueConstraint.fields and UniqueConstraint.opclasses must have "

+ 46 - 0
tests/invalid_models_tests/test_models.py

@@ -2753,6 +2753,52 @@ class ConstraintsTests(TestCase):
 
         self.assertEqual(Model.check(databases=self.databases), [])
 
+    def test_unique_constraint_nulls_distinct(self):
+        class Model(models.Model):
+            name = models.CharField(max_length=10)
+
+            class Meta:
+                constraints = [
+                    models.UniqueConstraint(
+                        fields=["name"],
+                        name="name_uq_distinct_null",
+                        nulls_distinct=True,
+                    ),
+                ]
+
+        warn = Warning(
+            f"{connection.display_name} does not support unique constraints with nulls "
+            "distinct.",
+            hint=(
+                "A constraint won't be created. Silence this warning if you don't care "
+                "about it."
+            ),
+            obj=Model,
+            id="models.W047",
+        )
+        expected = (
+            []
+            if connection.features.supports_nulls_distinct_unique_constraints
+            else [warn]
+        )
+        self.assertEqual(Model.check(databases=self.databases), expected)
+
+    def test_unique_constraint_nulls_distinct_required_db_features(self):
+        class Model(models.Model):
+            name = models.CharField(max_length=10)
+
+            class Meta:
+                constraints = [
+                    models.UniqueConstraint(
+                        fields=["name"],
+                        name="name_uq_distinct_null",
+                        nulls_distinct=True,
+                    ),
+                ]
+                required_db_features = {"supports_nulls_distinct_unique_constraints"}
+
+        self.assertEqual(Model.check(databases=self.databases), [])
+
     @skipUnlessDBFeature("supports_expression_indexes")
     def test_func_unique_constraint_expression_custom_lookup(self):
         class Model(models.Model):

+ 37 - 0
tests/schema/tests.py

@@ -3318,6 +3318,43 @@ class SchemaTests(TransactionTestCase):
             with self.assertRaises(DatabaseError):
                 editor.add_constraint(Author, constraint)
 
+    @skipUnlessDBFeature("supports_nulls_distinct_unique_constraints")
+    def test_unique_constraint_nulls_distinct(self):
+        with connection.schema_editor() as editor:
+            editor.create_model(Author)
+        nulls_distinct = UniqueConstraint(
+            F("height"), name="distinct_height", nulls_distinct=True
+        )
+        nulls_not_distinct = UniqueConstraint(
+            F("weight"), name="not_distinct_weight", nulls_distinct=False
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(Author, nulls_distinct)
+            editor.add_constraint(Author, nulls_not_distinct)
+        Author.objects.create(name="", height=None, weight=None)
+        Author.objects.create(name="", height=None, weight=1)
+        with self.assertRaises(IntegrityError):
+            Author.objects.create(name="", height=1, weight=None)
+        with connection.schema_editor() as editor:
+            editor.remove_constraint(Author, nulls_distinct)
+            editor.remove_constraint(Author, nulls_not_distinct)
+        constraints = self.get_constraints(Author._meta.db_table)
+        self.assertNotIn(nulls_distinct.name, constraints)
+        self.assertNotIn(nulls_not_distinct.name, constraints)
+
+    @skipIfDBFeature("supports_nulls_distinct_unique_constraints")
+    def test_unique_constraint_nulls_distinct_unsupported(self):
+        # UniqueConstraint is ignored on databases that don't support
+        # NULLS [NOT] DISTINCT.
+        with connection.schema_editor() as editor:
+            editor.create_model(Author)
+        constraint = UniqueConstraint(
+            F("name"), name="func_name_uq", nulls_distinct=True
+        )
+        with connection.schema_editor() as editor, self.assertNumQueries(0):
+            self.assertIsNone(editor.add_constraint(Author, constraint))
+            self.assertIsNone(editor.remove_constraint(Author, constraint))
+
     @ignore_warnings(category=RemovedInDjango51Warning)
     def test_index_together(self):
         """

+ 14 - 0
tests/validation/models.py

@@ -217,3 +217,17 @@ class UniqueConstraintConditionProduct(models.Model):
                 condition=models.Q(color__isnull=True),
             ),
         ]
+
+
+class UniqueConstraintNullsDistinctProduct(models.Model):
+    name = models.CharField(max_length=255, blank=True, null=True)
+
+    class Meta:
+        required_db_features = {"supports_nulls_distinct_unique_constraints"}
+        constraints = [
+            models.UniqueConstraint(
+                fields=["name"],
+                name="name_nulls_not_distinct_uniq",
+                nulls_distinct=False,
+            ),
+        ]

+ 23 - 0
tests/validation/test_constraints.py

@@ -6,6 +6,7 @@ from .models import (
     ChildUniqueConstraintProduct,
     Product,
     UniqueConstraintConditionProduct,
+    UniqueConstraintNullsDistinctProduct,
     UniqueConstraintProduct,
 )
 
@@ -93,3 +94,25 @@ class PerformConstraintChecksTest(TestCase):
         UniqueConstraintConditionProduct.objects.create(name="product")
         product = UniqueConstraintConditionProduct(name="product")
         product.full_clean(validate_constraints=False)
+
+    @skipUnlessDBFeature("supports_nulls_distinct_unique_constraints")
+    def test_full_clean_with_nulls_distinct_unique_constraints(self):
+        UniqueConstraintNullsDistinctProduct.objects.create(name=None)
+        product = UniqueConstraintNullsDistinctProduct(name=None)
+        with self.assertRaises(ValidationError) as cm:
+            product.full_clean()
+        self.assertEqual(
+            cm.exception.message_dict,
+            {
+                "name": [
+                    "Unique constraint nulls distinct product with this Name "
+                    "already exists."
+                ]
+            },
+        )
+
+    @skipUnlessDBFeature("supports_nulls_distinct_unique_constraints")
+    def test_full_clean_with_nulls_distinct_unique_constraints_disabled(self):
+        UniqueConstraintNullsDistinctProduct.objects.create(name=None)
+        product = UniqueConstraintNullsDistinctProduct(name=None)
+        product.full_clean(validate_constraints=False)