Sfoglia il codice sorgente

Fixed #35038 -- Created AlterConstraint operation.

Salvo Polizzi 4 mesi fa
parent
commit
b82f80906a

+ 60 - 2
django/db/migrations/autodetector.py

@@ -219,6 +219,7 @@ class MigrationAutodetector:
         self.generate_altered_unique_together()
         self.generate_added_indexes()
         self.generate_added_constraints()
+        self.generate_altered_constraints()
         self.generate_altered_db_table()
 
         self._sort_migrations()
@@ -1450,6 +1451,19 @@ class MigrationAutodetector:
                     ),
                 )
 
+    def _constraint_should_be_dropped_and_recreated(
+        self, old_constraint, new_constraint
+    ):
+        old_path, old_args, old_kwargs = old_constraint.deconstruct()
+        new_path, new_args, new_kwargs = new_constraint.deconstruct()
+
+        for attr in old_constraint.non_db_attrs:
+            old_kwargs.pop(attr, None)
+        for attr in new_constraint.non_db_attrs:
+            new_kwargs.pop(attr, None)
+
+        return (old_path, old_args, old_kwargs) != (new_path, new_args, new_kwargs)
+
     def create_altered_constraints(self):
         option_name = operations.AddConstraint.option_name
         for app_label, model_name in sorted(self.kept_model_keys):
@@ -1461,14 +1475,41 @@ class MigrationAutodetector:
 
             old_constraints = old_model_state.options[option_name]
             new_constraints = new_model_state.options[option_name]
-            add_constraints = [c for c in new_constraints if c not in old_constraints]
-            rem_constraints = [c for c in old_constraints if c not in new_constraints]
+
+            alt_constraints = []
+            alt_constraints_name = []
+
+            for old_c in old_constraints:
+                for new_c in new_constraints:
+                    old_c_dec = old_c.deconstruct()
+                    new_c_dec = new_c.deconstruct()
+                    if (
+                        old_c_dec != new_c_dec
+                        and old_c.name == new_c.name
+                        and not self._constraint_should_be_dropped_and_recreated(
+                            old_c, new_c
+                        )
+                    ):
+                        alt_constraints.append(new_c)
+                        alt_constraints_name.append(new_c.name)
+
+            add_constraints = [
+                c
+                for c in new_constraints
+                if c not in old_constraints and c.name not in alt_constraints_name
+            ]
+            rem_constraints = [
+                c
+                for c in old_constraints
+                if c not in new_constraints and c.name not in alt_constraints_name
+            ]
 
             self.altered_constraints.update(
                 {
                     (app_label, model_name): {
                         "added_constraints": add_constraints,
                         "removed_constraints": rem_constraints,
+                        "altered_constraints": alt_constraints,
                     }
                 }
             )
@@ -1503,6 +1544,23 @@ class MigrationAutodetector:
                     ),
                 )
 
+    def generate_altered_constraints(self):
+        for (
+            app_label,
+            model_name,
+        ), alt_constraints in self.altered_constraints.items():
+            dependencies = self._get_dependencies_for_model(app_label, model_name)
+            for constraint in alt_constraints["altered_constraints"]:
+                self.add_operation(
+                    app_label,
+                    operations.AlterConstraint(
+                        model_name=model_name,
+                        name=constraint.name,
+                        constraint=constraint,
+                    ),
+                    dependencies=dependencies,
+                )
+
     @staticmethod
     def _get_dependencies_for_foreign_key(app_label, model_name, field, project_state):
         remote_field_model = None

+ 2 - 0
django/db/migrations/operations/__init__.py

