test_aggregates.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972
  1. from django.db import connection
  2. from django.db.models import (
  3. CharField,
  4. F,
  5. Func,
  6. IntegerField,
  7. JSONField,
  8. OuterRef,
  9. Q,
  10. Subquery,
  11. Value,
  12. Window,
  13. )
  14. from django.db.models.fields.json import KeyTextTransform, KeyTransform
  15. from django.db.models.functions import Cast, Concat, Substr
  16. from django.test import skipUnlessDBFeature
  17. from django.test.utils import Approximate, ignore_warnings
  18. from django.utils import timezone
  19. from django.utils.deprecation import RemovedInDjango50Warning, RemovedInDjango51Warning
  20. from . import PostgreSQLTestCase
  21. from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
  22. try:
  23. from django.contrib.postgres.aggregates import (
  24. ArrayAgg,
  25. BitAnd,
  26. BitOr,
  27. BitXor,
  28. BoolAnd,
  29. BoolOr,
  30. Corr,
  31. CovarPop,
  32. JSONBAgg,
  33. RegrAvgX,
  34. RegrAvgY,
  35. RegrCount,
  36. RegrIntercept,
  37. RegrR2,
  38. RegrSlope,
  39. RegrSXX,
  40. RegrSXY,
  41. RegrSYY,
  42. StatAggregate,
  43. StringAgg,
  44. )
  45. from django.contrib.postgres.fields import ArrayField
  46. except ImportError:
  47. pass # psycopg2 is not installed
  48. class TestGeneralAggregate(PostgreSQLTestCase):
  49. @classmethod
  50. def setUpTestData(cls):
  51. cls.aggs = AggregateTestModel.objects.bulk_create(
  52. [
  53. AggregateTestModel(
  54. boolean_field=True,
  55. char_field="Foo1",
  56. text_field="Text1",
  57. integer_field=0,
  58. ),
  59. AggregateTestModel(
  60. boolean_field=False,
  61. char_field="Foo2",
  62. text_field="Text2",
  63. integer_field=1,
  64. json_field={"lang": "pl"},
  65. ),
  66. AggregateTestModel(
  67. boolean_field=False,
  68. char_field="Foo4",
  69. text_field="Text4",
  70. integer_field=2,
  71. json_field={"lang": "en"},
  72. ),
  73. AggregateTestModel(
  74. boolean_field=True,
  75. char_field="Foo3",
  76. text_field="Text3",
  77. integer_field=0,
  78. json_field={"breed": "collie"},
  79. ),
  80. ]
  81. )
  82. @ignore_warnings(category=RemovedInDjango50Warning)
  83. def test_empty_result_set(self):
  84. AggregateTestModel.objects.all().delete()
  85. tests = [
  86. (ArrayAgg("char_field"), []),
  87. (ArrayAgg("integer_field"), []),
  88. (ArrayAgg("boolean_field"), []),
  89. (BitAnd("integer_field"), None),
  90. (BitOr("integer_field"), None),
  91. (BoolAnd("boolean_field"), None),
  92. (BoolOr("boolean_field"), None),
  93. (JSONBAgg("integer_field"), []),
  94. (StringAgg("char_field", delimiter=";"), ""),
  95. ]
  96. if connection.features.has_bit_xor:
  97. tests.append((BitXor("integer_field"), None))
  98. for aggregation, expected_result in tests:
  99. with self.subTest(aggregation=aggregation):
  100. # Empty result with non-execution optimization.
  101. with self.assertNumQueries(0):
  102. values = AggregateTestModel.objects.none().aggregate(
  103. aggregation=aggregation,
  104. )
  105. self.assertEqual(values, {"aggregation": expected_result})
  106. # Empty result when query must be executed.
  107. with self.assertNumQueries(1):
  108. values = AggregateTestModel.objects.aggregate(
  109. aggregation=aggregation,
  110. )
  111. self.assertEqual(values, {"aggregation": expected_result})
  112. def test_default_argument(self):
  113. AggregateTestModel.objects.all().delete()
  114. tests = [
  115. (ArrayAgg("char_field", default=["<empty>"]), ["<empty>"]),
  116. (ArrayAgg("integer_field", default=[0]), [0]),
  117. (ArrayAgg("boolean_field", default=[False]), [False]),
  118. (BitAnd("integer_field", default=0), 0),
  119. (BitOr("integer_field", default=0), 0),
  120. (BoolAnd("boolean_field", default=False), False),
  121. (BoolOr("boolean_field", default=False), False),
  122. (JSONBAgg("integer_field", default=["<empty>"]), ["<empty>"]),
  123. (
  124. JSONBAgg("integer_field", default=Value(["<empty>"], JSONField())),
  125. ["<empty>"],
  126. ),
  127. (StringAgg("char_field", delimiter=";", default="<empty>"), "<empty>"),
  128. (
  129. StringAgg("char_field", delimiter=";", default=Value("<empty>")),
  130. "<empty>",
  131. ),
  132. ]
  133. if connection.features.has_bit_xor:
  134. tests.append((BitXor("integer_field", default=0), 0))
  135. for aggregation, expected_result in tests:
  136. with self.subTest(aggregation=aggregation):
  137. # Empty result with non-execution optimization.
  138. with self.assertNumQueries(0):
  139. values = AggregateTestModel.objects.none().aggregate(
  140. aggregation=aggregation,
  141. )
  142. self.assertEqual(values, {"aggregation": expected_result})
  143. # Empty result when query must be executed.
  144. with self.assertNumQueries(1):
  145. values = AggregateTestModel.objects.aggregate(
  146. aggregation=aggregation,
  147. )
  148. self.assertEqual(values, {"aggregation": expected_result})
  149. def test_convert_value_deprecation(self):
  150. AggregateTestModel.objects.all().delete()
  151. queryset = AggregateTestModel.objects.all()
  152. with self.assertWarnsMessage(
  153. RemovedInDjango50Warning, ArrayAgg.deprecation_msg
  154. ):
  155. queryset.aggregate(aggregation=ArrayAgg("boolean_field"))
  156. with self.assertWarnsMessage(
  157. RemovedInDjango50Warning, JSONBAgg.deprecation_msg
  158. ):
  159. queryset.aggregate(aggregation=JSONBAgg("integer_field"))
  160. with self.assertWarnsMessage(
  161. RemovedInDjango50Warning, StringAgg.deprecation_msg
  162. ):
  163. queryset.aggregate(aggregation=StringAgg("char_field", delimiter=";"))
  164. # No warnings raised if default argument provided.
  165. self.assertEqual(
  166. queryset.aggregate(aggregation=ArrayAgg("boolean_field", default=None)),
  167. {"aggregation": None},
  168. )
  169. self.assertEqual(
  170. queryset.aggregate(aggregation=JSONBAgg("integer_field", default=None)),
  171. {"aggregation": None},
  172. )
  173. self.assertEqual(
  174. queryset.aggregate(
  175. aggregation=StringAgg("char_field", delimiter=";", default=None),
  176. ),
  177. {"aggregation": None},
  178. )
  179. self.assertEqual(
  180. queryset.aggregate(
  181. aggregation=ArrayAgg("boolean_field", default=Value([]))
  182. ),
  183. {"aggregation": []},
  184. )
  185. self.assertEqual(
  186. queryset.aggregate(aggregation=JSONBAgg("integer_field", default=[])),
  187. {"aggregation": []},
  188. )
  189. self.assertEqual(
  190. queryset.aggregate(
  191. aggregation=StringAgg("char_field", delimiter=";", default=Value("")),
  192. ),
  193. {"aggregation": ""},
  194. )
  195. @ignore_warnings(category=RemovedInDjango51Warning)
  196. def test_jsonb_agg_default_str_value(self):
  197. AggregateTestModel.objects.all().delete()
  198. queryset = AggregateTestModel.objects.all()
  199. self.assertEqual(
  200. queryset.aggregate(
  201. aggregation=JSONBAgg("integer_field", default=Value("<empty>"))
  202. ),
  203. {"aggregation": "<empty>"},
  204. )
  205. def test_jsonb_agg_default_str_value_deprecation(self):
  206. queryset = AggregateTestModel.objects.all()
  207. msg = (
  208. "Passing a Value() with an output_field that isn't a JSONField as "
  209. "JSONBAgg(default) is deprecated. Pass default=Value('<empty>', "
  210. "output_field=JSONField()) instead."
  211. )
  212. with self.assertWarnsMessage(RemovedInDjango51Warning, msg):
  213. queryset.aggregate(
  214. aggregation=JSONBAgg("integer_field", default=Value("<empty>"))
  215. )
  216. with self.assertWarnsMessage(RemovedInDjango51Warning, msg):
  217. queryset.none().aggregate(
  218. aggregation=JSONBAgg("integer_field", default=Value("<empty>"))
  219. ),
  220. @ignore_warnings(category=RemovedInDjango51Warning)
  221. def test_jsonb_agg_default_encoded_json_string(self):
  222. AggregateTestModel.objects.all().delete()
  223. queryset = AggregateTestModel.objects.all()
  224. self.assertEqual(
  225. queryset.aggregate(
  226. aggregation=JSONBAgg("integer_field", default=Value("[]"))
  227. ),
  228. {"aggregation": []},
  229. )
  230. def test_jsonb_agg_default_encoded_json_string_deprecation(self):
  231. queryset = AggregateTestModel.objects.all()
  232. msg = (
  233. "Passing an encoded JSON string as JSONBAgg(default) is deprecated. Pass "
  234. "default=[] instead."
  235. )
  236. with self.assertWarnsMessage(RemovedInDjango51Warning, msg):
  237. queryset.aggregate(
  238. aggregation=JSONBAgg("integer_field", default=Value("[]"))
  239. )
  240. with self.assertWarnsMessage(RemovedInDjango51Warning, msg):
  241. queryset.none().aggregate(
  242. aggregation=JSONBAgg("integer_field", default=Value("[]"))
  243. )
  244. def test_array_agg_charfield(self):
  245. values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
  246. self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
  247. def test_array_agg_charfield_ordering(self):
  248. ordering_test_cases = (
  249. (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
  250. (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  251. (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  252. (
  253. [F("boolean_field"), F("char_field").desc()],
  254. ["Foo4", "Foo2", "Foo3", "Foo1"],
  255. ),
  256. (
  257. (F("boolean_field"), F("char_field").desc()),
  258. ["Foo4", "Foo2", "Foo3", "Foo1"],
  259. ),
  260. ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
  261. ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
  262. (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  263. (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
  264. (
  265. (
  266. Substr("char_field", 1, 1),
  267. F("integer_field"),
  268. Substr("char_field", 4, 1).desc(),
  269. ),
  270. ["Foo3", "Foo1", "Foo2", "Foo4"],
  271. ),
  272. )
  273. for ordering, expected_output in ordering_test_cases:
  274. with self.subTest(ordering=ordering, expected_output=expected_output):
  275. values = AggregateTestModel.objects.aggregate(
  276. arrayagg=ArrayAgg("char_field", ordering=ordering)
  277. )
  278. self.assertEqual(values, {"arrayagg": expected_output})
  279. def test_array_agg_integerfield(self):
  280. values = AggregateTestModel.objects.aggregate(
  281. arrayagg=ArrayAgg("integer_field")
  282. )
  283. self.assertEqual(values, {"arrayagg": [0, 1, 2, 0]})
  284. def test_array_agg_integerfield_ordering(self):
  285. values = AggregateTestModel.objects.aggregate(
  286. arrayagg=ArrayAgg("integer_field", ordering=F("integer_field").desc())
  287. )
  288. self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
  289. def test_array_agg_booleanfield(self):
  290. values = AggregateTestModel.objects.aggregate(
  291. arrayagg=ArrayAgg("boolean_field")
  292. )
  293. self.assertEqual(values, {"arrayagg": [True, False, False, True]})
  294. def test_array_agg_booleanfield_ordering(self):
  295. ordering_test_cases = (
  296. (F("boolean_field").asc(), [False, False, True, True]),
  297. (F("boolean_field").desc(), [True, True, False, False]),
  298. (F("boolean_field"), [False, False, True, True]),
  299. )
  300. for ordering, expected_output in ordering_test_cases:
  301. with self.subTest(ordering=ordering, expected_output=expected_output):
  302. values = AggregateTestModel.objects.aggregate(
  303. arrayagg=ArrayAgg("boolean_field", ordering=ordering)
  304. )
  305. self.assertEqual(values, {"arrayagg": expected_output})
  306. def test_array_agg_jsonfield(self):
  307. values = AggregateTestModel.objects.aggregate(
  308. arrayagg=ArrayAgg(
  309. KeyTransform("lang", "json_field"),
  310. filter=Q(json_field__lang__isnull=False),
  311. ),
  312. )
  313. self.assertEqual(values, {"arrayagg": ["pl", "en"]})
  314. def test_array_agg_jsonfield_ordering(self):
  315. values = AggregateTestModel.objects.aggregate(
  316. arrayagg=ArrayAgg(
  317. KeyTransform("lang", "json_field"),
  318. filter=Q(json_field__lang__isnull=False),
  319. ordering=KeyTransform("lang", "json_field"),
  320. ),
  321. )
  322. self.assertEqual(values, {"arrayagg": ["en", "pl"]})
  323. def test_array_agg_filter(self):
  324. values = AggregateTestModel.objects.aggregate(
  325. arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),
  326. )
  327. self.assertEqual(values, {"arrayagg": [1, 2]})
  328. def test_array_agg_lookups(self):
  329. aggr1 = AggregateTestModel.objects.create()
  330. aggr2 = AggregateTestModel.objects.create()
  331. StatTestModel.objects.create(related_field=aggr1, int1=1, int2=0)
  332. StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0)
  333. StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0)
  334. StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0)
  335. qs = (
  336. StatTestModel.objects.values("related_field")
  337. .annotate(array=ArrayAgg("int1"))
  338. .filter(array__overlap=[2])
  339. .values_list("array", flat=True)
  340. )
  341. self.assertCountEqual(qs.get(), [1, 2])
  342. def test_bit_and_general(self):
  343. values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
  344. bitand=BitAnd("integer_field")
  345. )
  346. self.assertEqual(values, {"bitand": 0})
  347. def test_bit_and_on_only_true_values(self):
  348. values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
  349. bitand=BitAnd("integer_field")
  350. )
  351. self.assertEqual(values, {"bitand": 1})
  352. def test_bit_and_on_only_false_values(self):
  353. values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
  354. bitand=BitAnd("integer_field")
  355. )
  356. self.assertEqual(values, {"bitand": 0})
  357. def test_bit_or_general(self):
  358. values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
  359. bitor=BitOr("integer_field")
  360. )
  361. self.assertEqual(values, {"bitor": 1})
  362. def test_bit_or_on_only_true_values(self):
  363. values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
  364. bitor=BitOr("integer_field")
  365. )
  366. self.assertEqual(values, {"bitor": 1})
  367. def test_bit_or_on_only_false_values(self):
  368. values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
  369. bitor=BitOr("integer_field")
  370. )
  371. self.assertEqual(values, {"bitor": 0})
  372. @skipUnlessDBFeature("has_bit_xor")
  373. def test_bit_xor_general(self):
  374. AggregateTestModel.objects.create(integer_field=3)
  375. values = AggregateTestModel.objects.filter(
  376. integer_field__in=[1, 3],
  377. ).aggregate(bitxor=BitXor("integer_field"))
  378. self.assertEqual(values, {"bitxor": 2})
  379. @skipUnlessDBFeature("has_bit_xor")
  380. def test_bit_xor_on_only_true_values(self):
  381. values = AggregateTestModel.objects.filter(
  382. integer_field=1,
  383. ).aggregate(bitxor=BitXor("integer_field"))
  384. self.assertEqual(values, {"bitxor": 1})
  385. @skipUnlessDBFeature("has_bit_xor")
  386. def test_bit_xor_on_only_false_values(self):
  387. values = AggregateTestModel.objects.filter(
  388. integer_field=0,
  389. ).aggregate(bitxor=BitXor("integer_field"))
  390. self.assertEqual(values, {"bitxor": 0})
  391. def test_bool_and_general(self):
  392. values = AggregateTestModel.objects.aggregate(booland=BoolAnd("boolean_field"))
  393. self.assertEqual(values, {"booland": False})
  394. def test_bool_and_q_object(self):
  395. values = AggregateTestModel.objects.aggregate(
  396. booland=BoolAnd(Q(integer_field__gt=2)),
  397. )
  398. self.assertEqual(values, {"booland": False})
  399. def test_bool_or_general(self):
  400. values = AggregateTestModel.objects.aggregate(boolor=BoolOr("boolean_field"))
  401. self.assertEqual(values, {"boolor": True})
  402. def test_bool_or_q_object(self):
  403. values = AggregateTestModel.objects.aggregate(
  404. boolor=BoolOr(Q(integer_field__gt=2)),
  405. )
  406. self.assertEqual(values, {"boolor": False})
  407. def test_string_agg_requires_delimiter(self):
  408. with self.assertRaises(TypeError):
  409. AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
  410. def test_string_agg_delimiter_escaping(self):
  411. values = AggregateTestModel.objects.aggregate(
  412. stringagg=StringAgg("char_field", delimiter="'")
  413. )
  414. self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
  415. def test_string_agg_charfield(self):
  416. values = AggregateTestModel.objects.aggregate(
  417. stringagg=StringAgg("char_field", delimiter=";")
  418. )
  419. self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
  420. def test_string_agg_default_output_field(self):
  421. values = AggregateTestModel.objects.aggregate(
  422. stringagg=StringAgg("text_field", delimiter=";"),
  423. )
  424. self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
  425. def test_string_agg_charfield_ordering(self):
  426. ordering_test_cases = (
  427. (F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
  428. (F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
  429. (F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
  430. ("char_field", "Foo1;Foo2;Foo3;Foo4"),
  431. ("-char_field", "Foo4;Foo3;Foo2;Foo1"),
  432. (Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
  433. (Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
  434. )
  435. for ordering, expected_output in ordering_test_cases:
  436. with self.subTest(ordering=ordering, expected_output=expected_output):
  437. values = AggregateTestModel.objects.aggregate(
  438. stringagg=StringAgg("char_field", delimiter=";", ordering=ordering)
  439. )
  440. self.assertEqual(values, {"stringagg": expected_output})
  441. def test_string_agg_jsonfield_ordering(self):
  442. values = AggregateTestModel.objects.aggregate(
  443. stringagg=StringAgg(
  444. KeyTextTransform("lang", "json_field"),
  445. delimiter=";",
  446. ordering=KeyTextTransform("lang", "json_field"),
  447. output_field=CharField(),
  448. ),
  449. )
  450. self.assertEqual(values, {"stringagg": "en;pl"})
  451. def test_string_agg_filter(self):
  452. values = AggregateTestModel.objects.aggregate(
  453. stringagg=StringAgg(
  454. "char_field",
  455. delimiter=";",
  456. filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
  457. )
  458. )
  459. self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
  460. def test_orderable_agg_alternative_fields(self):
  461. values = AggregateTestModel.objects.aggregate(
  462. arrayagg=ArrayAgg("integer_field", ordering=F("char_field").asc())
  463. )
  464. self.assertEqual(values, {"arrayagg": [0, 1, 0, 2]})
  465. def test_jsonb_agg(self):
  466. values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg("char_field"))
  467. self.assertEqual(values, {"jsonbagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
  468. def test_jsonb_agg_charfield_ordering(self):
  469. ordering_test_cases = (
  470. (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
  471. (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  472. (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  473. ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
  474. ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
  475. (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
  476. (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
  477. )
  478. for ordering, expected_output in ordering_test_cases:
  479. with self.subTest(ordering=ordering, expected_output=expected_output):
  480. values = AggregateTestModel.objects.aggregate(
  481. jsonbagg=JSONBAgg("char_field", ordering=ordering),
  482. )
  483. self.assertEqual(values, {"jsonbagg": expected_output})
  484. def test_jsonb_agg_integerfield_ordering(self):
  485. values = AggregateTestModel.objects.aggregate(
  486. jsonbagg=JSONBAgg("integer_field", ordering=F("integer_field").desc()),
  487. )
  488. self.assertEqual(values, {"jsonbagg": [2, 1, 0, 0]})
  489. def test_jsonb_agg_booleanfield_ordering(self):
  490. ordering_test_cases = (
  491. (F("boolean_field").asc(), [False, False, True, True]),
  492. (F("boolean_field").desc(), [True, True, False, False]),
  493. (F("boolean_field"), [False, False, True, True]),
  494. )
  495. for ordering, expected_output in ordering_test_cases:
  496. with self.subTest(ordering=ordering, expected_output=expected_output):
  497. values = AggregateTestModel.objects.aggregate(
  498. jsonbagg=JSONBAgg("boolean_field", ordering=ordering),
  499. )
  500. self.assertEqual(values, {"jsonbagg": expected_output})
  501. def test_jsonb_agg_jsonfield_ordering(self):
  502. values = AggregateTestModel.objects.aggregate(
  503. jsonbagg=JSONBAgg(
  504. KeyTransform("lang", "json_field"),
  505. filter=Q(json_field__lang__isnull=False),
  506. ordering=KeyTransform("lang", "json_field"),
  507. ),
  508. )
  509. self.assertEqual(values, {"jsonbagg": ["en", "pl"]})
  510. def test_jsonb_agg_key_index_transforms(self):
  511. room101 = Room.objects.create(number=101)
  512. room102 = Room.objects.create(number=102)
  513. datetimes = [
  514. timezone.datetime(2018, 6, 20),
  515. timezone.datetime(2018, 6, 24),
  516. timezone.datetime(2018, 6, 28),
  517. ]
  518. HotelReservation.objects.create(
  519. datespan=(datetimes[0].date(), datetimes[1].date()),
  520. start=datetimes[0],
  521. end=datetimes[1],
  522. room=room102,
  523. requirements={"double_bed": True, "parking": True},
  524. )
  525. HotelReservation.objects.create(
  526. datespan=(datetimes[1].date(), datetimes[2].date()),
  527. start=datetimes[1],
  528. end=datetimes[2],
  529. room=room102,
  530. requirements={"double_bed": False, "sea_view": True, "parking": False},
  531. )
  532. HotelReservation.objects.create(
  533. datespan=(datetimes[0].date(), datetimes[2].date()),
  534. start=datetimes[0],
  535. end=datetimes[2],
  536. room=room101,
  537. requirements={"sea_view": False},
  538. )
  539. values = (
  540. Room.objects.annotate(
  541. requirements=JSONBAgg(
  542. "hotelreservation__requirements",
  543. ordering="-hotelreservation__start",
  544. )
  545. )
  546. .filter(requirements__0__sea_view=True)
  547. .values("number", "requirements")
  548. )
  549. self.assertSequenceEqual(
  550. values,
  551. [
  552. {
  553. "number": 102,
  554. "requirements": [
  555. {"double_bed": False, "sea_view": True, "parking": False},
  556. {"double_bed": True, "parking": True},
  557. ],
  558. },
  559. ],
  560. )
  561. def test_string_agg_array_agg_ordering_in_subquery(self):
  562. stats = []
  563. for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
  564. stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
  565. stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
  566. StatTestModel.objects.bulk_create(stats)
  567. for aggregate, expected_result in (
  568. (
  569. ArrayAgg("stattestmodel__int1", ordering="-stattestmodel__int2"),
  570. [
  571. ("Foo1", [0, 1]),
  572. ("Foo2", [1, 2]),
  573. ("Foo3", [2, 3]),
  574. ("Foo4", [3, 4]),
  575. ],
  576. ),
  577. (
  578. StringAgg(
  579. Cast("stattestmodel__int1", CharField()),
  580. delimiter=";",
  581. ordering="-stattestmodel__int2",
  582. ),
  583. [("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
  584. ),
  585. ):
  586. with self.subTest(aggregate=aggregate.__class__.__name__):
  587. subquery = (
  588. AggregateTestModel.objects.filter(
  589. pk=OuterRef("pk"),
  590. )
  591. .annotate(agg=aggregate)
  592. .values("agg")
  593. )
  594. values = (
  595. AggregateTestModel.objects.annotate(
  596. agg=Subquery(subquery),
  597. )
  598. .order_by("char_field")
  599. .values_list("char_field", "agg")
  600. )
  601. self.assertEqual(list(values), expected_result)
  602. def test_string_agg_array_agg_filter_in_subquery(self):
  603. StatTestModel.objects.bulk_create(
  604. [
  605. StatTestModel(related_field=self.aggs[0], int1=0, int2=5),
  606. StatTestModel(related_field=self.aggs[0], int1=1, int2=4),
  607. StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
  608. ]
  609. )
  610. for aggregate, expected_result in (
  611. (
  612. ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
  613. [("Foo1", [0, 1]), ("Foo2", None)],
  614. ),
  615. (
  616. StringAgg(
  617. Cast("stattestmodel__int2", CharField()),
  618. delimiter=";",
  619. filter=Q(stattestmodel__int1__lt=2),
  620. ),
  621. [("Foo1", "5;4"), ("Foo2", None)],
  622. ),
  623. ):
  624. with self.subTest(aggregate=aggregate.__class__.__name__):
  625. subquery = (
  626. AggregateTestModel.objects.filter(
  627. pk=OuterRef("pk"),
  628. )
  629. .annotate(agg=aggregate)
  630. .values("agg")
  631. )
  632. values = (
  633. AggregateTestModel.objects.annotate(
  634. agg=Subquery(subquery),
  635. )
  636. .filter(
  637. char_field__in=["Foo1", "Foo2"],
  638. )
  639. .order_by("char_field")
  640. .values_list("char_field", "agg")
  641. )
  642. self.assertEqual(list(values), expected_result)
  643. def test_string_agg_filter_in_subquery_with_exclude(self):
  644. subquery = (
  645. AggregateTestModel.objects.annotate(
  646. stringagg=StringAgg(
  647. "char_field",
  648. delimiter=";",
  649. filter=Q(char_field__endswith="1"),
  650. )
  651. )
  652. .exclude(stringagg="")
  653. .values("id")
  654. )
  655. self.assertSequenceEqual(
  656. AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
  657. [self.aggs[0]],
  658. )
  659. def test_ordering_isnt_cleared_for_array_subquery(self):
  660. inner_qs = AggregateTestModel.objects.order_by("-integer_field")
  661. qs = AggregateTestModel.objects.annotate(
  662. integers=Func(
  663. Subquery(inner_qs.values("integer_field")),
  664. function="ARRAY",
  665. output_field=ArrayField(base_field=IntegerField()),
  666. ),
  667. )
  668. self.assertSequenceEqual(
  669. qs.first().integers,
  670. inner_qs.values_list("integer_field", flat=True),
  671. )
  672. def test_window(self):
  673. self.assertCountEqual(
  674. AggregateTestModel.objects.annotate(
  675. integers=Window(
  676. expression=ArrayAgg("char_field"),
  677. partition_by=F("integer_field"),
  678. )
  679. ).values("integers", "char_field"),
  680. [
  681. {"integers": ["Foo1", "Foo3"], "char_field": "Foo1"},
  682. {"integers": ["Foo1", "Foo3"], "char_field": "Foo3"},
  683. {"integers": ["Foo2"], "char_field": "Foo2"},
  684. {"integers": ["Foo4"], "char_field": "Foo4"},
  685. ],
  686. )
  687. def test_values_list(self):
  688. tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
  689. for aggregation in tests:
  690. with self.subTest(aggregation=aggregation):
  691. self.assertCountEqual(
  692. AggregateTestModel.objects.values_list(aggregation),
  693. [([0],), ([1],), ([2],), ([0],)],
  694. )
  695. class TestAggregateDistinct(PostgreSQLTestCase):
  696. @classmethod
  697. def setUpTestData(cls):
  698. AggregateTestModel.objects.create(char_field="Foo")
  699. AggregateTestModel.objects.create(char_field="Foo")
  700. AggregateTestModel.objects.create(char_field="Bar")
  701. def test_string_agg_distinct_false(self):
  702. values = AggregateTestModel.objects.aggregate(
  703. stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
  704. )
  705. self.assertEqual(values["stringagg"].count("Foo"), 2)
  706. self.assertEqual(values["stringagg"].count("Bar"), 1)
  707. def test_string_agg_distinct_true(self):
  708. values = AggregateTestModel.objects.aggregate(
  709. stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
  710. )
  711. self.assertEqual(values["stringagg"].count("Foo"), 1)
  712. self.assertEqual(values["stringagg"].count("Bar"), 1)
  713. def test_array_agg_distinct_false(self):
  714. values = AggregateTestModel.objects.aggregate(
  715. arrayagg=ArrayAgg("char_field", distinct=False)
  716. )
  717. self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo", "Foo"])
  718. def test_array_agg_distinct_true(self):
  719. values = AggregateTestModel.objects.aggregate(
  720. arrayagg=ArrayAgg("char_field", distinct=True)
  721. )
  722. self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo"])
  723. def test_jsonb_agg_distinct_false(self):
  724. values = AggregateTestModel.objects.aggregate(
  725. jsonbagg=JSONBAgg("char_field", distinct=False),
  726. )
  727. self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo", "Foo"])
  728. def test_jsonb_agg_distinct_true(self):
  729. values = AggregateTestModel.objects.aggregate(
  730. jsonbagg=JSONBAgg("char_field", distinct=True),
  731. )
  732. self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo"])
  733. class TestStatisticsAggregate(PostgreSQLTestCase):
  734. @classmethod
  735. def setUpTestData(cls):
  736. StatTestModel.objects.create(
  737. int1=1,
  738. int2=3,
  739. related_field=AggregateTestModel.objects.create(integer_field=0),
  740. )
  741. StatTestModel.objects.create(
  742. int1=2,
  743. int2=2,
  744. related_field=AggregateTestModel.objects.create(integer_field=1),
  745. )
  746. StatTestModel.objects.create(
  747. int1=3,
  748. int2=1,
  749. related_field=AggregateTestModel.objects.create(integer_field=2),
  750. )
  751. # Tests for base class (StatAggregate)
  752. def test_missing_arguments_raises_exception(self):
  753. with self.assertRaisesMessage(ValueError, "Both y and x must be provided."):
  754. StatAggregate(x=None, y=None)
  755. def test_correct_source_expressions(self):
  756. func = StatAggregate(x="test", y=13)
  757. self.assertIsInstance(func.source_expressions[0], Value)
  758. self.assertIsInstance(func.source_expressions[1], F)
  759. def test_alias_is_required(self):
  760. class SomeFunc(StatAggregate):
  761. function = "TEST"
  762. with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
  763. StatTestModel.objects.aggregate(SomeFunc(y="int2", x="int1"))
  764. # Test aggregates
  765. def test_empty_result_set(self):
  766. StatTestModel.objects.all().delete()
  767. tests = [
  768. (Corr(y="int2", x="int1"), None),
  769. (CovarPop(y="int2", x="int1"), None),
  770. (CovarPop(y="int2", x="int1", sample=True), None),
  771. (RegrAvgX(y="int2", x="int1"), None),
  772. (RegrAvgY(y="int2", x="int1"), None),
  773. (RegrCount(y="int2", x="int1"), 0),
  774. (RegrIntercept(y="int2", x="int1"), None),
  775. (RegrR2(y="int2", x="int1"), None),
  776. (RegrSlope(y="int2", x="int1"), None),
  777. (RegrSXX(y="int2", x="int1"), None),
  778. (RegrSXY(y="int2", x="int1"), None),
  779. (RegrSYY(y="int2", x="int1"), None),
  780. ]
  781. for aggregation, expected_result in tests:
  782. with self.subTest(aggregation=aggregation):
  783. # Empty result with non-execution optimization.
  784. with self.assertNumQueries(0):
  785. values = StatTestModel.objects.none().aggregate(
  786. aggregation=aggregation,
  787. )
  788. self.assertEqual(values, {"aggregation": expected_result})
  789. # Empty result when query must be executed.
  790. with self.assertNumQueries(1):
  791. values = StatTestModel.objects.aggregate(
  792. aggregation=aggregation,
  793. )
  794. self.assertEqual(values, {"aggregation": expected_result})
  795. def test_default_argument(self):
  796. StatTestModel.objects.all().delete()
  797. tests = [
  798. (Corr(y="int2", x="int1", default=0), 0),
  799. (CovarPop(y="int2", x="int1", default=0), 0),
  800. (CovarPop(y="int2", x="int1", sample=True, default=0), 0),
  801. (RegrAvgX(y="int2", x="int1", default=0), 0),
  802. (RegrAvgY(y="int2", x="int1", default=0), 0),
  803. # RegrCount() doesn't support the default argument.
  804. (RegrIntercept(y="int2", x="int1", default=0), 0),
  805. (RegrR2(y="int2", x="int1", default=0), 0),
  806. (RegrSlope(y="int2", x="int1", default=0), 0),
  807. (RegrSXX(y="int2", x="int1", default=0), 0),
  808. (RegrSXY(y="int2", x="int1", default=0), 0),
  809. (RegrSYY(y="int2", x="int1", default=0), 0),
  810. ]
  811. for aggregation, expected_result in tests:
  812. with self.subTest(aggregation=aggregation):
  813. # Empty result with non-execution optimization.
  814. with self.assertNumQueries(0):
  815. values = StatTestModel.objects.none().aggregate(
  816. aggregation=aggregation,
  817. )
  818. self.assertEqual(values, {"aggregation": expected_result})
  819. # Empty result when query must be executed.
  820. with self.assertNumQueries(1):
  821. values = StatTestModel.objects.aggregate(
  822. aggregation=aggregation,
  823. )
  824. self.assertEqual(values, {"aggregation": expected_result})
  825. def test_corr_general(self):
  826. values = StatTestModel.objects.aggregate(corr=Corr(y="int2", x="int1"))
  827. self.assertEqual(values, {"corr": -1.0})
  828. def test_covar_pop_general(self):
  829. values = StatTestModel.objects.aggregate(covarpop=CovarPop(y="int2", x="int1"))
  830. self.assertEqual(values, {"covarpop": Approximate(-0.66, places=1)})
  831. def test_covar_pop_sample(self):
  832. values = StatTestModel.objects.aggregate(
  833. covarpop=CovarPop(y="int2", x="int1", sample=True)
  834. )
  835. self.assertEqual(values, {"covarpop": -1.0})
  836. def test_regr_avgx_general(self):
  837. values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y="int2", x="int1"))
  838. self.assertEqual(values, {"regravgx": 2.0})
  839. def test_regr_avgy_general(self):
  840. values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y="int2", x="int1"))
  841. self.assertEqual(values, {"regravgy": 2.0})
  842. def test_regr_count_general(self):
  843. values = StatTestModel.objects.aggregate(
  844. regrcount=RegrCount(y="int2", x="int1")
  845. )
  846. self.assertEqual(values, {"regrcount": 3})
  847. def test_regr_count_default(self):
  848. msg = "RegrCount does not allow default."
  849. with self.assertRaisesMessage(TypeError, msg):
  850. RegrCount(y="int2", x="int1", default=0)
  851. def test_regr_intercept_general(self):
  852. values = StatTestModel.objects.aggregate(
  853. regrintercept=RegrIntercept(y="int2", x="int1")
  854. )
  855. self.assertEqual(values, {"regrintercept": 4})
  856. def test_regr_r2_general(self):
  857. values = StatTestModel.objects.aggregate(regrr2=RegrR2(y="int2", x="int1"))
  858. self.assertEqual(values, {"regrr2": 1})
  859. def test_regr_slope_general(self):
  860. values = StatTestModel.objects.aggregate(
  861. regrslope=RegrSlope(y="int2", x="int1")
  862. )
  863. self.assertEqual(values, {"regrslope": -1})
  864. def test_regr_sxx_general(self):
  865. values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y="int2", x="int1"))
  866. self.assertEqual(values, {"regrsxx": 2.0})
  867. def test_regr_sxy_general(self):
  868. values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y="int2", x="int1"))
  869. self.assertEqual(values, {"regrsxy": -2.0})
  870. def test_regr_syy_general(self):
  871. values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y="int2", x="int1"))
  872. self.assertEqual(values, {"regrsyy": 2.0})
  873. def test_regr_avgx_with_related_obj_and_number_as_argument(self):
  874. """
  875. This is more complex test to check if JOIN on field and
  876. number as argument works as expected.
  877. """
  878. values = StatTestModel.objects.aggregate(
  879. complex_regravgx=RegrAvgX(y=5, x="related_field__integer_field")
  880. )
  881. self.assertEqual(values, {"complex_regravgx": 1.0})