فهرست منبع

Fixed #31709 -- Added support for opclasses in ExclusionConstraint.

Hannes Ljungberg 4 سال پیش
والد
کامیت
0d6d4e78b1

+ 25 - 4
django/contrib/postgres/constraints.py

@@ -12,7 +12,7 @@ class ExclusionConstraint(BaseConstraint):
 
     def __init__(
         self, *, name, expressions, index_type=None, condition=None,
-        deferrable=None, include=None,
+        deferrable=None, include=None, opclasses=(),
     ):
         if index_type and index_type.lower() not in {'gist', 'spgist'}:
             raise ValueError(
@@ -48,20 +48,37 @@ class ExclusionConstraint(BaseConstraint):
             raise ValueError(
                 'Covering exclusion constraints only support GiST indexes.'
             )
+        if not isinstance(opclasses, (list, tuple)):
+            raise ValueError(
+                'ExclusionConstraint.opclasses must be a list or tuple.'
+            )
+        if opclasses and len(expressions) != len(opclasses):
+            raise ValueError(
+                'ExclusionConstraint.expressions and '
+                'ExclusionConstraint.opclasses must have the same number of '
+                'elements.'
+            )
         self.expressions = expressions
         self.index_type = index_type or 'GIST'
         self.condition = condition
         self.deferrable = deferrable
         self.include = tuple(include) if include else ()
+        self.opclasses = opclasses
         super().__init__(name=name)
 
     def _get_expression_sql(self, compiler, connection, query):
         expressions = []
-        for expression, operator in self.expressions:
+        for idx, (expression, operator) in enumerate(self.expressions):
             if isinstance(expression, str):
                 expression = F(expression)
             expression = expression.resolve_expression(query=query)
             sql, params = expression.as_sql(compiler, connection)
+            try:
+                opclass = self.opclasses[idx]
+                if opclass:
+                    sql = '%s %s' % (sql, opclass)
+            except IndexError:
+                pass
             expressions.append('%s WITH %s' % (sql % params, operator))
         return expressions
 
@@ -119,6 +136,8 @@ class ExclusionConstraint(BaseConstraint):
             kwargs['deferrable'] = self.deferrable
         if self.include:
             kwargs['include'] = self.include
+        if self.opclasses:
+            kwargs['opclasses'] = self.opclasses
         return path, args, kwargs
 
     def __eq__(self, other):
@@ -129,16 +148,18 @@ class ExclusionConstraint(BaseConstraint):
                 self.expressions == other.expressions and
                 self.condition == other.condition and
                 self.deferrable == other.deferrable and
-                self.include == other.include
+                self.include == other.include and
+                self.opclasses == other.opclasses
             )
         return super().__eq__(other)
 
     def __repr__(self):
-        return '<%s: index_type=%s, expressions=%s%s%s%s>' % (
+        return '<%s: index_type=%s, expressions=%s%s%s%s%s>' % (
             self.__class__.__qualname__,
             self.index_type,
             self.expressions,
             '' if self.condition is None else ', condition=%s' % self.condition,
             '' if self.deferrable is None else ', deferrable=%s' % self.deferrable,
             '' if not self.include else ', include=%s' % repr(self.include),
+            '' if not self.opclasses else ', opclasses=%s' % repr(self.opclasses),
         )

+ 23 - 1
docs/ref/contrib/postgres/constraints.txt

@@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
 ``ExclusionConstraint``
 =======================
 
-.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None)
+.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, opclasses=())
 
     Creates an exclusion constraint in the database. Internally, PostgreSQL
     implements exclusion constraints using indexes. The default index type is
@@ -121,6 +121,28 @@ used for queries that select only included fields
 
 ``include`` is supported only for GiST indexes on PostgreSQL 12+.
 
+``opclasses``
+-------------
+
+.. attribute:: ExclusionConstraint.opclasses
+
+.. versionadded:: 3.2
+
+The names of the `PostgreSQL operator classes
+<https://www.postgresql.org/docs/current/indexes-opclass.html>`_ to use for
+this constraint. If you require a custom operator class, you must provide one
+for each expression in the constraint.
+
+For example::
+
+    ExclusionConstraint(
+        name='exclude_overlapping_opclasses',
+        expressions=[('circle', RangeOperators.OVERLAPS)],
+        opclasses=['circle_ops'],
+    )
+
+creates an exclusion constraint on ``circle`` using ``circle_ops``.
+
 Examples
 --------
 

+ 3 - 0
docs/releases/3.2.txt

@@ -73,6 +73,9 @@ Minor features
 * The new :attr:`.ExclusionConstraint.include` attribute allows creating
   covering exclusion constraints on PostgreSQL 12+.
 
+* The new :attr:`.ExclusionConstraint.opclasses` attribute allows setting
+  PostgreSQL operator classes.
+
 * The new :attr:`.JSONBAgg.ordering` attribute determines the ordering of the
   aggregated elements.
 

+ 128 - 0
tests/postgres_tests/test_constraints.py

