浏览代码

Refs #30158 -- Added alias argument to Expression.get_group_by_cols().

Simon Charette 6 年之前
父节点
当前提交
9dc367dc10

+ 1 - 1
django/db/models/aggregates.py

@@ -64,7 +64,7 @@ class Aggregate(Func):
             return '%s__%s' % (expressions[0].name, self.name.lower())
         raise TypeError("Complex expressions require an alias")
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
 
     def as_sql(self, compiler, connection, **extra_context):

+ 10 - 10
django/db/models/expressions.py

@@ -332,7 +332,7 @@ class BaseExpression:
     def copy(self):
         return copy.copy(self)
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         if not self.contains_aggregate:
             return [self]
         cols = []
@@ -669,7 +669,7 @@ class Value(Expression):
         c.for_save = for_save
         return c
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
 
 
@@ -694,7 +694,7 @@ class RawSQL(Expression):
     def as_sql(self, compiler, connection):
         return '(%s)' % self.sql, self.params
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
 
 
@@ -737,7 +737,7 @@ class Col(Expression):
     def relabeled_clone(self, relabels):
         return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
 
     def get_db_converters(self, connection):
@@ -769,7 +769,7 @@ class SimpleCol(Expression):
         qn = compiler.quote_name_unless_alias
         return qn(self.target.column), []
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
 
     def get_db_converters(self, connection):
@@ -810,7 +810,7 @@ class Ref(Expression):
     def as_sql(self, compiler, connection):
         return connection.ops.quote_name(self.refs), []
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return [self]
 
 
@@ -905,7 +905,7 @@ class When(Expression):
         template = template or self.template
         return template % template_params, sql_params
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         # This is not a complete expression and cannot be used in GROUP BY.
         cols = []
         for source in self.get_source_expressions():
@@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression):
             template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
         return self.as_sql(compiler, connection, template=template)
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = []
         for source in self.get_source_expressions():
             cols.extend(source.get_group_by_cols())
@@ -1281,7 +1281,7 @@ class Window(Expression):
     def __repr__(self):
         return '<%s: %s>' % (self.__class__.__name__, self)
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
 
 
@@ -1317,7 +1317,7 @@ class WindowFrame(Expression):
     def __repr__(self):
         return '<%s: %s>' % (self.__class__.__name__, self)
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         return []
 
     def __str__(self):

+ 1 - 1
django/db/models/lookups.py

@@ -104,7 +104,7 @@ class Lookup:
             new.rhs = new.rhs.relabeled_clone(relabels)
         return new
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = self.lhs.get_group_by_cols()
         if hasattr(self.rhs, 'get_group_by_cols'):
             cols.extend(self.rhs.get_group_by_cols())

+ 17 - 3
django/db/models/sql/query.py

@@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs.
 """
 import difflib
 import functools
+import inspect
+import warnings
 from collections import Counter, namedtuple
 from collections.abc import Iterator, Mapping
 from itertools import chain, count, product
@@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import (
 from django.db.models.sql.where import (
     AND, OR, ExtraWhere, NothingNode, WhereNode,
 )
+from django.utils.deprecation import RemovedInDjango40Warning
 from django.utils.functional import cached_property
 from django.utils.tree import Node
 
@@ -1818,9 +1821,20 @@ class Query:
         """
         group_by = list(self.select)
         if self.annotation_select:
-            for annotation in self.annotation_select.values():
-                for col in annotation.get_group_by_cols():
-                    group_by.append(col)
+            for alias, annotation in self.annotation_select.items():
+                try:
+                    inspect.getcallargs(annotation.get_group_by_cols, alias=alias)
+                except TypeError:
+                    annotation_class = annotation.__class__
+                    msg = (
+                        '`alias=None` must be added to the signature of '
+                        '%s.%s.get_group_by_cols().'
+                    ) % (annotation_class.__module__, annotation_class.__qualname__)
+                    warnings.warn(msg, category=RemovedInDjango40Warning)
+                    group_by_cols = annotation.get_group_by_cols()
+                else:
+                    group_by_cols = annotation.get_group_by_cols(alias=alias)
+                group_by.extend(group_by_cols)
         self.group_by = tuple(group_by)
 
     def add_select_related(self, fields):

+ 1 - 1
django/db/models/sql/where.py

@@ -114,7 +114,7 @@ class WhereNode(tree.Node):
                 sql_string = '(%s)' % sql_string
         return sql_string, result_params
 
-    def get_group_by_cols(self):
+    def get_group_by_cols(self, alias=None):
         cols = []
         for child in self.children:
             cols.extend(child.get_group_by_cols())

+ 3 - 0
docs/internals/deprecation.txt

@@ -27,6 +27,9 @@ details on these changes.
 * ``django.views.i18n.set_language()`` will no longer set the user language in
   ``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
 
+* ``alias=None`` will be required in the signature of
+  ``django.db.models.Expression.get_group_by_cols()`` subclasses.
+
 .. _deprecation-removed-in-3.1:
 
 3.1

+ 7 - 2
docs/ref/models/expressions.txt

@@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression.
         A hook allowing the expression to coerce ``value`` into a more
         appropriate type.
 
-    .. method:: get_group_by_cols()
+    .. method:: get_group_by_cols(alias=None)
 
         Responsible for returning the list of columns references by
         this expression. ``get_group_by_cols()`` should be called on any
         nested expressions. ``F()`` objects, in particular, hold a reference
-        to a column.
+        to a column. The ``alias`` parameter will be ``None`` unless the
+        expression has been annotated and is used for grouping.
+
+        .. versionchanged:: 3.0
+
+            The ``alias`` parameter was added.
 
     .. method:: asc(nulls_first=False, nulls_last=False)
 

+ 3 - 0
docs/releases/3.0.txt

@@ -366,6 +366,9 @@ Miscellaneous
   in the session in Django 4.0. Since Django 2.1, the language is always stored
   in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
 
+* ``alias=None`` is added to the signature of
+  :meth:`.Expression.get_group_by_cols`.
+
 .. _removed-features-3.0:
 
 Features removed in 3.0

+ 24 - 0
tests/expressions/test_deprecation.py

@@ -0,0 +1,24 @@
+from django.db.models import Count, Func
+from django.test import SimpleTestCase
+from django.utils.deprecation import RemovedInDjango40Warning
+
+from .models import Employee
+
+
+class MissingAliasFunc(Func):
+    template = '1'
+
+    def get_group_by_cols(self):
+        return []
+
+
+class GetGroupByColsTest(SimpleTestCase):
+    def test_missing_alias(self):
+        msg = (
+            '`alias=None` must be added to the signature of '
+            'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().'
+        )
+        with self.assertRaisesMessage(RemovedInDjango40Warning, msg):
+            Employee.objects.values(
+                one=MissingAliasFunc(),
+            ).annotate(cnt=Count('company_ceo_set'))