Browse Source

Fixed #34338 -- Allowed customizing code of ValidationError in BaseConstraint and subclasses.

Xavier Fernandez 2 years ago
parent
commit
5b3d3e400a

+ 19 - 4
django/contrib/postgres/constraints.py

@@ -32,6 +32,7 @@ class ExclusionConstraint(BaseConstraint):
         condition=None,
         deferrable=None,
         include=None,
+        violation_error_code=None,
         violation_error_message=None,
     ):
         if index_type and index_type.lower() not in {"gist", "spgist"}:
@@ -60,7 +61,11 @@ class ExclusionConstraint(BaseConstraint):
         self.condition = condition
         self.deferrable = deferrable
         self.include = tuple(include) if include else ()
-        super().__init__(name=name, violation_error_message=violation_error_message)
+        super().__init__(
+            name=name,
+            violation_error_code=violation_error_code,
+            violation_error_message=violation_error_message,
+        )
 
     def _get_expressions(self, schema_editor, query):
         expressions = []
@@ -149,12 +154,13 @@ class ExclusionConstraint(BaseConstraint):
                 and self.condition == other.condition
                 and self.deferrable == other.deferrable
                 and self.include == other.include
+                and self.violation_error_code == other.violation_error_code
                 and self.violation_error_message == other.violation_error_message
             )
         return super().__eq__(other)
 
     def __repr__(self):