@@ -2,6 +2,7 @@ from .fields import AddField, AlterField, RemoveField, RenameField
 from .models import (
     AddConstraint,
     AddIndex,
+    AlterConstraint,
     AlterIndexTogether,
     AlterModelManagers,
     AlterModelOptions,
@@ -36,6 +37,7 @@ __all__ = [
     "RenameField",
     "AddConstraint",
     "RemoveConstraint",
+    "AlterConstraint",
     "SeparateDatabaseAndState",
     "RunSQL",
     "RunPython",

+ 54 - 0
django/db/migrations/operations/models.py

@@ -1230,6 +1230,12 @@ class AddConstraint(IndexOperation):
             and self.constraint.name == operation.name
         ):
             return []
+        if (
+            isinstance(operation, AlterConstraint)
+            and self.model_name_lower == operation.model_name_lower
+            and self.constraint.name == operation.name
+        ):
+            return [AddConstraint(self.model_name, operation.constraint)]
         return super().reduce(operation, app_label)
 
 
@@ -1274,3 +1280,51 @@ class RemoveConstraint(IndexOperation):
     @property
     def migration_name_fragment(self):
         return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
+
+
+class AlterConstraint(IndexOperation):
+    category = OperationCategory.ALTERATION
+    option_name = "constraints"
+
+    def __init__(self, model_name, name, constraint):
+        self.model_name = model_name
+        self.name = name
+        self.constraint = constraint
+
+    def state_forwards(self, app_label, state):
+        state.alter_constraint(
+            app_label, self.model_name_lower, self.name, self.constraint
+        )
+
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        pass
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        pass
+
+    def deconstruct(self):
+        return (
+            self.__class__.__name__,
+            [],
+            {
+                "model_name": self.model_name,
+                "name": self.name,
+                "constraint": self.constraint,
+            },
+        )
+
+    def describe(self):
+        return f"Alter constraint {self.name} on {self.model_name}"
+
+    @property
+    def migration_name_fragment(self):
+        return "alter_%s_%s" % (self.model_name_lower, self.constraint.name.lower())
+
+    def reduce(self, operation, app_label):
+        if (
+            isinstance(operation, (AlterConstraint, RemoveConstraint))
+            and self.model_name_lower == operation.model_name_lower
+            and self.name == operation.name
+        ):
+            return [operation]
+        return super().reduce(operation, app_label)

+ 13 - 0
django/db/migrations/state.py

@@ -211,6 +211,14 @@ class ProjectState:
         model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
         self.reload_model(app_label, model_name, delay=True)
 
+    def _alter_option(self, app_label, model_name, option_name, obj_name, alt_obj):
+        model_state = self.models[app_label, model_name]
+        objs = model_state.options[option_name]
+        model_state.options[option_name] = [
+            obj if obj.name != obj_name else alt_obj for obj in objs
+        ]
+        self.reload_model(app_label, model_name, delay=True)
+
     def add_index(self, app_label, model_name, index):
         self._append_option(app_label, model_name, "indexes", index)
 
@@ -237,6 +245,11 @@ class ProjectState:
     def remove_constraint(self, app_label, model_name, constraint_name):
         self._remove_option(app_label, model_name, "constraints", constraint_name)
 
+    def alter_constraint(self, app_label, model_name, constraint_name, constraint):
+        self._alter_option(
+            app_label, model_name, "constraints", constraint_name, constraint
+        )
+
     def add_field(self, app_label, model_name, name, field, preserve_default):
         # If preserve default is off, don't use the default for future state.
         if not preserve_default:

+ 2 - 0
django/db/models/constraints.py

@@ -23,6 +23,8 @@ class BaseConstraint:
     violation_error_code = None
     violation_error_message = None
 
+    non_db_attrs = ("violation_error_code", "violation_error_message")
+
     # RemovedInDjango60Warning: When the deprecation ends, replace with:
     # def __init__(
     #     self, *, name, violation_error_code=None, violation_error_message=None

+ 10 - 0
docs/ref/migration-operations.txt

@@ -278,6 +278,16 @@ the model with ``model_name``.
 
 Removes the constraint named ``name`` from the model with ``model_name``.
 
+``AlterConstraint``
+-------------------
+
+.. versionadded:: 5.2
+
+.. class:: AlterConstraint(model_name, name, constraint)
+
+Alters the constraint named ``name`` of the model with ``model_name`` with the
+new ``constraint`` without affecting the database.
+
 Special Operations
 ==================
 

+ 2 - 1
docs/releases/5.2.txt

@@ -259,7 +259,8 @@ Management Commands
 Migrations
 ~~~~~~~~~~
 
-* ...
+* The new operation :class:`.AlterConstraint` is a no-op operation that alters
+  constraints without dropping and recreating constraints in the database.
 
 Models
 ~~~~~~

+ 65 - 0
tests/migrations/test_autodetector.py

@@ -2969,6 +2969,71 @@ class AutodetectorTests(BaseAutodetectorTests):
             ["CreateModel", "AddField", "AddIndex"],
         )
 
