Browse Source

Refs #26430 -- Re-introduced empty aggregation optimization.

The introduction of the Expression.empty_aggregate_value interface
allows the compilation stage to enable the EmptyResultSet optimization
if all the aggregates expressions implement it.

This also removes unnecessary RegrCount/Count.convert_value() methods.
Disabling the empty result set aggregation optimization when it wasn't
appropriate prevented None returned for a Count aggregation value.

Thanks Nick Pope for the review.
Simon Charette 3 years ago
parent
commit
9f3cce172f

+ 1 - 3
django/contrib/postgres/aggregates/statistics.py

@@ -36,9 +36,7 @@ class RegrAvgY(StatAggregate):
 class RegrCount(StatAggregate):
 class RegrCount(StatAggregate):
     function = 'REGR_COUNT'
     function = 'REGR_COUNT'
     output_field = IntegerField()
     output_field = IntegerField()
-
-    def convert_value(self, value, expression, connection):
-        return 0 if value is None else value
+    empty_aggregate_value = 0
 
 
 
 
 class RegrIntercept(StatAggregate):
 class RegrIntercept(StatAggregate):

+ 2 - 3
django/db/models/aggregates.py

@@ -20,6 +20,7 @@ class Aggregate(Func):
     filter_template = '%s FILTER (WHERE %%(filter)s)'
     filter_template = '%s FILTER (WHERE %%(filter)s)'
     window_compatible = True
     window_compatible = True
     allow_distinct = False
     allow_distinct = False
+    empty_aggregate_value = None
 
 
     def __init__(self, *expressions, distinct=False, filter=None, **extra):
     def __init__(self, *expressions, distinct=False, filter=None, **extra):
         if distinct and not self.allow_distinct:
         if distinct and not self.allow_distinct:
@@ -107,6 +108,7 @@ class Count(Aggregate):
     name = 'Count'
     name = 'Count'
     output_field = IntegerField()
     output_field = IntegerField()
     allow_distinct = True
     allow_distinct = True
+    empty_aggregate_value = 0
 
 
     def __init__(self, expression, filter=None, **extra):
     def __init__(self, expression, filter=None, **extra):
         if expression == '*':
         if expression == '*':
@@ -115,9 +117,6 @@ class Count(Aggregate):
             raise ValueError('Star cannot be used with filter. Please specify a field.')
             raise ValueError('Star cannot be used with filter. Please specify a field.')
         super().__init__(expression, filter=filter, **extra)
         super().__init__(expression, filter=filter, **extra)
 
 
-    def convert_value(self, value, expression, connection):
-        return 0 if value is None else value
-
 
 
 class Max(Aggregate):
 class Max(Aggregate):
     function = 'MAX'
     function = 'MAX'

+ 5 - 0
django/db/models/expressions.py

@@ -153,6 +153,7 @@ class BaseExpression:
     # aggregate specific fields
     # aggregate specific fields
     is_summary = False
     is_summary = False
     _output_field_resolved_to_none = False
     _output_field_resolved_to_none = False
+    empty_aggregate_value = NotImplemented
     # Can the expression be used in a WHERE clause?
     # Can the expression be used in a WHERE clause?
     filterable = True
     filterable = True
     # Can the expression can be used as a source expression in Window?
     # Can the expression can be used as a source expression in Window?
@@ -795,6 +796,10 @@ class Value(SQLiteNumericMixin, Expression):
         if isinstance(self.value, UUID):
         if isinstance(self.value, UUID):
             return fields.UUIDField()
             return fields.UUIDField()
 
 
+    @property
+    def empty_aggregate_value(self):
+        return self.value
+
 
 
 class RawSQL(Expression):
 class RawSQL(Expression):
     def __init__(self, sql, params, output_field=None):
     def __init__(self, sql, params, output_field=None):

+ 8 - 0
django/db/models/functions/comparison.py

@@ -65,6 +65,14 @@ class Coalesce(Func):
             raise ValueError('Coalesce must take at least two expressions')
             raise ValueError('Coalesce must take at least two expressions')
         super().__init__(*expressions, **extra)
         super().__init__(*expressions, **extra)
 
 