-        return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % (
+        return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % (
             self.__class__.__qualname__,
             repr(self.index_type),
             repr(self.expressions),
@@ -162,6 +168,11 @@ class ExclusionConstraint(BaseConstraint):
             "" if self.condition is None else " condition=%s" % self.condition,
             "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
             "" if not self.include else " include=%s" % repr(self.include),
+            (
+                ""
+                if self.violation_error_code is None
+                else " violation_error_code=%r" % self.violation_error_code
+            ),
             (
                 ""
                 if self.violation_error_message is None
@@ -204,9 +215,13 @@ class ExclusionConstraint(BaseConstraint):
             queryset = queryset.exclude(pk=model_class_pk)
         if not self.condition:
             if queryset.exists():
-                raise ValidationError(self.get_violation_error_message())
+                raise ValidationError(
+                    self.get_violation_error_message(), code=self.violation_error_code
+                )
         else:
             if (self.condition & Exists(queryset.filter(self.condition))).check(
                 replacement_map, using=using
             ):
-                raise ValidationError(self.get_violation_error_message())
+                raise ValidationError(
+                    self.get_violation_error_message(), code=self.violation_error_code
+                )

+ 51 - 11
django/db/models/constraints.py

@@ -18,11 +18,16 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"
 
 class BaseConstraint:
     default_violation_error_message = _("Constraint “%(name)s” is violated.")
+    violation_error_code = None
     violation_error_message = None
 
     # RemovedInDjango60Warning: When the deprecation ends, replace with:
-    # def __init__(self, *, name, violation_error_message=None):
-    def __init__(self, *args, name=None, violation_error_message=None):
+    # def __init__(
+    #     self, *, name, violation_error_code=None, violation_error_message=None
+    # ):
+    def __init__(
+        self, *args, name=None, violation_error_code=None, violation_error_message=None
+    ):
         # RemovedInDjango60Warning.
         if name is None and not args:
             raise TypeError(
@@ -30,6 +35,8 @@ class BaseConstraint:
                 f"argument: 'name'"
             )
         self.name = name
+        if violation_error_code is not None:
+            self.violation_error_code = violation_error_code
         if violation_error_message is not None:
             self.violation_error_message = violation_error_message
         else:
@@ -74,6 +81,8 @@ class BaseConstraint:
             and self.violation_error_message != self.default_violation_error_message
         ):
             kwargs["violation_error_message"] = self.violation_error_message
+        if self.violation_error_code is not None:
+            kwargs["violation_error_code"] = self.violation_error_code
         return (path, (), kwargs)
 
     def clone(self):
@@ -82,13 +91,19 @@ class BaseConstraint:
 
 
 class CheckConstraint(BaseConstraint):
-    def __init__(self, *, check, name, violation_error_message=None):
+    def __init__(
+        self, *, check, name, violation_error_code=None, violation_error_message=None
+    ):
         self.check = check
         if not getattr(check, "conditional", False):
             raise TypeError(
                 "CheckConstraint.check must be a Q instance or boolean expression."
             )
-        super().__init__(name=name, violation_error_message=violation_error_message)
+        super().__init__(
+            name=name,
+            violation_error_code=violation_error_code,
+            violation_error_message=violation_error_message,
+        )
 
     def _get_check_sql(self, model, schema_editor):
         query = Query(model=model, alias_cols=False)
@@ -112,15 +127,22 @@ class CheckConstraint(BaseConstraint):
         against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
         try:
             if not Q(self.check).check(against, using=using):
-                raise ValidationError(self.get_violation_error_message())
+                raise ValidationError(
+                    self.get_violation_error_message(), code=self.violation_error_code
+                )
         except FieldError:
             pass
 
     def __repr__(self):
-        return "<%s: check=%s name=%s%s>" % (
+        return "<%s: check=%s name=%s%s%s>" % (
             self.__class__.__qualname__,
             self.check,
             repr(self.name),
+            (
+                ""
+                if self.violation_error_code is None
+                else " violation_error_code=%r" % self.violation_error_code
+            ),
             (
                 ""
                 if self.violation_error_message is None
@@ -134,6 +156,7 @@ class CheckConstraint(BaseConstraint):
             return (
                 self.name == other.name
                 and self.check == other.check
+                and self.violation_error_code == other.violation_error_code
                 and self.violation_error_message == other.violation_error_message
             )
         return super().__eq__(other)
@@ -163,6 +186,7 @@ class UniqueConstraint(BaseConstraint):
         deferrable=None,
         include=None,
         opclasses=(),
+        violation_error_code=None,
         violation_error_message=None,
     ):
         if not name:
@@ -213,7 +237,11 @@ class UniqueConstraint(BaseConstraint):
             F(expression) if isinstance(expression, str) else expression
             for expression in expressions
         )
-        super().__init__(name=name, violation_error_message=violation_error_message)
+        super().__init__(
+            name=name,
+            violation_error_code=violation_error_code,
+            violation_error_message=violation_error_message,
+        )
 
     @property
     def contains_expressions(self):
@@ -293,7 +321,7 @@ class UniqueConstraint(BaseConstraint):
         )
 
     def __repr__(self):
-        return "<%s:%s%s%s%s%s%s%s%s>" % (
+        return "<%s:%s%s%s%s%s%s%s%s%s>" % (
             self.__class__.__qualname__,
             "" if not self.fields else " fields=%s" % repr(self.fields),
             "" if not self.expressions else " expressions=%s" % repr(self.expressions),
@@ -302,6 +330,11 @@ class UniqueConstraint(BaseConstraint):
             "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
             "" if not self.include else " include=%s" % repr(self.include),
             "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
+            (
+                ""
+                if self.violation_error_code is None
+                else " violation_error_code=%r" % self.violation_error_code
+            ),
             (
                 ""
                 if self.violation_error_message is None
@@ -320,6 +353,7 @@ class UniqueConstraint(BaseConstraint):
                 and self.include == other.include
                 and self.opclasses == other.opclasses
                 and self.expressions == other.expressions
+                and self.violation_error_code == other.violation_error_code
                 and self.violation_error_message == other.violation_error_message
             )
         return super().__eq__(other)
@@ -385,14 +419,17 @@ class UniqueConstraint(BaseConstraint):
         if not self.condition:
             if queryset.exists():
                 if self.expressions:
-                    raise ValidationError(self.get_violation_error_message())
+                    raise ValidationError(
+                        self.get_violation_error_message(),
+                        code=self.violation_error_code,
+                    )
                 # When fields are defined, use the unique_error_message() for
                 # backward compatibility.
                 for model, constraints in instance.get_constraints():
                     for constraint in constraints:
                         if constraint is self:
                             raise ValidationError(
-                                instance.unique_error_message(model, self.fields)
+                                instance.unique_error_message(model, self.fields),
                             )
         else:
             against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
@@ -400,6 +437,9 @@ class UniqueConstraint(BaseConstraint):
                 if (self.condition & Exists(queryset.filter(self.condition))).check(
                     against, using=using
                 ):
-                    raise ValidationError(self.get_violation_error_message())
+                    raise ValidationError(
+                        self.get_violation_error_message(),
+                        code=self.violation_error_code,
+                    )
             except FieldError:
                 pass

+ 11 - 1
docs/ref/contrib/postgres/constraints.txt

@@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
 ``ExclusionConstraint``
 =======================
 
-.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_message=None)
+.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_code=None, violation_error_message=None)
 
     Creates an exclusion constraint in the database. Internally, PostgreSQL
     implements exclusion constraints using indexes. The default index type is
@@ -133,6 +133,16 @@ used for queries that select only included fields
 ``include`` is supported for GiST indexes. PostgreSQL 14+ also supports
 ``include`` for SP-GiST indexes.
 
+``violation_error_code``
+------------------------
+
+.. versionadded:: 5.0
+
+.. attribute:: ExclusionConstraint.violation_error_code
+
+The error code used when ``ValidationError`` is raised during
+:ref:`model validation <validating-objects>`. Defaults to ``None``.
+
 ``violation_error_message``
 ---------------------------
 

+ 29 - 3
docs/ref/models/constraints.txt

@@ -48,7 +48,7 @@ option.
 ``BaseConstraint``
 ==================
 
-.. class:: BaseConstraint(*, name, violation_error_message=None)
+.. class:: BaseConstraint(* name, violation_error_code=None, violation_error_message=None)
 
     Base class for all constraints. Subclasses must implement
     ``constraint_sql()``, ``create_sql()``, ``remove_sql()`` and
@@ -68,6 +68,16 @@ All constraints have the following parameters in common:
 The name of the constraint. You must always specify a unique name for the
 constraint.
 
+``violation_error_code``
+------------------------
+
+.. versionadded:: 5.0
+
+.. attribute:: BaseConstraint.violation_error_code
+
+The error code used when ``ValidationError`` is raised during
+:ref:`model validation <validating-objects>`. Defaults to ``None``.
+
 ``violation_error_message``
 ---------------------------
 
@@ -94,7 +104,7 @@ This method must be implemented by a subclass.
 ``CheckConstraint``
 ===================
 
-.. class:: CheckConstraint(*, check, name, violation_error_message=None)
+.. class:: CheckConstraint(*, check, name, violation_error_code=None, violation_error_message=None)
 
     Creates a check constraint in the database.
 
@@ -121,7 +131,7 @@ ensures the age field is never less than 18.
 ``UniqueConstraint``
 ====================
 
-.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_message=None)
+.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_code=None, violation_error_message=None)
 
     Creates a unique constraint in the database.
 
