test_array.py 55 KB


  1. import decimal
  2. import enum
  3. import json
  4. import unittest
  5. import uuid
  6. from django import forms
  7. from django.contrib.admin.utils import display_for_field
  8. from django.core import checks, exceptions, serializers, validators
  9. from django.core.exceptions import FieldError
  10. from django.core.management import call_command
  11. from django.db import IntegrityError, connection, models
  12. from django.db.models.expressions import Exists, F, OuterRef, RawSQL, Value
  13. from django.db.models.functions import Cast, JSONObject, Upper
  14. from django.test import TransactionTestCase, override_settings, skipUnlessDBFeature
  15. from django.test.utils import isolate_apps
  16. from django.utils import timezone
  17. from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase
  18. from .models import (
  19. ArrayEnumModel,
  20. ArrayFieldSubclass,
  21. CharArrayModel,
  22. DateTimeArrayModel,
  23. IntegerArrayModel,
  24. NestedIntegerArrayModel,
  25. NullableIntegerArrayModel,
  26. OtherTypesArrayModel,
  27. PostgreSQLModel,
  28. Tag,
  29. )
  30. try:
  31. from django.contrib.postgres.aggregates import ArrayAgg
  32. from django.contrib.postgres.expressions import ArraySubquery
  33. from django.contrib.postgres.fields import ArrayField
  34. from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
  35. from django.contrib.postgres.forms import (
  36. SimpleArrayField,
  37. SplitArrayField,
  38. SplitArrayWidget,
  39. )
  40. from django.db.backends.postgresql.psycopg_any import NumericRange
  41. except ImportError:
  42. pass
  43. @isolate_apps("postgres_tests")
  44. class BasicTests(PostgreSQLSimpleTestCase):
  45. def test_get_field_display(self):
  46. class MyModel(PostgreSQLModel):
  47. field = ArrayField(
  48. models.CharField(max_length=16),
  49. choices=[
  50. ["Media", [(["vinyl", "cd"], "Audio")]],
  51. (("mp3", "mp4"), "Digital"),
  52. ],
  53. )
  54. tests = (
  55. (["vinyl", "cd"], "Audio"),
  56. (("mp3", "mp4"), "Digital"),
  57. (("a", "b"), "('a', 'b')"),
  58. (["c", "d"], "['c', 'd']"),
  59. )
  60. for value, display in tests:
  61. with self.subTest(value=value, display=display):
  62. instance = MyModel(field=value)
  63. self.assertEqual(instance.get_field_display(), display)
  64. def test_get_field_display_nested_array(self):
  65. class MyModel(PostgreSQLModel):
  66. field = ArrayField(
  67. ArrayField(models.CharField(max_length=16)),
  68. choices=[
  69. [
  70. "Media",
  71. [([["vinyl", "cd"], ("x",)], "Audio")],
  72. ],
  73. ((["mp3"], ("mp4",)), "Digital"),
  74. ],
  75. )
  76. tests = (
  77. ([["vinyl", "cd"], ("x",)], "Audio"),
  78. ((["mp3"], ("mp4",)), "Digital"),
  79. ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"),
  80. ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"),
  81. )
  82. for value, display in tests:
  83. with self.subTest(value=value, display=display):
  84. instance = MyModel(field=value)
  85. self.assertEqual(instance.get_field_display(), display)
  86. class TestSaveLoad(PostgreSQLTestCase):
  87. def test_integer(self):
  88. instance = IntegerArrayModel(field=[1, 2, 3])
  89. instance.save()
  90. loaded = IntegerArrayModel.objects.get()
  91. self.assertEqual(instance.field, loaded.field)
  92. def test_char(self):
  93. instance = CharArrayModel(field=["hello", "goodbye"])
  94. instance.save()
  95. loaded = CharArrayModel.objects.get()
  96. self.assertEqual(instance.field, loaded.field)
  97. def test_dates(self):
  98. instance = DateTimeArrayModel(
  99. datetimes=[timezone.now()],
  100. dates=[timezone.now().date()],
  101. times=[timezone.now().time()],
  102. )
  103. instance.save()
  104. loaded = DateTimeArrayModel.objects.get()
  105. self.assertEqual(instance.datetimes, loaded.datetimes)
  106. self.assertEqual(instance.dates, loaded.dates)
  107. self.assertEqual(instance.times, loaded.times)
  108. def test_tuples(self):
  109. instance = IntegerArrayModel(field=(1,))
  110. instance.save()
  111. loaded = IntegerArrayModel.objects.get()
  112. self.assertSequenceEqual(instance.field, loaded.field)
  113. def test_integers_passed_as_strings(self):
  114. # This checks that get_prep_value is deferred properly
  115. instance = IntegerArrayModel(field=["1"])
  116. instance.save()
  117. loaded = IntegerArrayModel.objects.get()
  118. self.assertEqual(loaded.field, [1])
  119. def test_default_null(self):
  120. instance = NullableIntegerArrayModel()
  121. instance.save()
  122. loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk)
  123. self.assertIsNone(loaded.field)
  124. self.assertEqual(instance.field, loaded.field)
  125. def test_null_handling(self):
  126. instance = NullableIntegerArrayModel(field=None)
  127. instance.save()
  128. loaded = NullableIntegerArrayModel.objects.get()
  129. self.assertEqual(instance.field, loaded.field)
  130. instance = IntegerArrayModel(field=None)
  131. with self.assertRaises(IntegrityError):
  132. instance.save()
  133. def test_nested(self):
  134. instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
  135. instance.save()
  136. loaded = NestedIntegerArrayModel.objects.get()
  137. self.assertEqual(instance.field, loaded.field)
  138. def test_other_array_types(self):
  139. instance = OtherTypesArrayModel(
  140. ips=["192.168.0.1", "::1"],
  141. uuids=[uuid.uuid4()],
  142. decimals=[decimal.Decimal(1.25), 1.75],
  143. tags=[Tag(1), Tag(2), Tag(3)],
  144. json=[{"a": 1}, {"b": 2}],
  145. int_ranges=[NumericRange(10, 20), NumericRange(30, 40)],
  146. bigint_ranges=[
  147. NumericRange(7000000000, 10000000000),
  148. NumericRange(50000000000, 70000000000),
  149. ],
  150. )
  151. instance.save()
  152. loaded = OtherTypesArrayModel.objects.get()
  153. self.assertEqual(instance.ips, loaded.ips)
  154. self.assertEqual(instance.uuids, loaded.uuids)
  155. self.assertEqual(instance.decimals, loaded.decimals)
  156. self.assertEqual(instance.tags, loaded.tags)
  157. self.assertEqual(instance.json, loaded.json)
  158. self.assertEqual(instance.int_ranges, loaded.int_ranges)
  159. self.assertEqual(instance.bigint_ranges, loaded.bigint_ranges)
  160. def test_null_from_db_value_handling(self):
  161. instance = OtherTypesArrayModel.objects.create(
  162. ips=["192.168.0.1", "::1"],
  163. uuids=[uuid.uuid4()],
  164. decimals=[decimal.Decimal(1.25), 1.75],
  165. tags=None,
  166. )
  167. instance.refresh_from_db()
  168. self.assertIsNone(instance.tags)
  169. self.assertEqual(instance.json, [])
  170. self.assertIsNone(instance.int_ranges)
  171. self.assertIsNone(instance.bigint_ranges)
  172. def test_model_set_on_base_field(self):
  173. instance = IntegerArrayModel()
  174. field = instance._meta.get_field("field")
  175. self.assertEqual(field.model, IntegerArrayModel)
  176. self.assertEqual(field.base_field.model, IntegerArrayModel)
  177. def test_nested_nullable_base_field(self):
  178. instance = NullableIntegerArrayModel.objects.create(
  179. field_nested=[[None, None], [None, None]],
  180. )
  181. self.assertEqual(instance.field_nested, [[None, None], [None, None]])
  182. class TestQuerying(PostgreSQLTestCase):
  183. @classmethod
  184. def setUpTestData(cls):
  185. cls.objs = NullableIntegerArrayModel.objects.bulk_create(
  186. [
  187. NullableIntegerArrayModel(order=1, field=[1]),
  188. NullableIntegerArrayModel(order=2, field=[2]),
  189. NullableIntegerArrayModel(order=3, field=[2, 3]),
  190. NullableIntegerArrayModel(order=4, field=[20, 30, 40]),
  191. NullableIntegerArrayModel(order=5, field=None),
  192. ]
  193. )
  194. def test_empty_list(self):
  195. NullableIntegerArrayModel.objects.create(field=[])
  196. obj = (
  197. NullableIntegerArrayModel.objects.annotate(
  198. empty_array=models.Value(
  199. [], output_field=ArrayField(models.IntegerField())
  200. ),
  201. )
  202. .filter(field=models.F("empty_array"))
  203. .get()
  204. )
  205. self.assertEqual(obj.field, [])
  206. self.assertEqual(obj.empty_array, [])
  207. def test_exact(self):
  208. self.assertSequenceEqual(
  209. NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1]
  210. )
  211. def test_exact_null_only_array(self):
  212. obj = NullableIntegerArrayModel.objects.create(
  213. field=[None], field_nested=[None, None]
  214. )
  215. self.assertSequenceEqual(
  216. NullableIntegerArrayModel.objects.filter(field__exact=[None]), [obj]
  217. )
  218. self.assertSequenceEqual(
  219. NullableIntegerArrayModel.objects.filter(field_nested__exact=[None, None]),
  220. [obj],
  221. )
  222. def test_exact_null_only_nested_array(self):
  223. obj1 = NullableIntegerArrayModel.objects.create(field_nested=[[None, None]])
  224. obj2 = NullableIntegerArrayModel.objects.create(
  225. field_nested=[[None, None], [None, None]],
  226. )
  227. self.assertSequenceEqual(
  228. NullableIntegerArrayModel.objects.filter(
  229. field_nested__exact=[[None, None]],
  230. ),
  231. [obj1],
  232. )
  233. self.assertSequenceEqual(
  234. NullableIntegerArrayModel.objects.filter(
  235. field_nested__exact=[[None, None], [None, None]],
  236. ),
  237. [obj2],
  238. )
  239. def test_exact_with_expression(self):
  240. self.assertSequenceEqual(
  241. NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]),
  242. self.objs[:1],
  243. )
  244. def test_exact_charfield(self):
  245. instance = CharArrayModel.objects.create(field=["text"])
  246. self.assertSequenceEqual(
  247. CharArrayModel.objects.filter(field=["text"]), [instance]
  248. )
  249. def test_exact_nested(self):
  250. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  251. self.assertSequenceEqual(
  252. NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance]
  253. )
  254. def test_isnull(self):
  255. self.assertSequenceEqual(
  256. NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:]
  257. )
  258. def test_gt(self):
  259. self.assertSequenceEqual(
  260. NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4]
  261. )
  262. def test_lt(self):
  263. self.assertSequenceEqual(
  264. NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1]
  265. )
  266. def test_in(self):
  267. self.assertSequenceEqual(
  268. NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
  269. self.objs[:2],
  270. )
  271. def test_in_subquery(self):
  272. IntegerArrayModel.objects.create(field=[2, 3])
  273. self.assertSequenceEqual(
  274. NullableIntegerArrayModel.objects.filter(
  275. field__in=IntegerArrayModel.objects.values_list("field", flat=True)
  276. ),
  277. self.objs[2:3],
  278. )
  279. @unittest.expectedFailure
  280. def test_in_including_F_object(self):
  281. # This test asserts that Array objects passed to filters can be
  282. # constructed to contain F objects. This currently doesn't work as the
  283. # psycopg mogrify method that generates the ARRAY() syntax is
  284. # expecting literals, not column references (#27095).
  285. self.assertSequenceEqual(
  286. NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]),
  287. self.objs[:2],
  288. )
  289. def test_in_as_F_object(self):
  290. self.assertSequenceEqual(
  291. NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]),
  292. self.objs[:4],
  293. )
  294. def test_contained_by(self):
  295. self.assertSequenceEqual(
  296. NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
  297. self.objs[:2],
  298. )
  299. def test_contained_by_including_F_object(self):
  300. self.assertSequenceEqual(
  301. NullableIntegerArrayModel.objects.filter(
  302. field__contained_by=[models.F("order"), 2]
  303. ),
  304. self.objs[:3],
  305. )
  306. def test_contains(self):
  307. self.assertSequenceEqual(
  308. NullableIntegerArrayModel.objects.filter(field__contains=[2]),
  309. self.objs[1:3],
  310. )
  311. def test_contains_subquery(self):
  312. IntegerArrayModel.objects.create(field=[2, 3])
  313. inner_qs = IntegerArrayModel.objects.values_list("field", flat=True)
  314. self.assertSequenceEqual(
  315. NullableIntegerArrayModel.objects.filter(field__contains=inner_qs[:1]),
  316. self.objs[2:3],
  317. )
  318. inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef("field"))
  319. self.assertSequenceEqual(
  320. NullableIntegerArrayModel.objects.filter(Exists(inner_qs)),
  321. self.objs[1:3],
  322. )
  323. def test_contains_including_expression(self):
  324. self.assertSequenceEqual(
  325. NullableIntegerArrayModel.objects.filter(
  326. field__contains=[2, Value(6) / Value(2)],
  327. ),
  328. self.objs[2:3],
  329. )
  330. def test_icontains(self):
  331. # Using the __icontains lookup with ArrayField is inefficient.
  332. instance = CharArrayModel.objects.create(field=["FoO"])
  333. self.assertSequenceEqual(
  334. CharArrayModel.objects.filter(field__icontains="foo"), [instance]
  335. )
  336. def test_contains_charfield(self):
  337. # Regression for #22907
  338. self.assertSequenceEqual(
  339. CharArrayModel.objects.filter(field__contains=["text"]), []
  340. )
  341. def test_contained_by_charfield(self):
  342. self.assertSequenceEqual(
  343. CharArrayModel.objects.filter(field__contained_by=["text"]), []
  344. )
  345. def test_overlap_charfield(self):
  346. self.assertSequenceEqual(
  347. CharArrayModel.objects.filter(field__overlap=["text"]), []
  348. )
  349. def test_overlap_charfield_including_expression(self):
  350. obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
  351. obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
  352. CharArrayModel.objects.create(field=["lower text", "text"])
  353. self.assertSequenceEqual(
  354. CharArrayModel.objects.filter(
  355. field__overlap=[
  356. Upper(Value("text")),
  357. "other",
  358. ]
  359. ),
  360. [obj_1, obj_2],
  361. )
  362. def test_overlap_values(self):
  363. qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
  364. self.assertCountEqual(
  365. NullableIntegerArrayModel.objects.filter(
  366. field__overlap=qs.values_list("field"),
  367. ),
  368. self.objs[:3],
  369. )
  370. self.assertCountEqual(
  371. NullableIntegerArrayModel.objects.filter(
  372. field__overlap=qs.values("field"),
  373. ),
  374. self.objs[:3],
  375. )
  376. def test_lookups_autofield_array(self):
  377. qs = (
  378. NullableIntegerArrayModel.objects.filter(
  379. field__0__isnull=False,
  380. )
  381. .values("field__0")
  382. .annotate(
  383. arrayagg=ArrayAgg("id"),
  384. )
  385. .order_by("field__0")
  386. )
  387. tests = (
  388. ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]),
  389. ("contains", [self.objs[2].pk], [2]),
  390. ("exact", [self.objs[3].pk], [20]),
  391. ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]),
  392. )
  393. for lookup, value, expected in tests:
  394. with self.subTest(lookup=lookup):
  395. self.assertSequenceEqual(
  396. qs.filter(
  397. **{"arrayagg__" + lookup: value},
  398. ).values_list("field__0", flat=True),
  399. expected,
  400. )
  401. @skipUnlessDBFeature("allows_group_by_select_index")
  402. def test_group_by_order_by_select_index(self):
  403. with self.assertNumQueries(1) as ctx:
  404. self.assertSequenceEqual(
  405. NullableIntegerArrayModel.objects.filter(
  406. field__0__isnull=False,
  407. )
  408. .values("field__0")
  409. .annotate(arrayagg=ArrayAgg("id"))
  410. .order_by("field__0"),
  411. [
  412. {"field__0": 1, "arrayagg": [self.objs[0].pk]},
  413. {"field__0": 2, "arrayagg": [self.objs[1].pk, self.objs[2].pk]},
  414. {"field__0": 20, "arrayagg": [self.objs[3].pk]},
  415. ],
  416. )
  417. sql = ctx[0]["sql"]
  418. self.assertIn("GROUP BY 2", sql)
  419. self.assertIn("ORDER BY 2", sql)
  420. def test_order_by_arrayagg_index(self):
  421. qs = (
  422. NullableIntegerArrayModel.objects.values("order")
  423. .annotate(ids=ArrayAgg("id"))
  424. .order_by("-ids__0")
  425. )
  426. self.assertQuerySetEqual(
  427. qs, [{"order": obj.order, "ids": [obj.id]} for obj in reversed(self.objs)]
  428. )
  429. def test_index(self):
  430. self.assertSequenceEqual(
  431. NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
  432. )
  433. def test_index_chained(self):
  434. self.assertSequenceEqual(
  435. NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3]
  436. )
  437. def test_index_nested(self):
  438. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  439. self.assertSequenceEqual(
  440. NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]
  441. )
  442. @unittest.expectedFailure
  443. def test_index_used_on_nested_data(self):
  444. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  445. self.assertSequenceEqual(
  446. NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
  447. )
  448. def test_index_transform_expression(self):
  449. expr = RawSQL("string_to_array(%s, ';')", ["1;2"])
  450. self.assertSequenceEqual(
  451. NullableIntegerArrayModel.objects.filter(
  452. field__0=Cast(
  453. IndexTransform(1, models.IntegerField, expr),
  454. output_field=models.IntegerField(),
  455. ),
  456. ),
  457. self.objs[:1],
  458. )
  459. def test_index_annotation(self):
  460. qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1"))
  461. self.assertCountEqual(
  462. qs.values_list("second", flat=True),
  463. [None, None, None, 3, 30],
  464. )
  465. def test_overlap(self):
  466. self.assertSequenceEqual(
  467. NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
  468. self.objs[0:3],
  469. )
  470. def test_len(self):
  471. self.assertSequenceEqual(
  472. NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3]
  473. )
  474. def test_len_empty_array(self):
  475. obj = NullableIntegerArrayModel.objects.create(field=[])
  476. self.assertSequenceEqual(
  477. NullableIntegerArrayModel.objects.filter(field__len=0), [obj]
  478. )
  479. def test_slice(self):
  480. self.assertSequenceEqual(
  481. NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3]
  482. )
  483. self.assertSequenceEqual(
  484. NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
  485. )
  486. def test_order_by_slice(self):
  487. more_objs = (
  488. NullableIntegerArrayModel.objects.create(field=[1, 637]),
  489. NullableIntegerArrayModel.objects.create(field=[2, 1]),
  490. NullableIntegerArrayModel.objects.create(field=[3, -98123]),
  491. NullableIntegerArrayModel.objects.create(field=[4, 2]),
  492. )
  493. self.assertSequenceEqual(
  494. NullableIntegerArrayModel.objects.order_by("field__1"),
  495. [
  496. more_objs[2],
  497. more_objs[1],
  498. more_objs[3],
  499. self.objs[2],
  500. self.objs[3],
  501. more_objs[0],
  502. self.objs[4],
  503. self.objs[1],
  504. self.objs[0],
  505. ],
  506. )
  507. @unittest.expectedFailure
  508. def test_slice_nested(self):
  509. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  510. self.assertSequenceEqual(
  511. NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance]
  512. )
  513. def test_slice_transform_expression(self):
  514. expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"])
  515. self.assertSequenceEqual(
  516. NullableIntegerArrayModel.objects.filter(
  517. field__0_2=SliceTransform(2, 3, expr)
  518. ),
  519. self.objs[2:3],
  520. )
  521. def test_slice_annotation(self):
  522. qs = NullableIntegerArrayModel.objects.annotate(
  523. first_two=models.F("field__0_2"),
  524. )
  525. self.assertCountEqual(
  526. qs.values_list("first_two", flat=True),
  527. [None, [1], [2], [2, 3], [20, 30]],
  528. )
  529. def test_slicing_of_f_expressions(self):
  530. tests = [
  531. (F("field")[:2], [1, 2]),
  532. (F("field")[2:], [3, 4]),
  533. (F("field")[1:3], [2, 3]),
  534. (F("field")[3], [4]),
  535. (F("field")[:3][1:], [2, 3]), # Nested slicing.
  536. (F("field")[:3][1], [2]), # Slice then index.
  537. ]
  538. for expression, expected in tests:
  539. with self.subTest(expression=expression, expected=expected):
  540. instance = IntegerArrayModel.objects.create(field=[1, 2, 3, 4])
  541. instance.field = expression
  542. instance.save()
  543. instance.refresh_from_db()
  544. self.assertEqual(instance.field, expected)
  545. def test_slicing_of_f_expressions_with_annotate(self):
  546. IntegerArrayModel.objects.create(field=[1, 2, 3])
  547. annotated = IntegerArrayModel.objects.annotate(
  548. first_two=F("field")[:2],
  549. after_two=F("field")[2:],
  550. random_two=F("field")[1:3],
  551. ).get()
  552. self.assertEqual(annotated.first_two, [1, 2])
  553. self.assertEqual(annotated.after_two, [3])
  554. self.assertEqual(annotated.random_two, [2, 3])
  555. def test_slicing_of_f_expressions_with_len(self):
  556. queryset = NullableIntegerArrayModel.objects.annotate(
  557. subarray=F("field")[:1]
  558. ).filter(field__len=F("subarray__len"))
  559. self.assertSequenceEqual(queryset, self.objs[:2])
  560. def test_usage_in_subquery(self):
  561. self.assertSequenceEqual(
  562. NullableIntegerArrayModel.objects.filter(
  563. id__in=NullableIntegerArrayModel.objects.filter(field__len=3)
  564. ),
  565. [self.objs[3]],
  566. )
  567. def test_enum_lookup(self):
  568. class TestEnum(enum.Enum):
  569. VALUE_1 = "value_1"
  570. instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
  571. self.assertSequenceEqual(
  572. ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
  573. [instance],
  574. )
  575. def test_unsupported_lookup(self):
  576. msg = (
  577. "Unsupported lookup '0_bar' for ArrayField or join on the field not "
  578. "permitted."
  579. )
  580. with self.assertRaisesMessage(FieldError, msg):
  581. list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2]))
  582. msg = (
  583. "Unsupported lookup '0bar' for ArrayField or join on the field not "
  584. "permitted."
  585. )
  586. with self.assertRaisesMessage(FieldError, msg):
  587. list(NullableIntegerArrayModel.objects.filter(field__0bar=[2]))
  588. def test_grouping_by_annotations_with_array_field_param(self):
  589. value = models.Value([1], output_field=ArrayField(models.IntegerField()))
  590. self.assertEqual(
  591. NullableIntegerArrayModel.objects.annotate(
  592. array_length=models.Func(
  593. value,
  594. 1,
  595. function="ARRAY_LENGTH",
  596. output_field=models.IntegerField(),
  597. ),
  598. )
  599. .values("array_length")
  600. .annotate(
  601. count=models.Count("pk"),
  602. )
  603. .get()["array_length"],
  604. 1,
  605. )
  606. def test_filter_by_array_subquery(self):
  607. inner_qs = NullableIntegerArrayModel.objects.filter(
  608. field__len=models.OuterRef("field__len"),
  609. ).values("field")
  610. self.assertSequenceEqual(
  611. NullableIntegerArrayModel.objects.alias(
  612. same_sized_fields=ArraySubquery(inner_qs),
  613. ).filter(same_sized_fields__len__gt=1),
  614. self.objs[0:2],
  615. )
  616. def test_annotated_array_subquery(self):
  617. inner_qs = NullableIntegerArrayModel.objects.exclude(
  618. pk=models.OuterRef("pk")
  619. ).values("order")
  620. self.assertSequenceEqual(
  621. NullableIntegerArrayModel.objects.annotate(
  622. sibling_ids=ArraySubquery(inner_qs),
  623. )
  624. .get(order=1)
  625. .sibling_ids,
  626. [2, 3, 4, 5],
  627. )
  628. def test_group_by_with_annotated_array_subquery(self):
  629. inner_qs = NullableIntegerArrayModel.objects.exclude(
  630. pk=models.OuterRef("pk")
  631. ).values("order")
  632. self.assertSequenceEqual(
  633. NullableIntegerArrayModel.objects.annotate(
  634. sibling_ids=ArraySubquery(inner_qs),
  635. sibling_count=models.Max("sibling_ids__len"),
  636. ).values_list("sibling_count", flat=True),
  637. [len(self.objs) - 1] * len(self.objs),
  638. )
  639. def test_annotated_ordered_array_subquery(self):
  640. inner_qs = NullableIntegerArrayModel.objects.order_by("-order").values("order")
  641. self.assertSequenceEqual(
  642. NullableIntegerArrayModel.objects.annotate(
  643. ids=ArraySubquery(inner_qs),
  644. )
  645. .first()
  646. .ids,
  647. [5, 4, 3, 2, 1],
  648. )
  649. def test_annotated_array_subquery_with_json_objects(self):
  650. inner_qs = NullableIntegerArrayModel.objects.exclude(
  651. pk=models.OuterRef("pk")
  652. ).values(json=JSONObject(order="order", field="field"))
  653. siblings_json = (
  654. NullableIntegerArrayModel.objects.annotate(
  655. siblings_json=ArraySubquery(inner_qs),
  656. )
  657. .values_list("siblings_json", flat=True)
  658. .get(order=1)
  659. )
  660. self.assertSequenceEqual(
  661. siblings_json,
  662. [
  663. {"field": [2], "order": 2},
  664. {"field": [2, 3], "order": 3},
  665. {"field": [20, 30, 40], "order": 4},
  666. {"field": None, "order": 5},
  667. ],
  668. )
  669. class TestDateTimeExactQuerying(PostgreSQLTestCase):
  670. @classmethod
  671. def setUpTestData(cls):
  672. now = timezone.now()
  673. cls.datetimes = [now]
  674. cls.dates = [now.date()]
  675. cls.times = [now.time()]
  676. cls.objs = [
  677. DateTimeArrayModel.objects.create(
  678. datetimes=cls.datetimes, dates=cls.dates, times=cls.times
  679. ),
  680. ]
  681. def test_exact_datetimes(self):
  682. self.assertSequenceEqual(
  683. DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs
  684. )
  685. def test_exact_dates(self):
  686. self.assertSequenceEqual(
  687. DateTimeArrayModel.objects.filter(dates=self.dates), self.objs
  688. )
  689. def test_exact_times(self):
  690. self.assertSequenceEqual(
  691. DateTimeArrayModel.objects.filter(times=self.times), self.objs
  692. )
  693. class TestOtherTypesExactQuerying(PostgreSQLTestCase):
  694. @classmethod
  695. def setUpTestData(cls):
  696. cls.ips = ["192.168.0.1", "::1"]
  697. cls.uuids = [uuid.uuid4()]
  698. cls.decimals = [decimal.Decimal(1.25), 1.75]
  699. cls.tags = [Tag(1), Tag(2), Tag(3)]
  700. cls.objs = [
  701. OtherTypesArrayModel.objects.create(
  702. ips=cls.ips,
  703. uuids=cls.uuids,
  704. decimals=cls.decimals,
  705. tags=cls.tags,
  706. )
  707. ]
  708. def test_exact_ip_addresses(self):
  709. self.assertSequenceEqual(
  710. OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs
  711. )
  712. def test_exact_uuids(self):
  713. self.assertSequenceEqual(
  714. OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs
  715. )
  716. def test_exact_decimals(self):
  717. self.assertSequenceEqual(
  718. OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs
  719. )
  720. def test_exact_tags(self):
  721. self.assertSequenceEqual(
  722. OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs
  723. )
  724. @isolate_apps("postgres_tests")
  725. class TestChecks(PostgreSQLSimpleTestCase):
  726. def test_field_checks(self):
  727. class MyModel(PostgreSQLModel):
  728. field = ArrayField(models.CharField(max_length=-1))
  729. model = MyModel()
  730. errors = model.check()
  731. self.assertEqual(len(errors), 1)
  732. # The inner CharField has a non-positive max_length.
  733. self.assertEqual(errors[0].id, "postgres.E001")
  734. self.assertIn("max_length", errors[0].msg)
  735. def test_invalid_base_fields(self):
  736. class MyModel(PostgreSQLModel):
  737. field = ArrayField(
  738. models.ManyToManyField("postgres_tests.IntegerArrayModel")
  739. )
  740. model = MyModel()
  741. errors = model.check()
  742. self.assertEqual(len(errors), 1)
  743. self.assertEqual(errors[0].id, "postgres.E002")
  744. def test_invalid_default(self):
  745. class MyModel(PostgreSQLModel):
  746. field = ArrayField(models.IntegerField(), default=[])
  747. model = MyModel()
  748. self.assertEqual(
  749. model.check(),
  750. [
  751. checks.Warning(
  752. msg=(
  753. "ArrayField default should be a callable instead of an "
  754. "instance so that it's not shared between all field "
  755. "instances."
  756. ),
  757. hint="Use a callable instead, e.g., use `list` instead of `[]`.",
  758. obj=MyModel._meta.get_field("field"),
  759. id="fields.E010",
  760. )
  761. ],
  762. )
  763. def test_valid_default(self):
  764. class MyModel(PostgreSQLModel):
  765. field = ArrayField(models.IntegerField(), default=list)
  766. model = MyModel()
  767. self.assertEqual(model.check(), [])
  768. def test_valid_default_none(self):
  769. class MyModel(PostgreSQLModel):
  770. field = ArrayField(models.IntegerField(), default=None)
  771. model = MyModel()
  772. self.assertEqual(model.check(), [])
  773. def test_nested_field_checks(self):
  774. """
  775. Nested ArrayFields are permitted.
  776. """
  777. class MyModel(PostgreSQLModel):
  778. field = ArrayField(ArrayField(models.CharField(max_length=-1)))
  779. model = MyModel()
  780. errors = model.check()
  781. self.assertEqual(len(errors), 1)
  782. # The inner CharField has a non-positive max_length.
  783. self.assertEqual(errors[0].id, "postgres.E001")
  784. self.assertIn("max_length", errors[0].msg)
  785. def test_choices_tuple_list(self):
  786. class MyModel(PostgreSQLModel):
  787. field = ArrayField(
  788. models.CharField(max_length=16),
  789. choices=[
  790. [
  791. "Media",
  792. [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")],
  793. ],
  794. (["mp3", "mp4"], "Digital"),
  795. ],
  796. )
  797. self.assertEqual(MyModel._meta.get_field("field").check(), [])
  798. @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
  799. class TestMigrations(TransactionTestCase):
  800. available_apps = ["postgres_tests"]
  801. def test_deconstruct(self):
  802. field = ArrayField(models.IntegerField())
  803. name, path, args, kwargs = field.deconstruct()
  804. new = ArrayField(*args, **kwargs)
  805. self.assertEqual(type(new.base_field), type(field.base_field))
  806. self.assertIsNot(new.base_field, field.base_field)
  807. def test_deconstruct_with_size(self):
  808. field = ArrayField(models.IntegerField(), size=3)
  809. name, path, args, kwargs = field.deconstruct()
  810. new = ArrayField(*args, **kwargs)
  811. self.assertEqual(new.size, field.size)
  812. def test_deconstruct_args(self):
  813. field = ArrayField(models.CharField(max_length=20))
  814. name, path, args, kwargs = field.deconstruct()
  815. new = ArrayField(*args, **kwargs)
  816. self.assertEqual(new.base_field.max_length, field.base_field.max_length)
  817. def test_subclass_deconstruct(self):
  818. field = ArrayField(models.IntegerField())
  819. name, path, args, kwargs = field.deconstruct()
  820. self.assertEqual(path, "django.contrib.postgres.fields.ArrayField")
  821. field = ArrayFieldSubclass()
  822. name, path, args, kwargs = field.deconstruct()
  823. self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass")
  824. @override_settings(
  825. MIGRATION_MODULES={
  826. "postgres_tests": "postgres_tests.array_default_migrations",
  827. }
  828. )
  829. def test_adding_field_with_default(self):
  830. # See #22962
  831. table_name = "postgres_tests_integerarraydefaultmodel"
  832. with connection.cursor() as cursor:
  833. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  834. call_command("migrate", "postgres_tests", verbosity=0)
  835. with connection.cursor() as cursor:
  836. self.assertIn(table_name, connection.introspection.table_names(cursor))
  837. call_command("migrate", "postgres_tests", "zero", verbosity=0)
  838. with connection.cursor() as cursor:
  839. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  840. @override_settings(
  841. MIGRATION_MODULES={
  842. "postgres_tests": "postgres_tests.array_index_migrations",
  843. }
  844. )
  845. def test_adding_arrayfield_with_index(self):
  846. """
  847. ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes.
  848. """
  849. table_name = "postgres_tests_chartextarrayindexmodel"
  850. call_command("migrate", "postgres_tests", verbosity=0)
  851. with connection.cursor() as cursor:
  852. like_constraint_columns_list = [
  853. v["columns"]
  854. for k, v in list(
  855. connection.introspection.get_constraints(cursor, table_name).items()
  856. )
  857. if k.endswith("_like")
  858. ]
  859. # Only the CharField should have a LIKE index.
  860. self.assertEqual(like_constraint_columns_list, [["char2"]])
  861. # All fields should have regular indexes.
  862. with connection.cursor() as cursor:
  863. indexes = [
  864. c["columns"][0]
  865. for c in connection.introspection.get_constraints(
  866. cursor, table_name
  867. ).values()
  868. if c["index"] and len(c["columns"]) == 1
  869. ]
  870. self.assertIn("char", indexes)
  871. self.assertIn("char2", indexes)
  872. self.assertIn("text", indexes)
  873. call_command("migrate", "postgres_tests", "zero", verbosity=0)
  874. with connection.cursor() as cursor:
  875. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  876. class TestSerialization(PostgreSQLSimpleTestCase):
  877. test_data = (
  878. '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, '
  879. '"model": "postgres_tests.integerarraymodel", "pk": null}]'
  880. )
  881. def test_dumping(self):
  882. instance = IntegerArrayModel(field=[1, 2, None])
  883. data = serializers.serialize("json", [instance])
  884. self.assertEqual(json.loads(data), json.loads(self.test_data))
  885. def test_loading(self):
  886. instance = list(serializers.deserialize("json", self.test_data))[0].object
  887. self.assertEqual(instance.field, [1, 2, None])
  888. class TestValidation(PostgreSQLSimpleTestCase):
  889. def test_unbounded(self):
  890. field = ArrayField(models.IntegerField())
  891. with self.assertRaises(exceptions.ValidationError) as cm:
  892. field.clean([1, None], None)
  893. self.assertEqual(cm.exception.code, "item_invalid")
  894. self.assertEqual(
  895. cm.exception.message % cm.exception.params,
  896. "Item 2 in the array did not validate: This field cannot be null.",
  897. )
  898. def test_blank_true(self):
  899. field = ArrayField(models.IntegerField(blank=True, null=True))
  900. # This should not raise a validation error
  901. field.clean([1, None], None)
  902. def test_with_size(self):
  903. field = ArrayField(models.IntegerField(), size=3)
  904. field.clean([1, 2, 3], None)
  905. with self.assertRaises(exceptions.ValidationError) as cm:
  906. field.clean([1, 2, 3, 4], None)
  907. self.assertEqual(
  908. cm.exception.messages[0],
  909. "List contains 4 items, it should contain no more than 3.",
  910. )
  911. def test_with_size_singular(self):
  912. field = ArrayField(models.IntegerField(), size=1)
  913. field.clean([1], None)
  914. msg = "List contains 2 items, it should contain no more than 1."
  915. with self.assertRaisesMessage(exceptions.ValidationError, msg):
  916. field.clean([1, 2], None)
  917. def test_nested_array_mismatch(self):
  918. field = ArrayField(ArrayField(models.IntegerField()))
  919. field.clean([[1, 2], [3, 4]], None)
  920. with self.assertRaises(exceptions.ValidationError) as cm:
  921. field.clean([[1, 2], [3, 4, 5]], None)
  922. self.assertEqual(cm.exception.code, "nested_array_mismatch")
  923. self.assertEqual(
  924. cm.exception.messages[0], "Nested arrays must have the same length."
  925. )
  926. def test_with_base_field_error_params(self):
  927. field = ArrayField(models.CharField(max_length=2))
  928. with self.assertRaises(exceptions.ValidationError) as cm:
  929. field.clean(["abc"], None)
  930. self.assertEqual(len(cm.exception.error_list), 1)
  931. exception = cm.exception.error_list[0]
  932. self.assertEqual(
  933. exception.message,
  934. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  935. "characters (it has 3).",
  936. )
  937. self.assertEqual(exception.code, "item_invalid")
  938. self.assertEqual(
  939. exception.params,
  940. {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
  941. )
  942. def test_with_validators(self):
  943. field = ArrayField(
  944. models.IntegerField(validators=[validators.MinValueValidator(1)])
  945. )
  946. field.clean([1, 2], None)
  947. with self.assertRaises(exceptions.ValidationError) as cm:
  948. field.clean([0], None)
  949. self.assertEqual(len(cm.exception.error_list), 1)
  950. exception = cm.exception.error_list[0]
  951. self.assertEqual(
  952. exception.message,
  953. "Item 1 in the array did not validate: Ensure this value is greater than "
  954. "or equal to 1.",
  955. )
  956. self.assertEqual(exception.code, "item_invalid")
  957. self.assertEqual(
  958. exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0}
  959. )
  960. class TestSimpleFormField(PostgreSQLSimpleTestCase):
  961. def test_valid(self):
  962. field = SimpleArrayField(forms.CharField())
  963. value = field.clean("a,b,c")
  964. self.assertEqual(value, ["a", "b", "c"])
  965. def test_to_python_fail(self):
  966. field = SimpleArrayField(forms.IntegerField())
  967. with self.assertRaises(exceptions.ValidationError) as cm:
  968. field.clean("a,b,9")
  969. self.assertEqual(
  970. cm.exception.messages[0],
  971. "Item 1 in the array did not validate: Enter a whole number.",
  972. )
  973. def test_validate_fail(self):
  974. field = SimpleArrayField(forms.CharField(required=True))
  975. with self.assertRaises(exceptions.ValidationError) as cm:
  976. field.clean("a,b,")
  977. self.assertEqual(
  978. cm.exception.messages[0],
  979. "Item 3 in the array did not validate: This field is required.",
  980. )
  981. def test_validate_fail_base_field_error_params(self):
  982. field = SimpleArrayField(forms.CharField(max_length=2))
  983. with self.assertRaises(exceptions.ValidationError) as cm:
  984. field.clean("abc,c,defg")
  985. errors = cm.exception.error_list
  986. self.assertEqual(len(errors), 2)
  987. first_error = errors[0]
  988. self.assertEqual(
  989. first_error.message,
  990. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  991. "characters (it has 3).",
  992. )
  993. self.assertEqual(first_error.code, "item_invalid")
  994. self.assertEqual(
  995. first_error.params,
  996. {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
  997. )
  998. second_error = errors[1]
  999. self.assertEqual(
  1000. second_error.message,
  1001. "Item 3 in the array did not validate: Ensure this value has at most 2 "
  1002. "characters (it has 4).",
  1003. )
  1004. self.assertEqual(second_error.code, "item_invalid")
  1005. self.assertEqual(
  1006. second_error.params,
  1007. {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4},
  1008. )
  1009. def test_validators_fail(self):
  1010. field = SimpleArrayField(forms.RegexField("[a-e]{2}"))
  1011. with self.assertRaises(exceptions.ValidationError) as cm:
  1012. field.clean("a,bc,de")
  1013. self.assertEqual(
  1014. cm.exception.messages[0],
  1015. "Item 1 in the array did not validate: Enter a valid value.",
  1016. )
  1017. def test_delimiter(self):
  1018. field = SimpleArrayField(forms.CharField(), delimiter="|")
  1019. value = field.clean("a|b|c")
  1020. self.assertEqual(value, ["a", "b", "c"])
  1021. def test_delimiter_with_nesting(self):
  1022. field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|")
  1023. value = field.clean("a,b|c,d")
  1024. self.assertEqual(value, [["a", "b"], ["c", "d"]])
  1025. def test_prepare_value(self):
  1026. field = SimpleArrayField(forms.CharField())
  1027. value = field.prepare_value(["a", "b", "c"])
  1028. self.assertEqual(value, "a,b,c")
  1029. def test_max_length(self):
  1030. field = SimpleArrayField(forms.CharField(), max_length=2)
  1031. with self.assertRaises(exceptions.ValidationError) as cm:
  1032. field.clean("a,b,c")
  1033. self.assertEqual(
  1034. cm.exception.messages[0],
  1035. "List contains 3 items, it should contain no more than 2.",
  1036. )
  1037. def test_min_length(self):
  1038. field = SimpleArrayField(forms.CharField(), min_length=4)
  1039. with self.assertRaises(exceptions.ValidationError) as cm:
  1040. field.clean("a,b,c")
  1041. self.assertEqual(
  1042. cm.exception.messages[0],
  1043. "List contains 3 items, it should contain no fewer than 4.",
  1044. )
  1045. def test_min_length_singular(self):
  1046. field = SimpleArrayField(forms.IntegerField(), min_length=2)
  1047. field.clean([1, 2])
  1048. msg = "List contains 1 item, it should contain no fewer than 2."
  1049. with self.assertRaisesMessage(exceptions.ValidationError, msg):
  1050. field.clean([1])
  1051. def test_required(self):
  1052. field = SimpleArrayField(forms.CharField(), required=True)
  1053. with self.assertRaises(exceptions.ValidationError) as cm:
  1054. field.clean("")
  1055. self.assertEqual(cm.exception.messages[0], "This field is required.")
  1056. def test_model_field_formfield(self):
  1057. model_field = ArrayField(models.CharField(max_length=27))
  1058. form_field = model_field.formfield()
  1059. self.assertIsInstance(form_field, SimpleArrayField)
  1060. self.assertIsInstance(form_field.base_field, forms.CharField)
  1061. self.assertEqual(form_field.base_field.max_length, 27)
  1062. def test_model_field_formfield_size(self):
  1063. model_field = ArrayField(models.CharField(max_length=27), size=4)
  1064. form_field = model_field.formfield()
  1065. self.assertIsInstance(form_field, SimpleArrayField)
  1066. self.assertEqual(form_field.max_length, 4)
  1067. def test_model_field_choices(self):
  1068. model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
  1069. form_field = model_field.formfield()
  1070. self.assertEqual(form_field.clean("1,2"), [1, 2])
  1071. def test_already_converted_value(self):
  1072. field = SimpleArrayField(forms.CharField())
  1073. vals = ["a", "b", "c"]
  1074. self.assertEqual(field.clean(vals), vals)
  1075. def test_has_changed(self):
  1076. field = SimpleArrayField(forms.IntegerField())
  1077. self.assertIs(field.has_changed([1, 2], [1, 2]), False)
  1078. self.assertIs(field.has_changed([1, 2], "1,2"), False)
  1079. self.assertIs(field.has_changed([1, 2], "1,2,3"), True)
  1080. self.assertIs(field.has_changed([1, 2], "a,b"), True)
  1081. def test_has_changed_empty(self):
  1082. field = SimpleArrayField(forms.CharField())
  1083. self.assertIs(field.has_changed(None, None), False)
  1084. self.assertIs(field.has_changed(None, ""), False)
  1085. self.assertIs(field.has_changed(None, []), False)
  1086. self.assertIs(field.has_changed([], None), False)
  1087. self.assertIs(field.has_changed([], ""), False)
  1088. class TestSplitFormField(PostgreSQLSimpleTestCase):
  1089. def test_valid(self):
  1090. class SplitForm(forms.Form):
  1091. array = SplitArrayField(forms.CharField(), size=3)
  1092. data = {"array_0": "a", "array_1": "b", "array_2": "c"}
  1093. form = SplitForm(data)
  1094. self.assertTrue(form.is_valid())
  1095. self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]})
  1096. def test_required(self):
  1097. class SplitForm(forms.Form):
  1098. array = SplitArrayField(forms.CharField(), required=True, size=3)
  1099. data = {"array_0": "", "array_1": "", "array_2": ""}
  1100. form = SplitForm(data)
  1101. self.assertFalse(form.is_valid())
  1102. self.assertEqual(form.errors, {"array": ["This field is required."]})
  1103. def test_remove_trailing_nulls(self):
  1104. class SplitForm(forms.Form):
  1105. array = SplitArrayField(
  1106. forms.CharField(required=False), size=5, remove_trailing_nulls=True
  1107. )
  1108. data = {
  1109. "array_0": "a",
  1110. "array_1": "",
  1111. "array_2": "b",
  1112. "array_3": "",
  1113. "array_4": "",
  1114. }
  1115. form = SplitForm(data)
  1116. self.assertTrue(form.is_valid(), form.errors)
  1117. self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]})
  1118. def test_remove_trailing_nulls_not_required(self):
  1119. class SplitForm(forms.Form):
  1120. array = SplitArrayField(
  1121. forms.CharField(required=False),
  1122. size=2,
  1123. remove_trailing_nulls=True,
  1124. required=False,
  1125. )
  1126. data = {"array_0": "", "array_1": ""}
  1127. form = SplitForm(data)
  1128. self.assertTrue(form.is_valid())
  1129. self.assertEqual(form.cleaned_data, {"array": []})
  1130. def test_required_field(self):
  1131. class SplitForm(forms.Form):
  1132. array = SplitArrayField(forms.CharField(), size=3)
  1133. data = {"array_0": "a", "array_1": "b", "array_2": ""}
  1134. form = SplitForm(data)
  1135. self.assertFalse(form.is_valid())
  1136. self.assertEqual(
  1137. form.errors,
  1138. {
  1139. "array": [
  1140. "Item 3 in the array did not validate: This field is required."
  1141. ]
  1142. },
  1143. )
  1144. def test_invalid_integer(self):
  1145. msg = (
  1146. "Item 2 in the array did not validate: Ensure this value is less than or "
  1147. "equal to 100."
  1148. )
  1149. with self.assertRaisesMessage(exceptions.ValidationError, msg):
  1150. SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101])
  1151. def test_rendering(self):
  1152. class SplitForm(forms.Form):
  1153. array = SplitArrayField(forms.CharField(), size=3)
  1154. self.assertHTMLEqual(
  1155. str(SplitForm()),
  1156. """
  1157. <div>
  1158. <label for="id_array_0">Array:</label>
  1159. <input id="id_array_0" name="array_0" type="text" required>
  1160. <input id="id_array_1" name="array_1" type="text" required>
  1161. <input id="id_array_2" name="array_2" type="text" required>
  1162. </div>
  1163. """,
  1164. )
  1165. def test_invalid_char_length(self):
  1166. field = SplitArrayField(forms.CharField(max_length=2), size=3)
  1167. with self.assertRaises(exceptions.ValidationError) as cm:
  1168. field.clean(["abc", "c", "defg"])
  1169. self.assertEqual(
  1170. cm.exception.messages,
  1171. [
  1172. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  1173. "characters (it has 3).",
  1174. "Item 3 in the array did not validate: Ensure this value has at most 2 "
  1175. "characters (it has 4).",
  1176. ],
  1177. )
  1178. def test_splitarraywidget_value_omitted_from_data(self):
  1179. class Form(forms.ModelForm):
  1180. field = SplitArrayField(forms.IntegerField(), required=False, size=2)
  1181. class Meta:
  1182. model = IntegerArrayModel
  1183. fields = ("field",)
  1184. form = Form({"field_0": "1", "field_1": "2"})
  1185. self.assertEqual(form.errors, {})
  1186. obj = form.save(commit=False)
  1187. self.assertEqual(obj.field, [1, 2])
  1188. def test_splitarrayfield_has_changed(self):
  1189. class Form(forms.ModelForm):
  1190. field = SplitArrayField(forms.IntegerField(), required=False, size=2)
  1191. class Meta:
  1192. model = IntegerArrayModel
  1193. fields = ("field",)
  1194. tests = [
  1195. ({}, {"field_0": "", "field_1": ""}, True),
  1196. ({"field": None}, {"field_0": "", "field_1": ""}, True),
  1197. ({"field": [1]}, {"field_0": "", "field_1": ""}, True),
  1198. ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True),
  1199. ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False),
  1200. ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True),
  1201. ]
  1202. for initial, data, expected_result in tests:
  1203. with self.subTest(initial=initial, data=data):
  1204. obj = IntegerArrayModel(**initial)
  1205. form = Form(data, instance=obj)
  1206. self.assertIs(form.has_changed(), expected_result)
  1207. def test_splitarrayfield_remove_trailing_nulls_has_changed(self):
  1208. class Form(forms.ModelForm):
  1209. field = SplitArrayField(
  1210. forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True
  1211. )
  1212. class Meta:
  1213. model = IntegerArrayModel
  1214. fields = ("field",)
  1215. tests = [
  1216. ({}, {"field_0": "", "field_1": ""}, False),
  1217. ({"field": None}, {"field_0": "", "field_1": ""}, False),
  1218. ({"field": []}, {"field_0": "", "field_1": ""}, False),
  1219. ({"field": [1]}, {"field_0": "1", "field_1": ""}, False),
  1220. ]
  1221. for initial, data, expected_result in tests:
  1222. with self.subTest(initial=initial, data=data):
  1223. obj = IntegerArrayModel(**initial)
  1224. form = Form(data, instance=obj)
  1225. self.assertIs(form.has_changed(), expected_result)
  1226. class TestSplitFormWidget(PostgreSQLWidgetTestCase):
  1227. def test_get_context(self):
  1228. self.assertEqual(
  1229. SplitArrayWidget(forms.TextInput(), size=2).get_context(
  1230. "name", ["val1", "val2"]
  1231. ),
  1232. {
  1233. "widget": {
  1234. "name": "name",
  1235. "is_hidden": False,
  1236. "required": False,
  1237. "value": "['val1', 'val2']",
  1238. "attrs": {},
  1239. "template_name": "postgres/widgets/split_array.html",
  1240. "subwidgets": [
  1241. {
  1242. "name": "name_0",
  1243. "is_hidden": False,
  1244. "required": False,
  1245. "value": "val1",
  1246. "attrs": {},
  1247. "template_name": "django/forms/widgets/text.html",
  1248. "type": "text",
  1249. },
  1250. {
  1251. "name": "name_1",
  1252. "is_hidden": False,
  1253. "required": False,
  1254. "value": "val2",
  1255. "attrs": {},
  1256. "template_name": "django/forms/widgets/text.html",
  1257. "type": "text",
  1258. },
  1259. ],
  1260. }
  1261. },
  1262. )
  1263. def test_checkbox_get_context_attrs(self):
  1264. context = SplitArrayWidget(
  1265. forms.CheckboxInput(),
  1266. size=2,
  1267. ).get_context("name", [True, False])
  1268. self.assertEqual(context["widget"]["value"], "[True, False]")
  1269. self.assertEqual(
  1270. [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]],
  1271. [{"checked": True}, {}],
  1272. )
  1273. def test_render(self):
  1274. self.check_html(
  1275. SplitArrayWidget(forms.TextInput(), size=2),
  1276. "array",
  1277. None,
  1278. """
  1279. <input name="array_0" type="text">
  1280. <input name="array_1" type="text">
  1281. """,
  1282. )
  1283. def test_render_attrs(self):
  1284. self.check_html(
  1285. SplitArrayWidget(forms.TextInput(), size=2),
  1286. "array",
  1287. ["val1", "val2"],
  1288. attrs={"id": "foo"},
  1289. html=(
  1290. """
  1291. <input id="foo_0" name="array_0" type="text" value="val1">
  1292. <input id="foo_1" name="array_1" type="text" value="val2">
  1293. """
  1294. ),
  1295. )
  1296. def test_value_omitted_from_data(self):
  1297. widget = SplitArrayWidget(forms.TextInput(), size=2)
  1298. self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True)
  1299. self.assertIs(
  1300. widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False
  1301. )
  1302. self.assertIs(
  1303. widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False
  1304. )
  1305. self.assertIs(
  1306. widget.value_omitted_from_data(
  1307. {"field_0": "value", "field_1": "value"}, {}, "field"
  1308. ),
  1309. False,
  1310. )
  1311. class TestAdminUtils(PostgreSQLTestCase):
  1312. empty_value = "-empty-"
  1313. def test_array_display_for_field(self):
  1314. array_field = ArrayField(models.IntegerField())
  1315. display_value = display_for_field(
  1316. [1, 2],
  1317. array_field,
  1318. self.empty_value,
  1319. )
  1320. self.assertEqual(display_value, "1, 2")
  1321. def test_array_with_choices_display_for_field(self):
  1322. array_field = ArrayField(
  1323. models.IntegerField(),
  1324. choices=[
  1325. ([1, 2, 3], "1st choice"),
  1326. ([1, 2], "2nd choice"),
  1327. ],
  1328. )
  1329. display_value = display_for_field(
  1330. [1, 2],
  1331. array_field,
  1332. self.empty_value,
  1333. )
  1334. self.assertEqual(display_value, "2nd choice")
  1335. display_value = display_for_field(
  1336. [99, 99],
  1337. array_field,
  1338. self.empty_value,
  1339. )
  1340. self.assertEqual(display_value, self.empty_value)