Browse Source

Fixed #31702 -- Added support for PostgreSQL opclasses in UniqueConstraint.

Hannes Ljungberg 4 years ago
parent
commit
7edc6e53a7

+ 17 - 7
django/db/backends/base/schema.py

@@ -1092,13 +1092,16 @@ class BaseDatabaseSchemaEditor:
         if deferrable == Deferrable.IMMEDIATE:
             return ' DEFERRABLE INITIALLY IMMEDIATE'
 
-    def _unique_sql(self, model, fields, name, condition=None, deferrable=None, include=None):
+    def _unique_sql(
+        self, model, fields, name, condition=None, deferrable=None,
+        include=None, opclasses=None,
+    ):
         if (
             deferrable and
             not self.connection.features.supports_deferrable_unique_constraints
         ):
             return None
-        if condition or include:
+        if condition or include or opclasses:
             # Databases support conditional and covering unique constraints via
             # a unique index.
             sql = self._create_unique_sql(
@@ -1107,6 +1110,7 @@ class BaseDatabaseSchemaEditor:
                 name=name,
                 condition=condition,
                 include=include,
+                opclasses=opclasses,
             )
             if sql:
                 self.deferred_sql.append(sql)
@@ -1120,7 +1124,10 @@ class BaseDatabaseSchemaEditor:
             'constraint': constraint,
         }
 
-    def _create_unique_sql(self, model, columns, name=None, condition=None, deferrable=None, include=None):
+    def _create_unique_sql(
+        self, model, columns, name=None, condition=None, deferrable=None,
+        include=None, opclasses=None,
+    ):
         if (
             (
                 deferrable and
@@ -1139,8 +1146,8 @@ class BaseDatabaseSchemaEditor:
             name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
         else:
             name = self.quote_name(name)
-        columns = Columns(table, columns, self.quote_name)
-        if condition or include:
+        columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses)
+        if condition or include or opclasses:
             sql = self.sql_create_unique_index
         else:
             sql = self.sql_create_unique
@@ -1154,7 +1161,10 @@ class BaseDatabaseSchemaEditor:
             include=self._index_include_sql(model, include),
         )
 
-    def _delete_unique_sql(self, model, name, condition=None, deferrable=None, include=None):
+    def _delete_unique_sql(
+        self, model, name, condition=None, deferrable=None, include=None,
+        opclasses=None,
+    ):
         if (
             (
                 deferrable and
@@ -1164,7 +1174,7 @@ class BaseDatabaseSchemaEditor:
             (include and not self.connection.features.supports_covering_indexes)
         ):
             return None
-        if condition or include:
+        if condition or include or opclasses:
             sql = self.sql_delete_index
         else:
             sql = self.sql_delete_unique

+ 27 - 4
django/db/models/constraints.py

@@ -77,7 +77,16 @@ class Deferrable(Enum):
 
 
 class UniqueConstraint(BaseConstraint):
-    def __init__(self, *, fields, name, condition=None, deferrable=None, include=None):
+    def __init__(
+        self,
+        *,
+        fields,
+        name,
+        condition=None,
+        deferrable=None,
+        include=None,
+        opclasses=(),
+    ):
         if not fields:
             raise ValueError('At least one field is required to define a unique constraint.')
         if not isinstance(condition, (type(None), Q)):
@@ -92,10 +101,18 @@ class UniqueConstraint(BaseConstraint):
             )
         if not isinstance(include, (type(None), list, tuple)):
             raise ValueError('UniqueConstraint.include must be a list or tuple.')
+        if not isinstance(opclasses, (list, tuple)):
+            raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
+        if opclasses and len(fields) != len(opclasses):
+            raise ValueError(
+                'UniqueConstraint.fields and UniqueConstraint.opclasses must '
+                'have the same number of elements.'
+            )
         self.fields = tuple(fields)
         self.condition = condition
         self.deferrable = deferrable
         self.include = tuple(include) if include else ()
+        self.opclasses = opclasses
         super().__init__(name)
 
     def _get_condition_sql(self, model, schema_editor):
@@ -114,6 +131,7 @@ class UniqueConstraint(BaseConstraint):
         return schema_editor._unique_sql(
             model, fields, self.name, condition=condition,
             deferrable=self.deferrable, include=include,
+            opclasses=self.opclasses,
         )
 
     def create_sql(self, model, schema_editor):
@@ -123,6 +141,7 @@ class UniqueConstraint(BaseConstraint):
         return schema_editor._create_unique_sql(
             model, fields, self.name, condition=condition,
             deferrable=self.deferrable, include=include,
+            opclasses=self.opclasses,
         )
 
     def remove_sql(self, model, schema_editor):
@@ -130,15 +149,16 @@ class UniqueConstraint(BaseConstraint):
         include = [model._meta.get_field(field_name).column for field_name in self.include]
         return schema_editor._delete_unique_sql(
             model, self.name, condition=condition, deferrable=self.deferrable,
-            include=include,
+            include=include, opclasses=self.opclasses,
         )
 
     def __repr__(self):