@@ -242,6 +252,22 @@ creates a unique index on ``username`` using ``varchar_pattern_ops``.
 
 ``opclasses`` are ignored for databases besides PostgreSQL.
 
+``violation_error_code``
+------------------------
+
+.. versionadded:: 5.0
+
+.. attribute:: UniqueConstraint.violation_error_code
+
+The error code used when ``ValidationError`` is raised during
+:ref:`model validation <validating-objects>`. Defaults to ``None``.
+
+This code is *not used* for :class:`UniqueConstraint`\s with
+:attr:`~UniqueConstraint.fields` and without a
+:attr:`~UniqueConstraint.condition`. Such :class:`~UniqueConstraint`\s have the
+same error code as constraints defined with :attr:`.Field.unique` or in
+:attr:`Meta.unique_together <django.db.models.Options.constraints>`.
+
 ``violation_error_message``
 ---------------------------
 

+ 11 - 1
docs/releases/5.0.txt

@@ -78,7 +78,10 @@ Minor features
 :mod:`django.contrib.postgres`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-* ...
+* The new :attr:`~.ExclusionConstraint.violation_error_code` attribute of
+  :class:`~django.contrib.postgres.constraints.ExclusionConstraint` allows
+  customizing the ``code`` of ``ValidationError`` raised during
+  :ref:`model validation <validating-objects>`.
 
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -182,6 +185,13 @@ Models
   and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different
   field values for the create operation.
 
+* The new ``violation_error_code`` attribute of
+  :class:`~django.db.models.BaseConstraint`,
+  :class:`~django.db.models.CheckConstraint`, and
+  :class:`~django.db.models.UniqueConstraint` allows customizing the ``code``
+  of ``ValidationError`` raised during
+  :ref:`model validation <validating-objects>`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 108 - 2
tests/constraints/tests.py

@@ -77,17 +77,26 @@ class BaseConstraintTests(SimpleTestCase):
             "custom base_name message",
         )
 
