Browse Source

Fixed #30484 -- Added conditional expressions support to CheckConstraint.

Simon Charette 5 years ago
parent
commit
e9a0e1d4f6

+ 5 - 0
django/db/models/constraints.py

@@ -30,6 +30,11 @@ class BaseConstraint:
 class CheckConstraint(BaseConstraint):
     def __init__(self, *, check, name):
         self.check = check
+        if not getattr(check, 'conditional', False):
+            raise TypeError(
+                'CheckConstraint.check must be a Q instance or boolean '
+                'expression.'
+            )
         super().__init__(name)
 
     def _get_check_sql(self, model, schema_editor):

+ 20 - 15
django/db/models/sql/query.py

@@ -1221,8 +1221,19 @@ class Query(BaseExpression):
         """
         if isinstance(filter_expr, dict):
             raise FieldError("Cannot parse keyword query as dict")
+        if isinstance(filter_expr, Q):
+            return self._add_q(
+                filter_expr,
+                branch_negated=branch_negated,
+                current_negated=current_negated,
+                used_aliases=can_reuse,
+                allow_joins=allow_joins,
+                split_subq=split_subq,
+            )
         if hasattr(filter_expr, 'resolve_expression') and getattr(filter_expr, 'conditional', False):
-            condition = self.build_lookup(['exact'], filter_expr.resolve_expression(self), True)
+            condition = self.build_lookup(
+                ['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True
+            )
             clause = self.where_class()
             clause.add(condition, AND)
             return clause, []
@@ -1332,8 +1343,8 @@ class Query(BaseExpression):
             self.where.add(clause, AND)
         self.demote_joins(existing_inner)
 
-    def build_where(self, q_object):
-        return self._add_q(q_object, used_aliases=set(), allow_joins=False)[0]
+    def build_where(self, filter_expr):
+        return self.build_filter(filter_expr, allow_joins=False)[0]
 
     def _add_q(self, q_object, used_aliases, branch_negated=False,
                current_negated=False, allow_joins=True, split_subq=True):
@@ -1345,18 +1356,12 @@ class Query(BaseExpression):
                                          negated=q_object.negated)
         joinpromoter = JoinPromoter(q_object.connector, len(q_object.children), current_negated)
         for child in q_object.children:
-            if isinstance(child, Node):
-                child_clause, needed_inner = self._add_q(
-                    child, used_aliases, branch_negated,
-                    current_negated, allow_joins, split_subq)
-                joinpromoter.add_votes(needed_inner)
-            else:
-                child_clause, needed_inner = self.build_filter(
-                    child, can_reuse=used_aliases, branch_negated=branch_negated,
-                    current_negated=current_negated, allow_joins=allow_joins,
-                    split_subq=split_subq,
-                )
-                joinpromoter.add_votes(needed_inner)
+            child_clause, needed_inner = self.build_filter(
+                child, can_reuse=used_aliases, branch_negated=branch_negated,
+                current_negated=current_negated, allow_joins=allow_joins,
+                split_subq=split_subq,
+            )
+            joinpromoter.add_votes(needed_inner)
             if child_clause:
                 target_clause.add(child_clause, connector)
         needed_inner = joinpromoter.update_join_types(self)

+ 6 - 2
docs/ref/models/constraints.txt

@@ -52,12 +52,16 @@ option.
 
 .. attribute:: CheckConstraint.check
 
-A :class:`Q` object that specifies the check you want the constraint to
-enforce.
+A :class:`Q` object or boolean :class:`~django.db.models.Expression` that
+specifies the check you want the constraint to enforce.
 
 For example, ``CheckConstraint(check=Q(age__gte=18), name='age_gte_18')``
 ensures the age field is never less than 18.
 
+.. versionchanged:: 3.1
+
+    Support for boolean :class:`~django.db.models.Expression` was added.
+
 ``name``
 --------
 

+ 2 - 0
docs/releases/3.1.txt

@@ -204,6 +204,8 @@ Models
   ``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE
   RESTRICT``.
 
+* :attr:`.CheckConstraint.check` now supports boolean expressions.
+
 Pagination
 ~~~~~~~~~~
 

+ 13 - 0
tests/constraints/models.py

@@ -18,6 +18,19 @@ class Product(models.Model):
                 check=models.Q(price__gt=0),
                 name='%(app_label)s_%(class)s_price_gt_0',
             ),
+            models.CheckConstraint(
+                check=models.expressions.RawSQL(
+                    'price < %s', (1000,), output_field=models.BooleanField()
+                ),
+                name='%(app_label)s_price_lt_1000_raw',
+            ),
+            models.CheckConstraint(
+                check=models.expressions.ExpressionWrapper(
+                    models.Q(price__gt=500) | models.Q(price__lt=500),
+                    output_field=models.BooleanField()
+                ),
+                name='%(app_label)s_price_neq_500_wrap',
+            ),
         ]
 
 

+ 21 - 0
tests/constraints/tests.py

@@ -61,6 +61,13 @@ class CheckConstraintTests(TestCase):
             "<CheckConstraint: check='{}' name='{}'>".format(check, name),
         )
 
+    def test_invalid_check_types(self):
+        msg = (
+            'CheckConstraint.check must be a Q instance or boolean expression.'
+        )
+        with self.assertRaisesMessage(TypeError, msg):
+            models.CheckConstraint(check=models.F('discounted_price'), name='check')
+
     def test_deconstruction(self):
         check = models.Q(price__gt=models.F('discounted_price'))
         name = 'price_gt_discounted_price'
@@ -76,11 +83,25 @@ class CheckConstraintTests(TestCase):
         with self.assertRaises(IntegrityError):
             Product.objects.create(price=10, discounted_price=20)
 
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_database_constraint_expression(self):
+        Product.objects.create(price=999, discounted_price=5)
+        with self.assertRaises(IntegrityError):
+            Product.objects.create(price=1000, discounted_price=5)
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_database_constraint_expressionwrapper(self):
+        Product.objects.create(price=499, discounted_price=5)
+        with self.assertRaises(IntegrityError):
+            Product.objects.create(price=500, discounted_price=5)
+
     @skipUnlessDBFeature('supports_table_check_constraints', 'can_introspect_check_constraints')
     def test_name(self):
         constraints = get_constraints(Product._meta.db_table)
         for expected_name in (
             'price_gt_discounted_price',
+            'constraints_price_lt_1000_raw',
+            'constraints_price_neq_500_wrap',
             'constraints_product_price_gt_0',
         ):
             with self.subTest(expected_name):

+ 17 - 2
tests/queries/test_query.py

@@ -1,8 +1,8 @@
 from datetime import datetime
 
 from django.core.exceptions import FieldError
-from django.db.models import CharField, F, Q
-from django.db.models.expressions import Col
+from django.db.models import BooleanField, CharField, F, Q
+from django.db.models.expressions import Col, Func
 from django.db.models.fields.related_lookups import RelatedIsNull
 from django.db.models.functions import Lower
 from django.db.models.lookups import Exact, GreaterThan, IsNull, LessThan
@@ -129,3 +129,18 @@ class TestQuery(SimpleTestCase):
         name_exact = where.children[0]
         self.assertIsInstance(name_exact, Exact)
         self.assertEqual(name_exact.rhs, "['a', 'b']")
+
+    def test_filter_conditional(self):
+        query = Query(Item)
+        where = query.build_where(Func(output_field=BooleanField()))
+        exact = where.children[0]
+        self.assertIsInstance(exact, Exact)
+        self.assertIsInstance(exact.lhs, Func)
+        self.assertIs(exact.rhs, True)
+
+    def test_filter_conditional_join(self):
+        query = Query(Item)
+        filter_expr = Func('note__note', output_field=BooleanField())
+        msg = 'Joined field references are not permitted in this query'
+        with self.assertRaisesMessage(FieldError, msg):
+            query.build_where(filter_expr)