Browse Source

Refs #30158 -- Removed alias argument for Expression.get_group_by_cols().

Recent refactors allowed GROUP BY aliasing allowed for aliasing to be
entirely handled by the sql.Query.set_group_by and compiler layers.
Simon Charette 2 years ago
parent
commit
c6350d594c

+ 1 - 1
django/db/models/aggregates.py

@@ -97,7 +97,7 @@ class Aggregate(Func):
             return "%s__%s" % (expressions[0].name, self.name.lower())
         raise TypeError("Complex expressions require an alias")
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return []
 
     def as_sql(self, compiler, connection, **extra_context):

+ 18 - 20
django/db/models/expressions.py

@@ -417,10 +417,8 @@ class BaseExpression:
         )
         return clone
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         if not self.contains_aggregate:
-            if alias:
-                return [Ref(alias, self)]
             return [self]
         cols = []
         for source in self.get_source_expressions():
@@ -858,7 +856,7 @@ class ResolvedOuterRef(F):
     def relabeled_clone(self, relabels):
         return self
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return []
 
 
@@ -1021,7 +1019,7 @@ class Value(SQLiteNumericMixin, Expression):
         c.for_save = for_save
         return c
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return []
 
     def _resolve_output_field(self):
@@ -1066,7 +1064,7 @@ class RawSQL(Expression):
     def as_sql(self, compiler, connection):
         return "(%s)" % self.sql, self.params
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return [self]
 
     def resolve_expression(
@@ -1124,7 +1122,7 @@ class Col(Expression):
             relabels.get(self.alias, self.alias), self.target, self.output_field
         )
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return [self]
 
     def get_db_converters(self, connection):
@@ -1167,7 +1165,7 @@ class Ref(Expression):
     def as_sql(self, compiler, connection):
         return connection.ops.quote_name(self.refs), []
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return [self]
 
 
@@ -1238,14 +1236,14 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
     def get_source_expressions(self):
         return [self.expression]
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         if isinstance(self.expression, Expression):
             expression = self.expression.copy()
             expression.output_field = self.output_field
-            return expression.get_group_by_cols(alias=alias)
+            return expression.get_group_by_cols()
         # For non-expressions e.g. an SQL WHERE clause, the entire
         # `expression` must be included in the GROUP BY clause.
-        return super().get_group_by_cols(alias=alias)
+        return super().get_group_by_cols()
 
     def as_sql(self, compiler, connection):
         return compiler.compile(self.expression)
@@ -1330,7 +1328,7 @@ class When(Expression):
             *result_params,
         )
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         # This is not a complete expression and cannot be used in GROUP BY.
         cols = []
         for source in self.get_source_expressions():
@@ -1426,10 +1424,10 @@ class Case(SQLiteNumericMixin, Expression):
             sql = connection.ops.unification_cast_sql(self.output_field) % sql
         return sql, sql_params
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         if not self.cases:
-            return self.default.get_group_by_cols(alias)
-        return super().get_group_by_cols(alias)
+            return self.default.get_group_by_cols()
+        return super().get_group_by_cols()
 
 
 class Subquery(BaseExpression, Combinable):
@@ -1480,8 +1478,8 @@ class Subquery(BaseExpression, Combinable):
         sql = template % template_params
         return sql, sql_params
 
-    def get_group_by_cols(self, alias=None):
-        return self.query.get_group_by_cols(alias=alias, wrapper=self)
+    def get_group_by_cols(self):
+        return self.query.get_group_by_cols(wrapper=self)
 
 
 class Exists(Subquery):
@@ -1602,7 +1600,7 @@ class OrderBy(Expression):
             return copy.as_sql(compiler, connection)
         return self.as_sql(compiler, connection)
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         cols = []
         for source in self.get_source_expressions():
             cols.extend(source.get_group_by_cols())
@@ -1732,7 +1730,7 @@ class Window(SQLiteNumericMixin, Expression):
     def __repr__(self):
         return "<%s: %s>" % (self.__class__.__name__, self)
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         group_by_cols = []
         if self.partition_by:
             group_by_cols.extend(self.partition_by.get_group_by_cols())
@@ -1780,7 +1778,7 @@ class WindowFrame(Expression):
     def __repr__(self):
         return "<%s: %s>" % (self.__class__.__name__, self)
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return []
 
     def __str__(self):

+ 1 - 1
django/db/models/functions/math.py

