Browse Source

Refs #30158 -- Added alias argument to Expression.get_group_by_cols().

Simon Charette 6 years ago
parent
commit
9dc367dc10

+ 1 - 1
django/db/models/aggregates.py

@@ -64,7 +64,7 @@ class Aggregate(Func):
             return '%s__%s' % (expressions[0].name, self.name.lower())
             return '%s__%s' % (expressions[0].name, self.name.lower())
         raise TypeError("Complex expressions require an alias")
         raise TypeError("Complex expressions require an alias")
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
         return []
 
 
     def as_sql(self, compiler, connection, **extra_context):
     def as_sql(self, compiler, connection, **extra_context):

+ 10 - 10
django/db/models/expressions.py

@@ -332,7 +332,7 @@ class BaseExpression:
     def copy(self):
     def copy(self):
         return copy.copy(self)
         return copy.copy(self)
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         if not self.contains_aggregate:
         if not self.contains_aggregate:
             return [self]
             return [self]
         cols = []
         cols = []
@@ -669,7 +669,7 @@ class Value(Expression):
         c.for_save = for_save
         c.for_save = for_save
         return c
         return c
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
         return []
 
 
 
 
@@ -694,7 +694,7 @@ class RawSQL(Expression):
     def as_sql(self, compiler, connection):
     def as_sql(self, compiler, connection):
         return '(%s)' % self.sql, self.params
         return '(%s)' % self.sql, self.params
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
         return [self]
 
 
 
 
@@ -737,7 +737,7 @@ class Col(Expression):
     def relabeled_clone(self, relabels):
     def relabeled_clone(self, relabels):
         return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
         return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
         return [self]
 
 
     def get_db_converters(self, connection):
     def get_db_converters(self, connection):
@@ -769,7 +769,7 @@ class SimpleCol(Expression):
         qn = compiler.quote_name_unless_alias
         qn = compiler.quote_name_unless_alias
         return qn(self.target.column), []
         return qn(self.target.column), []
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
         return [self]
 
 
     def get_db_converters(self, connection):
     def get_db_converters(self, connection):
@@ -810,7 +810,7 @@ class Ref(Expression):
     def as_sql(self, compiler, connection):
     def as_sql(self, compiler, connection):
         return connection.ops.quote_name(self.refs), []
         return connection.ops.quote_name(self.refs), []
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
         return [self]
 
 
 
 
@@ -905,7 +905,7 @@ class When(Expression):
         template = template or self.template
         template = template or self.template
         return template % template_params, sql_params
         return template % template_params, sql_params
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         # This is not a complete expression and cannot be used in GROUP BY.
         # This is not a complete expression and cannot be used in GROUP BY.
         cols = []
         cols = []
         for source in self.get_source_expressions():
         for source in self.get_source_expressions():
@@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression):
             template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
             template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
         return self.as_sql(compiler, connection, template=template)
         return self.as_sql(compiler, connection, template=template)
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = []
         cols = []
         for source in self.get_source_expressions():
         for source in self.get_source_expressions():
             cols.extend(source.get_group_by_cols())
             cols.extend(source.get_group_by_cols())
@@ -1281,7 +1281,7 @@ class Window(Expression):
     def __repr__(self):
     def __repr__(self):
         return '<%s: %s>' % (self.__class__.__name__, self)
         return '<%s: %s>' % (self.__class__.__name__, self)
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
         return []
 
 
 
 
@@ -1317,7 +1317,7 @@ class WindowFrame(Expression):
     def __repr__(self):
     def __repr__(self):
         return '<%s: %s>' % (self.__class__.__name__, self)
         return '<%s: %s>' % (self.__class__.__name__, self)
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
         return []
 
 
     def __str__(self):
     def __str__(self):

+ 1 - 1
django/db/models/lookups.py

@@ -104,7 +104,7 @@ class Lookup:
             new.rhs = new.rhs.relabeled_clone(relabels)
             new.rhs = new.rhs.relabeled_clone(relabels)
         return new
         return new
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = self.lhs.get_group_by_cols()
         cols = self.lhs.get_group_by_cols()
         if hasattr(self.rhs, 'get_group_by_cols'):
         if hasattr(self.rhs, 'get_group_by_cols'):
             cols.extend(self.rhs.get_group_by_cols())
             cols.extend(self.rhs.get_group_by_cols())

+ 17 - 3
django/db/models/sql/query.py

@@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs.
 """
 """
 import difflib
 import difflib
 import functools
 import functools
+import inspect
+import warnings
 from collections import Counter, namedtuple
 from collections import Counter, namedtuple
 from collections.abc import Iterator, Mapping
 from collections.abc import Iterator, Mapping
 from itertools import chain, count, product
 from itertools import chain, count, product
