浏览代码

Fixed #35339 -- Fixed PostgreSQL aggregate's filter and order_by params order.

Updated OrderableAggMixin.as_sql() to separate the order_by parameters
from the filter parameters. Previously, the parameters and SQL were
calculated by the Aggregate parent class, resulting in a mixture of
order_by and filter parameters.

Thanks Simon Charette for the review.
Chris Muthig 11 月之前
父节点
当前提交
c8df2f9941
共有 2 个文件被更改,包括 32 次插入7 次删除
  1. 21 6
      django/contrib/postgres/aggregates/mixins.py
  2. 11 1
      tests/postgres_tests/test_aggregates.py

+ 21 - 6
django/contrib/postgres/aggregates/mixins.py

@@ -1,3 +1,4 @@
+from django.core.exceptions import FullResultSet
 from django.db.models.expressions import OrderByList
 
 
@@ -24,9 +25,23 @@ class OrderableAggMixin:
         return super().set_source_expressions(exprs)
 
     def as_sql(self, compiler, connection):
-        if self.order_by is not None:
-            order_by_sql, order_by_params = compiler.compile(self.order_by)
-        else:
-            order_by_sql, order_by_params = "", ()
-        sql, sql_params = super().as_sql(compiler, connection, ordering=order_by_sql)
-        return sql, (*sql_params, *order_by_params)
+        *source_exprs, filtering_expr, ordering_expr = self.get_source_expressions()
+
+        order_by_sql = ""
+        order_by_params = []
+        if ordering_expr is not None:
+            order_by_sql, order_by_params = compiler.compile(ordering_expr)
+
+        filter_params = []
+        if filtering_expr is not None:
+            try:
+                _, filter_params = compiler.compile(filtering_expr)
+            except FullResultSet:
+                pass
+
+        source_params = []
+        for source_expr in source_exprs:
+            source_params += compiler.compile(source_expr)[1]
+
+        sql, _ = super().as_sql(compiler, connection, ordering=order_by_sql)
+        return sql, (*source_params, *order_by_params, *filter_params)

+ 11 - 1
tests/postgres_tests/test_aggregates.py

@@ -12,7 +12,7 @@ from django.db.models import (
     Window,
 )
 from django.db.models.fields.json import KeyTextTransform, KeyTransform
-from django.db.models.functions import Cast, Concat, Substr
+from django.db.models.functions import Cast, Concat, LPad, Substr
 from django.test import skipUnlessDBFeature
 from django.test.utils import Approximate
 from django.utils import timezone
@@ -238,6 +238,16 @@ class TestGeneralAggregate(PostgreSQLTestCase):
         )
         self.assertEqual(values, {"arrayagg": ["en", "pl"]})
 
+    def test_array_agg_filter_and_ordering_params(self):
+        values = AggregateTestModel.objects.aggregate(
+            arrayagg=ArrayAgg(
+                "char_field",
+                filter=Q(json_field__has_key="lang"),
+                ordering=LPad(Cast("integer_field", CharField()), 2, Value("0")),
+            )
+        )
+        self.assertEqual(values, {"arrayagg": ["Foo2", "Foo4"]})
+
     def test_array_agg_filter(self):
         values = AggregateTestModel.objects.aggregate(
             arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),