浏览代码

Fixed #31649 -- Added support for covering exclusion constraints on PostgreSQL 12+.

Hannes Ljungberg 4 年之前
父节点
当前提交
e0cdd0fcf5

+ 27 - 4
django/contrib/postgres/constraints.py

@@ -1,3 +1,4 @@
+from django.db import NotSupportedError
 from django.db.backends.ddl_references import Statement, Table
 from django.db.models import Deferrable, F, Q
 from django.db.models.constraints import BaseConstraint
@@ -7,11 +8,11 @@ __all__ = ['ExclusionConstraint']
 
 
 class ExclusionConstraint(BaseConstraint):
-    template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s%(deferrable)s'
+    template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
 
     def __init__(
         self, *, name, expressions, index_type=None, condition=None,
-        deferrable=None,
+        deferrable=None, include=None,
     ):
         if index_type and index_type.lower() not in {'gist', 'spgist'}:
             raise ValueError(
@@ -39,10 +40,19 @@ class ExclusionConstraint(BaseConstraint):
             raise ValueError(
                 'ExclusionConstraint.deferrable must be a Deferrable instance.'
             )
+        if not isinstance(include, (type(None), list, tuple)):
+            raise ValueError(
+                'ExclusionConstraint.include must be a list or tuple.'
+            )
+        if include and index_type and index_type.lower() != 'gist':
+            raise ValueError(
+                'Covering exclusion constraints only support GiST indexes.'
+            )
         self.expressions = expressions
         self.index_type = index_type or 'GIST'
         self.condition = condition
         self.deferrable = deferrable
+        self.include = tuple(include) if include else ()
         super().__init__(name=name)
 
     def _get_expression_sql(self, compiler, connection, query):
@@ -67,15 +77,18 @@ class ExclusionConstraint(BaseConstraint):
         compiler = query.get_compiler(connection=schema_editor.connection)
         expressions = self._get_expression_sql(compiler, schema_editor.connection, query)
         condition = self._get_condition_sql(compiler, schema_editor, query)
+        include = [model._meta.get_field(field_name).column for field_name in self.include]
         return self.template % {
             'name': schema_editor.quote_name(self.name),
             'index_type': self.index_type,
             'expressions': ', '.join(expressions),
+            'include': schema_editor._index_include_sql(model, include),
             'where': ' WHERE (%s)' % condition if condition else '',
             'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable),
         }
 
     def create_sql(self, model, schema_editor):
+        self.check_supported(schema_editor)
         return Statement(
             'ALTER TABLE %(table)s ADD %(constraint)s',
             table=Table(model._meta.db_table, schema_editor.quote_name),
@@ -89,6 +102,12 @@ class ExclusionConstraint(BaseConstraint):
             schema_editor.quote_name(self.name),
         )
 
+    def check_supported(self, schema_editor):
+        if self.include and not schema_editor.connection.features.supports_covering_gist_indexes:
+            raise NotSupportedError(
+                'Covering exclusion constraints requires PostgreSQL 12+.'
+            )
+
     def deconstruct(self):
         path, args, kwargs = super().deconstruct()
         kwargs['expressions'] = self.expressions
@@ -98,6 +117,8 @@ class ExclusionConstraint(BaseConstraint):
             kwargs['index_type'] = self.index_type
         if self.deferrable:
             kwargs['deferrable'] = self.deferrable
+        if self.include:
+            kwargs['include'] = self.include
         return path, args, kwargs
 
     def __eq__(self, other):
@@ -107,15 +128,17 @@ class ExclusionConstraint(BaseConstraint):
                 self.index_type == other.index_type and
                 self.expressions == other.expressions and
                 self.condition == other.condition and
-                self.deferrable == other.deferrable
+                self.deferrable == other.deferrable and
+                self.include == other.include
             )
         return super().__eq__(other)
 
     def __repr__(self):
-        return '<%s: index_type=%s, expressions=%s%s%s>' % (
+        return '<%s: index_type=%s, expressions=%s%s%s%s>' % (
             self.__class__.__qualname__,
             self.index_type,
             self.expressions,
             '' if self.condition is None else ', condition=%s' % self.condition,
             '' if self.deferrable is None else ', deferrable=%s' % self.deferrable,
+            '' if not self.include else ', include=%s' % repr(self.include),
         )

+ 16 - 1
docs/ref/contrib/postgres/constraints.txt

@@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
 ``ExclusionConstraint``
 =======================
 
-.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None)
+.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None)
 
     Creates an exclusion constraint in the database. Internally, PostgreSQL
     implements exclusion constraints using indexes. The default index type is
@@ -106,6 +106,21 @@ enforced immediately after every command.
     Deferred exclusion constraints may lead to a `performance penalty
     <https://www.postgresql.org/docs/current/sql-createtable.html#id-1.9.3.85.9.4>`_.
 
+``include``
+-----------
+
+.. attribute:: ExclusionConstraint.include
+
+.. versionadded:: 3.2
+
+A list or tuple of the names of the fields to be included in the covering
+exclusion constraint as non-key columns. This allows index-only scans to be
+used for queries that select only included fields
+(:attr:`~ExclusionConstraint.include`) and filter only by indexed fields
+(:attr:`~ExclusionConstraint.expressions`).
+
+``include`` is supported only for GiST indexes on PostgreSQL 12+.
+
 Examples
 --------
 

+ 2 - 1
docs/releases/3.2.txt

@@ -70,7 +70,8 @@ Minor features
 :mod:`django.contrib.postgres`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-* ...
+* The new :attr:`.ExclusionConstraint.include` attribute allows creating
+  covering exclusion constraints on PostgreSQL 12+.
 
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

+ 130 - 1
tests/postgres_tests/test_constraints.py

