2
0
Эх сурвалжийг харах

Refs #34944 -- Propagated system checks for GeneratedField.output_field.

Mariusz Felisiak 1 жил өмнө
parent
commit
c705625ebf

+ 2 - 0
django/db/models/fields/__init__.py

@@ -286,6 +286,8 @@ class Field(RegisterLookupMixin):
         Check if field name is valid, i.e. 1) does not end with an
         Check if field name is valid, i.e. 1) does not end with an
         underscore, 2) does not contain "__" and 3) is not "pk".
         underscore, 2) does not contain "__" and 3) is not "pk".
         """
         """
+        if self.name is None:
+            return []
         if self.name.endswith("_"):
         if self.name.endswith("_"):
             return [
             return [
                 checks.Error(
                 checks.Error(

+ 35 - 1
django/db/models/fields/generated.py

@@ -63,11 +63,45 @@ class GeneratedField(Field):
 
 
     def check(self, **kwargs):
     def check(self, **kwargs):
         databases = kwargs.get("databases") or []
         databases = kwargs.get("databases") or []
-        return [
+        errors = [
             *super().check(**kwargs),
             *super().check(**kwargs),
             *self._check_supported(databases),
             *self._check_supported(databases),
             *self._check_persistence(databases),
             *self._check_persistence(databases),
         ]
         ]
+        output_field_clone = self.output_field.clone()
+        output_field_clone.model = self.model
+        output_field_checks = output_field_clone.check(databases=databases)
+        if output_field_checks:
+            separator = "\n    "
+            error_messages = separator.join(
+                f"{output_check.msg} ({output_check.id})"
+                for output_check in output_field_checks
+                if isinstance(output_check, checks.Error)
+            )
+            if error_messages:
+                errors.append(
+                    checks.Error(
+                        "GeneratedField.output_field has errors:"
+                        f"{separator}{error_messages}",
+                        obj=self,
+                        id="fields.E223",
+                    )
+                )
+            warning_messages = separator.join(
+                f"{output_check.msg} ({output_check.id})"
+                for output_check in output_field_checks
+                if isinstance(output_check, checks.Warning)
+            )
+            if warning_messages:
+                errors.append(
+                    checks.Warning(
+                        "GeneratedField.output_field has warnings:"
+                        f"{separator}{warning_messages}",
+                        obj=self,
+                        id="fields.W224",
+                    )
+                )
+        return errors
 
 
     def _check_supported(self, databases):
     def _check_supported(self, databases):
         errors = []
         errors = []

+ 2 - 0
docs/ref/checks.txt

@@ -213,6 +213,8 @@ Model fields
   ``GeneratedField``\s.
   ``GeneratedField``\s.
 * **fields.E222**: ``<database>`` does not support persisted
 * **fields.E222**: ``<database>`` does not support persisted
   ``GeneratedField``\s.
   ``GeneratedField``\s.
+* **fields.E223**: ``GeneratedField.output_field`` has errors: ...
+* **fields.W224**: ``GeneratedField.output_field`` has warnings: ...
 * **fields.E900**: ``IPAddressField`` has been removed except for support in
 * **fields.E900**: ``IPAddressField`` has been removed except for support in
   historical migrations.
   historical migrations.
 * **fields.W900**: ``IPAddressField`` has been deprecated. Support for it
 * **fields.W900**: ``IPAddressField`` has been deprecated. Support for it

+ 77 - 1
tests/invalid_models_tests/test_ordinary_fields.py

@@ -4,7 +4,7 @@ import uuid
 from django.core.checks import Error
 from django.core.checks import Error
 from django.core.checks import Warning as DjangoWarning
 from django.core.checks import Warning as DjangoWarning
 from django.db import connection, models
 from django.db import connection, models
-from django.db.models.functions import Coalesce, Pi
+from django.db.models.functions import Coalesce, LPad, Pi
 from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import isolate_apps, override_settings
 from django.test.utils import isolate_apps, override_settings
 from django.utils.functional import lazy
 from django.utils.functional import lazy
@@ -1336,3 +1336,79 @@ class GeneratedFieldTests(TestCase):
             Model._meta.get_field("field").check(databases={"default"}),
             Model._meta.get_field("field").check(databases={"default"}),
             expected_errors,
             expected_errors,
         )
         )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_check_error(self):
+        class Model(models.Model):
+            value = models.DecimalField(max_digits=5, decimal_places=2)
+            field = models.GeneratedField(
+                expression=models.F("value") * 2,
+                output_field=models.DecimalField(max_digits=-1, decimal_places=-1),
+                db_persist=True,
+            )
+
+        expected_errors = [
+            Error(
+                "GeneratedField.output_field has errors:"
+                "\n    'decimal_places' must be a non-negative integer. (fields.E131)"
+                "\n    'max_digits' must be a positive integer. (fields.E133)",
+                obj=Model._meta.get_field("field"),
+                id="fields.E223",
+            ),
+        ]
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_errors,
+        )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_charfield_unlimited_error(self):
+        class Model(models.Model):
+            name = models.CharField(max_length=255)
+            field = models.GeneratedField(
+                expression=LPad("name", 7, models.Value("xy")),
+                output_field=models.CharField(),
+                db_persist=True,
+            )
+
+        expected_errors = (
+            []
+            if connection.features.supports_unlimited_charfield
+            else [
+                Error(
+                    "GeneratedField.output_field has errors:"
+                    "\n    CharFields must define a 'max_length' attribute. "
+                    "(fields.E120)",
+                    obj=Model._meta.get_field("field"),
+                    id="fields.E223",
+                ),
+            ]
+        )
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_errors,
+        )
+
+    @skipUnlessDBFeature("supports_stored_generated_columns")
+    def test_output_field_check_warning(self):
+        class Model(models.Model):
+            value = models.IntegerField()
+            field = models.GeneratedField(
+                expression=models.F("value") * 2,
+                output_field=models.IntegerField(max_length=40),
+                db_persist=True,
+            )
+
+        expected_warnings = [
+            DjangoWarning(
+                "GeneratedField.output_field has warnings:"
+                "\n    'max_length' is ignored when used with IntegerField. "
+                "(fields.W122)",
+                obj=Model._meta.get_field("field"),
+                id="fields.W224",
+            ),
+        ]
+        self.assertEqual(
+            Model._meta.get_field("field").check(databases={"default"}),
+            expected_warnings,
+        )