ソースを参照

Fixed #34944 -- Made GeneratedField.output_field required.

Regression in f333e3513e8bdf5ffeb6eeb63021c230082e6f95.
Mariusz Felisiak 1 年間 前
コミット
5875f03ce6

+ 3 - 9
django/db/models/fields/generated.py

@@ -16,7 +16,7 @@ class GeneratedField(Field):
     _resolved_expression = None
     output_field = None
 
-    def __init__(self, *, expression, db_persist=None, output_field=None, **kwargs):
+    def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
         if kwargs.setdefault("editable", False):
             raise ValueError("GeneratedField cannot be editable.")
         if not kwargs.setdefault("blank", True):
@@ -29,7 +29,7 @@ class GeneratedField(Field):
             raise ValueError("GeneratedField.db_persist must be True or False.")
 
         self.expression = expression
-        self._output_field = output_field
+        self.output_field = output_field
         self.db_persist = db_persist
         super().__init__(**kwargs)
 
@@ -51,11 +51,6 @@ class GeneratedField(Field):
         self._resolved_expression = self.expression.resolve_expression(
             self._query, allow_joins=False
         )
-        self.output_field = (
-            self._output_field
-            if self._output_field is not None
-            else self._resolved_expression.output_field
-        )
         # Register lookups from the output_field class.
         for lookup_name, lookup in self.output_field.get_class_lookups().items():
             self.register_lookup(lookup, lookup_name=lookup_name)
@@ -150,8 +145,7 @@ class GeneratedField(Field):
         del kwargs["editable"]
         kwargs["db_persist"] = self.db_persist
         kwargs["expression"] = self.expression
-        if self._output_field is not None:
-            kwargs["output_field"] = self._output_field
+        kwargs["output_field"] = self.output_field
         return name, path, args, kwargs
 
     def get_internal_type(self):

+ 5 - 7
docs/ref/models/fields.txt

@@ -1237,7 +1237,7 @@ when :attr:`~django.forms.Field.localize` is ``False`` or
 
 .. versionadded:: 5.0
 
-.. class:: GeneratedField(expression, db_persist=None, output_field=None, **kwargs)
+.. class:: GeneratedField(expression, output_field, db_persist=None, **kwargs)
 
 A field that is always computed based on other fields in the model. This field
 is managed and updated by the database itself. Uses the ``GENERATED ALWAYS``
@@ -1259,6 +1259,10 @@ materialized view.
     the model (in the same database table). Generated fields cannot reference
     other generated fields. Database backends can impose further restrictions.
 
+.. attribute:: GeneratedField.output_field
+
+    A model field instance to define the field's data type.
+
 .. attribute:: GeneratedField.db_persist
 
     Determines if the database column should occupy storage as if it were a
@@ -1268,12 +1272,6 @@ materialized view.
     PostgreSQL only supports persisted columns. Oracle only supports virtual
     columns.
 
-.. attribute:: GeneratedField.output_field
-
-    An optional model field instance to define the field's data type. This can
-    be used to customize attributes like the field's collation. By default, the
-    output field is derived from ``expression``.
-
 .. admonition:: Refresh the data
 
     Since the database always computed the value, the object must be reloaded

+ 5 - 1
docs/releases/5.0.txt

@@ -142,7 +142,11 @@ to create a field that is always computed from other fields. For example::
 
     class Square(models.Model):
         side = models.IntegerField()
-        area = models.GeneratedField(expression=F("side") * F("side"), db_persist=True)
+        area = models.GeneratedField(
+            expression=F("side") * F("side"),
+            output_field=models.BigIntegerField(),
+            db_persist=True,
+        )
 
 More options for declaring field choices
 ----------------------------------------

+ 1 - 0
tests/admin_views/models.py

@@ -1147,6 +1147,7 @@ class Square(models.Model):
     area = models.GeneratedField(
         db_persist=True,
         expression=models.F("side") * models.F("side"),
+        output_field=models.BigIntegerField(),
     )
 
     class Meta:

+ 23 - 5
tests/invalid_models_tests/test_ordinary_fields.py

@@ -1216,7 +1216,9 @@ class GeneratedFieldTests(TestCase):
         class Model(models.Model):
             name = models.IntegerField()
             field = models.GeneratedField(
-                expression=models.F("name"), db_persist=db_persist
+                expression=models.F("name"),
+                output_field=models.IntegerField(),
+                db_persist=db_persist,
             )
 
         expected_errors = []
@@ -1252,7 +1254,11 @@ class GeneratedFieldTests(TestCase):
     def test_not_supported_stored_required_db_features(self):
         class Model(models.Model):
             name = models.IntegerField()
-            field = models.GeneratedField(expression=models.F("name"), db_persist=True)
+            field = models.GeneratedField(
+                expression=models.F("name"),
+                output_field=models.IntegerField(),
+                db_persist=True,
+            )
 
             class Meta:
                 required_db_features = {"supports_stored_generated_columns"}
