Bläddra i källkod

Fixed #27849 -- Added filtering support to aggregates.

Tom 8 år sedan
förälder
incheckning
b78d100fa6

+ 6 - 6
django/contrib/postgres/aggregates/statistics.py

@@ -8,10 +8,10 @@ __all__ = [
 
 
 class StatAggregate(Aggregate):
-    def __init__(self, y, x, output_field=FloatField()):
+    def __init__(self, y, x, output_field=FloatField(), filter=None):
         if not x or not y:
             raise ValueError('Both y and x must be provided.')
-        super().__init__(y, x, output_field=output_field)
+        super().__init__(y, x, output_field=output_field, filter=filter)
 
     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
         return super().resolve_expression(query, allow_joins, reuse, summarize)
@@ -22,9 +22,9 @@ class Corr(StatAggregate):
 
 
 class CovarPop(StatAggregate):
-    def __init__(self, y, x, sample=False):
+    def __init__(self, y, x, sample=False, filter=None):
         self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
-        super().__init__(y, x)
+        super().__init__(y, x, filter=filter)
 
 
 class RegrAvgX(StatAggregate):
@@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate):
 class RegrCount(StatAggregate):
     function = 'REGR_COUNT'
 
-    def __init__(self, y, x):
-        super().__init__(y=y, x=x, output_field=IntegerField())
+    def __init__(self, y, x, filter=None):
+        super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
 
     def convert_value(self, value, expression, connection):
         if value is None:

+ 4 - 0
django/db/backends/base/features.py

@@ -229,6 +229,10 @@ class BaseDatabaseFeatures:
     supports_select_difference = True
     supports_slicing_ordering_in_compound = False
 
+    # Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
+    # expressions?
+    supports_aggregate_filter_clause = False
+
     # Does the backend support indexing a TextField?
     supports_index_on_text_field = True
 

+ 4 - 0
django/db/backends/postgresql/features.py

@@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
         END;
     $$ LANGUAGE plpgsql;"""
 
+    @cached_property
+    def supports_aggregate_filter_clause(self):
+        return self.connection.pg_version >= 90400
+
     @cached_property
     def has_select_for_update_skip_locked(self):
         return self.connection.pg_version >= 90500

+ 61 - 8
django/db/models/aggregates.py

@@ -2,8 +2,9 @@
 Classes to represent the definitions of aggregate functions.
 """
 from django.core.exceptions import FieldError
-from django.db.models.expressions import Func, Star
+from django.db.models.expressions import Case, Func, Star, When
 from django.db.models.fields import DecimalField, FloatField, IntegerField
