Browse Source

Fixed #31910 -- Fixed crash of GIS aggregations over subqueries.

Regression was introduced by fff5186 but was due a long standing issue.

AggregateQuery was abusing Query.subquery: bool by stashing its
compiled inner query's SQL for later use in its compiler which made
select_format checks for Query.subquery wrongly assume the provide
query was a subquery.

This patch prevents that from happening by using a dedicated
inner_query attribute which is compiled at a later time by
SQLAggregateCompiler.

Moving the inner query's compilation to SQLAggregateCompiler.compile
had the side effect of addressing a long standing issue with
aggregation subquery pushdown which prevented converters from being
run. This is now fixed as the aggregation_regress adjustments
demonstrate.

Refs #25367.

Thanks Eran Keydar for the report.
Simon Charette 4 years ago
parent
commit
c2d4926702

+ 5 - 2
django/db/models/sql/compiler.py

@@ -1596,8 +1596,11 @@ class SQLAggregateCompiler(SQLCompiler):
         sql = ', '.join(sql)
         params = tuple(params)
 
-        sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
-        params = params + self.query.sub_params
+        inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
+            self.using
+        ).as_sql(with_col_aliases=True)
+        sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
+        params = params + inner_query_params
         return sql, params
 
 

+ 3 - 11
django/db/models/sql/query.py

@@ -17,9 +17,7 @@ from collections.abc import Iterator, Mapping
 from itertools import chain, count, product
 from string import ascii_uppercase
 
-from django.core.exceptions import (
-    EmptyResultSet, FieldDoesNotExist, FieldError,
-)
+from django.core.exceptions import FieldDoesNotExist, FieldError
 from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
 from django.db.models.aggregates import Count
 from django.db.models.constants import LOOKUP_SEP
@@ -449,8 +447,9 @@ class Query(BaseExpression):
         if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
                 self.distinct or self.combinator):
             from django.db.models.sql.subqueries import AggregateQuery
-            outer_query = AggregateQuery(self.model)
             inner_query = self.clone()
+            inner_query.subquery = True
+            outer_query = AggregateQuery(self.model, inner_query)
             inner_query.select_for_update = False
             inner_query.select_related = False
             inner_query.set_annotation_mask(self.annotation_select)
@@ -492,13 +491,6 @@ class Query(BaseExpression):
                 # field selected in the inner query, yet we must use a subquery.
                 # So, make sure at least one field is selected.
                 inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
-            try:
-                outer_query.add_subquery(inner_query, using)
-            except EmptyResultSet:
-                return {
-                    alias: None
-                    for alias in outer_query.annotation_select
-                }
         else:
             outer_query = self
             self.select = ()

+ 3 - 3
django/db/models/sql/subqueries.py

@@ -157,6 +157,6 @@ class AggregateQuery(Query):
 
     compiler = 'SQLAggregateCompiler'
 
-    def add_subquery(self, query, using):
-        query.subquery = True
-        self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)
+    def __init__(self, model, inner_query):
+        self.inner_query = inner_query
+        super().__init__(model)

+ 1 - 1
tests/aggregation_regress/tests.py

@@ -974,7 +974,7 @@ class AggregationTests(TestCase):
     def test_empty_filter_aggregate(self):
         self.assertEqual(
             Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")),
-            {"pk__count": None}
+            {"pk__count": 0}
         )
 
     def test_none_call_before_aggregate(self):

+ 14 - 0
tests/gis_tests/geoapp/tests.py

@@ -12,6 +12,7 @@ from django.core.management import call_command
 from django.db import DatabaseError, NotSupportedError, connection
 from django.db.models import F, OuterRef, Subquery
 from django.test import TestCase, skipUnlessDBFeature
+from django.test.utils import CaptureQueriesContext
 
 from ..utils import (
     mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite,
@@ -593,6 +594,19 @@ class GeoQuerySetTest(TestCase):
         qs = City.objects.filter(name='NotACity')
         self.assertIsNone(qs.aggregate(Union('point'))['point__union'])
 
+    @skipUnlessDBFeature('supports_union_aggr')
+    def test_geoagg_subquery(self):
+        ks = State.objects.get(name='Kansas')
+        union = GEOSGeometry('MULTIPOINT(-95.235060 38.971823)')
+        # Use distinct() to force the usage of a subquery for aggregation.
+        with CaptureQueriesContext(connection) as ctx:
+            self.assertIs(union.equals(
+                City.objects.filter(point__within=ks.poly).distinct().aggregate(
+                    Union('point'),
+                )['point__union'],
+            ), True)
+        self.assertIn('subquery', ctx.captured_queries[0]['sql'])
+
     @unittest.skipUnless(
         connection.vendor == 'oracle',
         'Oracle supports tolerance parameter.',