@@ -1262,7 +1268,11 @@ class GeneratedFieldTests(TestCase):
     def test_not_supported_virtual_required_db_features(self):
         class Model(models.Model):
             name = models.IntegerField()
-            field = models.GeneratedField(expression=models.F("name"), db_persist=False)
+            field = models.GeneratedField(
+                expression=models.F("name"),
+                output_field=models.IntegerField(),
+                db_persist=False,
+            )
 
             class Meta:
                 required_db_features = {"supports_virtual_generated_columns"}
@@ -1273,7 +1283,11 @@ class GeneratedFieldTests(TestCase):
     def test_not_supported_virtual(self):
         class Model(models.Model):
             name = models.IntegerField()
-            field = models.GeneratedField(expression=models.F("name"), db_persist=False)
+            field = models.GeneratedField(
+                expression=models.F("name"),
+                output_field=models.IntegerField(),
+                db_persist=False,
+            )
             a = models.TextField()
 
         excepted_errors = (
@@ -1298,7 +1312,11 @@ class GeneratedFieldTests(TestCase):
     def test_not_supported_stored(self):
         class Model(models.Model):
             name = models.IntegerField()
-            field = models.GeneratedField(expression=models.F("name"), db_persist=True)
+            field = models.GeneratedField(
+                expression=models.F("name"),
+                output_field=models.IntegerField(),
+                db_persist=True,
+            )
             a = models.TextField()
 
         expected_errors = (

+ 22 - 6
tests/migrations/test_operations.py

@@ -5664,10 +5664,14 @@ class OperationTests(OperationTestBase):
     def _test_invalid_generated_field_changes(self, db_persist):
         regular = models.IntegerField(default=1)
         generated_1 = models.GeneratedField(
-            expression=F("pink") + F("pink"), db_persist=db_persist
+            expression=F("pink") + F("pink"),
+            output_field=models.IntegerField(),
+            db_persist=db_persist,
         )
         generated_2 = models.GeneratedField(
-            expression=F("pink") + F("pink") + F("pink"), db_persist=db_persist
+            expression=F("pink") + F("pink") + F("pink"),
+            output_field=models.IntegerField(),
+            db_persist=db_persist,
         )
         tests = [
             ("test_igfc_1", regular, generated_1),
@@ -5707,12 +5711,20 @@ class OperationTests(OperationTestBase):
             migrations.AddField(
                 "Pony",
                 "modified_pink",
-                models.GeneratedField(expression=F("pink"), db_persist=True),
+                models.GeneratedField(
+                    expression=F("pink"),
+                    output_field=models.IntegerField(),
+                    db_persist=True,
+                ),
             ),
             migrations.AlterField(
                 "Pony",
                 "modified_pink",
-                models.GeneratedField(expression=F("pink"), db_persist=False),
+                models.GeneratedField(
+                    expression=F("pink"),
+                    output_field=models.IntegerField(),
+                    db_persist=False,
+                ),
             ),
         ]
         msg = (
@@ -5729,7 +5741,9 @@ class OperationTests(OperationTestBase):
             "Pony",
             "modified_pink",
             models.GeneratedField(
-                expression=F("pink") + F("pink"), db_persist=db_persist
+                expression=F("pink") + F("pink"),
+                output_field=models.IntegerField(),
+                db_persist=db_persist,
             ),
         )
         project_state, new_state = self.make_test_state(app_label, operation)
@@ -5760,7 +5774,9 @@ class OperationTests(OperationTestBase):
             "Pony",
             "modified_pink",
             models.GeneratedField(
-                expression=F("pink") + F("pink"), db_persist=db_persist
+                expression=F("pink") + F("pink"),
+                output_field=models.IntegerField(),
+                db_persist=db_persist,
             ),
         )
         project_state, new_state = self.make_test_state(app_label, operation)

+ 22 - 6
tests/model_fields/models.py

@@ -485,7 +485,11 @@ class UUIDGrandchild(UUIDChild):
 class GeneratedModel(models.Model):
     a = models.IntegerField()
     b = models.IntegerField()
-    field = models.GeneratedField(expression=F("a") + F("b"), db_persist=True)
+    field = models.GeneratedField(
+        expression=F("a") + F("b"),
+        output_field=models.IntegerField(),
+        db_persist=True,
+    )
 
     class Meta:
         required_db_features = {"supports_stored_generated_columns"}
@@ -494,7 +498,11 @@ class GeneratedModel(models.Model):
 class GeneratedModelVirtual(models.Model):
     a = models.IntegerField()
     b = models.IntegerField()
-    field = models.GeneratedField(expression=F("a") + F("b"), db_persist=False)
+    field = models.GeneratedField(
+        expression=F("a") + F("b"),
+        output_field=models.IntegerField(),
+        db_persist=False,
+    )
 
     class Meta:
         required_db_features = {"supports_virtual_generated_columns"}
@@ -503,6 +511,7 @@ class GeneratedModelVirtual(models.Model):
 class GeneratedModelParams(models.Model):
     field = models.GeneratedField(
         expression=Value("Constant", output_field=models.CharField(max_length=10)),
+        output_field=models.CharField(max_length=10),
         db_persist=True,
     )
 
@@ -513,6 +522,7 @@ class GeneratedModelParams(models.Model):
 class GeneratedModelParamsVirtual(models.Model):
     field = models.GeneratedField(
         expression=Value("Constant", output_field=models.CharField(max_length=10)),
+        output_field=models.CharField(max_length=10),
         db_persist=False,
     )
 
@@ -520,7 +530,7 @@ class GeneratedModelParamsVirtual(models.Model):
         required_db_features = {"supports_virtual_generated_columns"}
 
 
-class GeneratedModelOutputField(models.Model):
+class GeneratedModelOutputFieldDbCollation(models.Model):
     name = models.CharField(max_length=10)
     lower_name = models.GeneratedField(
         expression=Lower("name"),
@@ -532,7 +542,7 @@ class GeneratedModelOutputField(models.Model):
         required_db_features = {"supports_stored_generated_columns"}
 
 
-class GeneratedModelOutputFieldVirtual(models.Model):
+class GeneratedModelOutputFieldDbCollationVirtual(models.Model):
     name = models.CharField(max_length=10)
     lower_name = models.GeneratedField(
         expression=Lower("name"),
@@ -547,7 +557,10 @@ class GeneratedModelOutputFieldVirtual(models.Model):
 class GeneratedModelNull(models.Model):
     name = models.CharField(max_length=10, null=True)
     lower_name = models.GeneratedField(
-        expression=Lower("name"), db_persist=True, null=True
+        expression=Lower("name"),
+        output_field=models.CharField(max_length=10),
+        db_persist=True,
+        null=True,
     )
 
     class Meta:
@@ -557,7 +570,10 @@ class GeneratedModelNull(models.Model):
 class GeneratedModelNullVirtual(models.Model):
     name = models.CharField(max_length=10, null=True)
     lower_name = models.GeneratedField(
-        expression=Lower("name"), db_persist=False, null=True
+        expression=Lower("name"),
+        output_field=models.CharField(max_length=10),
+        db_persist=False,
+        null=True,
     )
 
     class Meta:

+ 63 - 18
tests/model_fields/test_generatedfield.py

@@ -1,5 +1,12 @@
 from django.db import IntegrityError, connection
-from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
+from django.db.models import (
+    CharField,
+    F,
+    FloatField,
+    GeneratedField,
+    IntegerField,
+    Model,
+)
 from django.db.models.functions import Lower
 from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
 from django.test.utils import isolate_apps
@@ -8,8 +15,8 @@ from .models import (
     GeneratedModel,
     GeneratedModelNull,
     GeneratedModelNullVirtual,
-    GeneratedModelOutputField,
-    GeneratedModelOutputFieldVirtual,
+    GeneratedModelOutputFieldDbCollation,
+    GeneratedModelOutputFieldDbCollationVirtual,
     GeneratedModelParams,
     GeneratedModelParamsVirtual,
     GeneratedModelVirtual,
@@ -19,41 +26,77 @@ from .models import (
 class BaseGeneratedFieldTests(SimpleTestCase):
     def test_editable_unsupported(self):
         with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."):
-            GeneratedField(expression=Lower("name"), editable=True, db_persist=False)
+            GeneratedField(
+                expression=Lower("name"),
+                output_field=CharField(max_length=255),
+                editable=True,
+                db_persist=False,
+            )
 
     def test_blank_unsupported(self):
         with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."):
-            GeneratedField(expression=Lower("name"), blank=False, db_persist=False)
+            GeneratedField(
+                expression=Lower("name"),
+                output_field=CharField(max_length=255),
+                blank=False,
+                db_persist=False,
+            )
 
     def test_default_unsupported(self):
         msg = "GeneratedField cannot have a default."
         with self.assertRaisesMessage(ValueError, msg):
-            GeneratedField(expression=Lower("name"), default="", db_persist=False)
+            GeneratedField(
+                expression=Lower("name"),
+                output_field=CharField(max_length=255),
+                default="",
+                db_persist=False,
+            )
 
     def test_database_default_unsupported(self):
         msg = "GeneratedField cannot have a database default."
         with self.assertRaisesMessage(ValueError, msg):
-            GeneratedField(expression=Lower("name"), db_default="", db_persist=False)
+            GeneratedField(
+                expression=Lower("name"),
+                output_field=CharField(max_length=255),
+                db_default="",
+                db_persist=False,
+            )
 
     def test_db_persist_required(self):
         msg = "GeneratedField.db_persist must be True or False."
         with self.assertRaisesMessage(ValueError, msg):
-            GeneratedField(expression=Lower("name"))
+            GeneratedField(
+                expression=Lower("name"), output_field=CharField(max_length=255)
+            )
         with self.assertRaisesMessage(ValueError, msg):
-            GeneratedField(expression=Lower("name"), db_persist=None)
+            GeneratedField(
+                expression=Lower("name"),
+                output_field=CharField(max_length=255),
+                db_persist=None,
+            )
 
     def test_deconstruct(self):
-        field = GeneratedField(expression=F("a") + F("b"), db_persist=True)
+        field = GeneratedField(
+            expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
+        )
         _, path, args, kwargs = field.deconstruct()
         self.assertEqual(path, "django.db.models.GeneratedField")
         self.assertEqual(args, [])
-        self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
+        self.assertEqual(kwargs["db_persist"], True)
+        self.assertEqual(kwargs["expression"], F("a") + F("b"))
+        self.assertEqual(
+            kwargs["output_field"].deconstruct(), IntegerField().deconstruct()
+        )
 
     @isolate_apps("model_fields")
     def test_get_col(self):
         class Square(Model):
             side = IntegerField()
-            area = GeneratedField(expression=F("side") * F("side"), db_persist=True)
+            area = GeneratedField(
+                expression=F("side") * F("side"),
+                output_field=IntegerField(),
+                db_persist=True,
+            )
 
         col = Square._meta.get_field("area").get_col("alias")
         self.assertIsInstance(col.output_field, IntegerField)
@@ -74,7 +117,9 @@ class BaseGeneratedFieldTests(SimpleTestCase):
         class Sum(Model):
             a = IntegerField()
             b = IntegerField()
-            total = GeneratedField(expression=F("a") + F("b"), db_persist=True)
+            total = GeneratedField(
+                expression=F("a") + F("b"), output_field=IntegerField(), db_persist=True
+            )
 
         field = Sum._meta.get_field("total")
         cached_col = field.cached_col
@@ -165,9 +210,9 @@ class GeneratedFieldTestMixin:
         with self.assertNumQueries(0), self.assertRaises(does_not_exist):
             self.base_model.objects.get(field__gte=overflow_value)
 
-    def test_output_field(self):
+    def test_output_field_db_collation(self):
         collation = connection.features.test_collations["virtual"]
-        m = self.output_field_model.objects.create(name="NAME")
+        m = self.output_field_db_collation_model.objects.create(name="NAME")
         field = m._meta.get_field("lower_name")
         db_parameters = field.db_parameters(connection)
         self.assertEqual(db_parameters["collation"], collation)
@@ -178,7 +223,7 @@ class GeneratedFieldTestMixin:
         )
 
     def test_db_type_parameters(self):
-        db_type_parameters = self.output_field_model._meta.get_field(
+        db_type_parameters = self.output_field_db_collation_model._meta.get_field(
             "lower_name"
         ).db_type_parameters(connection)
         self.assertEqual(db_type_parameters["max_length"], 11)
@@ -202,7 +247,7 @@ class GeneratedFieldTestMixin:
 class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
     base_model = GeneratedModel
     nullable_model = GeneratedModelNull
-    output_field_model = GeneratedModelOutputField
+    output_field_db_collation_model = GeneratedModelOutputFieldDbCollation
     params_model = GeneratedModelParams
 
 
@@ -210,5 +255,5 @@ class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
 class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
     base_model = GeneratedModelVirtual
     nullable_model = GeneratedModelNullVirtual
-    output_field_model = GeneratedModelOutputFieldVirtual
+    output_field_db_collation_model = GeneratedModelOutputFieldDbCollationVirtual
     params_model = GeneratedModelParamsVirtual

+ 6 - 2
tests/schema/tests.py

@@ -829,7 +829,11 @@ class SchemaTests(TransactionTestCase):
     def test_add_generated_field_with_kt_model(self):
         class GeneratedFieldKTModel(Model):
             data = JSONField()
-            status = GeneratedField(expression=KT("data__status"), db_persist=True)
+            status = GeneratedField(
+                expression=KT("data__status"),
+                output_field=TextField(),
+                db_persist=True,
+            )
 
             class Meta:
                 app_label = "schema"
@@ -844,7 +848,7 @@ class SchemaTests(TransactionTestCase):
 
     @isolate_apps("schema")
     @skipUnlessDBFeature("supports_stored_generated_columns")
-    def test_add_generated_field_with_output_field(self):
+    def test_add_generated_field(self):
         class GeneratedFieldOutputFieldModel(Model):
             price = DecimalField(max_digits=7, decimal_places=2)
             vat_price = GeneratedField(