+from django.db.models.query_utils import Q
 
 __all__ = [
     'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@@ -13,12 +14,36 @@ __all__ = [
 class Aggregate(Func):
     contains_aggregate = True
     name = None
+    filter_template = '%s FILTER (WHERE %%(filter)s)'
+
+    def __init__(self, *args, filter=None, **kwargs):
+        self.filter = filter
+        super().__init__(*args, **kwargs)
+
+    def get_source_fields(self):
+        # Don't return the filter expression since it's not a source field.
+        return [e._output_field_or_none for e in super().get_source_expressions()]
+
+    def get_source_expressions(self):
+        source_expressions = super().get_source_expressions()
+        if self.filter:
+            source_expressions += [self.filter]
+        return source_expressions
+
+    def set_source_expressions(self, exprs):
+        if self.filter:
+            self.filter = exprs.pop()
+        return super().set_source_expressions(exprs)
 
     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
         # Aggregates are not allowed in UPDATE queries, so ignore for_save
         c = super().resolve_expression(query, allow_joins, reuse, summarize)
+        if c.filter:
+            c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
         if not summarize:
-            expressions = c.get_source_expressions()
+            # Call Aggregate.get_source_expressions() to avoid
+            # returning self.filter and including that in this loop.
+            expressions = super(Aggregate, c).get_source_expressions()
             for index, expr in enumerate(expressions):
                 if expr.contains_aggregate:
                     before_resolved = self.get_source_expressions()[index]
@@ -36,6 +61,29 @@ class Aggregate(Func):
     def get_group_by_cols(self):
         return []
 
+    def as_sql(self, compiler, connection, **extra_context):
+        if self.filter:
+            if connection.features.supports_aggregate_filter_clause:
+                filter_sql, filter_params = self.filter.as_sql(compiler, connection)
+                template = self.filter_template % extra_context.get('template', self.template)
+                sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
+                return sql, params + filter_params
+            else:
+                copy = self.copy()
+                copy.filter = None
+                condition = When(Q())
+                source_expressions = copy.get_source_expressions()
+                condition.set_source_expressions([self.filter, source_expressions[0]])
+                copy.set_source_expressions([Case(condition)] + source_expressions[1:])
+                return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
+        return super().as_sql(compiler, connection, **extra_context)
+
+    def _get_repr_options(self):
+        options = super()._get_repr_options()
+        if self.filter:
+            options.update({'filter': self.filter})
+        return options
+
 
 class Avg(Aggregate):
     function = 'AVG'
@@ -52,7 +100,7 @@ class Avg(Aggregate):
             expression = self.get_source_expressions()[0]
             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
             return compiler.compile(
-                SecondsToInterval(Avg(IntervalToSeconds(expression)))
+                SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
             )
         return super().as_sql(compiler, connection)
 
@@ -62,16 +110,19 @@ class Count(Aggregate):
     name = 'Count'
     template = '%(function)s(%(distinct)s%(expressions)s)'
 
-    def __init__(self, expression, distinct=False, **extra):
+    def __init__(self, expression, distinct=False, filter=None, **extra):
         if expression == '*':
             expression = Star()
+        if isinstance(expression, Star) and filter is not None:
+            raise ValueError('Star cannot be used with filter. Please specify a field.')
         super().__init__(
             expression, distinct='DISTINCT ' if distinct else '',
-            output_field=IntegerField(), **extra
+            output_field=IntegerField(), filter=filter, **extra
         )
 
     def _get_repr_options(self):
-        return {'distinct': self.extra['distinct'] != ''}
+        options = super()._get_repr_options()
+        return dict(options, distinct=self.extra['distinct'] != '')
 
     def convert_value(self, value, expression, connection):
         if value is None:
@@ -97,7 +148,8 @@ class StdDev(Aggregate):
         super().__init__(expression, output_field=FloatField(), **extra)
 
     def _get_repr_options(self):
-        return {'sample': self.function == 'STDDEV_SAMP'}
+        options = super()._get_repr_options()
+        return dict(options, sample=self.function == 'STDDEV_SAMP')
 
     def convert_value(self, value, expression, connection):
         if value is None:
@@ -127,7 +179,8 @@ class Variance(Aggregate):
         super().__init__(expression, output_field=FloatField(), **extra)
 
     def _get_repr_options(self):
-        return {'sample': self.function == 'VAR_SAMP'}
+        options = super()._get_repr_options()
+        return dict(options, sample=self.function == 'VAR_SAMP')
 
     def convert_value(self, value, expression, connection):
         if value is None:

+ 18 - 18
docs/ref/contrib/postgres/aggregates.txt

@@ -22,7 +22,7 @@ General-purpose aggregation functions
 ``ArrayAgg``
 ------------
 
-.. class:: ArrayAgg(expression, distinct=False, **extra)
+.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
 
     Returns a list of values, including nulls, concatenated into an array.
 
@@ -36,7 +36,7 @@ General-purpose aggregation functions
 ``BitAnd``
 ----------
 
-.. class:: BitAnd(expression, **extra)
+.. class:: BitAnd(expression, filter=None, **extra)
 
     Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
     ``None`` if all values are null.
@@ -44,7 +44,7 @@ General-purpose aggregation functions
 ``BitOr``
 ---------
 
-.. class:: BitOr(expression, **extra)
+.. class:: BitOr(expression, filter=None, **extra)
 
     Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
     ``None`` if all values are null.
@@ -52,7 +52,7 @@ General-purpose aggregation functions
 ``BoolAnd``
 -----------
 
-.. class:: BoolAnd(expression, **extra)
+.. class:: BoolAnd(expression, filter=None, **extra)
 
     Returns ``True``, if all input values are true, ``None`` if all values are
     null or if there are no values, otherwise ``False`` .
@@ -60,7 +60,7 @@ General-purpose aggregation functions
 ``BoolOr``
 ----------
 
-.. class:: BoolOr(expression, **extra)
+.. class:: BoolOr(expression, filter=None, **extra)
 
     Returns ``True`` if at least one input value is true, ``None`` if all
     values are null or if there are no values, otherwise ``False``.
@@ -68,7 +68,7 @@ General-purpose aggregation functions
 ``JSONBAgg``
 ------------
 
-.. class:: JSONBAgg(expressions, **extra)
+.. class:: JSONBAgg(expressions, filter=None, **extra)
 
     .. versionadded:: 1.11
 
@@ -77,7 +77,7 @@ General-purpose aggregation functions
 ``StringAgg``
 -------------
 
-.. class:: StringAgg(expression, delimiter, distinct=False)
+.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
 
     Returns the input values concatenated into a string, separated by
     the ``delimiter`` string.
@@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required.
 ``Corr``
 --------
 
-.. class:: Corr(y, x)
+.. class:: Corr(y, x, filter=None)
 
     Returns the correlation coefficient as a ``float``, or ``None`` if there
     aren't any matching rows.
@@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required.
 ``CovarPop``
 ------------
 
-.. class:: CovarPop(y, x, sample=False)
+.. class:: CovarPop(y, x, sample=False, filter=None)
 
     Returns the population covariance as a ``float``, or ``None`` if there
     aren't any matching rows.
@@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrAvgX``
 ------------
 
-.. class:: RegrAvgX(y, x)
+.. class:: RegrAvgX(y, x, filter=None)
 
     Returns the average of the independent variable (``sum(x)/N``) as a
     ``float``, or ``None`` if there aren't any matching rows.
@@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrAvgY``
 ------------
 
-.. class:: RegrAvgY(y, x)
+.. class:: RegrAvgY(y, x, filter=None)
 
     Returns the average of the dependent variable (``sum(y)/N``) as a
     ``float``, or ``None`` if there aren't any matching rows.
@@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrCount``
 -------------
 
-.. class:: RegrCount(y, x)
+.. class:: RegrCount(y, x, filter=None)
 
     Returns an ``int`` of the number of input rows in which both expressions
     are not null.
@@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrIntercept``
 -----------------
 
-.. class:: RegrIntercept(y, x)
+.. class:: RegrIntercept(y, x, filter=None)
 
     Returns the y-intercept of the least-squares-fit linear equation determined
     by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrR2``
 ----------
 
-.. class:: RegrR2(y, x)
+.. class:: RegrR2(y, x, filter=None)
 
     Returns the square of the correlation coefficient as a ``float``, or
     ``None`` if there aren't any matching rows.
@@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrSlope``
 -------------
 
-.. class:: RegrSlope(y, x)
+.. class:: RegrSlope(y, x, filter=None)
 
     Returns the slope of the least-squares-fit linear equation determined
     by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrSXX``
 -----------
 
-.. class:: RegrSXX(y, x)
+.. class:: RegrSXX(y, x, filter=None)
 
     Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
     variable) as a ``float``, or ``None`` if there aren't any matching rows.
