|
@@ -1,13 +1,19 @@
|
|
|
+from django.contrib.postgres.indexes import OpClass
|
|
|
from django.db import NotSupportedError
|
|
|
-from django.db.backends.ddl_references import Statement, Table
|
|
|
+from django.db.backends.ddl_references import Expressions, Statement, Table
|
|
|
from django.db.models import Deferrable, F, Q
|
|
|
from django.db.models.constraints import BaseConstraint
|
|
|
-from django.db.models.expressions import Col
|
|
|
+from django.db.models.expressions import ExpressionList
|
|
|
+from django.db.models.indexes import IndexExpression
|
|
|
from django.db.models.sql import Query
|
|
|
|
|
|
__all__ = ['ExclusionConstraint']
|
|
|
|
|
|
|
|
|
+class ExclusionConstraintExpression(IndexExpression):
|
|
|
+ template = '%(expressions)s WITH %(operator)s'
|
|
|
+
|
|
|
+
|
|
|
class ExclusionConstraint(BaseConstraint):
|
|
|
template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
|
|
|
|
|
@@ -63,24 +69,19 @@ class ExclusionConstraint(BaseConstraint):
|
|
|
self.opclasses = opclasses
|
|
|
super().__init__(name=name)
|
|
|
|
|
|
- def _get_expression_sql(self, compiler, schema_editor, query):
|
|
|
+ def _get_expressions(self, schema_editor, query):
|
|
|
expressions = []
|
|
|
for idx, (expression, operator) in enumerate(self.expressions):
|
|
|
if isinstance(expression, str):
|
|
|
expression = F(expression)
|
|
|
- expression = expression.resolve_expression(query=query)
|
|
|
- sql, params = compiler.compile(expression)
|
|
|
- if not isinstance(expression, Col):
|
|
|
- sql = f'({sql})'
|
|
|
try:
|
|
|
- opclass = self.opclasses[idx]
|
|
|
- if opclass:
|
|
|
- sql = '%s %s' % (sql, opclass)
|
|
|
+ expression = OpClass(expression, self.opclasses[idx])
|
|
|
except IndexError:
|
|
|
pass
|
|
|
- sql = sql % tuple(schema_editor.quote_value(p) for p in params)
|
|
|
- expressions.append('%s WITH %s' % (sql, operator))
|
|
|
- return expressions
|
|
|
+ expression = ExclusionConstraintExpression(expression, operator=operator)
|
|
|
+ expression.set_wrapper_classes(schema_editor.connection)
|
|
|
+ expressions.append(expression)
|
|
|
+ return ExpressionList(*expressions).resolve_expression(query)
|
|
|
|
|
|
def _get_condition_sql(self, compiler, schema_editor, query):
|
|
|
if self.condition is None:
|
|
@@ -92,17 +93,20 @@ class ExclusionConstraint(BaseConstraint):
|
|
|
def constraint_sql(self, model, schema_editor):
|
|
|
query = Query(model, alias_cols=False)
|
|
|
compiler = query.get_compiler(connection=schema_editor.connection)
|
|
|
- expressions = self._get_expression_sql(compiler, schema_editor, query)
|
|
|
+ expressions = self._get_expressions(schema_editor, query)
|
|
|
+ table = model._meta.db_table
|
|
|
condition = self._get_condition_sql(compiler, schema_editor, query)
|
|
|
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
|
|
- return self.template % {
|
|
|
- 'name': schema_editor.quote_name(self.name),
|
|
|
- 'index_type': self.index_type,
|
|
|
- 'expressions': ', '.join(expressions),
|
|
|
- 'include': schema_editor._index_include_sql(model, include),
|
|
|
- 'where': ' WHERE (%s)' % condition if condition else '',
|
|
|
- 'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable),
|
|
|
- }
|
|
|
+ return Statement(
|
|
|
+ self.template,
|
|
|
+ table=Table(table, schema_editor.quote_name),
|
|
|
+ name=schema_editor.quote_name(self.name),
|
|
|
+ index_type=self.index_type,
|
|
|
+ expressions=Expressions(table, expressions, compiler, schema_editor.quote_value),
|
|
|
+ where=' WHERE (%s)' % condition if condition else '',
|
|
|
+ include=schema_editor._index_include_sql(model, include),
|
|
|
+ deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
|
|
|
+ )
|
|
|
|
|
|
def create_sql(self, model, schema_editor):
|
|
|
self.check_supported(schema_editor)
|