+    @property
+    def empty_aggregate_value(self):
+        for expression in self.get_source_expressions():
+            result = expression.empty_aggregate_value
+            if result is NotImplemented or result is not None:
+                return result
+        return None
+
     def as_oracle(self, compiler, connection, **extra_context):
     def as_oracle(self, compiler, connection, **extra_context):
         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
         # so convert all fields to NCLOB when that type is expected.
         # so convert all fields to NCLOB when that type is expected.

+ 8 - 1
django/db/models/sql/query.py

@@ -490,12 +490,19 @@ class Query(BaseExpression):
             self.default_cols = False
             self.default_cols = False
             self.extra = {}
             self.extra = {}
 
 
+        empty_aggregate_result = [
+            expression.empty_aggregate_value
+            for expression in outer_query.annotation_select.values()
+        ]
+        elide_empty = not any(result is NotImplemented for result in empty_aggregate_result)
         outer_query.clear_ordering(force=True)
         outer_query.clear_ordering(force=True)
         outer_query.clear_limits()
         outer_query.clear_limits()
         outer_query.select_for_update = False
         outer_query.select_for_update = False
         outer_query.select_related = False
         outer_query.select_related = False
-        compiler = outer_query.get_compiler(using, elide_empty=False)
+        compiler = outer_query.get_compiler(using, elide_empty=elide_empty)
         result = compiler.execute_sql(SINGLE)
         result = compiler.execute_sql(SINGLE)
+        if result is None:
+            result = empty_aggregate_result
 
 
         converters = compiler.get_converters(outer_query.annotation_select.values())
         converters = compiler.get_converters(outer_query.annotation_select.values())
         result = next(compiler.apply_converters((result,), converters))
         result = next(compiler.apply_converters((result,), converters))

+ 17 - 0
docs/ref/models/expressions.txt

@@ -410,6 +410,14 @@ The ``Aggregate`` API is as follows:
         allows passing a ``distinct`` keyword argument. If set to ``False``
         allows passing a ``distinct`` keyword argument. If set to ``False``
         (default), ``TypeError`` is raised if ``distinct=True`` is passed.
         (default), ``TypeError`` is raised if ``distinct=True`` is passed.
 
 
+    .. attribute:: empty_aggregate_value
+
+        .. versionadded:: 4.0
+
+        Override :attr:`~django.db.models.Expression.empty_aggregate_value` to
+        ``None`` since most aggregate functions result in ``NULL`` when applied
+        to an empty result set.
+
 The ``expressions`` positional arguments can include expressions, transforms of
 The ``expressions`` positional arguments can include expressions, transforms of
 the model field, or the names of model fields. They will be converted to a
 the model field, or the names of model fields. They will be converted to a
 string and used as the ``expressions`` placeholder within the ``template``.
 string and used as the ``expressions`` placeholder within the ``template``.
@@ -950,6 +958,15 @@ calling the appropriate methods on the wrapped expression.
         in :class:`~django.db.models.expressions.Window`. Defaults to
         in :class:`~django.db.models.expressions.Window`. Defaults to
         ``False``.
         ``False``.
 
 
+    .. attribute:: empty_aggregate_value
+
+        .. versionadded:: 4.0
+
+        Tells Django which value should be returned when the expression is used
+        to :meth:`aggregate <django.db.models.query.QuerySet.aggregate>` over
+        an empty result set. Defaults to :py:data:`NotImplemented` which forces
+        the expression to be computed on the database.
+
     .. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False)
     .. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False, for_save=False)
 
 
         Provides the chance to do any pre-processing or validation of
         Provides the chance to do any pre-processing or validation of

+ 3 - 0
docs/releases/4.0.txt

@@ -265,6 +265,9 @@ Models
 
 
 * :meth:`.QuerySet.bulk_update` now returns the number of objects updated.
 * :meth:`.QuerySet.bulk_update` now returns the number of objects updated.
 
 
