2
0
Эх сурвалжийг харах

Fixed #31530 -- Added system checks for invalid model field names in CheckConstraint.check and UniqueConstraint.condition.

Hasan Ramezani 4 жил өмнө
parent
commit
b7b7df5fbc

+ 63 - 3
django/db/models/base.py

@@ -28,7 +28,7 @@ from django.db.models.fields.related import (
 from django.db.models.functions import Coalesce
 from django.db.models.manager import Manager
 from django.db.models.options import Options
-from django.db.models.query import Q
+from django.db.models.query import F, Q
 from django.db.models.signals import (
     class_prepared, post_init, post_save, pre_init, pre_save,
 )
@@ -1878,6 +1878,22 @@ class Model(metaclass=ModelBase):
 
         return errors
 
+    @classmethod
+    def _get_expr_references(cls, expr):
+        if isinstance(expr, Q):
+            for child in expr.children:
+                if isinstance(child, tuple):
+                    lookup, value = child
+                    yield tuple(lookup.split(LOOKUP_SEP))
+                    yield from cls._get_expr_references(value)
+                else:
+                    yield from cls._get_expr_references(child)
+        elif isinstance(expr, F):
+            yield tuple(expr.name.split(LOOKUP_SEP))
+        elif hasattr(expr, 'get_source_expressions'):
+            for src_expr in expr.get_source_expressions():
+                yield from cls._get_expr_references(src_expr)
+
     @classmethod
     def _check_constraints(cls, databases):
         errors = []
@@ -1960,10 +1976,54 @@ class Model(metaclass=ModelBase):
                         id='models.W039',
                     )
                 )
-            fields = chain.from_iterable(
+            fields = set(chain.from_iterable(
                 (*constraint.fields, *constraint.include)
                 for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint)
-            )
+            ))
+            references = set()
+            for constraint in cls._meta.constraints:
+                if isinstance(constraint, UniqueConstraint):
+                    if (
+                        connection.features.supports_partial_indexes or
+                        'supports_partial_indexes' not in cls._meta.required_db_features
+                    ) and isinstance(constraint.condition, Q):
+                        references.update(cls._get_expr_references(constraint.condition))
+                elif isinstance(constraint, CheckConstraint):
+                    if (
+                        connection.features.supports_table_check_constraints or
+                        'supports_table_check_constraints' not in cls._meta.required_db_features
+                    ) and isinstance(constraint.check, Q):
+                        references.update(cls._get_expr_references(constraint.check))
+            for field_name, *lookups in references:
+                # pk is an alias that won't be found by opts.get_field.
+                if field_name != 'pk':
+                    fields.add(field_name)
+                if not lookups:
+                    # If it has no lookups it cannot result in a JOIN.
+                    continue
+                try:
+                    if field_name == 'pk':
+                        field = cls._meta.pk
+                    else:
+                        field = cls._meta.get_field(field_name)
+                    if not field.is_relation or field.many_to_many or field.one_to_many:
+                        continue
+                except FieldDoesNotExist:
+                    continue
+                # JOIN must happen at the first lookup.
+                first_lookup = lookups[0]
+                if (
+                    field.get_transform(first_lookup) is None and
+                    field.get_lookup(first_lookup) is None
+                ):
+                    errors.append(
+                        checks.Error(
+                            "'constraints' refers to the joined field '%s'."
+                            % LOOKUP_SEP.join([field_name] + lookups),
+                            obj=cls,
+                            id='models.E041',
+                        )
+                    )
             errors.extend(cls._check_local_fields(fields, 'constraints'))
         return errors
 

+ 1 - 0
docs/ref/checks.txt

@@ -364,6 +364,7 @@ Models
   non-key columns.
 * **models.W040**: ``<database>`` does not support indexes with non-key
   columns.
+* **models.E041**: ``constraints`` refers to the joined field ``<field name>``.
 
 Security
 --------

+ 232 - 0
tests/invalid_models_tests/test_models.py

@@ -1534,6 +1534,192 @@ class ConstraintsTests(TestCase):
                 constraints = [models.CheckConstraint(check=models.Q(age__gte=18), name='is_adult')]
         self.assertEqual(Model.check(databases=self.databases), [])
 
