test_aggregates.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. from django.db.models import (
  2. CharField, F, Func, IntegerField, OuterRef, Q, Subquery, Value,
  3. )
  4. from django.db.models.fields.json import KeyTextTransform, KeyTransform
  5. from django.db.models.functions import Cast, Concat, Substr
  6. from django.test.utils import Approximate, ignore_warnings
  7. from django.utils.deprecation import RemovedInDjango50Warning
  8. from . import PostgreSQLTestCase
  9. from .models import AggregateTestModel, StatTestModel
  10. try:
  11. from django.contrib.postgres.aggregates import (
  12. ArrayAgg, BitAnd, BitOr, BoolAnd, BoolOr, Corr, CovarPop, JSONBAgg,
  13. RegrAvgX, RegrAvgY, RegrCount, RegrIntercept, RegrR2, RegrSlope,
  14. RegrSXX, RegrSXY, RegrSYY, StatAggregate, StringAgg,
  15. )
  16. from django.contrib.postgres.fields import ArrayField
  17. except ImportError:
  18. pass # psycopg2 is not installed
  19. class TestGeneralAggregate(PostgreSQLTestCase):
  20. @classmethod
  21. def setUpTestData(cls):
  22. cls.aggs = AggregateTestModel.objects.bulk_create([
  23. AggregateTestModel(boolean_field=True, char_field='Foo1', integer_field=0),
  24. AggregateTestModel(
  25. boolean_field=False,
  26. char_field='Foo2',
  27. integer_field=1,
  28. json_field={'lang': 'pl'},
  29. ),
  30. AggregateTestModel(
  31. boolean_field=False,
  32. char_field='Foo4',
  33. integer_field=2,
  34. json_field={'lang': 'en'},
  35. ),
  36. AggregateTestModel(
  37. boolean_field=True,
  38. char_field='Foo3',
  39. integer_field=0,
  40. json_field={'breed': 'collie'},
  41. ),
  42. ])
  43. @ignore_warnings(category=RemovedInDjango50Warning)
  44. def test_empty_result_set(self):
  45. AggregateTestModel.objects.all().delete()
  46. tests = [
  47. (ArrayAgg('char_field'), []),
  48. (ArrayAgg('integer_field'), []),
  49. (ArrayAgg('boolean_field'), []),
  50. (BitAnd('integer_field'), None),
  51. (BitOr('integer_field'), None),
  52. (BoolAnd('boolean_field'), None),
  53. (BoolOr('boolean_field'), None),
  54. (JSONBAgg('integer_field'), []),
  55. (StringAgg('char_field', delimiter=';'), ''),
  56. ]
  57. for aggregation, expected_result in tests:
  58. with self.subTest(aggregation=aggregation):
  59. # Empty result with non-execution optimization.
  60. with self.assertNumQueries(0):
  61. values = AggregateTestModel.objects.none().aggregate(
  62. aggregation=aggregation,
  63. )
  64. self.assertEqual(values, {'aggregation': expected_result})
  65. # Empty result when query must be executed.
  66. with self.assertNumQueries(1):
  67. values = AggregateTestModel.objects.aggregate(
  68. aggregation=aggregation,
  69. )
  70. self.assertEqual(values, {'aggregation': expected_result})
  71. def test_default_argument(self):
  72. AggregateTestModel.objects.all().delete()
  73. tests = [
  74. (ArrayAgg('char_field', default=['<empty>']), ['<empty>']),
  75. (ArrayAgg('integer_field', default=[0]), [0]),
  76. (ArrayAgg('boolean_field', default=[False]), [False]),
  77. (BitAnd('integer_field', default=0), 0),
  78. (BitOr('integer_field', default=0), 0),
  79. (BoolAnd('boolean_field', default=False), False),
  80. (BoolOr('boolean_field', default=False), False),
  81. (JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']),
  82. (StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'),
  83. ]
  84. for aggregation, expected_result in tests:
  85. with self.subTest(aggregation=aggregation):
  86. # Empty result with non-execution optimization.
  87. with self.assertNumQueries(0):
  88. values = AggregateTestModel.objects.none().aggregate(
  89. aggregation=aggregation,
  90. )
  91. self.assertEqual(values, {'aggregation': expected_result})
  92. # Empty result when query must be executed.
  93. with self.assertNumQueries(1):
  94. values = AggregateTestModel.objects.aggregate(
  95. aggregation=aggregation,
  96. )
  97. self.assertEqual(values, {'aggregation': expected_result})
  98. def test_convert_value_deprecation(self):
  99. AggregateTestModel.objects.all().delete()
  100. queryset = AggregateTestModel.objects.all()
  101. with self.assertWarnsMessage(RemovedInDjango50Warning, ArrayAgg.deprecation_msg):
  102. queryset.aggregate(aggregation=ArrayAgg('boolean_field'))
  103. with self.assertWarnsMessage(RemovedInDjango50Warning, JSONBAgg.deprecation_msg):
  104. queryset.aggregate(aggregation=JSONBAgg('integer_field'))
  105. with self.assertWarnsMessage(RemovedInDjango50Warning, StringAgg.deprecation_msg):
  106. queryset.aggregate(aggregation=StringAgg('char_field', delimiter=';'))
  107. # No warnings raised if default argument provided.
  108. self.assertEqual(
  109. queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=None)),
  110. {'aggregation': None},
  111. )
  112. self.assertEqual(
  113. queryset.aggregate(aggregation=JSONBAgg('integer_field', default=None)),
  114. {'aggregation': None},
  115. )
  116. self.assertEqual(
  117. queryset.aggregate(
  118. aggregation=StringAgg('char_field', delimiter=';', default=None),
  119. ),
  120. {'aggregation': None},
  121. )
  122. self.assertEqual(
  123. queryset.aggregate(aggregation=ArrayAgg('boolean_field', default=Value([]))),
  124. {'aggregation': []},
  125. )
  126. self.assertEqual(
  127. queryset.aggregate(aggregation=JSONBAgg('integer_field', default=Value('[]'))),
  128. {'aggregation': []},
  129. )
  130. self.assertEqual(
  131. queryset.aggregate(
  132. aggregation=StringAgg('char_field', delimiter=';', default=Value('')),
  133. ),
  134. {'aggregation': ''},
  135. )
  136. def test_array_agg_charfield(self):
  137. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
  138. self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
  139. def test_array_agg_charfield_ordering(self):
  140. ordering_test_cases = (
  141. (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  142. (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  143. (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  144. ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
  145. ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
  146. ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  147. ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  148. (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  149. (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  150. (
  151. (Substr('char_field', 1, 1), F('integer_field'), Substr('char_field', 4, 1).desc()),
  152. ['Foo3', 'Foo1', 'Foo2', 'Foo4'],
  153. ),
  154. )
  155. for ordering, expected_output in ordering_test_cases:
  156. with self.subTest(ordering=ordering, expected_output=expected_output):
  157. values = AggregateTestModel.objects.aggregate(
  158. arrayagg=ArrayAgg('char_field', ordering=ordering)
  159. )
  160. self.assertEqual(values, {'arrayagg': expected_output})
  161. def test_array_agg_integerfield(self):
  162. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
  163. self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
  164. def test_array_agg_integerfield_ordering(self):
  165. values = AggregateTestModel.objects.aggregate(
  166. arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc())
  167. )
  168. self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]})
  169. def test_array_agg_booleanfield(self):
  170. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
  171. self.assertEqual(values, {'arrayagg': [True, False, False, True]})
  172. def test_array_agg_booleanfield_ordering(self):
  173. ordering_test_cases = (
  174. (F('boolean_field').asc(), [False, False, True, True]),
  175. (F('boolean_field').desc(), [True, True, False, False]),
  176. (F('boolean_field'), [False, False, True, True]),
  177. )
  178. for ordering, expected_output in ordering_test_cases:
  179. with self.subTest(ordering=ordering, expected_output=expected_output):
  180. values = AggregateTestModel.objects.aggregate(
  181. arrayagg=ArrayAgg('boolean_field', ordering=ordering)
  182. )
  183. self.assertEqual(values, {'arrayagg': expected_output})
  184. def test_array_agg_jsonfield(self):
  185. values = AggregateTestModel.objects.aggregate(
  186. arrayagg=ArrayAgg(
  187. KeyTransform('lang', 'json_field'),
  188. filter=Q(json_field__lang__isnull=False),
  189. ),
  190. )
  191. self.assertEqual(values, {'arrayagg': ['pl', 'en']})
  192. def test_array_agg_jsonfield_ordering(self):
  193. values = AggregateTestModel.objects.aggregate(
  194. arrayagg=ArrayAgg(
  195. KeyTransform('lang', 'json_field'),
  196. filter=Q(json_field__lang__isnull=False),
  197. ordering=KeyTransform('lang', 'json_field'),
  198. ),
  199. )
  200. self.assertEqual(values, {'arrayagg': ['en', 'pl']})
  201. def test_array_agg_filter(self):
  202. values = AggregateTestModel.objects.aggregate(
  203. arrayagg=ArrayAgg('integer_field', filter=Q(integer_field__gt=0)),
  204. )
  205. self.assertEqual(values, {'arrayagg': [1, 2]})
  206. def test_array_agg_lookups(self):
  207. aggr1 = AggregateTestModel.objects.create()
  208. aggr2 = AggregateTestModel.objects.create()
  209. StatTestModel.objects.create(related_field=aggr1, int1=1, int2=0)
  210. StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0)
  211. StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0)
  212. StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0)
  213. qs = StatTestModel.objects.values('related_field').annotate(
  214. array=ArrayAgg('int1')
  215. ).filter(array__overlap=[2]).values_list('array', flat=True)
  216. self.assertCountEqual(qs.get(), [1, 2])
  217. def test_bit_and_general(self):
  218. values = AggregateTestModel.objects.filter(
  219. integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field'))
  220. self.assertEqual(values, {'bitand': 0})
  221. def test_bit_and_on_only_true_values(self):
  222. values = AggregateTestModel.objects.filter(
  223. integer_field=1).aggregate(bitand=BitAnd('integer_field'))
  224. self.assertEqual(values, {'bitand': 1})
  225. def test_bit_and_on_only_false_values(self):
  226. values = AggregateTestModel.objects.filter(
  227. integer_field=0).aggregate(bitand=BitAnd('integer_field'))
  228. self.assertEqual(values, {'bitand': 0})
  229. def test_bit_or_general(self):
  230. values = AggregateTestModel.objects.filter(
  231. integer_field__in=[0, 1]).aggregate(bitor=BitOr('integer_field'))
  232. self.assertEqual(values, {'bitor': 1})
  233. def test_bit_or_on_only_true_values(self):
  234. values = AggregateTestModel.objects.filter(
  235. integer_field=1).aggregate(bitor=BitOr('integer_field'))
  236. self.assertEqual(values, {'bitor': 1})
  237. def test_bit_or_on_only_false_values(self):
  238. values = AggregateTestModel.objects.filter(
  239. integer_field=0).aggregate(bitor=BitOr('integer_field'))
  240. self.assertEqual(values, {'bitor': 0})
  241. def test_bool_and_general(self):
  242. values = AggregateTestModel.objects.aggregate(booland=BoolAnd('boolean_field'))
  243. self.assertEqual(values, {'booland': False})
  244. def test_bool_and_q_object(self):
  245. values = AggregateTestModel.objects.aggregate(
  246. booland=BoolAnd(Q(integer_field__gt=2)),
  247. )
  248. self.assertEqual(values, {'booland': False})
  249. def test_bool_or_general(self):
  250. values = AggregateTestModel.objects.aggregate(boolor=BoolOr('boolean_field'))
  251. self.assertEqual(values, {'boolor': True})
  252. def test_bool_or_q_object(self):
  253. values = AggregateTestModel.objects.aggregate(
  254. boolor=BoolOr(Q(integer_field__gt=2)),
  255. )
  256. self.assertEqual(values, {'boolor': False})
  257. def test_string_agg_requires_delimiter(self):
  258. with self.assertRaises(TypeError):
  259. AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field'))
  260. def test_string_agg_delimiter_escaping(self):
  261. values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter="'"))
  262. self.assertEqual(values, {'stringagg': "Foo1'Foo2'Foo4'Foo3"})
  263. def test_string_agg_charfield(self):
  264. values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
  265. self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'})
  266. def test_string_agg_charfield_ordering(self):
  267. ordering_test_cases = (
  268. (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'),
  269. (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'),
  270. (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'),
  271. ('char_field', 'Foo1;Foo2;Foo3;Foo4'),
  272. ('-char_field', 'Foo4;Foo3;Foo2;Foo1'),
  273. (Concat('char_field', Value('@')), 'Foo1;Foo2;Foo3;Foo4'),
  274. (Concat('char_field', Value('@')).desc(), 'Foo4;Foo3;Foo2;Foo1'),
  275. )
  276. for ordering, expected_output in ordering_test_cases:
  277. with self.subTest(ordering=ordering, expected_output=expected_output):
  278. values = AggregateTestModel.objects.aggregate(
  279. stringagg=StringAgg('char_field', delimiter=';', ordering=ordering)
  280. )
  281. self.assertEqual(values, {'stringagg': expected_output})
  282. def test_string_agg_jsonfield_ordering(self):
  283. values = AggregateTestModel.objects.aggregate(
  284. stringagg=StringAgg(
  285. KeyTextTransform('lang', 'json_field'),
  286. delimiter=';',
  287. ordering=KeyTextTransform('lang', 'json_field'),
  288. output_field=CharField(),
  289. ),
  290. )
  291. self.assertEqual(values, {'stringagg': 'en;pl'})
  292. def test_string_agg_filter(self):
  293. values = AggregateTestModel.objects.aggregate(
  294. stringagg=StringAgg(
  295. 'char_field',
  296. delimiter=';',
  297. filter=Q(char_field__endswith='3') | Q(char_field__endswith='1'),
  298. )
  299. )
  300. self.assertEqual(values, {'stringagg': 'Foo1;Foo3'})
  301. def test_orderable_agg_alternative_fields(self):
  302. values = AggregateTestModel.objects.aggregate(
  303. arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc())
  304. )
  305. self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]})
  306. def test_jsonb_agg(self):
  307. values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg('char_field'))
  308. self.assertEqual(values, {'jsonbagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
  309. def test_jsonb_agg_charfield_ordering(self):
  310. ordering_test_cases = (
  311. (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  312. (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  313. (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  314. ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  315. ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  316. (Concat('char_field', Value('@')), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
  317. (Concat('char_field', Value('@')).desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
  318. )
  319. for ordering, expected_output in ordering_test_cases:
  320. with self.subTest(ordering=ordering, expected_output=expected_output):
  321. values = AggregateTestModel.objects.aggregate(
  322. jsonbagg=JSONBAgg('char_field', ordering=ordering),
  323. )
  324. self.assertEqual(values, {'jsonbagg': expected_output})
  325. def test_jsonb_agg_integerfield_ordering(self):
  326. values = AggregateTestModel.objects.aggregate(
  327. jsonbagg=JSONBAgg('integer_field', ordering=F('integer_field').desc()),
  328. )
  329. self.assertEqual(values, {'jsonbagg': [2, 1, 0, 0]})
  330. def test_jsonb_agg_booleanfield_ordering(self):
  331. ordering_test_cases = (
  332. (F('boolean_field').asc(), [False, False, True, True]),
  333. (F('boolean_field').desc(), [True, True, False, False]),
  334. (F('boolean_field'), [False, False, True, True]),
  335. )
  336. for ordering, expected_output in ordering_test_cases:
  337. with self.subTest(ordering=ordering, expected_output=expected_output):
  338. values = AggregateTestModel.objects.aggregate(
  339. jsonbagg=JSONBAgg('boolean_field', ordering=ordering),
  340. )
  341. self.assertEqual(values, {'jsonbagg': expected_output})
  342. def test_jsonb_agg_jsonfield_ordering(self):
  343. values = AggregateTestModel.objects.aggregate(
  344. jsonbagg=JSONBAgg(
  345. KeyTransform('lang', 'json_field'),
  346. filter=Q(json_field__lang__isnull=False),
  347. ordering=KeyTransform('lang', 'json_field'),
  348. ),
  349. )
  350. self.assertEqual(values, {'jsonbagg': ['en', 'pl']})
  351. def test_string_agg_array_agg_ordering_in_subquery(self):
  352. stats = []
  353. for i, agg in enumerate(AggregateTestModel.objects.order_by('char_field')):
  354. stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
  355. stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
  356. StatTestModel.objects.bulk_create(stats)
  357. for aggregate, expected_result in (
  358. (
  359. ArrayAgg('stattestmodel__int1', ordering='-stattestmodel__int2'),
  360. [('Foo1', [0, 1]), ('Foo2', [1, 2]), ('Foo3', [2, 3]), ('Foo4', [3, 4])],
  361. ),
  362. (
  363. StringAgg(
  364. Cast('stattestmodel__int1', CharField()),
  365. delimiter=';',
  366. ordering='-stattestmodel__int2',
  367. ),
  368. [('Foo1', '0;1'), ('Foo2', '1;2'), ('Foo3', '2;3'), ('Foo4', '3;4')],
  369. ),
  370. ):
  371. with self.subTest(aggregate=aggregate.__class__.__name__):
  372. subquery = AggregateTestModel.objects.filter(
  373. pk=OuterRef('pk'),
  374. ).annotate(agg=aggregate).values('agg')
  375. values = AggregateTestModel.objects.annotate(
  376. agg=Subquery(subquery),
  377. ).order_by('char_field').values_list('char_field', 'agg')
  378. self.assertEqual(list(values), expected_result)
  379. def test_string_agg_array_agg_filter_in_subquery(self):
  380. StatTestModel.objects.bulk_create([
  381. StatTestModel(related_field=self.aggs[0], int1=0, int2=5),
  382. StatTestModel(related_field=self.aggs[0], int1=1, int2=4),
  383. StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
  384. ])
  385. for aggregate, expected_result in (
  386. (
  387. ArrayAgg('stattestmodel__int1', filter=Q(stattestmodel__int2__gt=3)),
  388. [('Foo1', [0, 1]), ('Foo2', None)],
  389. ),
  390. (
  391. StringAgg(
  392. Cast('stattestmodel__int2', CharField()),
  393. delimiter=';',
  394. filter=Q(stattestmodel__int1__lt=2),
  395. ),
  396. [('Foo1', '5;4'), ('Foo2', None)],
  397. ),
  398. ):
  399. with self.subTest(aggregate=aggregate.__class__.__name__):
  400. subquery = AggregateTestModel.objects.filter(
  401. pk=OuterRef('pk'),
  402. ).annotate(agg=aggregate).values('agg')
  403. values = AggregateTestModel.objects.annotate(
  404. agg=Subquery(subquery),
  405. ).filter(
  406. char_field__in=['Foo1', 'Foo2'],
  407. ).order_by('char_field').values_list('char_field', 'agg')
  408. self.assertEqual(list(values), expected_result)
  409. def test_string_agg_filter_in_subquery_with_exclude(self):
  410. subquery = AggregateTestModel.objects.annotate(
  411. stringagg=StringAgg(
  412. 'char_field',
  413. delimiter=';',
  414. filter=Q(char_field__endswith='1'),
  415. )
  416. ).exclude(stringagg='').values('id')
  417. self.assertSequenceEqual(
  418. AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
  419. [self.aggs[0]],
  420. )
  421. def test_ordering_isnt_cleared_for_array_subquery(self):
  422. inner_qs = AggregateTestModel.objects.order_by('-integer_field')
  423. qs = AggregateTestModel.objects.annotate(
  424. integers=Func(
  425. Subquery(inner_qs.values('integer_field')),
  426. function='ARRAY',
  427. output_field=ArrayField(base_field=IntegerField()),
  428. ),
  429. )
  430. self.assertSequenceEqual(
  431. qs.first().integers,
  432. inner_qs.values_list('integer_field', flat=True),
  433. )
  434. class TestAggregateDistinct(PostgreSQLTestCase):
  435. @classmethod
  436. def setUpTestData(cls):
  437. AggregateTestModel.objects.create(char_field='Foo')
  438. AggregateTestModel.objects.create(char_field='Foo')
  439. AggregateTestModel.objects.create(char_field='Bar')
  440. def test_string_agg_distinct_false(self):
  441. values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=False))
  442. self.assertEqual(values['stringagg'].count('Foo'), 2)
  443. self.assertEqual(values['stringagg'].count('Bar'), 1)
  444. def test_string_agg_distinct_true(self):
  445. values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=' ', distinct=True))
  446. self.assertEqual(values['stringagg'].count('Foo'), 1)
  447. self.assertEqual(values['stringagg'].count('Bar'), 1)
  448. def test_array_agg_distinct_false(self):
  449. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=False))
  450. self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo', 'Foo'])
  451. def test_array_agg_distinct_true(self):
  452. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field', distinct=True))
  453. self.assertEqual(sorted(values['arrayagg']), ['Bar', 'Foo'])
  454. def test_jsonb_agg_distinct_false(self):
  455. values = AggregateTestModel.objects.aggregate(
  456. jsonbagg=JSONBAgg('char_field', distinct=False),
  457. )
  458. self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo', 'Foo'])
  459. def test_jsonb_agg_distinct_true(self):
  460. values = AggregateTestModel.objects.aggregate(
  461. jsonbagg=JSONBAgg('char_field', distinct=True),
  462. )
  463. self.assertEqual(sorted(values['jsonbagg']), ['Bar', 'Foo'])
  464. class TestStatisticsAggregate(PostgreSQLTestCase):
  465. @classmethod
  466. def setUpTestData(cls):
  467. StatTestModel.objects.create(
  468. int1=1,
  469. int2=3,
  470. related_field=AggregateTestModel.objects.create(integer_field=0),
  471. )
  472. StatTestModel.objects.create(
  473. int1=2,
  474. int2=2,
  475. related_field=AggregateTestModel.objects.create(integer_field=1),
  476. )
  477. StatTestModel.objects.create(
  478. int1=3,
  479. int2=1,
  480. related_field=AggregateTestModel.objects.create(integer_field=2),
  481. )
  482. # Tests for base class (StatAggregate)
  483. def test_missing_arguments_raises_exception(self):
  484. with self.assertRaisesMessage(ValueError, 'Both y and x must be provided.'):
  485. StatAggregate(x=None, y=None)
  486. def test_correct_source_expressions(self):
  487. func = StatAggregate(x='test', y=13)
  488. self.assertIsInstance(func.source_expressions[0], Value)
  489. self.assertIsInstance(func.source_expressions[1], F)
  490. def test_alias_is_required(self):
  491. class SomeFunc(StatAggregate):
  492. function = 'TEST'
  493. with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'):
  494. StatTestModel.objects.aggregate(SomeFunc(y='int2', x='int1'))
  495. # Test aggregates
  496. def test_empty_result_set(self):
  497. StatTestModel.objects.all().delete()
  498. tests = [
  499. (Corr(y='int2', x='int1'), None),
  500. (CovarPop(y='int2', x='int1'), None),
  501. (CovarPop(y='int2', x='int1', sample=True), None),
  502. (RegrAvgX(y='int2', x='int1'), None),
  503. (RegrAvgY(y='int2', x='int1'), None),
  504. (RegrCount(y='int2', x='int1'), 0),
  505. (RegrIntercept(y='int2', x='int1'), None),
  506. (RegrR2(y='int2', x='int1'), None),
  507. (RegrSlope(y='int2', x='int1'), None),
  508. (RegrSXX(y='int2', x='int1'), None),
  509. (RegrSXY(y='int2', x='int1'), None),
  510. (RegrSYY(y='int2', x='int1'), None),
  511. ]
  512. for aggregation, expected_result in tests:
  513. with self.subTest(aggregation=aggregation):
  514. # Empty result with non-execution optimization.
  515. with self.assertNumQueries(0):
  516. values = StatTestModel.objects.none().aggregate(
  517. aggregation=aggregation,
  518. )
  519. self.assertEqual(values, {'aggregation': expected_result})
  520. # Empty result when query must be executed.
  521. with self.assertNumQueries(1):
  522. values = StatTestModel.objects.aggregate(
  523. aggregation=aggregation,
  524. )
  525. self.assertEqual(values, {'aggregation': expected_result})
  526. def test_default_argument(self):
  527. StatTestModel.objects.all().delete()
  528. tests = [
  529. (Corr(y='int2', x='int1', default=0), 0),
  530. (CovarPop(y='int2', x='int1', default=0), 0),
  531. (CovarPop(y='int2', x='int1', sample=True, default=0), 0),
  532. (RegrAvgX(y='int2', x='int1', default=0), 0),
  533. (RegrAvgY(y='int2', x='int1', default=0), 0),
  534. # RegrCount() doesn't support the default argument.
  535. (RegrIntercept(y='int2', x='int1', default=0), 0),
  536. (RegrR2(y='int2', x='int1', default=0), 0),
  537. (RegrSlope(y='int2', x='int1', default=0), 0),
  538. (RegrSXX(y='int2', x='int1', default=0), 0),
  539. (RegrSXY(y='int2', x='int1', default=0), 0),
  540. (RegrSYY(y='int2', x='int1', default=0), 0),
  541. ]
  542. for aggregation, expected_result in tests:
  543. with self.subTest(aggregation=aggregation):
  544. # Empty result with non-execution optimization.
  545. with self.assertNumQueries(0):
  546. values = StatTestModel.objects.none().aggregate(
  547. aggregation=aggregation,
  548. )
  549. self.assertEqual(values, {'aggregation': expected_result})
  550. # Empty result when query must be executed.
  551. with self.assertNumQueries(1):
  552. values = StatTestModel.objects.aggregate(
  553. aggregation=aggregation,
  554. )
  555. self.assertEqual(values, {'aggregation': expected_result})
  556. def test_corr_general(self):
  557. values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
  558. self.assertEqual(values, {'corr': -1.0})
  559. def test_covar_pop_general(self):
  560. values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1'))
  561. self.assertEqual(values, {'covarpop': Approximate(-0.66, places=1)})
  562. def test_covar_pop_sample(self):
  563. values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1', sample=True))
  564. self.assertEqual(values, {'covarpop': -1.0})
  565. def test_regr_avgx_general(self):
  566. values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y='int2', x='int1'))
  567. self.assertEqual(values, {'regravgx': 2.0})
  568. def test_regr_avgy_general(self):
  569. values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y='int2', x='int1'))
  570. self.assertEqual(values, {'regravgy': 2.0})
  571. def test_regr_count_general(self):
  572. values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
  573. self.assertEqual(values, {'regrcount': 3})
  574. def test_regr_count_default(self):
  575. msg = 'RegrCount does not allow default.'
  576. with self.assertRaisesMessage(TypeError, msg):
  577. RegrCount(y='int2', x='int1', default=0)
  578. def test_regr_intercept_general(self):
  579. values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
  580. self.assertEqual(values, {'regrintercept': 4})
  581. def test_regr_r2_general(self):
  582. values = StatTestModel.objects.aggregate(regrr2=RegrR2(y='int2', x='int1'))
  583. self.assertEqual(values, {'regrr2': 1})
  584. def test_regr_slope_general(self):
  585. values = StatTestModel.objects.aggregate(regrslope=RegrSlope(y='int2', x='int1'))
  586. self.assertEqual(values, {'regrslope': -1})
  587. def test_regr_sxx_general(self):
  588. values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y='int2', x='int1'))
  589. self.assertEqual(values, {'regrsxx': 2.0})
  590. def test_regr_sxy_general(self):
  591. values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y='int2', x='int1'))
  592. self.assertEqual(values, {'regrsxy': -2.0})
  593. def test_regr_syy_general(self):
  594. values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y='int2', x='int1'))
  595. self.assertEqual(values, {'regrsyy': 2.0})
  596. def test_regr_avgx_with_related_obj_and_number_as_argument(self):
  597. """
  598. This is more complex test to check if JOIN on field and
  599. number as argument works as expected.
  600. """
  601. values = StatTestModel.objects.aggregate(complex_regravgx=RegrAvgX(y=5, x='related_field__integer_field'))
  602. self.assertEqual(values, {'complex_regravgx': 1.0})