فهرست منبع

Fixed #28658 -- Added DISTINCT handling to the Aggregate class.

Simon Charette 6 سال پیش
والد
کامیت
bc05547cd8

+ 4 - 6
django/contrib/postgres/aggregates/general.py

@@ -11,14 +11,12 @@ __all__ = [
 class ArrayAgg(OrderableAggMixin, Aggregate):
     function = 'ARRAY_AGG'
     template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
+    allow_distinct = True
 
     @property
     def output_field(self):
         return ArrayField(self.source_expressions[0].output_field)
 
-    def __init__(self, expression, distinct=False, **extra):
-        super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra)
-
     def convert_value(self, value, expression, connection):
         if not value:
             return []
@@ -54,10 +52,10 @@ class JSONBAgg(Aggregate):
 class StringAgg(OrderableAggMixin, Aggregate):
     function = 'STRING_AGG'
     template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
+    allow_distinct = True
 
-    def __init__(self, expression, delimiter, distinct=False, **extra):
-        distinct = 'DISTINCT ' if distinct else ''
-        super().__init__(expression, delimiter=delimiter, distinct=distinct, **extra)
+    def __init__(self, expression, delimiter, **extra):
+        super().__init__(expression, delimiter=delimiter, **extra)
 
     def convert_value(self, value, expression, connection):
         if not value:

+ 5 - 0
django/db/backends/sqlite3/operations.py

@@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations):
                             'aggregations on date/time fields in sqlite3 '
                             'since date/time is saved as text.'
                         )
