Bläddra i källkod

Fixed #26430 -- Fixed coalesced aggregation of empty result sets.

Disable the EmptyResultSet optimization when performing aggregation as
it might interfere with coalescence.
Simon Charette 3 år sedan
förälder
incheckning
f3112fde98
3 ändrade filer med 48 tillägg och 9 borttagningar
  1. 13 4
      django/db/models/sql/compiler.py
  2. 3 5
      django/db/models/sql/query.py
  3. 32 0
      tests/aggregation/tests.py

+ 13 - 4
django/db/models/sql/compiler.py

@@ -26,10 +26,13 @@ class SQLCompiler:
         re.MULTILINE | re.DOTALL,
     )
 
-    def __init__(self, query, connection, using):
+    def __init__(self, query, connection, using, elide_empty=True):
         self.query = query
         self.connection = connection
         self.using = using
+        # Some queries, e.g. coalesced aggregation, need to be executed even if
+        # they would return an empty result set.
+        self.elide_empty = elide_empty
         self.quote_cache = {'*': '*'}
         # The select, klass_info, and annotations are needed by QuerySet.iterator()
         # these are set as a side-effect of executing the query. Note that we calculate
@@ -458,7 +461,7 @@ class SQLCompiler:
     def get_combinator_sql(self, combinator, all):
         features = self.connection.features
         compilers = [
-            query.get_compiler(self.using, self.connection)
+            query.get_compiler(self.using, self.connection, self.elide_empty)
             for query in self.query.combined_queries if not query.is_empty()
         ]
         if not features.supports_slicing_ordering_in_compound:
@@ -535,7 +538,13 @@ class SQLCompiler:
                 # This must come after 'select', 'ordering', and 'distinct'
                 # (see docstring of get_from_clause() for details).
                 from_, f_params = self.get_from_clause()
-                where, w_params = self.compile(self.where) if self.where is not None else ("", [])
+                try:
+                    where, w_params = self.compile(self.where) if self.where is not None else ('', [])
+                except EmptyResultSet:
+                    if self.elide_empty:
+                        raise
+                    # Use a predicate that's always False.
+                    where, w_params = '0 = 1', []
                 having, h_params = self.compile(self.having) if self.having is not None else ("", [])
                 result = ['SELECT']
                 params = []
@@ -1652,7 +1661,7 @@ class SQLAggregateCompiler(SQLCompiler):
         params = tuple(params)
 
         inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
-            self.using
+            self.using, elide_empty=self.elide_empty,
         ).as_sql(with_col_aliases=True)
         sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
         params = params + inner_query_params

+ 3 - 5
django/db/models/sql/query.py

@@ -273,12 +273,12 @@ class Query(BaseExpression):
         memo[id(self)] = result
         return result
 
-    def get_compiler(self, using=None, connection=None):
+    def get_compiler(self, using=None, connection=None, elide_empty=True):
         if using is None and connection is None:
             raise ValueError("Need either using or connection")
         if using:
             connection = connections[using]
-        return connection.ops.compiler(self.compiler)(self, connection, using)
+        return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty)
 
     def get_meta(self):
         """
@@ -494,10 +494,8 @@ class Query(BaseExpression):
         outer_query.clear_limits()
         outer_query.select_for_update = False
         outer_query.select_related = False
-        compiler = outer_query.get_compiler(using)
+        compiler = outer_query.get_compiler(using, elide_empty=False)
         result = compiler.execute_sql(SINGLE)
-        if result is None:
-            result = [None] * len(outer_query.annotation_select)
 
         converters = compiler.get_converters(outer_query.annotation_select.values())
         result = next(compiler.apply_converters((result,), converters))

+ 32 - 0
tests/aggregation/tests.py

@@ -8,6 +8,7 @@ from django.db.models import (
     Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
     IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
 )
+from django.db.models.expressions import RawSQL
 from django.db.models.functions import Coalesce, Greatest
 from django.test import TestCase
 from django.test.testcases import skipUnlessDBFeature
@@ -1340,3 +1341,34 @@ class AggregateTestCase(TestCase):
             ('Stuart Russell', 1),
             ('Peter Norvig', 2),
         ], lambda a: (a.name, a.contact_count), ordered=False)
+
+    def test_coalesced_empty_result_set(self):
+        self.assertEqual(
+            Publisher.objects.none().aggregate(
+                sum_awards=Coalesce(Sum('num_awards'), 0),
+            )['sum_awards'],
+            0,
+        )
+        # Multiple expressions.
+        self.assertEqual(
+            Publisher.objects.none().aggregate(
+                sum_awards=Coalesce(Sum('num_awards'), None, 0),
+            )['sum_awards'],
+            0,
+        )
+        # Nested coalesce.
+        self.assertEqual(
+            Publisher.objects.none().aggregate(
+                sum_awards=Coalesce(Coalesce(Sum('num_awards'), None), 0),
+            )['sum_awards'],
+            0,
+        )
+        # Expression coalesce.
+        self.assertIsInstance(
+            Store.objects.none().aggregate(
+                latest_opening=Coalesce(
+                    Max('original_opening'), RawSQL('CURRENT_TIMESTAMP', []),
+                ),
+            )['latest_opening'],
+            datetime.datetime,
+        )