+    def test_custom_violation_code_message(self):
+        c = BaseConstraint(name="base_name", violation_error_code="custom_code")
+        self.assertEqual(c.violation_error_code, "custom_code")
+
     def test_deconstruction(self):
         constraint = BaseConstraint(
             name="base_name",
             violation_error_message="custom %(name)s message",
+            violation_error_code="custom_code",
         )
         path, args, kwargs = constraint.deconstruct()
         self.assertEqual(path, "django.db.models.BaseConstraint")
         self.assertEqual(args, ())
         self.assertEqual(
             kwargs,
-            {"name": "base_name", "violation_error_message": "custom %(name)s message"},
+            {
+                "name": "base_name",
+                "violation_error_message": "custom %(name)s message",
+                "violation_error_code": "custom_code",
+            },
         )
 
     def test_deprecation(self):
@@ -148,6 +157,20 @@ class CheckConstraintTests(TestCase):
                 check=check1, name="price", violation_error_message="custom error"
             ),
         )
+        self.assertNotEqual(
+            models.CheckConstraint(check=check1, name="price"),
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_code="custom_code"
+            ),
+        )
+        self.assertEqual(
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_code="custom_code"
+            ),
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_code="custom_code"
+            ),
+        )
 
     def test_repr(self):
         constraint = models.CheckConstraint(
@@ -172,6 +195,18 @@ class CheckConstraintTests(TestCase):
             "violation_error_message='More than 1'>",
         )
 
+    def test_repr_with_violation_error_code(self):
+        constraint = models.CheckConstraint(
+            check=models.Q(price__lt=1),
+            name="price_lt_one",
+            violation_error_code="more_than_one",
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<CheckConstraint: check=(AND: ('price__lt', 1)) name='price_lt_one' "
+            "violation_error_code='more_than_one'>",
+        )
+
     def test_invalid_check_types(self):
         msg = "CheckConstraint.check must be a Q instance or boolean expression."
         with self.assertRaisesMessage(TypeError, msg):
@@ -237,6 +272,21 @@ class CheckConstraintTests(TestCase):
         # Valid product.
         constraint.validate(Product, Product(price=10, discounted_price=5))
 
+    def test_validate_custom_error(self):
+        check = models.Q(price__gt=models.F("discounted_price"))
+        constraint = models.CheckConstraint(
+            check=check,
+            name="price",
+            violation_error_message="discount is fake",
+            violation_error_code="fake_discount",
+        )
+        # Invalid product.
+        invalid_product = Product(price=10, discounted_price=42)
+        msg = "discount is fake"
+        with self.assertRaisesMessage(ValidationError, msg) as cm:
+            constraint.validate(Product, invalid_product)
+        self.assertEqual(cm.exception.code, "fake_discount")
+
     def test_validate_boolean_expressions(self):
         constraint = models.CheckConstraint(
             check=models.expressions.ExpressionWrapper(
@@ -341,6 +391,30 @@ class UniqueConstraintTests(TestCase):
                 violation_error_message="custom error",
             ),
         )
+        self.assertNotEqual(
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_code="custom_error",
+            ),
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_code="other_custom_error",
+            ),
+        )
+        self.assertEqual(
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_code="custom_error",
+            ),
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_code="custom_error",
+            ),
+        )
 
     def test_eq_with_condition(self):
         self.assertEqual(
@@ -512,6 +586,20 @@ class UniqueConstraintTests(TestCase):
             ),
         )
 
+    def test_repr_with_violation_error_code(self):
+        constraint = models.UniqueConstraint(
+            models.F("baz__lower"),
+            name="unique_lower_baz",
+            violation_error_code="baz",
+        )
+        self.assertEqual(
+            repr(constraint),
+            (
+                "<UniqueConstraint: expressions=(F(baz__lower),) "
+                "name='unique_lower_baz' violation_error_code='baz'>"
+            ),
+        )
+
     def test_deconstruction(self):
         fields = ["foo", "bar"]
         name = "unique_fields"
@@ -656,12 +744,16 @@ class UniqueConstraintTests(TestCase):
 
     def test_validate(self):
         constraint = UniqueConstraintProduct._meta.constraints[0]
+        # Custom message and error code are ignored.
+        constraint.violation_error_message = "Custom message"
+        constraint.violation_error_code = "custom_code"
         msg = "Unique constraint product with this Name and Color already exists."
         non_unique_product = UniqueConstraintProduct(
             name=self.p1.name, color=self.p1.color
         )
-        with self.assertRaisesMessage(ValidationError, msg):
+        with self.assertRaisesMessage(ValidationError, msg) as cm:
             constraint.validate(UniqueConstraintProduct, non_unique_product)