+        if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
+            raise utils.NotSupportedError(
+                "SQLite doesn't support DISTINCT on aggregate functions "
+                "accepting multiple arguments."
+            )
 
     def date_extract_sql(self, lookup_type, field_name):
         """

+ 14 - 12
django/db/models/aggregates.py

@@ -11,14 +11,19 @@ __all__ = [
 
 
 class Aggregate(Func):
+    template = '%(function)s(%(distinct)s%(expressions)s)'
     contains_aggregate = True
     name = None
     filter_template = '%s FILTER (WHERE %%(filter)s)'
     window_compatible = True
+    allow_distinct = False
 
-    def __init__(self, *args, filter=None, **kwargs):
+    def __init__(self, *expressions, distinct=False, filter=None, **extra):
+        if distinct and not self.allow_distinct:
+            raise TypeError("%s does not allow distinct." % self.__class__.__name__)
+        self.distinct = distinct
         self.filter = filter
-        super().__init__(*args, **kwargs)
+        super().__init__(*expressions, **extra)
 
     def get_source_fields(self):
         # Don't return the filter expression since it's not a source field.
@@ -60,6 +65,7 @@ class Aggregate(Func):
         return []
 
     def as_sql(self, compiler, connection, **extra_context):
+        extra_context['distinct'] = 'DISTINCT' if self.distinct else ''
         if self.filter:
             if connection.features.supports_aggregate_filter_clause:
                 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
@@ -80,8 +86,10 @@ class Aggregate(Func):
 
     def _get_repr_options(self):
         options = super()._get_repr_options()
+        if self.distinct:
+            options['distinct'] = self.distinct
         if self.filter:
-            options.update({'filter': self.filter})
+            options['filter'] = self.filter
         return options
 
 
@@ -114,21 +122,15 @@ class Avg(Aggregate):
 class Count(Aggregate):
     function = 'COUNT'
     name = 'Count'
-    template = '%(function)s(%(distinct)s%(expressions)s)'
     output_field = IntegerField()
+    allow_distinct = True
 
-    def __init__(self, expression, distinct=False, filter=None, **extra):
+    def __init__(self, expression, filter=None, **extra):
         if expression == '*':
             expression = Star()
         if isinstance(expression, Star) and filter is not None:
             raise ValueError('Star cannot be used with filter. Please specify a field.')
-        super().__init__(
-            expression, distinct='DISTINCT ' if distinct else '',
-            filter=filter, **extra
-        )
-
-    def _get_repr_options(self):
-        return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
+        super().__init__(expression, filter=filter, **extra)
 
     def convert_value(self, value, expression, connection):
         return 0 if value is None else value

+ 18 - 1
docs/ref/models/expressions.txt

@@ -373,7 +373,7 @@ some complex computations::
 
 The ``Aggregate`` API is as follows:
 
-.. class:: Aggregate(*expressions, output_field=None, filter=None, **extra)
+.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra)
 
     .. attribute:: template
 
@@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows:
         Defaults to ``True`` since most aggregate functions can be used as the
         source expression in :class:`~django.db.models.expressions.Window`.
 
+    .. attribute:: allow_distinct
+
+        .. versionadded:: 2.2
+
+        A class attribute determining whether or not this aggregate function
+        allows passing a ``distinct`` keyword argument. If set to ``False``
+        (default), ``TypeError`` is raised if ``distinct=True`` is passed.
+
 The ``expressions`` positional arguments can include expressions or the names
 of model fields. They will be converted to a string and used as the
 ``expressions`` placeholder within the ``template``.
@@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an
 ``IntegerField()`` and a ``FloatField()`` together should probably have
 ``output_field=FloatField()`` defined.
 
+The ``distinct`` argument determines whether or not the aggregate function
+should be invoked for each distinct value of ``expressions`` (or set of
+values, for multiple ``expressions``). The argument is only supported on
+aggregates that have :attr:`~Aggregate.allow_distinct` set to ``True``.
+
 The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
 used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
 and :ref:`filtering-on-annotations` for example usage.
@@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage.
 The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
 into the ``template`` attribute.
 
+.. versionadded:: 2.2
+
+    The ``allow_distinct`` attribute and ``distinct`` argument were added.
+
 Creating your own Aggregate Functions
 -------------------------------------
 

+ 7 - 0
docs/releases/2.2.txt

@@ -239,6 +239,13 @@ Models
 * Added SQLite support for the :class:`~django.db.models.StdDev` and
   :class:`~django.db.models.Variance` functions.
 
+* The handling of ``DISTINCT`` aggregation is added to the
+  :class:`~django.db.models.Aggregate` class. Adding :attr:`allow_distinct =
+  True <django.db.models.Aggregate.allow_distinct>` as a class attribute on
+  ``Aggregate`` subclasses allows a ``distinct`` keyword argument to be
+  specified on initialization to ensure that the aggregate function is only
+  called for each distinct value of ``expressions``.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 2 - 2
tests/aggregation/tests.py

@@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase):
         # test completely changing how the output is rendered
         def lower_case_function_override(self, compiler, connection):
             sql, params = compiler.compile(self.source_expressions[0])
-            substitutions = {'function': self.function.lower(), 'expressions': sql}
+            substitutions = {'function': self.function.lower(), 'expressions': sql, 'distinct': ''}
             substitutions.update(self.extra)
             return self.template % substitutions, params
         setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)
@@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase):
 
         # test overriding all parts of the template
         def be_evil(self, compiler, connection):
-            substitutions = {'function': 'MAX', 'expressions': '2'}
+            substitutions = {'function': 'MAX', 'expressions': '2', 'distinct': ''}
             substitutions.update(self.extra)
             return self.template % substitutions, ()
         setattr(MySum, 'as_' + connection.vendor, be_evil)

+ 11 - 0
tests/aggregation_regress/tests.py

@@ -11,6 +11,7 @@ from django.db.models import (
     Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
     Value, Variance, When,
 )
+from django.db.models.aggregates import Aggregate
 from django.test import (
     TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature,
 )
@@ -1496,6 +1497,16 @@ class AggregationTests(TestCase):
         qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1)
         self.assertSequenceEqual(qs, [29])
 
+    def test_allow_distinct(self):
+        class MyAggregate(Aggregate):
+            pass
+        with self.assertRaisesMessage(TypeError, 'MyAggregate does not allow distinct'):
+            MyAggregate('foo', distinct=True)
+
+        class DistinctAggregate(Aggregate):
+            allow_distinct = True
+        DistinctAggregate('foo', distinct=True)
+
 
 class JoinPromotionTests(TestCase):
     def test_ticket_21150(self):

+ 12 - 0
tests/backends/sqlite/tests.py

@@ -4,6 +4,7 @@ import unittest
 
 from django.db import connection, transaction
 from django.db.models import Avg, StdDev, Sum, Variance
+from django.db.models.aggregates import Aggregate
 from django.db.models.fields import CharField
 from django.db.utils import NotSupportedError
 from django.test import (
@@ -34,6 +35,17 @@ class Tests(TestCase):
                     **{'complex': aggregate('last_modified') + aggregate('last_modified')}
                 )
 
+    def test_distinct_aggregation(self):
+        class DistinctAggregate(Aggregate):
+            allow_distinct = True
+        aggregate = DistinctAggregate('first', 'second', distinct=True)
+        msg = (
+            "SQLite doesn't support DISTINCT on aggregate functions accepting "
+            "multiple arguments."
+        )
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            connection.ops.check_expression_support(aggregate)
+
     def test_memory_db_test_name(self):
         """A named in-memory db should be allowed where supported."""
         from django.db.backends.sqlite3.base import DatabaseWrapper

+ 10 - 3
tests/expressions/tests.py

@@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase):
 
     def test_aggregates(self):
         self.assertEqual(repr(Avg('a')), "Avg(F(a))")
-        self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
-        self.assertEqual(repr(Count('*')), "Count('*', distinct=False)")
+        self.assertEqual(repr(Count('a')), "Count(F(a))")
+        self.assertEqual(repr(Count('*')), "Count('*')")
         self.assertEqual(repr(Max('a')), "Max(F(a))")
         self.assertEqual(repr(Min('a')), "Min(F(a))")
         self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
         self.assertEqual(repr(Sum('a')), "Sum(F(a))")
         self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
 
+    def test_distinct_aggregates(self):
+        self.assertEqual(repr(Count('a', distinct=True)), "Count(F(a), distinct=True)")
+        self.assertEqual(repr(Count('*', distinct=True)), "Count('*', distinct=True)")
+
     def test_filtered_aggregates(self):
         filter = Q(a=1)
         self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
-        self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
+        self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), filter=(AND: ('a', 1)))")
         self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
         self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
         self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
@@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase):
             repr(Variance('a', sample=True, filter=filter)),
             "Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
         )
+        self.assertEqual(
+            repr(Count('a', filter=filter, distinct=True)), "Count(F(a), distinct=True, filter=(AND: ('a', 1)))"
+        )
 
 
 class CombinableTests(SimpleTestCase):