Browse Source

Fixed #33829 -- Made BaseConstraint.deconstruct() and equality handle violation_error_message.

Regression in 667105877e6723c6985399803a364848891513cc.
Stéphane "Twidi" Angel 2 years ago
parent
commit
ccbf714ebe

+ 1 - 0
django/contrib/postgres/constraints.py

@@ -177,6 +177,7 @@ class ExclusionConstraint(BaseConstraint):
                 and self.deferrable == other.deferrable
                 and self.include == other.include
                 and self.opclasses == other.opclasses
+                and self.violation_error_message == other.violation_error_message
             )
         return super().__eq__(other)
 

+ 17 - 3
django/db/models/constraints.py

@@ -14,12 +14,15 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"
 
 
 class BaseConstraint:
-    violation_error_message = _("Constraint “%(name)s” is violated.")
+    default_violation_error_message = _("Constraint “%(name)s” is violated.")
+    violation_error_message = None
 
     def __init__(self, name, violation_error_message=None):
         self.name = name
         if violation_error_message is not None:
             self.violation_error_message = violation_error_message
+        else:
+            self.violation_error_message = self.default_violation_error_message
 
     @property
     def contains_expressions(self):
@@ -43,7 +46,13 @@ class BaseConstraint:
     def deconstruct(self):
         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
         path = path.replace("django.db.models.constraints", "django.db.models")
-        return (path, (), {"name": self.name})
+        kwargs = {"name": self.name}
+        if (
+            self.violation_error_message is not None
+            and self.violation_error_message != self.default_violation_error_message
+        ):
+            kwargs["violation_error_message"] = self.violation_error_message
+        return (path, (), kwargs)
 
     def clone(self):
         _, args, kwargs = self.deconstruct()
@@ -94,7 +103,11 @@ class CheckConstraint(BaseConstraint):
 
     def __eq__(self, other):
         if isinstance(other, CheckConstraint):
-            return self.name == other.name and self.check == other.check
+            return (
+                self.name == other.name
+                and self.check == other.check
+                and self.violation_error_message == other.violation_error_message
+            )
         return super().__eq__(other)
 
     def deconstruct(self):
@@ -273,6 +286,7 @@ class UniqueConstraint(BaseConstraint):
                 and self.include == other.include
                 and self.opclasses == other.opclasses
                 and self.expressions == other.expressions
+                and self.violation_error_message == other.violation_error_message
             )
         return super().__eq__(other)
 

+ 77 - 0
tests/constraints/tests.py

@@ -65,6 +65,29 @@ class BaseConstraintTests(SimpleTestCase):
         )
         self.assertEqual(c.get_violation_error_message(), "custom base_name message")
 
+    def test_custom_violation_error_message_clone(self):
+        constraint = BaseConstraint(
+            "base_name",
+            violation_error_message="custom %(name)s message",
+        ).clone()
+        self.assertEqual(
+            constraint.get_violation_error_message(),
+            "custom base_name message",
+        )
+
+    def test_deconstruction(self):
+        constraint = BaseConstraint(
+            "base_name",
+            violation_error_message="custom %(name)s message",
+        )
+        path, args, kwargs = constraint.deconstruct()
+        self.assertEqual(path, "django.db.models.BaseConstraint")
+        self.assertEqual(args, ())
+        self.assertEqual(
+            kwargs,
+            {"name": "base_name", "violation_error_message": "custom %(name)s message"},
+        )
+
 
 class CheckConstraintTests(TestCase):
     def test_eq(self):
@@ -84,6 +107,28 @@ class CheckConstraintTests(TestCase):
             models.CheckConstraint(check=check2, name="price"),
         )
         self.assertNotEqual(models.CheckConstraint(check=check1, name="price"), 1)
+        self.assertNotEqual(
+            models.CheckConstraint(check=check1, name="price"),
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_message="custom error"
+            ),
+        )
+        self.assertNotEqual(
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_message="custom error"
+            ),
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_message="other custom error"
+            ),
+        )
+        self.assertEqual(
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_message="custom error"
+            ),
+            models.CheckConstraint(
+                check=check1, name="price", violation_error_message="custom error"
+            ),
+        )
 
     def test_repr(self):
         constraint = models.CheckConstraint(
@@ -216,6 +261,38 @@ class UniqueConstraintTests(TestCase):
         self.assertNotEqual(
             models.UniqueConstraint(fields=["foo", "bar"], name="unique"), 1
         )
+        self.assertNotEqual(
+            models.UniqueConstraint(fields=["foo", "bar"], name="unique"),
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_message="custom error",
+            ),
+        )
+        self.assertNotEqual(
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_message="custom error",
+            ),
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_message="other custom error",
+            ),
+        )
+        self.assertEqual(
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_message="custom error",
+            ),
+            models.UniqueConstraint(
+                fields=["foo", "bar"],
+                name="unique",
+                violation_error_message="custom error",
+            ),
+        )
 
     def test_eq_with_condition(self):
         self.assertEqual(

+ 22 - 0
tests/postgres_tests/test_constraints.py

@@ -444,17 +444,39 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             )
             self.assertNotEqual(constraint_2, constraint_9)
             self.assertNotEqual(constraint_7, constraint_8)
+
+        constraint_10 = ExclusionConstraint(
+            name="exclude_overlapping",
+            expressions=[
+                (F("datespan"), RangeOperators.OVERLAPS),
+                (F("room"), RangeOperators.EQUAL),
+            ],
+            condition=Q(cancelled=False),
+            violation_error_message="custom error",
+        )
+        constraint_11 = ExclusionConstraint(
+            name="exclude_overlapping",
+            expressions=[
+                (F("datespan"), RangeOperators.OVERLAPS),
+                (F("room"), RangeOperators.EQUAL),
+            ],
+            condition=Q(cancelled=False),
+            violation_error_message="other custom error",
+        )
         self.assertEqual(constraint_1, constraint_1)
         self.assertEqual(constraint_1, mock.ANY)
         self.assertNotEqual(constraint_1, constraint_2)
         self.assertNotEqual(constraint_1, constraint_3)
         self.assertNotEqual(constraint_1, constraint_4)
+        self.assertNotEqual(constraint_1, constraint_10)
         self.assertNotEqual(constraint_2, constraint_3)
         self.assertNotEqual(constraint_2, constraint_4)
         self.assertNotEqual(constraint_2, constraint_7)
         self.assertNotEqual(constraint_4, constraint_5)
         self.assertNotEqual(constraint_5, constraint_6)
         self.assertNotEqual(constraint_1, object())
+        self.assertNotEqual(constraint_10, constraint_11)
+        self.assertEqual(constraint_10, constraint_10)
 
     def test_deconstruct(self):
         constraint = ExclusionConstraint(