瀏覽代碼

Fixed #34362 -- Fixed FilteredRelation() crash on conditional expressions.

Thanks zhu for the report and Simon Charette for reviews.
Francesco Panico 2 年之前
父節點
當前提交
59f4754704
共有 2 個文件被更改,包括 123 次插入7 次删除
  1. 38 7
      django/db/models/sql/query.py
  2. 85 0
      tests/filtered_relation/tests.py

+ 38 - 7
django/db/models/sql/query.py

@@ -65,20 +65,52 @@ def get_field_names_from_opts(opts):
     )
 
 
+def get_paths_from_expression(expr):
+    if isinstance(expr, F):
+        yield expr.name
+    elif hasattr(expr, "flatten"):
+        for child in expr.flatten():
+            if isinstance(child, F):
+                yield child.name
+            elif isinstance(child, Q):
+                yield from get_children_from_q(child)
+
+
 def get_children_from_q(q):
     for child in q.children:
         if isinstance(child, Node):
             yield from get_children_from_q(child)
-        else:
-            yield child
+        elif isinstance(child, tuple):
+            lhs, rhs = child
+            yield lhs
+            if hasattr(rhs, "resolve_expression"):
+                yield from get_paths_from_expression(rhs)
+        elif hasattr(child, "resolve_expression"):
+            yield from get_paths_from_expression(child)
 
 
 def get_child_with_renamed_prefix(prefix, replacement, child):
     if isinstance(child, Node):
         return rename_prefix_from_q(prefix, replacement, child)
-    lhs, rhs = child
-    lhs = lhs.replace(prefix, replacement, 1)
-    return lhs, rhs
+    if isinstance(child, tuple):
+        lhs, rhs = child
+        lhs = lhs.replace(prefix, replacement, 1)
+        if not isinstance(rhs, F) and hasattr(rhs, "resolve_expression"):
+            rhs = get_child_with_renamed_prefix(prefix, replacement, rhs)
+        return lhs, rhs
+
+    if isinstance(child, F):
+        child = child.copy()
+        child.name = child.name.replace(prefix, replacement, 1)
+    elif hasattr(child, "resolve_expression"):
+        child = child.copy()
+        child.set_source_expressions(
+            [
+                get_child_with_renamed_prefix(prefix, replacement, grand_child)
+                for grand_child in child.get_source_expressions()
+            ]
+        )
+    return child
 
 
 def rename_prefix_from_q(prefix, replacement, q):
@@ -1618,7 +1650,6 @@ class Query(BaseExpression):
 
     def add_filtered_relation(self, filtered_relation, alias):
         filtered_relation.alias = alias
-        lookups = dict(get_children_from_q(filtered_relation.condition))
         relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
             filtered_relation.relation_name
         )
@@ -1627,7 +1658,7 @@ class Query(BaseExpression):
                 "FilteredRelation's relation_name cannot contain lookups "
                 "(got %r)." % filtered_relation.relation_name
             )
-        for lookup in chain(lookups):
+        for lookup in get_children_from_q(filtered_relation.condition):
             lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
             shift = 2 if not lookup_parts else 1
             lookup_field_path = lookup_field_parts[:-shift]

+ 85 - 0
tests/filtered_relation/tests.py

@@ -4,9 +4,11 @@ from unittest import mock
 
 from django.db import connection, transaction
 from django.db.models import (
+    BooleanField,
     Case,
     Count,
     DecimalField,
+    ExpressionWrapper,
     F,
     FilteredRelation,
     Q,
@@ -15,6 +17,7 @@ from django.db.models import (
     When,
 )
 from django.db.models.functions import Concat
+from django.db.models.lookups import Exact, IStartsWith
 from django.test import TestCase
 from django.test.testcases import skipUnlessDBFeature
 
@@ -707,6 +710,88 @@ class FilteredRelationTests(TestCase):
             FilteredRelation("book", condition=Q(book__title="b")), mock.ANY
         )
 
+    def test_conditional_expression(self):
+        qs = Author.objects.annotate(
+            the_book=FilteredRelation("book", condition=Q(Value(False))),
+        ).filter(the_book__isnull=False)
+        self.assertSequenceEqual(qs, [])
+
+    def test_expression_outside_relation_name(self):
+        qs = Author.objects.annotate(
+            book_editor=FilteredRelation(
+                "book__editor",
+                condition=Q(
+                    Exact(F("book__author__name"), "Alice"),
+                    Value(True),
+                    book__title__startswith="Poem",
+                ),
+            ),
+        ).filter(book_editor__isnull=False)
+        self.assertSequenceEqual(qs, [self.author1])
+
+    def test_conditional_expression_with_case(self):
+        qs = Book.objects.annotate(
+            alice_author=FilteredRelation(
+                "author",
+                condition=Q(
+                    Case(When(author__name="Alice", then=True), default=False),
+                ),
+            ),
+        ).filter(alice_author__isnull=False)
+        self.assertCountEqual(qs, [self.book1, self.book4])
+
+    def test_conditional_expression_outside_relation_name(self):
+        tests = [
+            Q(Case(When(book__author__name="Alice", then=True), default=False)),
+            Q(
+                ExpressionWrapper(
+                    Q(Value(True), Exact(F("book__author__name"), "Alice")),
+                    output_field=BooleanField(),
+                ),
+            ),
+        ]
+        for condition in tests:
+            with self.subTest(condition=condition):
+                qs = Author.objects.annotate(
+                    book_editor=FilteredRelation("book__editor", condition=condition),
+                ).filter(book_editor__isnull=True)
+                self.assertSequenceEqual(qs, [self.author2, self.author2])
+
+    def test_conditional_expression_with_lookup(self):
+        lookups = [
+            Q(book__title__istartswith="poem"),
+            Q(IStartsWith(F("book__title"), "poem")),
+        ]
+        for condition in lookups:
+            with self.subTest(condition=condition):
+                qs = Author.objects.annotate(
+                    poem_book=FilteredRelation("book", condition=condition)
+                ).filter(poem_book__isnull=False)
+                self.assertSequenceEqual(qs, [self.author1])
+
+    def test_conditional_expression_with_expressionwrapper(self):
+        qs = Author.objects.annotate(
+            poem_book=FilteredRelation(
+                "book",
+                condition=Q(
+                    ExpressionWrapper(
+                        Q(Exact(F("book__title"), "Poem by Alice")),
+                        output_field=BooleanField(),
+                    ),
+                ),
+            ),
+        ).filter(poem_book__isnull=False)
+        self.assertSequenceEqual(qs, [self.author1])
+
+    def test_conditional_expression_with_multiple_fields(self):
+        qs = Author.objects.annotate(
+            my_books=FilteredRelation(
+                "book__author",
+                condition=Q(Exact(F("book__author__name"), F("book__author__name"))),
+            ),
+        ).filter(my_books__isnull=True)
+        self.assertSequenceEqual(qs, [])
+
 
 class FilteredRelationAggregationTests(TestCase):
     @classmethod