Browse Source

Fixed #26617 -- Added distinct argument to contrib.postgres's StringAgg.

Rustam Kashapov 8 years ago
parent
commit
df8412d2e5

+ 4 - 3
django/contrib/postgres/aggregates/general.py

@@ -32,10 +32,11 @@ class BoolOr(Aggregate):
 
 class StringAgg(Aggregate):
     function = 'STRING_AGG'
-    template = "%(function)s(%(expressions)s, '%(delimiter)s')"
+    template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')"
 
-    def __init__(self, expression, delimiter, **extra):
-        super(StringAgg, self).__init__(expression, delimiter=delimiter, **extra)
+    def __init__(self, expression, delimiter, distinct=False, **extra):
+        distinct = 'DISTINCT ' if distinct else ''
+        super(StringAgg, self).__init__(expression, delimiter=delimiter, distinct=distinct, **extra)
 
     def convert_value(self, value, expression, connection, context):
         if not value:

+ 8 - 1
docs/ref/contrib/postgres/aggregates.txt

@@ -61,7 +61,7 @@ General-purpose aggregation functions
 ``StringAgg``
 -------------
 
-.. class:: StringAgg(expression, delimiter)
+.. class:: StringAgg(expression, delimiter, distinct=False)
 
     Returns the input values concatenated into a string, separated by
     the ``delimiter`` string.
@@ -70,6 +70,13 @@ General-purpose aggregation functions
 
         Required argument. Needs to be a string.
 
+    .. attribute:: distinct
+
+        .. versionadded:: 1.11
+
+        An optional boolean argument that determines if concatenated values
+        will be distinct. Defaults to ``False``.
+
 Aggregate functions for statistics
 ==================================
 

+ 3 - 1
docs/releases/1.11.txt

@@ -81,7 +81,9 @@ Minor features
 :mod:`django.contrib.postgres`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-* ...
+* The new ``distinct`` argument for
+  :class:`~django.contrib.postgres.aggregates.StringAgg` determines if
+  concatenated values will be distinct.
 
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

+ 18 - 0
tests/postgres_tests/test_aggregates.py

@@ -111,6 +111,24 @@ class TestGeneralAggregate(PostgreSQLTestCase):
         self.assertEqual(values, {'stringagg': ''})
 
 
+class TestStringAggregateDistinct(PostgreSQLTestCase):
+    @classmethod
+    def setUpTestData(cls):
+        AggregateTestModel.objects.create(char_field='Foo')
+        AggregateTestModel.objects.create(char_field='Foo')
+        AggregateTestModel.objects.create(char_field='Bar')
+
+    def test_string_agg_distinct_false(self):
+        values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=False))
+        self.assertEqual(values['stringagg'].count('Foo'), 2)
+        self.assertEqual(values['stringagg'].count('Bar'), 1)
+
+    def test_string_agg_distinct_true(self):
+        values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=True))
+        self.assertEqual(values['stringagg'].count('Foo'), 1)
+        self.assertEqual(values['stringagg'].count('Bar'), 1)
+
+
 class TestStatisticsAggregate(PostgreSQLTestCase):
     @classmethod
     def setUpTestData(cls):