@@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrSXY``
 -----------
 
-.. class:: RegrSXY(y, x)
+.. class:: RegrSXY(y, x, filter=None)
 
     Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
     times dependent variable) as a ``float``, or ``None`` if there aren't any
@@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required.
 ``RegrSYY``
 -----------
 
-.. class:: RegrSYY(y, x)
+.. class:: RegrSYY(y, x, filter=None)
 
     Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
     variable)  as a ``float``, or ``None`` if there aren't any matching rows.

+ 29 - 14
docs/ref/models/conditional-expressions.txt

@@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the
     >>> Client.objects.values_list('name', 'account_type')
     <QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]>
 
+.. _conditional-aggregation:
+
 Conditional aggregation
 -----------------------
 
 What if we want to find out how many clients there are for each
-``account_type``? We can nest conditional expression within
-:ref:`aggregate functions <aggregation-functions>` to achieve this::
+``account_type``? We can use the ``filter`` argument of :ref:`aggregate
+functions <aggregation-functions>` to achieve this::
 
     >>> # Create some more Clients first so we can have something to count
     >>> Client.objects.create(
@@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each
     >>> # Get counts for each value of account_type
     >>> from django.db.models import IntegerField, Sum
     >>> Client.objects.aggregate(
-    ...     regular=Sum(
-    ...         Case(When(account_type=Client.REGULAR, then=1),
-    ...              output_field=IntegerField())
-    ...     ),
-    ...     gold=Sum(
-    ...         Case(When(account_type=Client.GOLD, then=1),
-    ...              output_field=IntegerField())
-    ...     ),
-    ...     platinum=Sum(
-    ...         Case(When(account_type=Client.PLATINUM, then=1),
-    ...              output_field=IntegerField())
-    ...     )
+    ...     regular=Count('pk', filter=Q(account_type=Client.REGULAR)),
+    ...     gold=Count('pk', filter=Q(account_type=Client.GOLD)),
+    ...     platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)),
     ... )
     {'regular': 2, 'gold': 1, 'platinum': 3}