@@ -169,7 +169,7 @@ class Random(NumericOutputFieldMixin, Func):
     def as_sqlite(self, compiler, connection, **extra_context):
         return super().as_sql(compiler, connection, function="RAND", **extra_context)
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         return []
 
 

+ 1 - 1
django/db/models/lookups.py

@@ -128,7 +128,7 @@ class Lookup(Expression):
     def rhs_is_direct_value(self):
         return not hasattr(self.rhs, "as_sql")
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         cols = []
         for source in self.get_source_expressions():
             cols.extend(source.get_group_by_cols())

+ 11 - 8
django/db/models/sql/query.py

@@ -1147,13 +1147,11 @@ class Query(BaseExpression):
             if col.alias in self.external_aliases
         ]
 
-    def get_group_by_cols(self, alias=None, wrapper=None):
+    def get_group_by_cols(self, wrapper=None):
         # If wrapper is referenced by an alias for an explicit GROUP BY through
         # values() a reference to this expression and not the self must be
         # returned to ensure external column references are not grouped against
         # as well.
-        if alias:
-            return [Ref(alias, wrapper or self)]
         external_cols = self.get_external_cols()
         if any(col.possibly_multivalued for col in external_cols):
             return [wrapper or self]
@@ -2247,11 +2245,16 @@ class Query(BaseExpression):
                 self.select = tuple(values_select.values())
                 self.values_select = tuple(values_select)
         group_by = list(self.select)
-        if self.annotation_select:
-            for alias, annotation in self.annotation_select.items():
-                if not allow_aliases or alias in column_names:
-                    alias = None
-                group_by_cols = annotation.get_group_by_cols(alias=alias)
+        for alias, annotation in self.annotation_select.items():
+            if not (group_by_cols := annotation.get_group_by_cols()):
+                continue
+            if (
+                allow_aliases
+                and alias not in column_names
+                and not annotation.contains_aggregate
+            ):
+                group_by.append(Ref(alias, annotation))
+            else:
                 group_by.extend(group_by_cols)
         self.group_by = tuple(group_by)
 

+ 1 - 1
django/db/models/sql/where.py

@@ -178,7 +178,7 @@ class WhereNode(tree.Node):
                 sql_string = "(%s)" % sql_string
         return sql_string, result_params
 
-    def get_group_by_cols(self, alias=None):
+    def get_group_by_cols(self):
         cols = []
         for child in self.children:
             cols.extend(child.get_group_by_cols())

+ 6 - 3
docs/ref/models/expressions.txt

@@ -1036,13 +1036,16 @@ calling the appropriate methods on the wrapped expression.
 
         ``expression`` is the same as ``self``.
 
-    .. method:: get_group_by_cols(alias=None)
+    .. method:: get_group_by_cols()
 
         Responsible for returning the list of columns references by
         this expression. ``get_group_by_cols()`` should be called on any
         nested expressions. ``F()`` objects, in particular, hold a reference
-        to a column. The ``alias`` parameter will be ``None`` unless the
-        expression has been annotated and is used for grouping.
+        to a column.
+
+        .. versionchanged:: 4.2
+
+            The ``alias=None`` keyword argument was removed.
 
     .. method:: asc(nulls_first=None, nulls_last=None)
 

+ 2 - 0
docs/releases/4.2.txt

@@ -329,6 +329,8 @@ Miscellaneous
 * The :option:`makemigrations --check` option no longer creates missing
   migration files.
 
+* The ``alias`` argument for :meth:`.Expression.get_group_by_cols` is removed.
+
 .. _deprecated-features-4.2:
 
 Features deprecated in 4.2

+ 2 - 2
tests/expressions/tests.py

@@ -2525,13 +2525,13 @@ class CombinedExpressionTests(SimpleTestCase):
 class ExpressionWrapperTests(SimpleTestCase):
     def test_empty_group_by(self):
         expr = ExpressionWrapper(Value(3), output_field=IntegerField())
-        self.assertEqual(expr.get_group_by_cols(alias=None), [])
+        self.assertEqual(expr.get_group_by_cols(), [])
 
     def test_non_empty_group_by(self):
         value = Value("f")
         value.output_field = None
         expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
-        group_by_cols = expr.get_group_by_cols(alias=None)
+        group_by_cols = expr.get_group_by_cols()
         self.assertEqual(group_by_cols, [expr.expression])
         self.assertEqual(group_by_cols[0].output_field, expr.output_field)