+* The new :attr:`.Aggregate.empty_aggregate_value` attribute allows specifying
+  a value to return when the aggregation is used over an empty result set.
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 57 - 27
tests/aggregation/tests.py

@@ -8,7 +8,7 @@ from django.db.models import (
     Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
     Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
     IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
     IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
 )
 )
-from django.db.models.expressions import RawSQL
+from django.db.models.expressions import Func, RawSQL
 from django.db.models.functions import Coalesce, Greatest
 from django.db.models.functions import Coalesce, Greatest
 from django.test import TestCase
 from django.test import TestCase
 from django.test.testcases import skipUnlessDBFeature
 from django.test.testcases import skipUnlessDBFeature
@@ -1342,33 +1342,63 @@ class AggregateTestCase(TestCase):
             ('Peter Norvig', 2),
             ('Peter Norvig', 2),
         ], lambda a: (a.name, a.contact_count), ordered=False)
         ], lambda a: (a.name, a.contact_count), ordered=False)
 
 
+    def test_empty_result_optimization(self):
+        with self.assertNumQueries(0):
+            self.assertEqual(
+                Publisher.objects.none().aggregate(
+                    sum_awards=Sum('num_awards'),
+                    books_count=Count('book'),
+                ), {
+                    'sum_awards': None,
+                    'books_count': 0,
+                }
+            )
+        # Expression without empty_aggregate_value forces queries to be
+        # executed even if they would return an empty result set.
+        raw_books_count = Func('book', function='COUNT')
+        raw_books_count.contains_aggregate = True
+        with self.assertNumQueries(1):
+            self.assertEqual(
+                Publisher.objects.none().aggregate(
+                    sum_awards=Sum('num_awards'),
+                    books_count=raw_books_count,
+                ), {
+                    'sum_awards': None,
+                    'books_count': 0,
+                }
+            )
+
     def test_coalesced_empty_result_set(self):
     def test_coalesced_empty_result_set(self):
-        self.assertEqual(
-            Publisher.objects.none().aggregate(
-                sum_awards=Coalesce(Sum('num_awards'), 0),
-            )['sum_awards'],
-            0,
-        )
+        with self.assertNumQueries(0):
+            self.assertEqual(
+                Publisher.objects.none().aggregate(
+                    sum_awards=Coalesce(Sum('num_awards'), 0),
+                )['sum_awards'],
+                0,
+            )
         # Multiple expressions.
         # Multiple expressions.
-        self.assertEqual(
-            Publisher.objects.none().aggregate(
-                sum_awards=Coalesce(Sum('num_awards'), None, 0),
-            )['sum_awards'],
-            0,
-        )
+        with self.assertNumQueries(0):
+            self.assertEqual(
+                Publisher.objects.none().aggregate(
+                    sum_awards=Coalesce(Sum('num_awards'), None, 0),
+                )['sum_awards'],
+                0,
+            )
         # Nested coalesce.
         # Nested coalesce.
-        self.assertEqual(
-            Publisher.objects.none().aggregate(
-                sum_awards=Coalesce(Coalesce(Sum('num_awards'), None), 0),
-            )['sum_awards'],
-            0,
-        )
+        with self.assertNumQueries(0):
+            self.assertEqual(
+                Publisher.objects.none().aggregate(
+                    sum_awards=Coalesce(Coalesce(Sum('num_awards'), None), 0),
+                )['sum_awards'],
+                0,
+            )
         # Expression coalesce.
         # Expression coalesce.
-        self.assertIsInstance(
-            Store.objects.none().aggregate(
-                latest_opening=Coalesce(
-                    Max('original_opening'), RawSQL('CURRENT_TIMESTAMP', []),
-                ),
-            )['latest_opening'],
-            datetime.datetime,
-        )
+        with self.assertNumQueries(1):
+            self.assertIsInstance(
+                Store.objects.none().aggregate(
+                    latest_opening=Coalesce(
+                        Max('original_opening'), RawSQL('CURRENT_TIMESTAMP', []),
+                    ),
+                )['latest_opening'],
+                datetime.datetime,
+            )