+    def test_check_constraint_pointing_to_missing_field(self):
+        class Model(models.Model):
+            class Meta:
+                required_db_features = {'supports_table_check_constraints'}
+                constraints = [
+                    models.CheckConstraint(
+                        name='name', check=models.Q(missing_field=2),
+                    ),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to the nonexistent field "
+                "'missing_field'.",
+                obj=Model,
+                id='models.E012',
+            ),
+        ] if connection.features.supports_table_check_constraints else [])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_reverse_fk(self):
+        class Model(models.Model):
+            parent = models.ForeignKey('self', models.CASCADE, related_name='parents')
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(name='name', check=models.Q(parents=3)),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to the nonexistent field 'parents'.",
+                obj=Model,
+                id='models.E012',
+            ),
+        ])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_m2m_field(self):
+        class Model(models.Model):
+            m2m = models.ManyToManyField('self')
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(name='name', check=models.Q(m2m=2)),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to a ManyToManyField 'm2m', but "
+                "ManyToManyFields are not permitted in 'constraints'.",
+                obj=Model,
+                id='models.E013',
+            ),
+        ])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_fk(self):
+        class Target(models.Model):
+            pass
+
+        class Model(models.Model):
+            fk_1 = models.ForeignKey(Target, models.CASCADE, related_name='target_1')
+            fk_2 = models.ForeignKey(Target, models.CASCADE, related_name='target_2')
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(
+                        name='name',
+                        check=models.Q(fk_1_id=2) | models.Q(fk_2=2),
+                    ),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_pk(self):
+        class Model(models.Model):
+            age = models.SmallIntegerField()
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(
+                        name='name',
+                        check=models.Q(pk__gt=5) & models.Q(age__gt=models.F('pk')),
+                    ),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_non_local_field(self):
+        class Parent(models.Model):
+            field1 = models.IntegerField()
+
+        class Child(Parent):
+            pass
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(name='name', check=models.Q(field1=1)),
+                ]
+
+        self.assertEqual(Child.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to field 'field1' which is not local to "
+                "model 'Child'.",
+                hint='This issue may be caused by multi-table inheritance.',
+                obj=Child,
+                id='models.E016',
+            ),
+        ])
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_joined_fields(self):
+        class Model(models.Model):
+            name = models.CharField(max_length=10)
+            field1 = models.PositiveSmallIntegerField()
+            field2 = models.PositiveSmallIntegerField()
+            field3 = models.PositiveSmallIntegerField()
+            parent = models.ForeignKey('self', models.CASCADE)
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(
+                        name='name1', check=models.Q(
+                            field1__lt=models.F('parent__field1') + models.F('parent__field2')
+                        )
+                    ),
+                    models.CheckConstraint(
+                        name='name2', check=models.Q(name=Lower('parent__name'))
+                    ),
+                    models.CheckConstraint(
+                        name='name3', check=models.Q(parent__field3=models.F('field1'))
+                    ),
+                ]
+
+        joined_fields = ['parent__field1', 'parent__field2', 'parent__field3', 'parent__name']
+        errors = Model.check(databases=self.databases)
+        expected_errors = [
+            Error(
+                "'constraints' refers to the joined field '%s'." % field_name,
+                obj=Model,
+                id='models.E041',
+            ) for field_name in joined_fields
+        ]
+        self.assertCountEqual(errors, expected_errors)
+
+    @skipUnlessDBFeature('supports_table_check_constraints')
+    def test_check_constraint_pointing_to_joined_fields_complex_check(self):
+        class Model(models.Model):
+            name = models.PositiveSmallIntegerField()
+            field1 = models.PositiveSmallIntegerField()
+            field2 = models.PositiveSmallIntegerField()
+            parent = models.ForeignKey('self', models.CASCADE)
+
+            class Meta:
+                constraints = [
+                    models.CheckConstraint(
+                        name='name',
+                        check=models.Q(
+                            (
+                                models.Q(name='test') &
+                                models.Q(field1__lt=models.F('parent__field1'))
+                            ) |
+                            (
+                                models.Q(name__startswith=Lower('parent__name')) &
+                                models.Q(field1__gte=(
+                                    models.F('parent__field1') + models.F('parent__field2')
+                                ))
+                            )
+                        ) | (models.Q(name='test1'))
+                    ),
+                ]
+
+        joined_fields = ['parent__field1', 'parent__field2', 'parent__name']
+        errors = Model.check(databases=self.databases)
+        expected_errors = [
+            Error(
+                "'constraints' refers to the joined field '%s'." % field_name,
+                obj=Model,
+                id='models.E041',
+            ) for field_name in joined_fields
+        ]
+        self.assertCountEqual(errors, expected_errors)
+
     def test_unique_constraint_with_condition(self):
         class Model(models.Model):
             age = models.IntegerField()
@@ -1578,6 +1764,52 @@ class ConstraintsTests(TestCase):
 
         self.assertEqual(Model.check(databases=self.databases), [])
 
+    def test_unique_constraint_condition_pointing_to_missing_field(self):
+        class Model(models.Model):
+            age = models.SmallIntegerField()
+
+            class Meta:
+                required_db_features = {'supports_partial_indexes'}
+                constraints = [
+                    models.UniqueConstraint(
+                        name='name',
+                        fields=['age'],
+                        condition=models.Q(missing_field=2),
+                    ),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to the nonexistent field "
+                "'missing_field'.",
+                obj=Model,
+                id='models.E012',
+            ),
+        ] if connection.features.supports_partial_indexes else [])
+
+    def test_unique_constraint_condition_pointing_to_joined_fields(self):
+        class Model(models.Model):
+            age = models.SmallIntegerField()
+            parent = models.ForeignKey('self', models.CASCADE)
+
+            class Meta:
+                required_db_features = {'supports_partial_indexes'}
+                constraints = [
+                    models.UniqueConstraint(
+                        name='name',
+                        fields=['age'],
+                        condition=models.Q(parent__age__lt=2),
+                    ),
+                ]
+
+        self.assertEqual(Model.check(databases=self.databases), [
+            Error(
+                "'constraints' refers to the joined field 'parent__age__lt'.",
+                obj=Model,
+                id='models.E041',
+            )
+        ] if connection.features.supports_partial_indexes else [])
+
     def test_deferrable_unique_constraint(self):
         class Model(models.Model):
             age = models.IntegerField()