-        return '<%s: fields=%r name=%r%s%s%s>' % (
+        return '<%s: fields=%r name=%r%s%s%s%s>' % (
             self.__class__.__name__, self.fields, self.name,
             '' if self.condition is None else ' condition=%s' % self.condition,
             '' if self.deferrable is None else ' deferrable=%s' % self.deferrable,
             '' if not self.include else ' include=%s' % repr(self.include),
+            '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
         )
 
     def __eq__(self, other):
@@ -148,7 +168,8 @@ class UniqueConstraint(BaseConstraint):
                 self.fields == other.fields and
                 self.condition == other.condition and
                 self.deferrable == other.deferrable and
-                self.include == other.include
+                self.include == other.include and
+                self.opclasses == other.opclasses
             )
         return super().__eq__(other)
 
@@ -161,4 +182,6 @@ class UniqueConstraint(BaseConstraint):
             kwargs['deferrable'] = self.deferrable
         if self.include:
             kwargs['include'] = self.include
+        if self.opclasses:
+            kwargs['opclasses'] = self.opclasses
         return path, args, kwargs

+ 22 - 1
docs/ref/models/constraints.txt

@@ -73,7 +73,7 @@ constraint.
 ``UniqueConstraint``
 ====================
 
-.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None)
+.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None, opclasses=())
 
     Creates a unique constraint in the database.
 
@@ -168,3 +168,24 @@ while fetching data only from the index.
 ``include`` is supported only on PostgreSQL.
 
 Non-key columns have the same database restrictions as :attr:`Index.include`.
+
+
+``opclasses``
+-------------
+
+.. attribute:: UniqueConstraint.opclasses
+
+.. versionadded:: 3.2
+
+The names of the `PostgreSQL operator classes
+<https://www.postgresql.org/docs/current/indexes-opclass.html>`_ to use for
+this unique index. If you require a custom operator class, you must provide one
+for each field in the index.
+
+For example::
+
+    UniqueConstraint(name='unique_username', fields=['username'], opclasses=['varchar_pattern_ops'])
+
+creates a unique index on ``username`` using ``varchar_pattern_ops``.
+
+``opclasses`` are ignored for databases besides PostgreSQL.

+ 3 - 0
docs/releases/3.2.txt

@@ -196,6 +196,9 @@ Models
   attributes allow creating covering indexes and covering unique constraints on
   PostgreSQL 11+.
 
+* The new :attr:`.UniqueConstraint.opclasses` attribute allows setting
+  PostgreSQL operator classes.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 61 - 0
tests/constraints/tests.py

@@ -196,6 +196,20 @@ class UniqueConstraintTests(TestCase):
         self.assertEqual(constraint_1, constraint_1)
         self.assertNotEqual(constraint_1, constraint_2)
 
+    def test_eq_with_opclasses(self):
+        constraint_1 = models.UniqueConstraint(
+            fields=['foo', 'bar'],
+            name='opclasses',
+            opclasses=['text_pattern_ops', 'varchar_pattern_ops'],
+        )
+        constraint_2 = models.UniqueConstraint(
+            fields=['foo', 'bar'],
+            name='opclasses',
+            opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
+        )
+        self.assertEqual(constraint_1, constraint_1)
+        self.assertNotEqual(constraint_1, constraint_2)
+
     def test_repr(self):
         fields = ['foo', 'bar']
         name = 'unique_fields'
@@ -241,6 +255,18 @@ class UniqueConstraintTests(TestCase):
             "include=('baz_1', 'baz_2')>",
         )
 
+    def test_repr_with_opclasses(self):
+        constraint = models.UniqueConstraint(
+            fields=['foo', 'bar'],
+            name='opclasses_fields',
+            opclasses=['text_pattern_ops', 'varchar_pattern_ops'],
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<UniqueConstraint: fields=('foo', 'bar') name='opclasses_fields' "
+            "opclasses=['text_pattern_ops', 'varchar_pattern_ops']>",
+        )
+
     def test_deconstruction(self):
         fields = ['foo', 'bar']
         name = 'unique_fields'
@@ -291,6 +317,20 @@ class UniqueConstraintTests(TestCase):
             'include': tuple(include),
         })
 
+    def test_deconstruction_with_opclasses(self):
+        fields = ['foo', 'bar']
+        name = 'unique_fields'
+        opclasses = ['varchar_pattern_ops', 'text_pattern_ops']
+        constraint = models.UniqueConstraint(fields=fields, name=name, opclasses=opclasses)
+        path, args, kwargs = constraint.deconstruct()
+        self.assertEqual(path, 'django.db.models.UniqueConstraint')
+        self.assertEqual(args, ())
+        self.assertEqual(kwargs, {
+            'fields': tuple(fields),
+            'name': name,
+            'opclasses': opclasses,
+        })
+
     def test_database_constraint(self):
         with self.assertRaises(IntegrityError):
             UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color)
