浏览代码

Refs #30581 -- Added Q.flatten().

Gagaro 2 年之前
父节点
当前提交
9d04711261
共有 2 个文件被更改,包括 42 次插入1 次删除
  1. 15 0
      django/db/models/query_utils.py
  2. 27 1
      tests/queries/test_q.py

+ 15 - 0
django/db/models/query_utils.py

@@ -95,6 +95,21 @@ class Q(tree.Node):
         query.promote_joins(joins)
         return clause
 
+    def flatten(self):
+        """
+        Recursively yield this Q object and all subexpressions, in depth-first
+        order.
+        """
+        yield self
+        for child in self.children:
+            if isinstance(child, tuple):
+                # Use the lookup.
+                child = child[1]
+            if hasattr(child, "flatten"):
+                yield from child.flatten()
+            else:
+                yield child
+
     def deconstruct(self):
         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
         if path.startswith("django.db.models.query_utils"):

+ 27 - 1
tests/queries/test_q.py

@@ -1,5 +1,15 @@
-from django.db.models import BooleanField, Exists, F, OuterRef, Q
+from django.db.models import (
+    BooleanField,
+    Exists,
+    ExpressionWrapper,
+    F,
+    OuterRef,
+    Q,
+    Value,
+)
 from django.db.models.expressions import RawSQL
+from django.db.models.functions import Lower
+from django.db.models.sql.where import NothingNode
 from django.test import SimpleTestCase
 
 from .models import Tag
@@ -188,3 +198,19 @@ class QTests(SimpleTestCase):
         q = q1 & q2
         path, args, kwargs = q.deconstruct()
         self.assertEqual(Q(*args, **kwargs), q)
+
+    def test_flatten(self):
+        q = Q()
+        self.assertEqual(list(q.flatten()), [q])
+        q = Q(NothingNode())
+        self.assertEqual(list(q.flatten()), [q, q.children[0]])
+        q = Q(
+            ExpressionWrapper(
+                Q(RawSQL("id = 0", params=(), output_field=BooleanField()))
+                | Q(price=Value("4.55"))
+                | Q(name=Lower("category")),
+                output_field=BooleanField(),
+            )
+        )
+        flatten = list(q.flatten())
+        self.assertEqual(len(flatten), 7)