Sfoglia il codice sorgente

Refs #33304 -- Enclosed aggregate ordering logic in an expression.

This greatly simplifies the implementation of contrib.postgres'
OrderableAggMixin and allows for reuse in Window expressions.
Simon Charette 3 anni fa
parent
commit
e06dc4571e

+ 14 - 35
django/contrib/postgres/aggregates/mixins.py

@@ -1,48 +1,27 @@
-from django.db.models import F, OrderBy
+from django.db.models.expressions import OrderByList
 
 
 class OrderableAggMixin:
 
     def __init__(self, *expressions, ordering=(), **extra):
-        if not isinstance(ordering, (list, tuple)):
-            ordering = [ordering]
-        ordering = ordering or []
-        # Transform minus sign prefixed strings into an OrderBy() expression.
-        ordering = (
-            (OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
-            for o in ordering
-        )
+        if isinstance(ordering, (list, tuple)):
+            self.order_by = OrderByList(*ordering)
+        else:
+            self.order_by = OrderByList(ordering)
         super().__init__(*expressions, **extra)
-        self.ordering = self._parse_expressions(*ordering)
 
     def resolve_expression(self, *args, **kwargs):
-        self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering]
+        self.order_by = self.order_by.resolve_expression(*args, **kwargs)
         return super().resolve_expression(*args, **kwargs)
 
-    def as_sql(self, compiler, connection):
-        if self.ordering:
-            ordering_params = []
-            ordering_expr_sql = []
-            for expr in self.ordering:
-                expr_sql, expr_params = compiler.compile(expr)
-                ordering_expr_sql.append(expr_sql)
-                ordering_params.extend(expr_params)
-            sql, sql_params = super().as_sql(compiler, connection, ordering=(
-                'ORDER BY ' + ', '.join(ordering_expr_sql)
-            ))
-            return sql, (*sql_params, *ordering_params)
-        return super().as_sql(compiler, connection, ordering='')
+    def get_source_expressions(self):
+        return super().get_source_expressions() + [self.order_by]
 
     def set_source_expressions(self, exprs):
-        # Extract the ordering expressions because ORDER BY clause is handled
-        # in a custom way.
-        self.ordering = exprs[self._get_ordering_expressions_index():]
-        return super().set_source_expressions(exprs[:self._get_ordering_expressions_index()])
+        *exprs, self.order_by = exprs
+        return super().set_source_expressions(exprs)
 
-    def get_source_expressions(self):
-        return super().get_source_expressions() + self.ordering
-
-    def _get_ordering_expressions_index(self):
-        """Return the index at which the ordering expressions start."""
-        source_expressions = self.get_source_expressions()
-        return len(source_expressions) - len(self.ordering)
+    def as_sql(self, compiler, connection):
+        order_by_sql, order_by_params = compiler.compile(self.order_by)
+        sql, sql_params = super().as_sql(compiler, connection, ordering=order_by_sql)
+        return sql, (*sql_params, *order_by_params)

+ 22 - 2
django/db/models/expressions.py

@@ -915,8 +915,8 @@ class Ref(Expression):
 class ExpressionList(Func):
     """
     An expression containing multiple expressions. Can be used to provide a
-    list of expressions as an argument to another expression, like an
-    ordering clause.
+    list of expressions as an argument to another expression, like a partition
+    clause.
     """
     template = '%(expressions)s'
 
@@ -933,6 +933,26 @@ class ExpressionList(Func):
         return self.as_sql(compiler, connection, **extra_context)
 
 
+class OrderByList(Func):
+    template = 'ORDER BY %(expressions)s'
+
+    def __init__(self, *expressions, **extra):
+        expressions = (
+            (
+                OrderBy(F(expr[1:]), descending=True)
+                if isinstance(expr, str) and expr[0] == '-'
+                else expr
+            )
+            for expr in expressions
+        )
+        super().__init__(*expressions, **extra)
+
+    def as_sql(self, *args, **kwargs):
+        if not self.source_expressions:
+            return '', ()
+        return super().as_sql(*args, **kwargs)
+
+
 class ExpressionWrapper(SQLiteNumericMixin, Expression):
     """
     An expression that can wrap another expression so that it can provide