浏览代码

Refs #28477 -- Reduced complexity of aggregation over qualify queries.

Simon Charette 2 年之前
父节点
当前提交
a9d2d8d1c3
共有 2 个文件被更改,包括 34 次插入22 次删除
  1. 20 14
      django/db/models/sql/query.py
  2. 14 8
      tests/expressions_window/tests.py

+ 20 - 14
django/db/models/sql/query.py

@@ -447,13 +447,15 @@ class Query(BaseExpression):
             if alias not in added_aggregate_names
         }
         # Existing usage of aggregation can be determined by the presence of
-        # selected aggregate and window annotations but also by filters against
-        # aliased aggregate and windows via HAVING / QUALIFY.
-        has_existing_aggregation = any(
-            getattr(annotation, "contains_aggregate", True)
-            or getattr(annotation, "contains_over_clause", True)
-            for annotation in existing_annotations.values()
-        ) or any(self.where.split_having_qualify()[1:])
+        # selected aggregates but also by filters against aliased aggregates.
+        _, having, qualify = self.where.split_having_qualify()
+        has_existing_aggregation = (
+            any(
+                getattr(annotation, "contains_aggregate", True)
+                for annotation in existing_annotations.values()
+            )
+            or having
+        )
         # Decide if we need to use a subquery.
         #
         # Existing aggregations would cause incorrect results as
@@ -468,6 +470,7 @@ class Query(BaseExpression):
             isinstance(self.group_by, tuple)
             or self.is_sliced
             or has_existing_aggregation
+            or qualify
             or self.distinct
             or self.combinator
         ):
@@ -494,13 +497,16 @@ class Query(BaseExpression):
                         self.model._meta.pk.get_col(inner_query.get_initial_alias()),
                     )
                 inner_query.default_cols = False
-                # Mask existing annotations that are not referenced by
-                # aggregates to be pushed to the outer query.
-                annotation_mask = set()
-                for name in added_aggregate_names:
-                    annotation_mask.add(name)
-                    annotation_mask |= inner_query.annotations[name].get_refs()
-                inner_query.set_annotation_mask(annotation_mask)
+                if not qualify:
+                    # Mask existing annotations that are not referenced by
+                    # aggregates to be pushed to the outer query unless
+                    # filtering against window functions is involved as it
+                    # requires complex realising.
+                    annotation_mask = set()
+                    for name in added_aggregate_names:
+                        annotation_mask.add(name)
+                        annotation_mask |= inner_query.annotations[name].get_refs()
+                    inner_query.set_annotation_mask(annotation_mask)
 
             relabels = {t: "subquery" for t in inner_query.alias_map}
             relabels[None] = "subquery"

+ 14 - 8
tests/expressions_window/tests.py

@@ -42,6 +42,7 @@ from django.db.models.functions import (
 )
 from django.db.models.lookups import Exact
 from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
+from django.test.utils import CaptureQueriesContext
 
 from .models import Classification, Detail, Employee, PastEmployeeDepartment
 
@@ -1157,16 +1158,21 @@ class WindowFunctionTests(TestCase):
         )
 
     def test_filter_count(self):
-        self.assertEqual(
-            Employee.objects.annotate(
-                department_salary_rank=Window(
-                    Rank(), partition_by="department", order_by="-salary"
+        with CaptureQueriesContext(connection) as ctx:
+            self.assertEqual(
+                Employee.objects.annotate(
+                    department_salary_rank=Window(
+                        Rank(), partition_by="department", order_by="-salary"
+                    )
                 )
+                .filter(department_salary_rank=1)
+                .count(),
+                5,
             )
-            .filter(department_salary_rank=1)
-            .count(),
-            5,
-        )
+        self.assertEqual(len(ctx.captured_queries), 1)
+        sql = ctx.captured_queries[0]["sql"].lower()
+        self.assertEqual(sql.count("select"), 3)
+        self.assertNotIn("group by", sql)
 
     @skipUnlessDBFeature("supports_frame_range_fixed_distance")
     def test_range_n_preceding_and_following(self):