@@ -392,3 +432,24 @@ class UniqueConstraintTests(TestCase):
                 fields=['field'],
                 include='other',
             )
+
+    def test_invalid_opclasses_argument(self):
+        msg = 'UniqueConstraint.opclasses must be a list or tuple.'
+        with self.assertRaisesMessage(ValueError, msg):
+            models.UniqueConstraint(
+                name='uniq_opclasses',
+                fields=['field'],
+                opclasses='jsonb_path_ops',
+            )
+
+    def test_opclasses_and_fields_same_length(self):
+        msg = (
+            'UniqueConstraint.fields and UniqueConstraint.opclasses must have '
+            'the same number of elements.'
+        )
+        with self.assertRaisesMessage(ValueError, msg):
+            models.UniqueConstraint(
+                name='uniq_opclasses',
+                fields=['field'],
+                opclasses=['foo', 'bar'],
+            )

+ 80 - 2
tests/postgres_tests/test_constraints.py

@@ -4,12 +4,14 @@ from unittest import mock
 from django.db import (
     IntegrityError, NotSupportedError, connection, transaction,
 )
-from django.db.models import CheckConstraint, Deferrable, F, Func, Q
+from django.db.models import (
+    CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint,
+)
 from django.test import skipUnlessDBFeature
 from django.utils import timezone
 
 from . import PostgreSQLTestCase
-from .models import HotelReservation, RangesModel, Room
+from .models import HotelReservation, RangesModel, Room, Scene
 
 try:
     from django.contrib.postgres.constraints import ExclusionConstraint
@@ -21,6 +23,13 @@ except ImportError:
 
 
 class SchemaTests(PostgreSQLTestCase):
+    get_opclass_query = '''
+        SELECT opcname, c.relname FROM pg_opclass AS oc
+        JOIN pg_index as i on oc.oid = ANY(i.indclass)
+        JOIN pg_class as c on c.oid = i.indexrelid
+        WHERE c.relname = %s
+    '''
+
     def get_constraints(self, table):
         """Get the constraints on the table using a new cursor."""
         with connection.cursor() as cursor:
@@ -84,6 +93,75 @@ class SchemaTests(PostgreSQLTestCase):
             timestamps_inner=(datetime_1, datetime_2),
         )
 
+    def test_opclass(self):
+        constraint = UniqueConstraint(
+            name='test_opclass',
+            fields=['scene'],
+            opclasses=['varchar_pattern_ops'],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(Scene, constraint)
+        self.assertIn(constraint.name, self.get_constraints(Scene._meta.db_table))
+        with editor.connection.cursor() as cursor:
+            cursor.execute(self.get_opclass_query, [constraint.name])
+            self.assertEqual(
+                cursor.fetchall(),
+                [('varchar_pattern_ops', constraint.name)],
+            )
+        # Drop the constraint.
+        with connection.schema_editor() as editor:
+            editor.remove_constraint(Scene, constraint)
+        self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table))
+
+    def test_opclass_multiple_columns(self):
+        constraint = UniqueConstraint(
+            name='test_opclass_multiple',
+            fields=['scene', 'setting'],
+            opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(Scene, constraint)
+        with editor.connection.cursor() as cursor:
+            cursor.execute(self.get_opclass_query, [constraint.name])
+            expected_opclasses = (
+                ('varchar_pattern_ops', constraint.name),
+                ('text_pattern_ops', constraint.name),
+            )
+            self.assertCountEqual(cursor.fetchall(), expected_opclasses)
+
+    def test_opclass_partial(self):
+        constraint = UniqueConstraint(
+            name='test_opclass_partial',
+            fields=['scene'],
+            opclasses=['varchar_pattern_ops'],
+            condition=Q(setting__contains="Sir Bedemir's Castle"),
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(Scene, constraint)
+        with editor.connection.cursor() as cursor:
+            cursor.execute(self.get_opclass_query, [constraint.name])
+            self.assertCountEqual(
+                cursor.fetchall(),
+                [('varchar_pattern_ops', constraint.name)],
+            )
+
+    @skipUnlessDBFeature('supports_covering_indexes')
+    def test_opclass_include(self):
+        constraint = UniqueConstraint(
+            name='test_opclass_include',
+            fields=['scene'],
+            opclasses=['varchar_pattern_ops'],
+            include=['setting'],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(Scene, constraint)
+        with editor.connection.cursor() as cursor:
+            cursor.execute(self.get_opclass_query, [constraint.name])
+            self.assertCountEqual(
+                cursor.fetchall(),
+                [('varchar_pattern_ops', constraint.name)],
+            )
+
 
 class ExclusionConstraintTests(PostgreSQLTestCase):
     def get_constraints(self, table):