Bladeren bron

Fixed #33342 -- Added support for using OpClass() in exclusion constraints.

Hannes Ljungberg 3 jaren geleden
bovenliggende
commit
0e656c02fe

+ 26 - 22
django/contrib/postgres/constraints.py

@@ -1,13 +1,19 @@
+from django.contrib.postgres.indexes import OpClass
 from django.db import NotSupportedError
-from django.db.backends.ddl_references import Statement, Table
+from django.db.backends.ddl_references import Expressions, Statement, Table
 from django.db.models import Deferrable, F, Q
 from django.db.models.constraints import BaseConstraint
-from django.db.models.expressions import Col
+from django.db.models.expressions import ExpressionList
+from django.db.models.indexes import IndexExpression
 from django.db.models.sql import Query
 
 __all__ = ['ExclusionConstraint']
 
 
+class ExclusionConstraintExpression(IndexExpression):
+    template = '%(expressions)s WITH %(operator)s'
+
+
 class ExclusionConstraint(BaseConstraint):
     template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
 
@@ -63,24 +69,19 @@ class ExclusionConstraint(BaseConstraint):
         self.opclasses = opclasses
         super().__init__(name=name)
 
-    def _get_expression_sql(self, compiler, schema_editor, query):
+    def _get_expressions(self, schema_editor, query):
         expressions = []
         for idx, (expression, operator) in enumerate(self.expressions):
             if isinstance(expression, str):
                 expression = F(expression)
-            expression = expression.resolve_expression(query=query)
-            sql, params = compiler.compile(expression)
-            if not isinstance(expression, Col):
-                sql = f'({sql})'
             try:
-                opclass = self.opclasses[idx]
-                if opclass:
-                    sql = '%s %s' % (sql, opclass)
+                expression = OpClass(expression, self.opclasses[idx])
             except IndexError:
                 pass
-            sql = sql % tuple(schema_editor.quote_value(p) for p in params)
-            expressions.append('%s WITH %s' % (sql, operator))
-        return expressions
+            expression = ExclusionConstraintExpression(expression, operator=operator)
+            expression.set_wrapper_classes(schema_editor.connection)
+            expressions.append(expression)
+        return ExpressionList(*expressions).resolve_expression(query)
 
     def _get_condition_sql(self, compiler, schema_editor, query):
         if self.condition is None:
@@ -92,17 +93,20 @@ class ExclusionConstraint(BaseConstraint):
     def constraint_sql(self, model, schema_editor):
         query = Query(model, alias_cols=False)
         compiler = query.get_compiler(connection=schema_editor.connection)
-        expressions = self._get_expression_sql(compiler, schema_editor, query)
+        expressions = self._get_expressions(schema_editor, query)
+        table = model._meta.db_table
         condition = self._get_condition_sql(compiler, schema_editor, query)
         include = [model._meta.get_field(field_name).column for field_name in self.include]
-        return self.template % {
-            'name': schema_editor.quote_name(self.name),
-            'index_type': self.index_type,
-            'expressions': ', '.join(expressions),
-            'include': schema_editor._index_include_sql(model, include),
-            'where': ' WHERE (%s)' % condition if condition else '',
-            'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable),
-        }
+        return Statement(
+            self.template,
+            table=Table(table, schema_editor.quote_name),
+            name=schema_editor.quote_name(self.name),
+            index_type=self.index_type,
+            expressions=Expressions(table, expressions, compiler, schema_editor.quote_value),
+            where=' WHERE (%s)' % condition if condition else '',
+            include=schema_editor._index_include_sql(model, include),
+            deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
+        )
 
     def create_sql(self, model, schema_editor):
         self.check_supported(schema_editor)

+ 18 - 0
docs/ref/contrib/postgres/constraints.txt

@@ -53,6 +53,10 @@ operators with strings. For example::
 
     Only commutative operators can be used in exclusion constraints.
 
+.. versionchanged:: 4.1
+
+    Support for the ``OpClass()`` expression was added.
+
 ``index_type``
 --------------
 
@@ -143,6 +147,20 @@ For example::
 
 creates an exclusion constraint on ``circle`` using ``circle_ops``.
 
+Alternatively, you can use
+:class:`OpClass() <django.contrib.postgres.indexes.OpClass>` in
+:attr:`~ExclusionConstraint.expressions`::
+
+    ExclusionConstraint(
+        name='exclude_overlapping_opclasses',
+        expressions=[(OpClass('circle', 'circle_ops'), RangeOperators.OVERLAPS)],
+    )
+
+.. versionchanged:: 4.1
+
+    Support for specifying operator classes with the ``OpClass()`` expression
+    was added.
+
 Examples
 --------
 

