瀏覽代碼

Fixed #35575 -- Added support for constraint validation on GeneratedFields.

Mark Gensler 8 月之前
父節點
當前提交
228128618b

+ 3 - 9
django/contrib/postgres/constraints.py

@@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint):
         )
         replacements = {F(field): value for field, value in replacement_map.items()}
         lookups = []
-        for idx, (expression, operator) in enumerate(self.expressions):
+        for expression, operator in self.expressions:
             if isinstance(expression, str):
                 expression = F(expression)
-            if exclude:
-                if isinstance(expression, F):
-                    if expression.name in exclude:
-                        return
-                else:
-                    for expr in expression.flatten():
-                        if isinstance(expr, F) and expr.name in exclude:
-                            return
+            if exclude and self._expression_refs_exclude(model, expression, exclude):
+                return
             rhs_expression = expression.replace_expressions(replacements)
             if hasattr(expression, "get_expression_for_validation"):
                 expression = expression.get_expression_for_validation()

+ 25 - 10
django/db/models/base.py

@@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
         if exclude is None:
             exclude = set()
         meta = meta or self._meta
-        field_map = {
-            field.name: (
-                value
-                if (value := getattr(self, field.attname))
-                and hasattr(value, "resolve_expression")
-                else Value(value, field)
-            )
-            for field in meta.local_concrete_fields
-            if field.name not in exclude and not field.generated
-        }
+        field_map = {}
+        generated_fields = []
+        for field in meta.local_concrete_fields:
+            if field.name in exclude:
+                continue
+            if field.generated:
+                if any(
+                    ref[0] in exclude
+                    for ref in self._get_expr_references(field.expression)
+                ):
+                    continue
+                generated_fields.append(field)
+                continue
+            value = getattr(self, field.attname)
+            if not value or not hasattr(value, "resolve_expression"):
+                value = Value(value, field)
+            field_map[field.name] = value
         if "pk" not in exclude:
             field_map["pk"] = Value(self.pk, meta.pk)
+        if generated_fields:
+            replacements = {F(name): value for name, value in field_map.items()}
+            for generated_field in generated_fields:
+                field_map[generated_field.name] = ExpressionWrapper(
+                    generated_field.expression.replace_expressions(replacements),
+                    generated_field.output_field,
+                )
+
         return field_map
 
     def prepare_database_save(self, field):

+ 57 - 24
django/db/models/constraints.py

@@ -68,6 +68,19 @@ class BaseConstraint:
     def remove_sql(self, model, schema_editor):
         raise NotImplementedError("This method must be implemented by a subclass.")
 
+    @classmethod
+    def _expression_refs_exclude(cls, model, expression, exclude):
+        get_field = model._meta.get_field
+        for field_name, *__ in model._get_expr_references(expression):
+            if field_name in exclude:
+                return True
+            field = get_field(field_name)
+            if field.generated and cls._expression_refs_exclude(
+                model, field.expression, exclude
+            ):
+                return True
+        return False
+
     def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
         raise NotImplementedError("This method must be implemented by a subclass.")
 
@@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint):
         queryset = model._default_manager.using(using)
         if self.fields:
             lookup_kwargs = {}
+            generated_field_names = []
             for field_name in self.fields:
                 if exclude and field_name in exclude:
                     return
                 field = model._meta.get_field(field_name)
-                lookup_value = getattr(instance, field.attname)
-                if (
-                    self.nulls_distinct is not False
-                    and lookup_value is None
-                    or (
-                        lookup_value == ""
-                        and connections[
-                            using
-                        ].features.interprets_empty_strings_as_nulls
-                    )
-                ):
-                    # A composite constraint containing NULL value cannot cause
-                    # a violation since NULL != NULL in SQL.
-                    return
-                lookup_kwargs[field.name] = lookup_value
-            queryset = queryset.filter(**lookup_kwargs)
+                if field.generated:
+                    if exclude and self._expression_refs_exclude(
+                        model, field.expression, exclude
+                    ):
+                        return
+                    generated_field_names.append(field.name)
+                else:
+                    lookup_value = getattr(instance, field.attname)
+                    if (
+                        self.nulls_distinct is not False
+                        and lookup_value is None
+                        or (
+                            lookup_value == ""
+                            and connections[
+                                using
+                            ].features.interprets_empty_strings_as_nulls
+                        )
+                    ):
+                        # A composite constraint containing NULL value cannot cause
+                        # a violation since NULL != NULL in SQL.
+                        return
+                    lookup_kwargs[field.name] = lookup_value
+            lookup_args = []
+            if generated_field_names:
+                field_expression_map = instance._get_field_expression_map(
+                    meta=model._meta, exclude=exclude
+                )
+                for field_name in generated_field_names:
+                    expression = field_expression_map[field_name]
+                    if self.nulls_distinct is False:
+                        lhs = F(field_name)
+                        condition = Q(Exact(lhs, expression)) | Q(
+                            IsNull(lhs, True), IsNull(expression, True)
+                        )
+                        lookup_args.append(condition)
+                    else:
+                        lookup_kwargs[field_name] = expression
+            queryset = queryset.filter(*lookup_args, **lookup_kwargs)
         else:
             # Ignore constraints with excluded fields.