+    def test_alter_constraint(self):
+        book_constraint = models.CheckConstraint(
+            condition=models.Q(title__contains="title"),
+            name="title_contains_title",
+        )
+        book_altered_constraint = models.CheckConstraint(
+            condition=models.Q(title__contains="title"),
+            name="title_contains_title",
+            violation_error_code="error_code",
+        )
+        author_altered_constraint = models.CheckConstraint(
+            condition=models.Q(name__contains="Bob"),
+            name="name_contains_bob",
+            violation_error_message="Name doesn't contain Bob",
+        )
+
+        book_check_constraint = copy.deepcopy(self.book)
+        book_check_constraint_with_error_message = copy.deepcopy(self.book)
+        author_name_check_constraint_with_error_message = copy.deepcopy(
+            self.author_name_check_constraint
+        )
+
+        book_check_constraint.options = {"constraints": [book_constraint]}
+        book_check_constraint_with_error_message.options = {
+            "constraints": [book_altered_constraint]
+        }
+        author_name_check_constraint_with_error_message.options = {
+            "constraints": [author_altered_constraint]
+        }
+
+        changes = self.get_changes(
+            [self.author_name_check_constraint, book_check_constraint],
+            [
+                author_name_check_constraint_with_error_message,
+                book_check_constraint_with_error_message,
+            ],
+        )
+
+        self.assertNumberMigrations(changes, "testapp", 1)
+        self.assertOperationTypes(changes, "testapp", 0, ["AlterConstraint"])
+        self.assertOperationAttributes(
+            changes,
+            "testapp",
+            0,
+            0,
+            model_name="author",
+            name="name_contains_bob",
+            constraint=author_altered_constraint,
+        )
+
+        self.assertNumberMigrations(changes, "otherapp", 1)
+        self.assertOperationTypes(changes, "otherapp", 0, ["AlterConstraint"])
+        self.assertOperationAttributes(
+            changes,
+            "otherapp",
+            0,
+            0,
+            model_name="book",
+            name="title_contains_title",
+            constraint=book_altered_constraint,
+        )
+        self.assertMigrationDependencies(
+            changes, "otherapp", 0, [("testapp", "auto_1")]
+        )
+
     def test_remove_constraints(self):
         """Test change detection of removed constraints."""
         changes = self.get_changes(

+ 75 - 0
tests/migrations/test_operations.py

@@ -4366,6 +4366,81 @@ class OperationTests(OperationTestBase):
             {"model_name": "Pony", "name": "test_remove_constraint_pony_pink_gt_2"},
         )
 
