瀏覽代碼

Fixed #31507 -- Added QuerySet.exists() optimizations to compound queries.

David-Wobrock 4 年之前
父節點
當前提交
ba42569d5c
共有 3 個文件被更改,包括 35 次插入8 次删除
  1. 11 4
      django/db/models/expressions.py
  2. 10 3
      django/db/models/sql/query.py
  3. 14 1
      tests/queries/test_qs_combinators.py

+ 11 - 4
django/db/models/expressions.py

@@ -1116,10 +1116,11 @@ class Subquery(Expression):
     def external_aliases(self):
         return self.query.external_aliases
 
-    def as_sql(self, compiler, connection, template=None, **extra_context):
+    def as_sql(self, compiler, connection, template=None, query=None, **extra_context):
         connection.ops.check_expression_support(self)
         template_params = {**self.extra, **extra_context}
-        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
+        query = query or self.query
+        subquery_sql, sql_params = query.as_sql(compiler, connection)
         template_params['subquery'] = subquery_sql[1:-1]
 
         template = template or template_params.get('template', self.template)
@@ -1142,7 +1143,6 @@ class Exists(Subquery):
     def __init__(self, queryset, negated=False, **kwargs):
         self.negated = negated
         super().__init__(queryset, **kwargs)
-        self.query = self.query.exists()
 
     def __invert__(self):
         clone = self.copy()
@@ -1150,7 +1150,14 @@ class Exists(Subquery):
         return clone
 
     def as_sql(self, compiler, connection, template=None, **extra_context):
-        sql, params = super().as_sql(compiler, connection, template, **extra_context)
+        query = self.query.exists(using=connection.alias)
+        sql, params = super().as_sql(
+            compiler,
+            connection,
+            template=template,
+            query=query,
+            **extra_context,
+        )
         if self.negated:
             sql = 'NOT {}'.format(sql)
         return sql, params

+ 10 - 3
django/db/models/sql/query.py

@@ -525,7 +525,7 @@ class Query(BaseExpression):
     def has_filters(self):
         return self.where
 
-    def exists(self):
+    def exists(self, using, limit=True):
         q = self.clone()
         if not q.distinct:
             if q.group_by is True:
@@ -534,14 +534,21 @@ class Query(BaseExpression):
                 # SELECT clause which is about to be cleared.
                 q.set_group_by(allow_aliases=False)
             q.clear_select_clause()
+        if q.combined_queries and q.combinator == 'union':
+            limit_combined = connections[using].features.supports_slicing_ordering_in_compound
+            q.combined_queries = tuple(
+                combined_query.exists(using, limit=limit_combined)
+                for combined_query in q.combined_queries
+            )
         q.clear_ordering(True)
-        q.set_limits(high=1)
+        if limit:
+            q.set_limits(high=1)
         q.add_extra({'a': 1}, None, None, None, None, None)
         q.set_extra_mask(['a'])
         return q
 
     def has_results(self, using):
-        q = self.exists()
+        q = self.exists(using)
         compiler = q.get_compiler(using=using)
         return compiler.has_results()
 

+ 14 - 1
tests/queries/test_qs_combinators.py

@@ -3,6 +3,7 @@ import operator
 from django.db import DatabaseError, NotSupportedError, connection
 from django.db.models import Exists, F, IntegerField, OuterRef, Value
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
+from django.test.utils import CaptureQueriesContext
 
 from .models import Number, ReservedName
 
@@ -257,7 +258,19 @@ class QuerySetSetOperationTests(TestCase):
     def test_exists_union(self):
         qs1 = Number.objects.filter(num__gte=5)
         qs2 = Number.objects.filter(num__lte=5)
-        self.assertIs(qs1.union(qs2).exists(), True)
+        with CaptureQueriesContext(connection) as context:
+            self.assertIs(qs1.union(qs2).exists(), True)
+        captured_queries = context.captured_queries
+        self.assertEqual(len(captured_queries), 1)
+        captured_sql = captured_queries[0]['sql']
+        self.assertNotIn(
+            connection.ops.quote_name(Number._meta.pk.column),
+            captured_sql,
+        )
+        self.assertEqual(
+            captured_sql.count(connection.ops.limit_offset_sql(None, 1)),
+            3 if connection.features.supports_slicing_ordering_in_compound else 1
+        )
 
     def test_exists_union_empty_result(self):
         qs = Number.objects.filter(pk__in=[])