|
@@ -1276,6 +1276,8 @@ class AggregateTestCase(TestCase):
|
|
|
Book.objects.annotate(Max("id")).annotate(Sum("id__max"))
|
|
|
|
|
|
class MyMax(Max):
|
|
|
+ arity = None
|
|
|
+
|
|
|
def as_sql(self, compiler, connection):
|
|
|
self.set_source_expressions(self.get_source_expressions()[0:1])
|
|
|
return super().as_sql(compiler, connection)
|
|
@@ -1288,6 +1290,7 @@ class AggregateTestCase(TestCase):
|
|
|
def test_multi_arg_aggregate(self):
|
|
|
class MyMax(Max):
|
|
|
output_field = DecimalField()
|
|
|
+ arity = None
|
|
|
|
|
|
def as_sql(self, compiler, connection):
|
|
|
copy = self.copy()
|
|
@@ -2178,6 +2181,27 @@ class AggregateTestCase(TestCase):
|
|
|
)
|
|
|
self.assertEqual(list(author_qs), [337])
|
|
|
|
|
|
+ def test_aggregate_arity(self):
|
|
|
+ funcs_with_inherited_constructors = [Avg, Max, Min, Sum]
|
|
|
+ msg = "takes exactly 1 argument (2 given)"
|
|
|
+ for function in funcs_with_inherited_constructors:
|
|
|
+ with (
|
|
|
+ self.subTest(function=function),
|
|
|
+ self.assertRaisesMessage(TypeError, msg),
|
|
|
+ ):
|
|
|
+ function(Value(1), Value(2))
|
|
|
+
|
|
|
+ funcs_with_custom_constructors = [Count, StdDev, Variance]
|
|
|
+ for function in funcs_with_custom_constructors:
|
|
|
+ with self.subTest(function=function):
|
|
|
+ # Extra arguments are rejected via the constructor.
|
|
|
+ with self.assertRaises(TypeError):
|
|
|
+ function(Value(1), True, Value(2))
|
|
|
+ # If the constructor is skipped, the arity check runs.
|
|
|
+ func_instance = function(Value(1), True)
|
|
|
+ with self.assertRaisesMessage(TypeError, msg):
|
|
|
+ super(function, func_instance).__init__(Value(1), Value(2))
|
|
|
+
|
|
|
|
|
|
class AggregateAnnotationPruningTests(TestCase):
|
|
|
@classmethod
|