소스 검색

Fixed #36173 -- Stabilized identity of Concat with an explicit output_field.

When Expression.__init__() overrides make use of *args, **kwargs
captures their argument values are respectively bound as a tuple and
dict instances. These composite values might themselves contain values
that require special identity treatments such as Concat(output_field)
as it's a Field instance.

Refs #30628 which introduced bound Field differentiation but lacked
argument captures handling.

Thanks erchenstein for the report.
Simon Charette 1 개월 전
부모
커밋
df2c4952df
3개의 변경된 파일53개의 추가작업 그리고 7개의 파일을 삭제
  1. 16 7
      django/db/models/expressions.py
  2. 14 0
      tests/db_functions/text/test_concat.py
  3. 23 0
      tests/expressions/tests.py

+ 16 - 7
django/db/models/expressions.py

@@ -523,6 +523,18 @@ class Expression(BaseExpression, Combinable):
     def _constructor_signature(cls):
         return inspect.signature(cls.__init__)
 
+    @classmethod
+    def _identity(cls, value):
+        if isinstance(value, tuple):
+            return tuple(map(cls._identity, value))
+        if isinstance(value, dict):
+            return tuple((key, cls._identity(val)) for key, val in value.items())
+        if isinstance(value, fields.Field):
+            if value.name and value.model:
+                return value.model._meta.label, value.name
+            return type(value)
+        return make_hashable(value)
+
     @cached_property
     def identity(self):
         args, kwargs = self._constructor_args
@@ -532,13 +544,10 @@ class Expression(BaseExpression, Combinable):
         next(arguments)
         identity = [self.__class__]
         for arg, value in arguments:
-            if isinstance(value, fields.Field):
-                if value.name and value.model:
-                    value = (value.model._meta.label, value.name)
-                else:
-                    value = type(value)
-            else:
-                value = make_hashable(value)
+            # If __init__() makes use of *args or **kwargs captures `value`
+            # will respectively be a tuple or a dict that must have its
+            # constituents unpacked (mainly if contain Field instances).
+            value = self._identity(value)
             identity.append((arg, value))
         return tuple(identity)
 

+ 14 - 0
tests/db_functions/text/test_concat.py

@@ -107,3 +107,17 @@ class ConcatTests(TestCase):
             ctx.captured_queries[0]["sql"].count("::text"),
             1 if connection.vendor == "postgresql" else 0,
         )
+
+    def test_equal(self):
+        self.assertEqual(
+            Concat("foo", "bar", output_field=TextField()),
+            Concat("foo", "bar", output_field=TextField()),
+        )
+        self.assertNotEqual(
+            Concat("foo", "bar", output_field=TextField()),
+            Concat("foo", "bar", output_field=CharField()),
+        )
+        self.assertNotEqual(
+            Concat("foo", "bar", output_field=TextField()),
+            Concat("bar", "foo", output_field=TextField()),
+        )

+ 23 - 0
tests/expressions/tests.py

@@ -1433,6 +1433,29 @@ class SimpleExpressionTests(SimpleTestCase):
             Expression(TestModel._meta.get_field("other_field")),
         )
 
+        class InitCaptureExpression(Expression):
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+
+        # The identity of expressions that obscure their __init__() signature
+        # with *args and **kwargs cannot be determined when bound with
+        # different combinations or *args and **kwargs.
+        self.assertNotEqual(
+            InitCaptureExpression(IntegerField()),
+            InitCaptureExpression(output_field=IntegerField()),
+        )
+
+        # However, they should be considered equal when their bindings are
+        # equal.
+        self.assertEqual(
+            InitCaptureExpression(IntegerField()),
+            InitCaptureExpression(IntegerField()),
+        )
+        self.assertEqual(
+            InitCaptureExpression(output_field=IntegerField()),
+            InitCaptureExpression(output_field=IntegerField()),
+        )
+
     def test_hash(self):
         self.assertEqual(hash(Expression()), hash(Expression()))
         self.assertEqual(