+        self.assertEqual(cm.exception.code, "unique_together")
         # Null values are ignored.
         constraint.validate(
             UniqueConstraintProduct,
@@ -716,6 +808,20 @@ class UniqueConstraintTests(TestCase):
             exclude={"name"},
         )
 
+    @skipUnlessDBFeature("supports_partial_indexes")
+    def test_validate_conditon_custom_error(self):
+        p1 = UniqueConstraintConditionProduct.objects.create(name="p1")
+        constraint = UniqueConstraintConditionProduct._meta.constraints[0]
+        constraint.violation_error_message = "Custom message"
+        constraint.violation_error_code = "custom_code"
+        msg = "Custom message"
+        with self.assertRaisesMessage(ValidationError, msg) as cm:
+            constraint.validate(
+                UniqueConstraintConditionProduct,
+                UniqueConstraintConditionProduct(name=p1.name, color=None),
+            )
+        self.assertEqual(cm.exception.code, "custom_code")
+
     def test_validate_expression(self):
         constraint = models.UniqueConstraint(Lower("name"), name="name_lower_uniq")
         msg = "Constraint “name_lower_uniq” is violated."

+ 39 - 1
tests/postgres_tests/test_constraints.py

@@ -397,6 +397,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             "(F(datespan), '-|-')] name='exclude_overlapping' "
             "violation_error_message='Overlapping must be excluded'>",
         )
+        constraint = ExclusionConstraint(
+            name="exclude_overlapping",
+            expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
+            violation_error_code="overlapping_must_be_excluded",
+        )
+        self.assertEqual(
+            repr(constraint),
+            "<ExclusionConstraint: index_type='GIST' expressions=["
+            "(F(datespan), '-|-')] name='exclude_overlapping' "
+            "violation_error_code='overlapping_must_be_excluded'>",
+        )
 
     def test_eq(self):
         constraint_1 = ExclusionConstraint(
@@ -470,6 +481,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             condition=Q(cancelled=False),
             violation_error_message="other custom error",
         )
+        constraint_12 = ExclusionConstraint(
+            name="exclude_overlapping",
+            expressions=[
+                (F("datespan"), RangeOperators.OVERLAPS),
+                (F("room"), RangeOperators.EQUAL),
+            ],
+            condition=Q(cancelled=False),
+            violation_error_code="custom_code",
+            violation_error_message="other custom error",
+        )
         self.assertEqual(constraint_1, constraint_1)
         self.assertEqual(constraint_1, mock.ANY)
         self.assertNotEqual(constraint_1, constraint_2)
@@ -483,7 +504,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         self.assertNotEqual(constraint_5, constraint_6)
         self.assertNotEqual(constraint_1, object())
         self.assertNotEqual(constraint_10, constraint_11)
+        self.assertNotEqual(constraint_11, constraint_12)
         self.assertEqual(constraint_10, constraint_10)
+        self.assertEqual(constraint_12, constraint_12)
 
     def test_deconstruct(self):
         constraint = ExclusionConstraint(
@@ -760,17 +783,32 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
         constraint = ExclusionConstraint(
             name="ints_adjacent",
             expressions=[("ints", RangeOperators.ADJACENT_TO)],
+            violation_error_code="custom_code",
             violation_error_message="Custom error message.",
         )
         range_obj = RangesModel.objects.create(ints=(20, 50))
         constraint.validate(RangesModel, range_obj)
         msg = "Custom error message."
-        with self.assertRaisesMessage(ValidationError, msg):
+        with self.assertRaisesMessage(ValidationError, msg) as cm:
             constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
+        self.assertEqual(cm.exception.code, "custom_code")
         constraint.validate(RangesModel, RangesModel(ints=(10, 19)))
         constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
         constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
 
+    def test_validate_with_custom_code_and_condition(self):
+        constraint = ExclusionConstraint(
+            name="ints_adjacent",
+            expressions=[("ints", RangeOperators.ADJACENT_TO)],
+            violation_error_code="custom_code",
+            condition=Q(ints__lt=(100, 200)),
+        )
+        range_obj = RangesModel.objects.create(ints=(20, 50))
+        constraint.validate(RangesModel, range_obj)
+        with self.assertRaises(ValidationError) as cm:
+            constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
+        self.assertEqual(cm.exception.code, "custom_code")
+
     def test_expressions_with_params(self):
         constraint_name = "scene_left_equal"
         self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))