浏览代码

Fixed #34838 -- Corrected output_field of resolved columns for GeneratedFields.

Thanks Simon Charette for the implementation idea.
Paolo Melchiorre 1 年之前
父节点
当前提交
68d769e691
共有 2 个文件被更改,包括 47 次插入1 次删除
  1. 12 0
      django/db/models/fields/generated.py
  2. 35 1
      tests/model_fields/test_generatedfield.py

+ 12 - 0
django/db/models/fields/generated.py

@@ -1,6 +1,7 @@
 from django.core import checks
 from django.db import connections, router
 from django.db.models.sql import Query
+from django.utils.functional import cached_property
 
 from . import NOT_PROVIDED, Field
 
@@ -32,6 +33,17 @@ class GeneratedField(Field):
         self.db_persist = db_persist
         super().__init__(**kwargs)
 
+    @cached_property
+    def cached_col(self):
+        from django.db.models.expressions import Col
+
+        return Col(self.model._meta.db_table, self, self.output_field)
+
+    def get_col(self, alias, output_field=None):
+        if alias != self.model._meta.db_table and output_field is None:
+            output_field = self.output_field
+        return super().get_col(alias, output_field)
+
     def contribute_to_class(self, *args, **kwargs):
         super().contribute_to_class(*args, **kwargs)
 

+ 35 - 1
tests/model_fields/test_generatedfield.py

@@ -1,6 +1,6 @@
 from django.core.exceptions import FieldError
 from django.db import IntegrityError, connection
-from django.db.models import F, GeneratedField, IntegerField
+from django.db.models import F, FloatField, GeneratedField, IntegerField, Model
 from django.db.models.functions import Lower
 from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
 
@@ -49,6 +49,40 @@ class BaseGeneratedFieldTests(SimpleTestCase):
         self.assertEqual(args, [])
         self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
 
+    def test_get_col(self):
+        class Square(Model):
+            side = IntegerField()
+            area = GeneratedField(expression=F("side") * F("side"), db_persist=True)
+
+        col = Square._meta.get_field("area").get_col("alias")
+        self.assertIsInstance(col.output_field, IntegerField)
+
+        class FloatSquare(Model):
+            side = IntegerField()
+            area = GeneratedField(
+                expression=F("side") * F("side"),
+                db_persist=True,
+                output_field=FloatField(),
+            )
+
+        col = FloatSquare._meta.get_field("area").get_col("alias")
+        self.assertIsInstance(col.output_field, FloatField)
+
+    def test_cached_col(self):
+        class Sum(Model):
+            a = IntegerField()
+            b = IntegerField()
+            total = GeneratedField(expression=F("a") + F("b"), db_persist=True)
+
+        field = Sum._meta.get_field("total")
+        cached_col = field.cached_col
+        self.assertIs(field.get_col(Sum._meta.db_table), cached_col)
+        self.assertIs(field.get_col(Sum._meta.db_table, field), cached_col)
+        self.assertIsNot(field.get_col("alias"), cached_col)
+        self.assertIsNot(field.get_col(Sum._meta.db_table, IntegerField()), cached_col)
+        self.assertIs(cached_col.target, field)
+        self.assertIsInstance(cached_col.output_field, IntegerField)
+
 
 class GeneratedFieldTestMixin:
     def _refresh_if_needed(self, m):