+    def test_alter_constraint(self):
+        constraint = models.UniqueConstraint(
+            fields=["pink"], name="test_alter_constraint_pony_fields_uq"
+        )
+        project_state = self.set_up_test_model(
+            "test_alterconstraint", constraints=[constraint]
+        )
+
+        new_state = project_state.clone()
+        violation_error_message = "Pink isn't unique"
+        uq_constraint = models.UniqueConstraint(
+            fields=["pink"],
+            name="test_alter_constraint_pony_fields_uq",
+            violation_error_message=violation_error_message,
+        )
+        uq_operation = migrations.AlterConstraint(
+            "Pony", "test_alter_constraint_pony_fields_uq", uq_constraint
+        )
+        self.assertEqual(
+            uq_operation.describe(),
+            "Alter constraint test_alter_constraint_pony_fields_uq on Pony",
+        )
+        self.assertEqual(
+            uq_operation.formatted_description(),
+            "~ Alter constraint test_alter_constraint_pony_fields_uq on Pony",
+        )
+        self.assertEqual(
+            uq_operation.migration_name_fragment,
+            "alter_pony_test_alter_constraint_pony_fields_uq",
+        )
+
+        uq_operation.state_forwards("test_alterconstraint", new_state)
+        self.assertEqual(
+            project_state.models["test_alterconstraint", "pony"]
+            .options["constraints"][0]
+            .violation_error_message,
+            "Constraint “%(name)s” is violated.",
+        )
+        self.assertEqual(
+            new_state.models["test_alterconstraint", "pony"]
+            .options["constraints"][0]
+            .violation_error_message,
+            violation_error_message,
+        )
+
+        with connection.schema_editor() as editor, self.assertNumQueries(0):
+            uq_operation.database_forwards(
+                "test_alterconstraint", editor, project_state, new_state
+            )
+        self.assertConstraintExists(
+            "test_alterconstraint_pony",
+            "test_alter_constraint_pony_fields_uq",
+            value=False,
+        )
+        with connection.schema_editor() as editor, self.assertNumQueries(0):
+            uq_operation.database_backwards(
+                "test_alterconstraint", editor, project_state, new_state
+            )
+        self.assertConstraintExists(
+            "test_alterconstraint_pony",
+            "test_alter_constraint_pony_fields_uq",
+            value=False,
+        )
+        definition = uq_operation.deconstruct()
+        self.assertEqual(definition[0], "AlterConstraint")
+        self.assertEqual(definition[1], [])
+        self.assertEqual(
+            definition[2],
+            {
+                "model_name": "Pony",
+                "name": "test_alter_constraint_pony_fields_uq",
+                "constraint": uq_constraint,
+            },
+        )
+
     def test_add_partial_unique_constraint(self):
         project_state = self.set_up_test_model("test_addpartialuniqueconstraint")
         partial_unique_constraint = models.UniqueConstraint(

+ 74 - 0
tests/migrations/test_optimizer.py

@@ -1232,6 +1232,80 @@ class OptimizerTests(OptimizerTestBase):
             ],
         )
 
+    def test_multiple_alter_constraints(self):
+        gt_constraint_violation_msg_added = models.CheckConstraint(
+            condition=models.Q(pink__gt=2),
+            name="pink_gt_2",
+            violation_error_message="ERROR",
+        )
+        gt_constraint_violation_msg_altered = models.CheckConstraint(
+            condition=models.Q(pink__gt=2),
+            name="pink_gt_2",
+            violation_error_message="error",
+        )
+        self.assertOptimizesTo(
+            [
+                migrations.AlterConstraint(
+                    "Pony", "pink_gt_2", gt_constraint_violation_msg_added
+                ),
+                migrations.AlterConstraint(
+                    "Pony", "pink_gt_2", gt_constraint_violation_msg_altered
+                ),
+            ],
+            [
+                migrations.AlterConstraint(
+                    "Pony", "pink_gt_2", gt_constraint_violation_msg_altered
+                )
+            ],
+        )
+        other_constraint_violation_msg = models.CheckConstraint(
+            condition=models.Q(weight__gt=3),
+            name="pink_gt_3",
+            violation_error_message="error",
+        )
+        self.assertDoesNotOptimize(
+            [
+                migrations.AlterConstraint(
+                    "Pony", "pink_gt_2", gt_constraint_violation_msg_added
+                ),
+                migrations.AlterConstraint(
+                    "Pony", "pink_gt_3", other_constraint_violation_msg
+                ),
+            ]
+        )
+
+    def test_alter_remove_constraint(self):
+        self.assertOptimizesTo(
+            [
+                migrations.AlterConstraint(
+                    "Pony",
+                    "pink_gt_2",
+                    models.CheckConstraint(
+                        condition=models.Q(pink__gt=2), name="pink_gt_2"
+                    ),
+                ),
+                migrations.RemoveConstraint("Pony", "pink_gt_2"),
+            ],
+            [migrations.RemoveConstraint("Pony", "pink_gt_2")],
+        )
+
+    def test_add_alter_constraint(self):
+        constraint = models.CheckConstraint(
+            condition=models.Q(pink__gt=2), name="pink_gt_2"
+        )
+        constraint_with_error = models.CheckConstraint(
+            condition=models.Q(pink__gt=2),
+            name="pink_gt_2",
+            violation_error_message="error",
+        )
+        self.assertOptimizesTo(
+            [
+                migrations.AddConstraint("Pony", constraint),
+                migrations.AlterConstraint("Pony", "pink_gt_2", constraint_with_error),
+            ],
+            [migrations.AddConstraint("Pony", constraint_with_error)],
+        )
+
     def test_create_model_add_index(self):
         self.assertOptimizesTo(
             [