Browse Source

[5.0.x] Refs #34944 -- Propagated system checks for GeneratedField.output_field.

Backport of c705625ebff0141ed2b95dd3c8174bda8270a47f from main
Mariusz Felisiak 1 year ago
parent
commit
fcc55f8c26

+ 2 - 0
django/db/models/fields/__init__.py

@@ -286,6 +286,8 @@ class Field(RegisterLookupMixin):
         Check if field name is valid, i.e. 1) does not end with an
         underscore, 2) does not contain "__" and 3) is not "pk".
         """
+        if self.name is None:
+            return []
         if self.name.endswith("_"):
             return [
                 checks.Error(

+ 35 - 1
django/db/models/fields/generated.py

@@ -63,11 +63,45 @@ class GeneratedField(Field):
 
     def check(self, **kwargs):
         databases = kwargs.get("databases") or []
-        return [
+        errors = [
             *super().check(**kwargs),
             *self._check_supported(databases),
             *self._check_persistence(databases),
         ]
+        output_field_clone = self.output_field.clone()
+        output_field_clone.model = self.model
+        output_field_checks = output_field_clone.check(databases=databases)
+        if output_field_checks:
+            separator = "\n    "
+            error_messages = separator.join(
+                f"{output_check.msg} ({output_check.id})"
+                for output_check in output_field_checks
+                if isinstance(output_check, checks.Error)
+            )
+            if error_messages:
+                errors.append(
+                    checks.Error(
+                        "GeneratedField.output_field has errors:"
+                        f"{separator}{error_messages}",
+                        obj=self,
+                        id="fields.E223",
+                    )
+                )
+            warning_messages = separator.join(
+                f"{output_check.msg} ({output_check.id})"
+                for output_check in output_field_checks
+                if isinstance(output_check, checks.Warning)
+            )
+            if warning_messages:
+                errors.append(
+                    checks.Warning(
+                        "GeneratedField.output_field has warnings:"
+                        f"{separator}{warning_messages}",
+                        obj=self,
+                        id="fields.W224",
+                    )
+                )
+        return errors
 
     def _check_supported(self, databases):
         errors = []

+ 2 - 0
docs/ref/checks.txt

@@ -213,6 +213,8 @@ Model fields
   ``GeneratedField``\s.
 * **fields.E222**: ``<database>`` does not support persisted
   ``GeneratedField``\s.
+* **fields.E223**: ``GeneratedField.output_field`` has errors: ...
+* **fields.W224**: ``GeneratedField.output_field`` has warnings: ...
 * **fields.E900**: ``IPAddressField`` has been removed except for support in
   historical migrations.
 * **fields.W900**: ``IPAddressField`` has been deprecated. Support for it

+ 77 - 1
tests/invalid_models_tests/test_ordinary_fields.py

@@ -4,7 +4,7 @@ import uuid
 from django.core.checks import Error
 from django.core.checks import Warning as DjangoWarning
 from django.db import connection, models
-from django.db.models.functions import Coalesce, Pi
+from django.db.models.functions import Coalesce, LPad, Pi
 from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import isolate_apps, override_settings
 from django.utils.functional import lazy
@@ -1336,3 +1336,79 @@ class GeneratedFieldTests(TestCase):
             Model._meta.get_field("field").check(databases={"default"}),
             expected_errors,
         )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_check_error(self):
+        class Model(models.Model):
+            value = models.DecimalField(max_digits=5, decimal_places=2)
+            field = models.GeneratedField(
+                expression=models.F("value") * 2,
+                output_field=models.DecimalField(max_digits=-1, decimal_places=-1),
+                db_persist=True,
+            )
+
+        expected_errors = [
+            Error(
+                "GeneratedField.output_field has errors:"
+                "\n    'decimal_places' must be a non-negative integer. (fields.E131)"
+                "\n    'max_digits' must be a positive integer. (fields.E133)",
+                obj=Model._meta.get_field("field"),
+                id="fields.E223",
+            ),
+        ]
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_errors,
+        )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_charfield_unlimited_error(self):
+        class Model(models.Model):
+            name = models.CharField(max_length=255)
+            field = models.GeneratedField(
+                expression=LPad("name", 7, models.Value("xy")),
+                output_field=models.CharField(),
+                db_persist=True,
+            )
+
+        expected_errors = (
+            []
+            if connection.features.supports_unlimited_charfield
+            else [
+                Error(
+                    "GeneratedField.output_field has errors:"
+                    "\n    CharFields must define a 'max_length' attribute. "
+                    "(fields.E120)",
+                    obj=Model._meta.get_field("field"),
+                    id="fields.E223",
+                ),
+            ]
+        )
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_errors,
+        )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_check_warning(self):
+        class Model(models.Model):
+            value = models.IntegerField()
+            field = models.GeneratedField(
+                expression=models.F("value") * 2,
+                output_field=models.IntegerField(max_length=40),
+                db_persist=True,
+            )
+
+        expected_warnings = [
+            DjangoWarning(
+                "GeneratedField.output_field has warnings:"
+                "\n    'max_length' is ignored when used with IntegerField. "
+                "(fields.W122)",
+                obj=Model._meta.get_field("field"),
+                id="fields.W224",
+            ),
+        ]
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_warnings,
+        )