@@ -246,6 +246,28 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
                 index_type='spgist',
             )
 
+    def test_invalid_opclasses_type(self):
+        msg = 'ExclusionConstraint.opclasses must be a list or tuple.'
+        with self.assertRaisesMessage(ValueError, msg):
+            ExclusionConstraint(
+                name='exclude_invalid_opclasses',
+                expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+                opclasses='invalid',
+            )
+
+    def test_opclasses_and_expressions_same_length(self):
+        msg = (
+            'ExclusionConstraint.expressions and '
+            'ExclusionConstraint.opclasses must have the same number of '
+            'elements.'
+        )
+        with self.assertRaisesMessage(ValueError, msg):
+            ExclusionConstraint(
+                name='exclude_invalid_expressions_opclasses_length',
+                expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+                opclasses=['foo', 'bar'],
+            )
+
     def test_repr(self):
         constraint = ExclusionConstraint(
             name='exclude_overlapping',
@@ -290,6 +312,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             "<ExclusionConstraint: index_type=GIST, expressions=["
             "(F(datespan), '-|-')], include=('cancelled', 'room')>",
         )
+        constraint = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
+            opclasses=['range_ops'],
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<ExclusionConstraint: index_type=GIST, expressions=["
+            "(F(datespan), '-|-')], opclasses=['range_ops']>",
+        )
 
     def test_eq(self):
         constraint_1 = ExclusionConstraint(
@@ -345,6 +377,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             ],
             include=['cancelled'],
         )
+        constraint_8 = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[
+                ('datespan', RangeOperators.OVERLAPS),
+                ('room', RangeOperators.EQUAL),
+            ],
+            include=['cancelled'],
+            opclasses=['range_ops', 'range_ops']
+        )
+        constraint_9 = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[
+                ('datespan', RangeOperators.OVERLAPS),
+                ('room', RangeOperators.EQUAL),
+            ],
+            opclasses=['range_ops', 'range_ops']
+        )
         self.assertEqual(constraint_1, constraint_1)
         self.assertEqual(constraint_1, mock.ANY)
         self.assertNotEqual(constraint_1, constraint_2)
@@ -353,8 +402,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         self.assertNotEqual(constraint_2, constraint_3)
         self.assertNotEqual(constraint_2, constraint_4)
         self.assertNotEqual(constraint_2, constraint_7)
+        self.assertNotEqual(constraint_2, constraint_9)
         self.assertNotEqual(constraint_4, constraint_5)
         self.assertNotEqual(constraint_5, constraint_6)
+        self.assertNotEqual(constraint_7, constraint_8)
         self.assertNotEqual(constraint_1, object())
 
     def test_deconstruct(self):
@@ -430,6 +481,21 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             'include': ('cancelled', 'room'),
         })
 
+    def test_deconstruct_opclasses(self):
+        constraint = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[('datespan', RangeOperators.OVERLAPS)],
+            opclasses=['range_ops'],
+        )
+        path, args, kwargs = constraint.deconstruct()
+        self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+        self.assertEqual(args, ())
+        self.assertEqual(kwargs, {
+            'name': 'exclude_overlapping',
+            'expressions': [('datespan', RangeOperators.OVERLAPS)],
+            'opclasses': ['range_ops'],
+        })
+
     def _test_range_overlaps(self, constraint):
         # Create exclusion constraint.
         self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
@@ -505,6 +571,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
                 ('room', RangeOperators.EQUAL)
             ],
             condition=Q(cancelled=False),
+            opclasses=['range_ops', 'gist_int4_ops'],
         )
         self._test_range_overlaps(constraint)
 
@@ -624,3 +691,64 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             ):
                 with self.assertRaisesMessage(NotSupportedError, msg):
                     editor.add_constraint(RangesModel, constraint)
+
+    def test_range_adjacent_opclasses(self):
+        constraint_name = 'ints_adjacent_opclasses'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            opclasses=['range_ops'],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        RangesModel.objects.create(ints=(20, 50))
+        with self.assertRaises(IntegrityError), transaction.atomic():
+            RangesModel.objects.create(ints=(10, 20))
+        RangesModel.objects.create(ints=(10, 19))
+        RangesModel.objects.create(ints=(51, 60))
+        # Drop the constraint.
+        with connection.schema_editor() as editor:
+            editor.remove_constraint(RangesModel, constraint)
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+
+    def test_range_adjacent_opclasses_condition(self):
+        constraint_name = 'ints_adjacent_opclasses_condition'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            opclasses=['range_ops'],
+            condition=Q(id__gte=100),
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+
+    def test_range_adjacent_opclasses_deferrable(self):
+        constraint_name = 'ints_adjacent_opclasses_deferrable'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            opclasses=['range_ops'],
+            deferrable=Deferrable.DEFERRED,
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+
+    @skipUnlessDBFeature('supports_covering_gist_indexes')
+    def test_range_adjacent_opclasses_include(self):
+        constraint_name = 'ints_adjacent_opclasses_include'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            opclasses=['range_ops'],
+            include=['decimals'],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))