+
+This aggregate produces a query with the SQL 2003 ``FILTER WHERE`` syntax
+on databases that support it:
+
+.. code-block:: sql
+
+    SELECT count('id') FILTER (WHERE account_type=1) as regular,
+           count('id') FILTER (WHERE account_type=2) as gold,
+           count('id') FILTER (WHERE account_type=3) as platinum
+    FROM clients;
+
+On other databases, this is emulated using a ``CASE`` statement:
+
+.. code-block:: sql
+
+    SELECT count(CASE WHEN account_type=1 THEN id ELSE null) as regular,
+           count(CASE WHEN account_type=2 THEN id ELSE null) as gold,
+           count(CASE WHEN account_type=3 THEN id ELSE null) as platinum
+    FROM clients;
+
+The two SQL statements are functionally equivalent but the more explicit
+``FILTER`` may perform better.

+ 9 - 1
docs/ref/models/expressions.txt

@@ -339,7 +339,7 @@ some complex computations::
 
 The ``Aggregate`` API is as follows:
 
-.. class:: Aggregate(expression, output_field=None, **extra)
+.. class:: Aggregate(expression, output_field=None, filter=None, **extra)
 
     .. attribute:: template
 
@@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an
 ``IntegerField()`` and a ``FloatField()`` together should probably have
 ``output_field=FloatField()`` defined.
 
+The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
+used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
+and :ref:`filtering-on-annotations` for example usage.
+
 The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
 into the ``template`` attribute.
 
+.. versionchanged:: 2.0
+
+    The ``filter`` argument was added.
+
 Creating your own Aggregate Functions
 -------------------------------------
 

+ 18 - 7
docs/ref/models/querysets.txt

@@ -3085,6 +3085,17 @@ of the return value
     ``output_field`` if all fields are of the same type. Otherwise, you
     must provide the ``output_field`` yourself.
 
+``filter``
+~~~~~~~~~~
+
+.. versionadded:: 2.0
+
+An optional :class:`Q object <django.db.models.Q>` that's used to filter the
+rows that are aggregated.
+
+See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
+example usage.
+
 ``**extra``
 ~~~~~~~~~~~
 
@@ -3094,7 +3105,7 @@ by the aggregate.
 ``Avg``
 ~~~~~~~
 
-.. class:: Avg(expression, output_field=FloatField(), **extra)
+.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
 
     Returns the mean value of the given expression, which must be numeric
     unless you specify a different ``output_field``.
@@ -3106,7 +3117,7 @@ by the aggregate.
 ``Count``
 ~~~~~~~~~
 
-.. class:: Count(expression, distinct=False, **extra)
+.. class:: Count(expression, distinct=False, filter=None, **extra)
 
     Returns the number of objects that are related through the provided
     expression.
@@ -3125,7 +3136,7 @@ by the aggregate.
 ``Max``
 ~~~~~~~
 