@@ -1,8 +1,11 @@
 import datetime
 from unittest import mock
 
-from django.db import IntegrityError, connection, transaction
+from django.db import (
+    IntegrityError, NotSupportedError, connection, transaction,
+)
 from django.db.models import CheckConstraint, Deferrable, F, Func, Q
+from django.test import skipUnlessDBFeature
 from django.utils import timezone
 
 from . import PostgreSQLTestCase
@@ -146,6 +149,25 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
                 deferrable=Deferrable.DEFERRED,
             )
 
+    def test_invalid_include_type(self):
+        msg = 'ExclusionConstraint.include must be a list or tuple.'
+        with self.assertRaisesMessage(ValueError, msg):
+            ExclusionConstraint(
+                name='exclude_invalid_include',
+                expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+                include='invalid',
+            )
+
+    def test_invalid_include_index_type(self):
+        msg = 'Covering exclusion constraints only support GiST indexes.'
+        with self.assertRaisesMessage(ValueError, msg):
+            ExclusionConstraint(
+                name='exclude_invalid_index_type',
+                expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
+                include=['cancelled'],
+                index_type='spgist',
+            )
+
     def test_repr(self):
         constraint = ExclusionConstraint(
             name='exclude_overlapping',
@@ -180,6 +202,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             "<ExclusionConstraint: index_type=GIST, expressions=["
             "(F(datespan), '-|-')], deferrable=Deferrable.IMMEDIATE>",
         )
+        constraint = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
+            include=['cancelled', 'room'],
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<ExclusionConstraint: index_type=GIST, expressions=["
+            "(F(datespan), '-|-')], include=('cancelled', 'room')>",
+        )
 
     def test_eq(self):
         constraint_1 = ExclusionConstraint(
@@ -218,6 +250,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             ],
             deferrable=Deferrable.IMMEDIATE,
         )
+        constraint_6 = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[
+                ('datespan', RangeOperators.OVERLAPS),
+                ('room', RangeOperators.EQUAL),
+            ],
+            deferrable=Deferrable.IMMEDIATE,
+            include=['cancelled'],
+        )
+        constraint_7 = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[
+                ('datespan', RangeOperators.OVERLAPS),
+                ('room', RangeOperators.EQUAL),
+            ],
+            include=['cancelled'],
+        )
         self.assertEqual(constraint_1, constraint_1)
         self.assertEqual(constraint_1, mock.ANY)
         self.assertNotEqual(constraint_1, constraint_2)
@@ -225,7 +274,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         self.assertNotEqual(constraint_1, constraint_4)
         self.assertNotEqual(constraint_2, constraint_3)
         self.assertNotEqual(constraint_2, constraint_4)
+        self.assertNotEqual(constraint_2, constraint_7)
         self.assertNotEqual(constraint_4, constraint_5)
+        self.assertNotEqual(constraint_5, constraint_6)
         self.assertNotEqual(constraint_1, object())
 
     def test_deconstruct(self):
@@ -286,6 +337,21 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             'deferrable': Deferrable.DEFERRED,
         })
 
+    def test_deconstruct_include(self):
+        constraint = ExclusionConstraint(
+            name='exclude_overlapping',
+            expressions=[('datespan', RangeOperators.OVERLAPS)],
+            include=['cancelled', 'room'],
+        )
+        path, args, kwargs = constraint.deconstruct()
+        self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
+        self.assertEqual(args, ())
+        self.assertEqual(kwargs, {
+            'name': 'exclude_overlapping',
+            'expressions': [('datespan', RangeOperators.OVERLAPS)],
+            'include': ('cancelled', 'room'),
+        })
+
     def _test_range_overlaps(self, constraint):
         # Create exclusion constraint.
         self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
@@ -417,3 +483,66 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         adjacent_range.delete()
         RangesModel.objects.create(ints=(10, 19))
         RangesModel.objects.create(ints=(51, 60))
+
+    @skipUnlessDBFeature('supports_covering_gist_indexes')
+    def test_range_adjacent_include(self):
+        constraint_name = 'ints_adjacent_include'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            include=['decimals', 'ints'],
+            index_type='gist',
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        RangesModel.objects.create(ints=(20, 50))
+        with self.assertRaises(IntegrityError), transaction.atomic():
+            RangesModel.objects.create(ints=(10, 20))
+        RangesModel.objects.create(ints=(10, 19))
+        RangesModel.objects.create(ints=(51, 60))
+
+    @skipUnlessDBFeature('supports_covering_gist_indexes')
+    def test_range_adjacent_include_condition(self):
+        constraint_name = 'ints_adjacent_include_condition'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            include=['decimals'],
+            condition=Q(id__gte=100),
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+
+    @skipUnlessDBFeature('supports_covering_gist_indexes')
+    def test_range_adjacent_include_deferrable(self):
+        constraint_name = 'ints_adjacent_include_deferrable'
+        self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            include=['decimals'],
+            deferrable=Deferrable.DEFERRED,
+        )
+        with connection.schema_editor() as editor:
+            editor.add_constraint(RangesModel, constraint)
+        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
+
+    def test_include_not_supported(self):
+        constraint_name = 'ints_adjacent_include_not_supported'
+        constraint = ExclusionConstraint(
+            name=constraint_name,
+            expressions=[('ints', RangeOperators.ADJACENT_TO)],
+            include=['id'],
+        )
+        msg = 'Covering exclusion constraints requires PostgreSQL 12+.'
+        with connection.schema_editor() as editor:
+            with mock.patch(
+                'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes',
+                False,
+            ):
+                with self.assertRaisesMessage(NotSupportedError, msg):
+                    editor.add_constraint(RangesModel, constraint)