Pārlūkot izejas kodu

Fixed #36051 -- Declared arity on aggregate functions.

Follow-up to 4a66a69239c493c05b322815b18c605cd4c96e7c.
Jacob Walls 2 mēneši atpakaļ
vecāks
revīzija
d206d4c200

+ 7 - 0
django/db/models/aggregates.py

@@ -158,6 +158,7 @@ class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
     function = "AVG"
     name = "Avg"
     allow_distinct = True
+    arity = 1
 
 
 class Count(Aggregate):
@@ -166,6 +167,7 @@ class Count(Aggregate):
     output_field = IntegerField()
     allow_distinct = True
     empty_result_set_value = 0
+    arity = 1
     allows_composite_expressions = True
 
     def __init__(self, expression, filter=None, **extra):
@@ -195,15 +197,18 @@ class Count(Aggregate):
 class Max(Aggregate):
     function = "MAX"
     name = "Max"
+    arity = 1
 
 
 class Min(Aggregate):
     function = "MIN"
     name = "Min"
+    arity = 1
 
 
 class StdDev(NumericOutputFieldMixin, Aggregate):
     name = "StdDev"
+    arity = 1
 
     def __init__(self, expression, sample=False, **extra):
         self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
@@ -217,10 +222,12 @@ class Sum(FixDurationInputMixin, Aggregate):
     function = "SUM"
     name = "Sum"
     allow_distinct = True
+    arity = 1
 
 
 class Variance(NumericOutputFieldMixin, Aggregate):
     name = "Variance"
+    arity = 1
 
     def __init__(self, expression, sample=False, **extra):
         self.function = "VAR_SAMP" if sample else "VAR_POP"

+ 1 - 0
docs/ref/models/expressions.txt

@@ -516,6 +516,7 @@ generated. Here's a brief example::
         function = "SUM"
         template = "%(function)s(%(all_values)s%(expressions)s)"
         allow_distinct = False
+        arity = 1
 
         def __init__(self, expression, all_values=False, **extra):
             super().__init__(expression, all_values="ALL " if all_values else "", **extra)

+ 4 - 0
docs/releases/5.2.txt

@@ -511,6 +511,10 @@ Miscellaneous
 * The minimum supported version of ``oracledb`` is increased from 1.3.2 to
   2.3.0.
 
+* Built-in aggregate functions accepting only one argument (``Avg``, ``Count``,
+  ``Max``, ``Min``, ``StdDev``, ``Sum``, and ``Variance``) now raise
+  :exc:`TypeError` when called with an incorrect number of arguments.
+
 .. _deprecated-features-5.2:
 
 Features deprecated in 5.2

+ 24 - 0
tests/aggregation/tests.py

@@ -1276,6 +1276,8 @@ class AggregateTestCase(TestCase):
             Book.objects.annotate(Max("id")).annotate(Sum("id__max"))
 
         class MyMax(Max):
+            arity = None
+
             def as_sql(self, compiler, connection):
                 self.set_source_expressions(self.get_source_expressions()[0:1])
                 return super().as_sql(compiler, connection)
@@ -1288,6 +1290,7 @@ class AggregateTestCase(TestCase):
     def test_multi_arg_aggregate(self):
         class MyMax(Max):
             output_field = DecimalField()
+            arity = None
 
             def as_sql(self, compiler, connection):
                 copy = self.copy()
@@ -2178,6 +2181,27 @@ class AggregateTestCase(TestCase):
         )
         self.assertEqual(list(author_qs), [337])
 
+    def test_aggregate_arity(self):
+        funcs_with_inherited_constructors = [Avg, Max, Min, Sum]
+        msg = "takes exactly 1 argument (2 given)"
+        for function in funcs_with_inherited_constructors:
+            with (
+                self.subTest(function=function),
+                self.assertRaisesMessage(TypeError, msg),
+            ):
+                function(Value(1), Value(2))
+
+        funcs_with_custom_constructors = [Count, StdDev, Variance]
+        for function in funcs_with_custom_constructors:
+            with self.subTest(function=function):
+                # Extra arguments are rejected via the constructor.
+                with self.assertRaises(TypeError):
+                    function(Value(1), True, Value(2))
+                # If the constructor is skipped, the arity check runs.
+                func_instance = function(Value(1), True)
+                with self.assertRaisesMessage(TypeError, msg):
+                    super(function, func_instance).__init__(Value(1), Value(2))
+
 
 class AggregateAnnotationPruningTests(TestCase):
     @classmethod