-.. class:: Max(expression, output_field=None, **extra)
+.. class:: Max(expression, output_field=None, filter=None, **extra)
 
     Returns the maximum value of the given expression.
 
@@ -3135,7 +3146,7 @@ by the aggregate.
 ``Min``
 ~~~~~~~
 
-.. class:: Min(expression, output_field=None, **extra)
+.. class:: Min(expression, output_field=None, filter=None, **extra)
 
     Returns the minimum value of the given expression.
 
@@ -3145,7 +3156,7 @@ by the aggregate.
 ``StdDev``
 ~~~~~~~~~~
 
-.. class:: StdDev(expression, sample=False, **extra)
+.. class:: StdDev(expression, sample=False, filter=None, **extra)
 
     Returns the standard deviation of the data in the provided expression.
 
@@ -3169,7 +3180,7 @@ by the aggregate.
 ``Sum``
 ~~~~~~~
 
-.. class:: Sum(expression, output_field=None, **extra)
+.. class:: Sum(expression, output_field=None, filter=None, **extra)
 
     Computes the sum of all values of the given expression.
 
@@ -3179,7 +3190,7 @@ by the aggregate.
 ``Variance``
 ~~~~~~~~~~~~
 
-.. class:: Variance(expression, sample=False, **extra)
+.. class:: Variance(expression, sample=False, filter=None, **extra)
 
     Returns the variance of the data in the provided expression.
 

+ 4 - 0
docs/releases/2.0.txt

@@ -273,6 +273,10 @@ Models
   parameters, if the backend supports this feature. Of Django's built-in
   backends, only Oracle supports it.
 
+* The new ``filter`` argument for built-in aggregates allows :ref:`adding
+  different conditionals <conditional-aggregation>` to multiple aggregations
+  over the same fields or relations.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 33 - 0
docs/topics/db/aggregation.txt

@@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
     >>> pubs[0].num_books
     73
 
+    # Each publisher, with a separate count of books with a rating above and below 5
+    >>> from django.db.models import Q
+    >>> above_5 = Count('book', filter=Q(book__rating__gt=5))
+    >>> below_5 = Count('book', filter=Q(book__rating__lte=5))
+    >>> pubs = Publisher.objects.annotate(below_5=below_5).annotate(above_5=above_5)
+    >>> pubs[0].above_5
+    23
+    >>> pubs[0].below_5
+    12
+
     # The top 5 publishers, in order by number of books.
     >>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5]
     >>> pubs[0].num_books
@@ -324,6 +334,8 @@ title that starts with "Django" using the query::
 
     >>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price'))
 
+.. _filtering-on-annotations:
+
 Filtering on annotations
 ~~~~~~~~~~~~~~~~~~~~~~~~
 
@@ -339,6 +351,27 @@ you can issue the query::
 This query generates an annotated result set, and then generates a filter
 based upon that annotation.
 
+If you need two annotations with two separate filters you can use the
+``filter`` argument with any aggregate. For example, to generate a list of
+authors with a count of highly rated books::
+
+    >>> highly_rated = Count('books', filter=Q(books__rating__gte=7))
+    >>> Author.objects.annotate(num_books=Count('books'), highly_rated_books=highly_rated)
+
+Each ``Author`` in the result set will have the ``num_books`` and
+``highly_rated_books`` attributes.
+
+.. admonition:: Choosing between ``filter`` and ``QuerySet.filter()``
+
+    Avoid using the ``filter`` argument with a single annotation or
+    aggregation. It's more efficient to use ``QuerySet.filter()`` to exclude
+    rows. The aggregation ``filter`` argument is only useful when using two or
+    more aggregations over the same relations with different conditionals.
+
+.. versionchanged:: 2.0
+
+    The ``filter`` argument was added to aggregates.
+
 Order of ``annotate()`` and ``filter()`` clauses
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 81 - 0
tests/aggregation/test_filter_argument.py