-            if exclude:
-                for expression in self.expressions:
-                    if hasattr(expression, "flatten"):
-                        for expr in expression.flatten():
-                            if isinstance(expr, F) and expr.name in exclude:
-                                return
-                    elif isinstance(expression, F) and expression.name in exclude:
-                        return
+            if exclude and any(
+                self._expression_refs_exclude(model, expression, exclude)
+                for expression in self.expressions
+            ):
+                return
             replacements = {
                 F(field): value
                 for field, value in instance._get_field_expression_map(

+ 3 - 0
docs/releases/5.2.txt

@@ -215,6 +215,9 @@ Models
   methods such as
   :meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
 
+* Added support for validation of model constraints which use a
+  :class:`~django.db.models.GeneratedField`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 41 - 0
tests/constraints/models.py

@@ -1,4 +1,5 @@
 from django.db import models
+from django.db.models.functions import Coalesce, Lower
 
 
 class Product(models.Model):
@@ -28,6 +29,46 @@ class Product(models.Model):
         ]
 
 
+class GeneratedFieldStoredProduct(models.Model):
+    name = models.CharField(max_length=255, null=True)
+    price = models.IntegerField(null=True)
+    discounted_price = models.IntegerField(null=True)
+    rebate = models.GeneratedField(
+        expression=Coalesce("price", 0)
+        - Coalesce("discounted_price", Coalesce("price", 0)),
+        output_field=models.IntegerField(),
+        db_persist=True,
+    )
+    lower_name = models.GeneratedField(
+        expression=Lower(models.F("name")),
+        output_field=models.CharField(max_length=255, null=True),
+        db_persist=True,
+    )
+
+    class Meta:
+        required_db_features = {"supports_stored_generated_columns"}
+
+
+class GeneratedFieldVirtualProduct(models.Model):
+    name = models.CharField(max_length=255, null=True)
+    price = models.IntegerField(null=True)
+    discounted_price = models.IntegerField(null=True)
+    rebate = models.GeneratedField(
+        expression=Coalesce("price", 0)
+        - Coalesce("discounted_price", Coalesce("price", 0)),
+        output_field=models.IntegerField(),
+        db_persist=False,
+    )
+    lower_name = models.GeneratedField(
+        expression=Lower(models.F("name")),
+        output_field=models.CharField(max_length=255, null=True),
+        db_persist=False,
+    )
+
+    class Meta:
+        required_db_features = {"supports_virtual_generated_columns"}
+
+
 class UniqueConstraintProduct(models.Model):
     name = models.CharField(max_length=255)
     color = models.CharField(max_length=32, null=True)

+ 110 - 1
tests/constraints/tests.py

@@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
 from django.db import IntegrityError, connection, models
 from django.db.models import F
 from django.db.models.constraints import BaseConstraint, UniqueConstraint
-from django.db.models.functions import Abs, Lower, Upper
+from django.db.models.functions import Abs, Lower, Sqrt, Upper
 from django.db.transaction import atomic
 from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import ignore_warnings
@@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
 from .models import (
     ChildModel,
     ChildUniqueConstraintProduct,
+    GeneratedFieldStoredProduct,
+    GeneratedFieldVirtualProduct,
     JSONFieldModel,
     ModelWithDatabaseDefault,
     Product,
@@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
         with self.assertRaisesMessage(ValidationError, msg):
             json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
 
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_validate_generated_field_stored(self):
+        self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
+
+    @skipUnlessDBFeature("supports_virtual_generated_columns")
+    def test_validate_generated_field_virtual(self):
+        self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
+
+    def assertGeneratedFieldIsValidated(self, model):
+        constraint = models.CheckConstraint(
+            condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
+        )
+        constraint.validate(model, model(price=50, discounted_price=20))
+
+        invalid_product = model(price=1200, discounted_price=500)
+        msg = f"Constraint “{constraint.name}” is violated."
+        with self.assertRaisesMessage(ValidationError, msg):
+            constraint.validate(model, invalid_product)
+
+        # Excluding referenced or generated fields should skip validation.
+        constraint.validate(model, invalid_product, exclude={"price"})
+        constraint.validate(model, invalid_product, exclude={"rebate"})
+
     def test_check_deprecation(self):
         msg = "CheckConstraint.check is deprecated in favor of `.condition`."
         condition = models.Q(foo="bar")
@@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
             exclude={"name"},
         )
 
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_validate_expression_generated_field_stored(self):
+        self.assertGeneratedFieldWithExpressionIsValidated(
+            model=GeneratedFieldStoredProduct
+        )
+
+    @skipUnlessDBFeature("supports_virtual_generated_columns")
+    def test_validate_expression_generated_field_virtual(self):
+        self.assertGeneratedFieldWithExpressionIsValidated(
+            model=GeneratedFieldVirtualProduct
+        )
+
+    def assertGeneratedFieldWithExpressionIsValidated(self, model):
+        constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
+        model.objects.create(price=100, discounted_price=84)
+
+        valid_product = model(price=100, discounted_price=75)
+        constraint.validate(model, valid_product)
+
+        invalid_product = model(price=20, discounted_price=4)
+        with self.assertRaisesMessage(
+            ValidationError, f"Constraint “{constraint.name}” is violated."
+        ):
+            constraint.validate(model, invalid_product)
+
+        # Excluding referenced or generated fields should skip validation.
+        constraint.validate(model, invalid_product, exclude={"rebate"})
+        constraint.validate(model, invalid_product, exclude={"price"})
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_validate_fields_generated_field_stored(self):
+        self.assertGeneratedFieldWithFieldsIsValidated(
+            model=GeneratedFieldStoredProduct
+        )
+
+    @skipUnlessDBFeature("supports_virtual_generated_columns")
+    def test_validate_fields_generated_field_virtual(self):
+        self.assertGeneratedFieldWithFieldsIsValidated(
+            model=GeneratedFieldVirtualProduct
+        )
+
+    def assertGeneratedFieldWithFieldsIsValidated(self, model):
+        constraint = models.UniqueConstraint(
+            fields=["lower_name"], name="lower_name_unique"
+        )
+        model.objects.create(name="Box")
+        constraint.validate(model, model(name="Case"))
+
+        invalid_product = model(name="BOX")
+        msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
+        with self.assertRaisesMessage(ValidationError, msg):
+            constraint.validate(model, invalid_product)
+
+        # Excluding referenced or generated fields should skip validation.
+        constraint.validate(model, invalid_product, exclude={"lower_name"})
+        constraint.validate(model, invalid_product, exclude={"name"})
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_validate_fields_generated_field_stored_nulls_distinct(self):
+        self.assertGeneratedFieldNullsDistinctIsValidated(
+            model=GeneratedFieldStoredProduct
+        )
+
+    @skipUnlessDBFeature("supports_virtual_generated_columns")
+    def test_validate_fields_generated_field_virtual_nulls_distinct(self):
+        self.assertGeneratedFieldNullsDistinctIsValidated(
+            model=GeneratedFieldVirtualProduct
+        )
+
+    def assertGeneratedFieldNullsDistinctIsValidated(self, model):
+        constraint = models.UniqueConstraint(
+            fields=["lower_name"],
+            name="lower_name_unique_nulls_distinct",
+            nulls_distinct=False,
+        )
+        model.objects.create(name=None)
+        valid_product = model(name="Box")
+        constraint.validate(model, valid_product)
+
+        invalid_product = model(name=None)
+        msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
+        with self.assertRaisesMessage(ValidationError, msg):
+            constraint.validate(model, invalid_product)
+
     @skipUnlessDBFeature("supports_table_check_constraints")
     def test_validate_nullable_textfield_with_isnull_true(self):
         is_null_constraint = models.UniqueConstraint(

+ 34 - 0
tests/postgres_tests/test_constraints.py

@@ -14,6 +14,7 @@ from django.db.models import (
     F,
     ForeignKey,
     Func,
+    GeneratedField,
     IntegerField,
     Model,
     Q,
@@ -32,6 +33,7 @@ try:
     from django.contrib.postgres.constraints import ExclusionConstraint
     from django.contrib.postgres.fields import (
         DateTimeRangeField,
+        IntegerRangeField,
         RangeBoundary,
         RangeOperators,
     )
@@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
         constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
 
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    @isolate_apps("postgres_tests")
+    def test_validate_generated_field_range_adjacent(self):
+        class RangesModelGeneratedField(Model):
+            ints = IntegerRangeField(blank=True, null=True)
+            ints_generated = GeneratedField(
+                expression=F("ints"),
+                output_field=IntegerRangeField(null=True),
+                db_persist=True,
+            )
+
+        with connection.schema_editor() as editor:
+            editor.create_model(RangesModelGeneratedField)
+
+        constraint = ExclusionConstraint(
+            name="ints_adjacent",
+            expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
+            violation_error_code="custom_code",
+            violation_error_message="Custom error message.",
+        )
+        RangesModelGeneratedField.objects.create(ints=(20, 50))
+
+        range_obj = RangesModelGeneratedField(ints=(3, 20))
+        with self.assertRaisesMessage(ValidationError, "Custom error message."):
+            constraint.validate(RangesModelGeneratedField, range_obj)
+
+        # Excluding referenced or generated field should skip validation.
+        constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
+        constraint.validate(
+            RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
+        )
+
     def test_validate_with_custom_code_and_condition(self):
         constraint = ExclusionConstraint(
             name="ints_adjacent",