+ 19 - 6
docs/ref/contrib/postgres/indexes.txt

@@ -150,10 +150,10 @@ available from the ``django.contrib.postgres.indexes`` module.
 .. class:: OpClass(expression, name)
 
     An ``OpClass()`` expression represents the ``expression`` with a custom
-    `operator class`_ that can be used to define functional indexes or unique
-    constraints. To use it, you need to add ``'django.contrib.postgres'`` in
-    your :setting:`INSTALLED_APPS`. Set the ``name`` parameter to the name of
-    the `operator class`_.
+    `operator class`_ that can be used to define functional indexes, functional
+    unique constraints, or exclusion constraints. To use it, you need to add
+    ``'django.contrib.postgres'`` in your :setting:`INSTALLED_APPS`. Set the
+    ``name`` parameter to the name of the `operator class`_.
 
     For example::
 
@@ -163,8 +163,7 @@ available from the ``django.contrib.postgres.indexes`` module.
         )
 
     creates an index on ``Lower('username')`` using ``varchar_pattern_ops``.
-
-    Another example::
+    ::
 
         UniqueConstraint(
             OpClass(Upper('description'), name='text_pattern_ops'),
@@ -173,9 +172,23 @@ available from the ``django.contrib.postgres.indexes`` module.
 
     creates a unique constraint on ``Upper('description')`` using
     ``text_pattern_ops``.
+    ::
+
+        ExclusionConstraint(
+            name='exclude_overlapping_ops',
+            expressions=[
+                (OpClass('circle', name='circle_ops'), RangeOperators.OVERLAPS),
+            ],
+        )
+
+    creates an exclusion constraint on ``circle`` using ``circle_ops``.
 
     .. versionchanged:: 4.0
 
         Support for functional unique constraints was added.
 
+    .. versionchanged:: 4.1
+
+        Support for exclusion constraints was added.
+
     .. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html

+ 4 - 0
docs/releases/4.1.txt

@@ -108,6 +108,10 @@ Minor features
   <django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows
   specifying bounds for list and tuple inputs.
 
+* :class:`~django.contrib.postgres.constraints.ExclusionConstraint` now allows
+  specifying operator classes with the
+  :class:`OpClass() <django.contrib.postgres.indexes.OpClass>` expression.
+
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 47 - 0
tests/postgres_tests/test_constraints.py

@@ -198,6 +198,7 @@ class SchemaTests(PostgreSQLTestCase):
         Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
 
 
+@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
 class ExclusionConstraintTests(PostgreSQLTestCase):
     def get_constraints(self, table):
         """Get the constraints on the table using a new cursor."""
@@ -604,6 +605,24 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         )
         self._test_range_overlaps(constraint)
 
+    def test_range_overlaps_custom_opclass_expression(self):
+        class TsTzRange(Func):
+            function = 'TSTZRANGE'
+            output_field = DateTimeRangeField()
+
+        constraint = ExclusionConstraint(
+            name='exclude_overlapping_reservations_custom_opclass',
+            expressions=[
+                (
+                    OpClass(TsTzRange('start', 'end', RangeBoundary()), 'range_ops'),
+                    RangeOperators.OVERLAPS,
+                ),
+                (OpClass('room', 'gist_int4_ops'), RangeOperators.EQUAL),
+            ],
+            condition=Q(cancelled=False),
+        )
+        self._test_range_overlaps(constraint)
+
     def test_range_overlaps(self):
         constraint = ExclusionConstraint(
             name='exclude_overlapping_reservations',
@@ -914,6 +933,34 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             editor.add_constraint(RangesModel, constraint)
         self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 
+    def test_opclass_expression(self):
+        constraint_name = 'ints_adjacent_opclass_expression'
+        self.assertNotIn(
+            constraint_name,
+            self.get_constraints(RangesModel._meta.db_table),
+        )
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[(OpClass('ints', 'range_ops'), RangeOperators.ADJACENT_TO)],
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        constraints = self.get_constraints(RangesModel._meta.db_table)
+        self.assertIn(constraint_name, constraints)
+        with editor.connection.cursor() as cursor:
+            cursor.execute(SchemaTests.get_opclass_query, [constraint_name])
+            self.assertEqual(
+                cursor.fetchall(),
+                [('range_ops', constraint_name)],
+            )
+        # Drop the constraint.
+        with connection.schema_editor() as editor:
+            editor.remove_constraint(RangesModel, constraint)
+        self.assertNotIn(
+            constraint_name,
+            self.get_constraints(RangesModel._meta.db_table),
+        )
+
     def test_range_equal_cast(self):
         constraint_name = 'exclusion_equal_room_cast'
         self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table))