@@ -0,0 +1,81 @@
+import datetime
+from decimal import Decimal
+
+from django.db.models import Case, Count, F, Q, Sum, When
+from django.test import TestCase
+
+from .models import Author, Book, Publisher
+
+
+class FilteredAggregateTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.a1 = Author.objects.create(name='test', age=40)
+        cls.a2 = Author.objects.create(name='test2', age=60)
+        cls.a3 = Author.objects.create(name='test3', age=100)
+        cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1))
+        cls.b1 = Book.objects.create(
+            isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right',
+            pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1,
+            pubdate=datetime.date(2007, 12, 6),
+        )
+        cls.b2 = Book.objects.create(
+            isbn='067232959', name='Sams Teach Yourself Django in 24 Hours',
+            pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1,
+            pubdate=datetime.date(2008, 3, 3),
+        )
+        cls.b3 = Book.objects.create(
+            isbn='159059996', name='Practical Django Projects',
+            pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1,
+            pubdate=datetime.date(2008, 6, 23),
+        )
+        cls.a1.friends.add(cls.a2)
+        cls.a1.friends.add(cls.a3)
+        cls.b1.authors.add(cls.a1)
+        cls.b1.authors.add(cls.a3)
+        cls.b2.authors.add(cls.a2)
+        cls.b3.authors.add(cls.a3)
+
+    def test_filtered_aggregates(self):
+        agg = Sum('age', filter=Q(name__startswith='test'))
+        self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200)
+
+    def test_double_filtered_aggregates(self):
+        agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test')))
+        self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60)
+
+    def test_excluded_aggregates(self):
+        agg = Sum('age', filter=~Q(name='test2'))
+        self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140)
+
+    def test_related_aggregates_m2m(self):
+        agg = Sum('friends__age', filter=~Q(friends__name='test'))
+        self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160)
+
+    def test_related_aggregates_m2m_and_fk(self):
+        q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3')
+        agg = Sum('friends__book__pages', filter=q)
+        self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528)
+
+    def test_plain_annotate(self):
+        agg = Sum('book__pages', filter=Q(book__rating__gt=3))
+        qs = Author.objects.annotate(pages=agg).order_by('pk')
+        self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047])
+
+    def test_filtered_aggregate_on_annotate(self):
+        pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3))
+        age_agg = Sum('age', filter=Q(total_pages__gte=400))
+        aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg)
+        self.assertEqual(aggregated, {'summed_age': 140})
+
+    def test_case_aggregate(self):
+        agg = Sum(
+            Case(When(friends__age=40, then=F('friends__age'))),
+            filter=Q(friends__name__startswith='test'),
+        )
+        self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80)
+
+    def test_sum_star_exception(self):
+        msg = 'Star cannot be used with filter. Please specify a field.'
+        with self.assertRaisesMessage(ValueError, msg):
+            Count('*', filter=Q(age=40))

+ 14 - 1
tests/expressions/tests.py

@@ -5,7 +5,7 @@ from copy import deepcopy
 
 from django.core.exceptions import FieldError
 from django.db import DatabaseError, connection, models, transaction
-from django.db.models import CharField, TimeField, UUIDField
+from django.db.models import CharField, Q, TimeField, UUIDField
 from django.db.models.aggregates import (
     Avg, Count, Max, Min, StdDev, Sum, Variance,
 )
@@ -1369,3 +1369,16 @@ class ReprTests(TestCase):
         self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
         self.assertEqual(repr(Sum('a')), "Sum(F(a))")
         self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
+
+    def test_filtered_aggregates(self):
+        filter = Q(a=1)
+        self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
+        self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
+        self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
+        self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
+        self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
+        self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))")
+        self.assertEqual(
+            repr(Variance('a', sample=True, filter=filter)),
+            "Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
+        )

+ 9 - 0
tests/expressions_case/tests.py

@@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase):
             account_type=Client.PLATINUM,
             registered_on=date.today(),
         )
+        self.assertEqual(
+            Client.objects.aggregate(
+                regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)),
+                gold=models.Count('pk', filter=Q(account_type=Client.GOLD)),
+                platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)),
+            ),
+            {'regular': 2, 'gold': 1, 'platinum': 3}
+        )
+        # This was the example before the filter argument was added.
         self.assertEqual(
             Client.objects.aggregate(
                 regular=models.Sum(Case(