@@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import (
 from django.db.models.sql.where import (
 from django.db.models.sql.where import (
     AND, OR, ExtraWhere, NothingNode, WhereNode,
     AND, OR, ExtraWhere, NothingNode, WhereNode,
 )
 )
+from django.utils.deprecation import RemovedInDjango40Warning
 from django.utils.functional import cached_property
 from django.utils.functional import cached_property
 from django.utils.tree import Node
 from django.utils.tree import Node
 
 
@@ -1818,9 +1821,20 @@ class Query:
         """
         """
         group_by = list(self.select)
         group_by = list(self.select)
         if self.annotation_select:
         if self.annotation_select:
-            for annotation in self.annotation_select.values():
+            for alias, annotation in self.annotation_select.items():
-                for col in annotation.get_group_by_cols():
+                try:
-                    group_by.append(col)
+                    inspect.getcallargs(annotation.get_group_by_cols, alias=alias)
+                except TypeError:
+                    annotation_class = annotation.__class__
+                    msg = (
+                        '`alias=None` must be added to the signature of '
+                        '%s.%s.get_group_by_cols().'
+                    ) % (annotation_class.__module__, annotation_class.__qualname__)
+                    warnings.warn(msg, category=RemovedInDjango40Warning)
+                    group_by_cols = annotation.get_group_by_cols()
+                else:
+                    group_by_cols = annotation.get_group_by_cols(alias=alias)
+                group_by.extend(group_by_cols)
         self.group_by = tuple(group_by)
         self.group_by = tuple(group_by)
 
 
     def add_select_related(self, fields):
     def add_select_related(self, fields):

+ 1 - 1
django/db/models/sql/where.py

@@ -114,7 +114,7 @@ class WhereNode(tree.Node):
                 sql_string = '(%s)' % sql_string
                 sql_string = '(%s)' % sql_string
         return sql_string, result_params
         return sql_string, result_params
 
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = []
         cols = []
         for child in self.children:
         for child in self.children:
             cols.extend(child.get_group_by_cols())
             cols.extend(child.get_group_by_cols())

+ 3 - 0
docs/internals/deprecation.txt

@@ -27,6 +27,9 @@ details on these changes.
 * ``django.views.i18n.set_language()`` will no longer set the user language in
 * ``django.views.i18n.set_language()`` will no longer set the user language in
   ``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
   ``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
 
 
+* ``alias=None`` will be required in the signature of
+  ``django.db.models.Expression.get_group_by_cols()`` subclasses.
+
 .. _deprecation-removed-in-3.1:
 .. _deprecation-removed-in-3.1:
 
 
 3.1
 3.1

+ 7 - 2
docs/ref/models/expressions.txt

@@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression.
         A hook allowing the expression to coerce ``value`` into a more
         A hook allowing the expression to coerce ``value`` into a more
         appropriate type.
         appropriate type.
 
 
-    .. method:: get_group_by_cols()
+    .. method:: get_group_by_cols(alias=None)
 
 
         Responsible for returning the list of columns references by
         Responsible for returning the list of columns references by
         this expression. ``get_group_by_cols()`` should be called on any
         this expression. ``get_group_by_cols()`` should be called on any
         nested expressions. ``F()`` objects, in particular, hold a reference
         nested expressions. ``F()`` objects, in particular, hold a reference
-        to a column.
+        to a column. The ``alias`` parameter will be ``None`` unless the
+        expression has been annotated and is used for grouping.
+
+        .. versionchanged:: 3.0
+
+            The ``alias`` parameter was added.
 
 
     .. method:: asc(nulls_first=False, nulls_last=False)
     .. method:: asc(nulls_first=False, nulls_last=False)
 
 

+ 3 - 0
docs/releases/3.0.txt

@@ -366,6 +366,9 @@ Miscellaneous
   in the session in Django 4.0. Since Django 2.1, the language is always stored
   in the session in Django 4.0. Since Django 2.1, the language is always stored
   in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
   in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
 
 
+* ``alias=None`` is added to the signature of
+  :meth:`.Expression.get_group_by_cols`.
+
 .. _removed-features-3.0:
 .. _removed-features-3.0:
 
 
 Features removed in 3.0
 Features removed in 3.0

+ 24 - 0
tests/expressions/test_deprecation.py

@@ -0,0 +1,24 @@
+from django.db.models import Count, Func
+from django.test import SimpleTestCase
+from django.utils.deprecation import RemovedInDjango40Warning
+
+from .models import Employee
+
+
+class MissingAliasFunc(Func):
+    template = '1'
+
+    def get_group_by_cols(self):
+        return []
+
+
+class GetGroupByColsTest(SimpleTestCase):
+    def test_missing_alias(self):
+        msg = (
+            '`alias=None` must be added to the signature of '
+            'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().'
+        )
+        with self.assertRaisesMessage(RemovedInDjango40Warning, msg):
+            Employee.objects.values(
+                one=MissingAliasFunc(),
+            ).annotate(cnt=Count('company_ceo_set'))