test_array.py 52 KB


  1. import decimal
  2. import enum
  3. import json
  4. import unittest
  5. import uuid
  6. from django import forms
  7. from django.contrib.admin.utils import display_for_field
  8. from django.core import checks, exceptions, serializers, validators
  9. from django.core.exceptions import FieldError
  10. from django.core.management import call_command
  11. from django.db import IntegrityError, connection, models
  12. from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
  13. from django.db.models.functions import Cast, JSONObject, Upper
  14. from django.test import TransactionTestCase, override_settings, skipUnlessDBFeature
  15. from django.test.utils import isolate_apps
  16. from django.utils import timezone
  17. from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase, PostgreSQLWidgetTestCase
  18. from .models import (
  19. ArrayEnumModel,
  20. ArrayFieldSubclass,
  21. CharArrayModel,
  22. DateTimeArrayModel,
  23. IntegerArrayModel,
  24. NestedIntegerArrayModel,
  25. NullableIntegerArrayModel,
  26. OtherTypesArrayModel,
  27. PostgreSQLModel,
  28. Tag,
  29. )
  30. try:
  31. from django.contrib.postgres.aggregates import ArrayAgg
  32. from django.contrib.postgres.expressions import ArraySubquery
  33. from django.contrib.postgres.fields import ArrayField
  34. from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
  35. from django.contrib.postgres.forms import (
  36. SimpleArrayField,
  37. SplitArrayField,
  38. SplitArrayWidget,
  39. )
  40. from django.db.backends.postgresql.psycopg_any import NumericRange
  41. except ImportError:
  42. pass
  43. @isolate_apps("postgres_tests")
  44. class BasicTests(PostgreSQLSimpleTestCase):
  45. def test_get_field_display(self):
  46. class MyModel(PostgreSQLModel):
  47. field = ArrayField(
  48. models.CharField(max_length=16),
  49. choices=[
  50. ["Media", [(["vinyl", "cd"], "Audio")]],
  51. (("mp3", "mp4"), "Digital"),
  52. ],
  53. )
  54. tests = (
  55. (["vinyl", "cd"], "Audio"),
  56. (("mp3", "mp4"), "Digital"),
  57. (("a", "b"), "('a', 'b')"),
  58. (["c", "d"], "['c', 'd']"),
  59. )
  60. for value, display in tests:
  61. with self.subTest(value=value, display=display):
  62. instance = MyModel(field=value)
  63. self.assertEqual(instance.get_field_display(), display)
  64. def test_get_field_display_nested_array(self):
  65. class MyModel(PostgreSQLModel):
  66. field = ArrayField(
  67. ArrayField(models.CharField(max_length=16)),
  68. choices=[
  69. [
  70. "Media",
  71. [([["vinyl", "cd"], ("x",)], "Audio")],
  72. ],
  73. ((["mp3"], ("mp4",)), "Digital"),
  74. ],
  75. )
  76. tests = (
  77. ([["vinyl", "cd"], ("x",)], "Audio"),
  78. ((["mp3"], ("mp4",)), "Digital"),
  79. ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"),
  80. ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"),
  81. )
  82. for value, display in tests:
  83. with self.subTest(value=value, display=display):
  84. instance = MyModel(field=value)
  85. self.assertEqual(instance.get_field_display(), display)
  86. class TestSaveLoad(PostgreSQLTestCase):
  87. def test_integer(self):
  88. instance = IntegerArrayModel(field=[1, 2, 3])
  89. instance.save()
  90. loaded = IntegerArrayModel.objects.get()
  91. self.assertEqual(instance.field, loaded.field)
  92. def test_char(self):
  93. instance = CharArrayModel(field=["hello", "goodbye"])
  94. instance.save()
  95. loaded = CharArrayModel.objects.get()
  96. self.assertEqual(instance.field, loaded.field)
  97. def test_dates(self):
  98. instance = DateTimeArrayModel(
  99. datetimes=[timezone.now()],
  100. dates=[timezone.now().date()],
  101. times=[timezone.now().time()],
  102. )
  103. instance.save()
  104. loaded = DateTimeArrayModel.objects.get()
  105. self.assertEqual(instance.datetimes, loaded.datetimes)
  106. self.assertEqual(instance.dates, loaded.dates)
  107. self.assertEqual(instance.times, loaded.times)
  108. def test_tuples(self):
  109. instance = IntegerArrayModel(field=(1,))
  110. instance.save()
  111. loaded = IntegerArrayModel.objects.get()
  112. self.assertSequenceEqual(instance.field, loaded.field)
  113. def test_integers_passed_as_strings(self):
  114. # This checks that get_prep_value is deferred properly
  115. instance = IntegerArrayModel(field=["1"])
  116. instance.save()
  117. loaded = IntegerArrayModel.objects.get()
  118. self.assertEqual(loaded.field, [1])
  119. def test_default_null(self):
  120. instance = NullableIntegerArrayModel()
  121. instance.save()
  122. loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk)
  123. self.assertIsNone(loaded.field)
  124. self.assertEqual(instance.field, loaded.field)
  125. def test_null_handling(self):
  126. instance = NullableIntegerArrayModel(field=None)
  127. instance.save()
  128. loaded = NullableIntegerArrayModel.objects.get()
  129. self.assertEqual(instance.field, loaded.field)
  130. instance = IntegerArrayModel(field=None)
  131. with self.assertRaises(IntegrityError):
  132. instance.save()
  133. def test_nested(self):
  134. instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
  135. instance.save()
  136. loaded = NestedIntegerArrayModel.objects.get()
  137. self.assertEqual(instance.field, loaded.field)
  138. def test_other_array_types(self):
  139. instance = OtherTypesArrayModel(
  140. ips=["192.168.0.1", "::1"],
  141. uuids=[uuid.uuid4()],
  142. decimals=[decimal.Decimal(1.25), 1.75],
  143. tags=[Tag(1), Tag(2), Tag(3)],
  144. json=[{"a": 1}, {"b": 2}],
  145. int_ranges=[NumericRange(10, 20), NumericRange(30, 40)],
  146. bigint_ranges=[
  147. NumericRange(7000000000, 10000000000),
  148. NumericRange(50000000000, 70000000000),
  149. ],
  150. )
  151. instance.save()
  152. loaded = OtherTypesArrayModel.objects.get()
  153. self.assertEqual(instance.ips, loaded.ips)
  154. self.assertEqual(instance.uuids, loaded.uuids)
  155. self.assertEqual(instance.decimals, loaded.decimals)
  156. self.assertEqual(instance.tags, loaded.tags)
  157. self.assertEqual(instance.json, loaded.json)
  158. self.assertEqual(instance.int_ranges, loaded.int_ranges)
  159. self.assertEqual(instance.bigint_ranges, loaded.bigint_ranges)
  160. def test_null_from_db_value_handling(self):
  161. instance = OtherTypesArrayModel.objects.create(
  162. ips=["192.168.0.1", "::1"],
  163. uuids=[uuid.uuid4()],
  164. decimals=[decimal.Decimal(1.25), 1.75],
  165. tags=None,
  166. )
  167. instance.refresh_from_db()
  168. self.assertIsNone(instance.tags)
  169. self.assertEqual(instance.json, [])
  170. self.assertIsNone(instance.int_ranges)
  171. self.assertIsNone(instance.bigint_ranges)
  172. def test_model_set_on_base_field(self):
  173. instance = IntegerArrayModel()
  174. field = instance._meta.get_field("field")
  175. self.assertEqual(field.model, IntegerArrayModel)
  176. self.assertEqual(field.base_field.model, IntegerArrayModel)
  177. def test_nested_nullable_base_field(self):
  178. instance = NullableIntegerArrayModel.objects.create(
  179. field_nested=[[None, None], [None, None]],
  180. )
  181. self.assertEqual(instance.field_nested, [[None, None], [None, None]])
  182. class TestQuerying(PostgreSQLTestCase):
  183. @classmethod
  184. def setUpTestData(cls):
  185. cls.objs = NullableIntegerArrayModel.objects.bulk_create(
  186. [
  187. NullableIntegerArrayModel(order=1, field=[1]),
  188. NullableIntegerArrayModel(order=2, field=[2]),
  189. NullableIntegerArrayModel(order=3, field=[2, 3]),
  190. NullableIntegerArrayModel(order=4, field=[20, 30, 40]),
  191. NullableIntegerArrayModel(order=5, field=None),
  192. ]
  193. )
  194. def test_empty_list(self):
  195. NullableIntegerArrayModel.objects.create(field=[])
  196. obj = (
  197. NullableIntegerArrayModel.objects.annotate(
  198. empty_array=models.Value(
  199. [], output_field=ArrayField(models.IntegerField())
  200. ),
  201. )
  202. .filter(field=models.F("empty_array"))
  203. .get()
  204. )
  205. self.assertEqual(obj.field, [])
  206. self.assertEqual(obj.empty_array, [])
  207. def test_exact(self):
  208. self.assertSequenceEqual(
  209. NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1]
  210. )
  211. def test_exact_null_only_array(self):
  212. obj = NullableIntegerArrayModel.objects.create(
  213. field=[None], field_nested=[None, None]
  214. )
  215. self.assertSequenceEqual(
  216. NullableIntegerArrayModel.objects.filter(field__exact=[None]), [obj]
  217. )
  218. self.assertSequenceEqual(
  219. NullableIntegerArrayModel.objects.filter(field_nested__exact=[None, None]),
  220. [obj],
  221. )
  222. def test_exact_null_only_nested_array(self):
  223. obj1 = NullableIntegerArrayModel.objects.create(field_nested=[[None, None]])
  224. obj2 = NullableIntegerArrayModel.objects.create(
  225. field_nested=[[None, None], [None, None]],
  226. )
  227. self.assertSequenceEqual(
  228. NullableIntegerArrayModel.objects.filter(
  229. field_nested__exact=[[None, None]],
  230. ),
  231. [obj1],
  232. )
  233. self.assertSequenceEqual(
  234. NullableIntegerArrayModel.objects.filter(
  235. field_nested__exact=[[None, None], [None, None]],
  236. ),
  237. [obj2],
  238. )
  239. def test_exact_with_expression(self):
  240. self.assertSequenceEqual(
  241. NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]),
  242. self.objs[:1],
  243. )
  244. def test_exact_charfield(self):
  245. instance = CharArrayModel.objects.create(field=["text"])
  246. self.assertSequenceEqual(
  247. CharArrayModel.objects.filter(field=["text"]), [instance]
  248. )
  249. def test_exact_nested(self):
  250. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  251. self.assertSequenceEqual(
  252. NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance]
  253. )
  254. def test_isnull(self):
  255. self.assertSequenceEqual(
  256. NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:]
  257. )
  258. def test_gt(self):
  259. self.assertSequenceEqual(
  260. NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4]
  261. )
  262. def test_lt(self):
  263. self.assertSequenceEqual(
  264. NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1]
  265. )
  266. def test_in(self):
  267. self.assertSequenceEqual(
  268. NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
  269. self.objs[:2],
  270. )
  271. def test_in_subquery(self):
  272. IntegerArrayModel.objects.create(field=[2, 3])
  273. self.assertSequenceEqual(
  274. NullableIntegerArrayModel.objects.filter(
  275. field__in=IntegerArrayModel.objects.values_list("field", flat=True)
  276. ),
  277. self.objs[2:3],
  278. )
  279. @unittest.expectedFailure
  280. def test_in_including_F_object(self):
  281. # This test asserts that Array objects passed to filters can be
  282. # constructed to contain F objects. This currently doesn't work as the
  283. # psycopg mogrify method that generates the ARRAY() syntax is
  284. # expecting literals, not column references (#27095).
  285. self.assertSequenceEqual(
  286. NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]),
  287. self.objs[:2],
  288. )
  289. def test_in_as_F_object(self):
  290. self.assertSequenceEqual(
  291. NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]),
  292. self.objs[:4],
  293. )
  294. def test_contained_by(self):
  295. self.assertSequenceEqual(
  296. NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
  297. self.objs[:2],
  298. )
  299. def test_contained_by_including_F_object(self):
  300. self.assertSequenceEqual(
  301. NullableIntegerArrayModel.objects.filter(
  302. field__contained_by=[models.F("order"), 2]
  303. ),
  304. self.objs[:3],
  305. )
  306. def test_contains(self):
  307. self.assertSequenceEqual(
  308. NullableIntegerArrayModel.objects.filter(field__contains=[2]),
  309. self.objs[1:3],
  310. )
  311. def test_contains_subquery(self):
  312. IntegerArrayModel.objects.create(field=[2, 3])
  313. inner_qs = IntegerArrayModel.objects.values_list("field", flat=True)
  314. self.assertSequenceEqual(
  315. NullableIntegerArrayModel.objects.filter(field__contains=inner_qs[:1]),
  316. self.objs[2:3],
  317. )
  318. inner_qs = IntegerArrayModel.objects.filter(field__contains=OuterRef("field"))
  319. self.assertSequenceEqual(
  320. NullableIntegerArrayModel.objects.filter(Exists(inner_qs)),
  321. self.objs[1:3],
  322. )
  323. def test_contains_including_expression(self):
  324. self.assertSequenceEqual(
  325. NullableIntegerArrayModel.objects.filter(
  326. field__contains=[2, Value(6) / Value(2)],
  327. ),
  328. self.objs[2:3],
  329. )
  330. def test_icontains(self):
  331. # Using the __icontains lookup with ArrayField is inefficient.
  332. instance = CharArrayModel.objects.create(field=["FoO"])
  333. self.assertSequenceEqual(
  334. CharArrayModel.objects.filter(field__icontains="foo"), [instance]
  335. )
  336. def test_contains_charfield(self):
  337. # Regression for #22907
  338. self.assertSequenceEqual(
  339. CharArrayModel.objects.filter(field__contains=["text"]), []
  340. )
  341. def test_contained_by_charfield(self):
  342. self.assertSequenceEqual(
  343. CharArrayModel.objects.filter(field__contained_by=["text"]), []
  344. )
  345. def test_overlap_charfield(self):
  346. self.assertSequenceEqual(
  347. CharArrayModel.objects.filter(field__overlap=["text"]), []
  348. )
  349. def test_overlap_charfield_including_expression(self):
  350. obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
  351. obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
  352. CharArrayModel.objects.create(field=["lower text", "text"])
  353. self.assertSequenceEqual(
  354. CharArrayModel.objects.filter(
  355. field__overlap=[
  356. Upper(Value("text")),
  357. "other",
  358. ]
  359. ),
  360. [obj_1, obj_2],
  361. )
  362. def test_overlap_values(self):
  363. qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
  364. self.assertCountEqual(
  365. NullableIntegerArrayModel.objects.filter(
  366. field__overlap=qs.values_list("field"),
  367. ),
  368. self.objs[:3],
  369. )
  370. self.assertCountEqual(
  371. NullableIntegerArrayModel.objects.filter(
  372. field__overlap=qs.values("field"),
  373. ),
  374. self.objs[:3],
  375. )
  376. def test_lookups_autofield_array(self):
  377. qs = (
  378. NullableIntegerArrayModel.objects.filter(
  379. field__0__isnull=False,
  380. )
  381. .values("field__0")
  382. .annotate(
  383. arrayagg=ArrayAgg("id"),
  384. )
  385. .order_by("field__0")
  386. )
  387. tests = (
  388. ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]),
  389. ("contains", [self.objs[2].pk], [2]),
  390. ("exact", [self.objs[3].pk], [20]),
  391. ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]),
  392. )
  393. for lookup, value, expected in tests:
  394. with self.subTest(lookup=lookup):
  395. self.assertSequenceEqual(
  396. qs.filter(
  397. **{"arrayagg__" + lookup: value},
  398. ).values_list("field__0", flat=True),
  399. expected,
  400. )
  401. @skipUnlessDBFeature("allows_group_by_refs")
  402. def test_group_by_order_by_aliases(self):
  403. with self.assertNumQueries(1) as ctx:
  404. self.assertSequenceEqual(
  405. NullableIntegerArrayModel.objects.filter(
  406. field__0__isnull=False,
  407. )
  408. .values("field__0")
  409. .annotate(arrayagg=ArrayAgg("id"))
  410. .order_by("field__0"),
  411. [
  412. {"field__0": 1, "arrayagg": [self.objs[0].pk]},
  413. {"field__0": 2, "arrayagg": [self.objs[1].pk, self.objs[2].pk]},
  414. {"field__0": 20, "arrayagg": [self.objs[3].pk]},
  415. ],
  416. )
  417. alias = connection.ops.quote_name("field__0")
  418. sql = ctx[0]["sql"]
  419. self.assertIn(f"GROUP BY {alias}", sql)
  420. self.assertIn(f"ORDER BY {alias}", sql)
  421. def test_index(self):
  422. self.assertSequenceEqual(
  423. NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
  424. )
  425. def test_index_chained(self):
  426. self.assertSequenceEqual(
  427. NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3]
  428. )
  429. def test_index_nested(self):
  430. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  431. self.assertSequenceEqual(
  432. NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]
  433. )
  434. @unittest.expectedFailure
  435. def test_index_used_on_nested_data(self):
  436. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  437. self.assertSequenceEqual(
  438. NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
  439. )
  440. def test_index_transform_expression(self):
  441. expr = RawSQL("string_to_array(%s, ';')", ["1;2"])
  442. self.assertSequenceEqual(
  443. NullableIntegerArrayModel.objects.filter(
  444. field__0=Cast(
  445. IndexTransform(1, models.IntegerField, expr),
  446. output_field=models.IntegerField(),
  447. ),
  448. ),
  449. self.objs[:1],
  450. )
  451. def test_index_annotation(self):
  452. qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1"))
  453. self.assertCountEqual(
  454. qs.values_list("second", flat=True),
  455. [None, None, None, 3, 30],
  456. )
  457. def test_overlap(self):
  458. self.assertSequenceEqual(
  459. NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
  460. self.objs[0:3],
  461. )
  462. def test_len(self):
  463. self.assertSequenceEqual(
  464. NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3]
  465. )
  466. def test_len_empty_array(self):
  467. obj = NullableIntegerArrayModel.objects.create(field=[])
  468. self.assertSequenceEqual(
  469. NullableIntegerArrayModel.objects.filter(field__len=0), [obj]
  470. )
  471. def test_slice(self):
  472. self.assertSequenceEqual(
  473. NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3]
  474. )
  475. self.assertSequenceEqual(
  476. NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
  477. )
  478. def test_order_by_slice(self):
  479. more_objs = (
  480. NullableIntegerArrayModel.objects.create(field=[1, 637]),
  481. NullableIntegerArrayModel.objects.create(field=[2, 1]),
  482. NullableIntegerArrayModel.objects.create(field=[3, -98123]),
  483. NullableIntegerArrayModel.objects.create(field=[4, 2]),
  484. )
  485. self.assertSequenceEqual(
  486. NullableIntegerArrayModel.objects.order_by("field__1"),
  487. [
  488. more_objs[2],
  489. more_objs[1],
  490. more_objs[3],
  491. self.objs[2],
  492. self.objs[3],
  493. more_objs[0],
  494. self.objs[4],
  495. self.objs[1],
  496. self.objs[0],
  497. ],
  498. )
  499. @unittest.expectedFailure
  500. def test_slice_nested(self):
  501. instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
  502. self.assertSequenceEqual(
  503. NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance]
  504. )
  505. def test_slice_transform_expression(self):
  506. expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"])
  507. self.assertSequenceEqual(
  508. NullableIntegerArrayModel.objects.filter(
  509. field__0_2=SliceTransform(2, 3, expr)
  510. ),
  511. self.objs[2:3],
  512. )
  513. def test_slice_annotation(self):
  514. qs = NullableIntegerArrayModel.objects.annotate(
  515. first_two=models.F("field__0_2"),
  516. )
  517. self.assertCountEqual(
  518. qs.values_list("first_two", flat=True),
  519. [None, [1], [2], [2, 3], [20, 30]],
  520. )
  521. def test_usage_in_subquery(self):
  522. self.assertSequenceEqual(
  523. NullableIntegerArrayModel.objects.filter(
  524. id__in=NullableIntegerArrayModel.objects.filter(field__len=3)
  525. ),
  526. [self.objs[3]],
  527. )
  528. def test_enum_lookup(self):
  529. class TestEnum(enum.Enum):
  530. VALUE_1 = "value_1"
  531. instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
  532. self.assertSequenceEqual(
  533. ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
  534. [instance],
  535. )
  536. def test_unsupported_lookup(self):
  537. msg = (
  538. "Unsupported lookup '0_bar' for ArrayField or join on the field not "
  539. "permitted."
  540. )
  541. with self.assertRaisesMessage(FieldError, msg):
  542. list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2]))
  543. msg = (
  544. "Unsupported lookup '0bar' for ArrayField or join on the field not "
  545. "permitted."
  546. )
  547. with self.assertRaisesMessage(FieldError, msg):
  548. list(NullableIntegerArrayModel.objects.filter(field__0bar=[2]))
  549. def test_grouping_by_annotations_with_array_field_param(self):
  550. value = models.Value([1], output_field=ArrayField(models.IntegerField()))
  551. self.assertEqual(
  552. NullableIntegerArrayModel.objects.annotate(
  553. array_length=models.Func(
  554. value,
  555. 1,
  556. function="ARRAY_LENGTH",
  557. output_field=models.IntegerField(),
  558. ),
  559. )
  560. .values("array_length")
  561. .annotate(
  562. count=models.Count("pk"),
  563. )
  564. .get()["array_length"],
  565. 1,
  566. )
  567. def test_filter_by_array_subquery(self):
  568. inner_qs = NullableIntegerArrayModel.objects.filter(
  569. field__len=models.OuterRef("field__len"),
  570. ).values("field")
  571. self.assertSequenceEqual(
  572. NullableIntegerArrayModel.objects.alias(
  573. same_sized_fields=ArraySubquery(inner_qs),
  574. ).filter(same_sized_fields__len__gt=1),
  575. self.objs[0:2],
  576. )
  577. def test_annotated_array_subquery(self):
  578. inner_qs = NullableIntegerArrayModel.objects.exclude(
  579. pk=models.OuterRef("pk")
  580. ).values("order")
  581. self.assertSequenceEqual(
  582. NullableIntegerArrayModel.objects.annotate(
  583. sibling_ids=ArraySubquery(inner_qs),
  584. )
  585. .get(order=1)
  586. .sibling_ids,
  587. [2, 3, 4, 5],
  588. )
  589. def test_group_by_with_annotated_array_subquery(self):
  590. inner_qs = NullableIntegerArrayModel.objects.exclude(
  591. pk=models.OuterRef("pk")
  592. ).values("order")
  593. self.assertSequenceEqual(
  594. NullableIntegerArrayModel.objects.annotate(
  595. sibling_ids=ArraySubquery(inner_qs),
  596. sibling_count=models.Max("sibling_ids__len"),
  597. ).values_list("sibling_count", flat=True),
  598. [len(self.objs) - 1] * len(self.objs),
  599. )
  600. def test_annotated_ordered_array_subquery(self):
  601. inner_qs = NullableIntegerArrayModel.objects.order_by("-order").values("order")
  602. self.assertSequenceEqual(
  603. NullableIntegerArrayModel.objects.annotate(
  604. ids=ArraySubquery(inner_qs),
  605. )
  606. .first()
  607. .ids,
  608. [5, 4, 3, 2, 1],
  609. )
  610. def test_annotated_array_subquery_with_json_objects(self):
  611. inner_qs = NullableIntegerArrayModel.objects.exclude(
  612. pk=models.OuterRef("pk")
  613. ).values(json=JSONObject(order="order", field="field"))
  614. siblings_json = (
  615. NullableIntegerArrayModel.objects.annotate(
  616. siblings_json=ArraySubquery(inner_qs),
  617. )
  618. .values_list("siblings_json", flat=True)
  619. .get(order=1)
  620. )
  621. self.assertSequenceEqual(
  622. siblings_json,
  623. [
  624. {"field": [2], "order": 2},
  625. {"field": [2, 3], "order": 3},
  626. {"field": [20, 30, 40], "order": 4},
  627. {"field": None, "order": 5},
  628. ],
  629. )
  630. class TestDateTimeExactQuerying(PostgreSQLTestCase):
  631. @classmethod
  632. def setUpTestData(cls):
  633. now = timezone.now()
  634. cls.datetimes = [now]
  635. cls.dates = [now.date()]
  636. cls.times = [now.time()]
  637. cls.objs = [
  638. DateTimeArrayModel.objects.create(
  639. datetimes=cls.datetimes, dates=cls.dates, times=cls.times
  640. ),
  641. ]
  642. def test_exact_datetimes(self):
  643. self.assertSequenceEqual(
  644. DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs
  645. )
  646. def test_exact_dates(self):
  647. self.assertSequenceEqual(
  648. DateTimeArrayModel.objects.filter(dates=self.dates), self.objs
  649. )
  650. def test_exact_times(self):
  651. self.assertSequenceEqual(
  652. DateTimeArrayModel.objects.filter(times=self.times), self.objs
  653. )
  654. class TestOtherTypesExactQuerying(PostgreSQLTestCase):
  655. @classmethod
  656. def setUpTestData(cls):
  657. cls.ips = ["192.168.0.1", "::1"]
  658. cls.uuids = [uuid.uuid4()]
  659. cls.decimals = [decimal.Decimal(1.25), 1.75]
  660. cls.tags = [Tag(1), Tag(2), Tag(3)]
  661. cls.objs = [
  662. OtherTypesArrayModel.objects.create(
  663. ips=cls.ips,
  664. uuids=cls.uuids,
  665. decimals=cls.decimals,
  666. tags=cls.tags,
  667. )
  668. ]
  669. def test_exact_ip_addresses(self):
  670. self.assertSequenceEqual(
  671. OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs
  672. )
  673. def test_exact_uuids(self):
  674. self.assertSequenceEqual(
  675. OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs
  676. )
  677. def test_exact_decimals(self):
  678. self.assertSequenceEqual(
  679. OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs
  680. )
  681. def test_exact_tags(self):
  682. self.assertSequenceEqual(
  683. OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs
  684. )
  685. @isolate_apps("postgres_tests")
  686. class TestChecks(PostgreSQLSimpleTestCase):
  687. def test_field_checks(self):
  688. class MyModel(PostgreSQLModel):
  689. field = ArrayField(models.CharField(max_length=-1))
  690. model = MyModel()
  691. errors = model.check()
  692. self.assertEqual(len(errors), 1)
  693. # The inner CharField has a non-positive max_length.
  694. self.assertEqual(errors[0].id, "postgres.E001")
  695. self.assertIn("max_length", errors[0].msg)
  696. def test_invalid_base_fields(self):
  697. class MyModel(PostgreSQLModel):
  698. field = ArrayField(
  699. models.ManyToManyField("postgres_tests.IntegerArrayModel")
  700. )
  701. model = MyModel()
  702. errors = model.check()
  703. self.assertEqual(len(errors), 1)
  704. self.assertEqual(errors[0].id, "postgres.E002")
  705. def test_invalid_default(self):
  706. class MyModel(PostgreSQLModel):
  707. field = ArrayField(models.IntegerField(), default=[])
  708. model = MyModel()
  709. self.assertEqual(
  710. model.check(),
  711. [
  712. checks.Warning(
  713. msg=(
  714. "ArrayField default should be a callable instead of an "
  715. "instance so that it's not shared between all field "
  716. "instances."
  717. ),
  718. hint="Use a callable instead, e.g., use `list` instead of `[]`.",
  719. obj=MyModel._meta.get_field("field"),
  720. id="fields.E010",
  721. )
  722. ],
  723. )
  724. def test_valid_default(self):
  725. class MyModel(PostgreSQLModel):
  726. field = ArrayField(models.IntegerField(), default=list)
  727. model = MyModel()
  728. self.assertEqual(model.check(), [])
  729. def test_valid_default_none(self):
  730. class MyModel(PostgreSQLModel):
  731. field = ArrayField(models.IntegerField(), default=None)
  732. model = MyModel()
  733. self.assertEqual(model.check(), [])
  734. def test_nested_field_checks(self):
  735. """
  736. Nested ArrayFields are permitted.
  737. """
  738. class MyModel(PostgreSQLModel):
  739. field = ArrayField(ArrayField(models.CharField(max_length=-1)))
  740. model = MyModel()
  741. errors = model.check()
  742. self.assertEqual(len(errors), 1)
  743. # The inner CharField has a non-positive max_length.
  744. self.assertEqual(errors[0].id, "postgres.E001")
  745. self.assertIn("max_length", errors[0].msg)
  746. def test_choices_tuple_list(self):
  747. class MyModel(PostgreSQLModel):
  748. field = ArrayField(
  749. models.CharField(max_length=16),
  750. choices=[
  751. [
  752. "Media",
  753. [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")],
  754. ],
  755. (["mp3", "mp4"], "Digital"),
  756. ],
  757. )
  758. self.assertEqual(MyModel._meta.get_field("field").check(), [])
  759. @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
  760. class TestMigrations(TransactionTestCase):
  761. available_apps = ["postgres_tests"]
  762. def test_deconstruct(self):
  763. field = ArrayField(models.IntegerField())
  764. name, path, args, kwargs = field.deconstruct()
  765. new = ArrayField(*args, **kwargs)
  766. self.assertEqual(type(new.base_field), type(field.base_field))
  767. self.assertIsNot(new.base_field, field.base_field)
  768. def test_deconstruct_with_size(self):
  769. field = ArrayField(models.IntegerField(), size=3)
  770. name, path, args, kwargs = field.deconstruct()
  771. new = ArrayField(*args, **kwargs)
  772. self.assertEqual(new.size, field.size)
  773. def test_deconstruct_args(self):
  774. field = ArrayField(models.CharField(max_length=20))
  775. name, path, args, kwargs = field.deconstruct()
  776. new = ArrayField(*args, **kwargs)
  777. self.assertEqual(new.base_field.max_length, field.base_field.max_length)
  778. def test_subclass_deconstruct(self):
  779. field = ArrayField(models.IntegerField())
  780. name, path, args, kwargs = field.deconstruct()
  781. self.assertEqual(path, "django.contrib.postgres.fields.ArrayField")
  782. field = ArrayFieldSubclass()
  783. name, path, args, kwargs = field.deconstruct()
  784. self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass")
  785. @override_settings(
  786. MIGRATION_MODULES={
  787. "postgres_tests": "postgres_tests.array_default_migrations",
  788. }
  789. )
  790. def test_adding_field_with_default(self):
  791. # See #22962
  792. table_name = "postgres_tests_integerarraydefaultmodel"
  793. with connection.cursor() as cursor:
  794. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  795. call_command("migrate", "postgres_tests", verbosity=0)
  796. with connection.cursor() as cursor:
  797. self.assertIn(table_name, connection.introspection.table_names(cursor))
  798. call_command("migrate", "postgres_tests", "zero", verbosity=0)
  799. with connection.cursor() as cursor:
  800. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  801. @override_settings(
  802. MIGRATION_MODULES={
  803. "postgres_tests": "postgres_tests.array_index_migrations",
  804. }
  805. )
  806. def test_adding_arrayfield_with_index(self):
  807. """
  808. ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes.
  809. """
  810. table_name = "postgres_tests_chartextarrayindexmodel"
  811. call_command("migrate", "postgres_tests", verbosity=0)
  812. with connection.cursor() as cursor:
  813. like_constraint_columns_list = [
  814. v["columns"]
  815. for k, v in list(
  816. connection.introspection.get_constraints(cursor, table_name).items()
  817. )
  818. if k.endswith("_like")
  819. ]
  820. # Only the CharField should have a LIKE index.
  821. self.assertEqual(like_constraint_columns_list, [["char2"]])
  822. # All fields should have regular indexes.
  823. with connection.cursor() as cursor:
  824. indexes = [
  825. c["columns"][0]
  826. for c in connection.introspection.get_constraints(
  827. cursor, table_name
  828. ).values()
  829. if c["index"] and len(c["columns"]) == 1
  830. ]
  831. self.assertIn("char", indexes)
  832. self.assertIn("char2", indexes)
  833. self.assertIn("text", indexes)
  834. call_command("migrate", "postgres_tests", "zero", verbosity=0)
  835. with connection.cursor() as cursor:
  836. self.assertNotIn(table_name, connection.introspection.table_names(cursor))
  837. class TestSerialization(PostgreSQLSimpleTestCase):
  838. test_data = (
  839. '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, '
  840. '"model": "postgres_tests.integerarraymodel", "pk": null}]'
  841. )
  842. def test_dumping(self):
  843. instance = IntegerArrayModel(field=[1, 2, None])
  844. data = serializers.serialize("json", [instance])
  845. self.assertEqual(json.loads(data), json.loads(self.test_data))
  846. def test_loading(self):
  847. instance = list(serializers.deserialize("json", self.test_data))[0].object
  848. self.assertEqual(instance.field, [1, 2, None])
  849. class TestValidation(PostgreSQLSimpleTestCase):
  850. def test_unbounded(self):
  851. field = ArrayField(models.IntegerField())
  852. with self.assertRaises(exceptions.ValidationError) as cm:
  853. field.clean([1, None], None)
  854. self.assertEqual(cm.exception.code, "item_invalid")
  855. self.assertEqual(
  856. cm.exception.message % cm.exception.params,
  857. "Item 2 in the array did not validate: This field cannot be null.",
  858. )
  859. def test_blank_true(self):
  860. field = ArrayField(models.IntegerField(blank=True, null=True))
  861. # This should not raise a validation error
  862. field.clean([1, None], None)
  863. def test_with_size(self):
  864. field = ArrayField(models.IntegerField(), size=3)
  865. field.clean([1, 2, 3], None)
  866. with self.assertRaises(exceptions.ValidationError) as cm:
  867. field.clean([1, 2, 3, 4], None)
  868. self.assertEqual(
  869. cm.exception.messages[0],
  870. "List contains 4 items, it should contain no more than 3.",
  871. )
  872. def test_nested_array_mismatch(self):
  873. field = ArrayField(ArrayField(models.IntegerField()))
  874. field.clean([[1, 2], [3, 4]], None)
  875. with self.assertRaises(exceptions.ValidationError) as cm:
  876. field.clean([[1, 2], [3, 4, 5]], None)
  877. self.assertEqual(cm.exception.code, "nested_array_mismatch")
  878. self.assertEqual(
  879. cm.exception.messages[0], "Nested arrays must have the same length."
  880. )
  881. def test_with_base_field_error_params(self):
  882. field = ArrayField(models.CharField(max_length=2))
  883. with self.assertRaises(exceptions.ValidationError) as cm:
  884. field.clean(["abc"], None)
  885. self.assertEqual(len(cm.exception.error_list), 1)
  886. exception = cm.exception.error_list[0]
  887. self.assertEqual(
  888. exception.message,
  889. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  890. "characters (it has 3).",
  891. )
  892. self.assertEqual(exception.code, "item_invalid")
  893. self.assertEqual(
  894. exception.params,
  895. {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
  896. )
  897. def test_with_validators(self):
  898. field = ArrayField(
  899. models.IntegerField(validators=[validators.MinValueValidator(1)])
  900. )
  901. field.clean([1, 2], None)
  902. with self.assertRaises(exceptions.ValidationError) as cm:
  903. field.clean([0], None)
  904. self.assertEqual(len(cm.exception.error_list), 1)
  905. exception = cm.exception.error_list[0]
  906. self.assertEqual(
  907. exception.message,
  908. "Item 1 in the array did not validate: Ensure this value is greater than "
  909. "or equal to 1.",
  910. )
  911. self.assertEqual(exception.code, "item_invalid")
  912. self.assertEqual(
  913. exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0}
  914. )
  915. class TestSimpleFormField(PostgreSQLSimpleTestCase):
  916. def test_valid(self):
  917. field = SimpleArrayField(forms.CharField())
  918. value = field.clean("a,b,c")
  919. self.assertEqual(value, ["a", "b", "c"])
  920. def test_to_python_fail(self):
  921. field = SimpleArrayField(forms.IntegerField())
  922. with self.assertRaises(exceptions.ValidationError) as cm:
  923. field.clean("a,b,9")
  924. self.assertEqual(
  925. cm.exception.messages[0],
  926. "Item 1 in the array did not validate: Enter a whole number.",
  927. )
  928. def test_validate_fail(self):
  929. field = SimpleArrayField(forms.CharField(required=True))
  930. with self.assertRaises(exceptions.ValidationError) as cm:
  931. field.clean("a,b,")
  932. self.assertEqual(
  933. cm.exception.messages[0],
  934. "Item 3 in the array did not validate: This field is required.",
  935. )
  936. def test_validate_fail_base_field_error_params(self):
  937. field = SimpleArrayField(forms.CharField(max_length=2))
  938. with self.assertRaises(exceptions.ValidationError) as cm:
  939. field.clean("abc,c,defg")
  940. errors = cm.exception.error_list
  941. self.assertEqual(len(errors), 2)
  942. first_error = errors[0]
  943. self.assertEqual(
  944. first_error.message,
  945. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  946. "characters (it has 3).",
  947. )
  948. self.assertEqual(first_error.code, "item_invalid")
  949. self.assertEqual(
  950. first_error.params,
  951. {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
  952. )
  953. second_error = errors[1]
  954. self.assertEqual(
  955. second_error.message,
  956. "Item 3 in the array did not validate: Ensure this value has at most 2 "
  957. "characters (it has 4).",
  958. )
  959. self.assertEqual(second_error.code, "item_invalid")
  960. self.assertEqual(
  961. second_error.params,
  962. {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4},
  963. )
  964. def test_validators_fail(self):
  965. field = SimpleArrayField(forms.RegexField("[a-e]{2}"))
  966. with self.assertRaises(exceptions.ValidationError) as cm:
  967. field.clean("a,bc,de")
  968. self.assertEqual(
  969. cm.exception.messages[0],
  970. "Item 1 in the array did not validate: Enter a valid value.",
  971. )
  972. def test_delimiter(self):
  973. field = SimpleArrayField(forms.CharField(), delimiter="|")
  974. value = field.clean("a|b|c")
  975. self.assertEqual(value, ["a", "b", "c"])
  976. def test_delimiter_with_nesting(self):
  977. field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|")
  978. value = field.clean("a,b|c,d")
  979. self.assertEqual(value, [["a", "b"], ["c", "d"]])
  980. def test_prepare_value(self):
  981. field = SimpleArrayField(forms.CharField())
  982. value = field.prepare_value(["a", "b", "c"])
  983. self.assertEqual(value, "a,b,c")
  984. def test_max_length(self):
  985. field = SimpleArrayField(forms.CharField(), max_length=2)
  986. with self.assertRaises(exceptions.ValidationError) as cm:
  987. field.clean("a,b,c")
  988. self.assertEqual(
  989. cm.exception.messages[0],
  990. "List contains 3 items, it should contain no more than 2.",
  991. )
  992. def test_min_length(self):
  993. field = SimpleArrayField(forms.CharField(), min_length=4)
  994. with self.assertRaises(exceptions.ValidationError) as cm:
  995. field.clean("a,b,c")
  996. self.assertEqual(
  997. cm.exception.messages[0],
  998. "List contains 3 items, it should contain no fewer than 4.",
  999. )
  1000. def test_required(self):
  1001. field = SimpleArrayField(forms.CharField(), required=True)
  1002. with self.assertRaises(exceptions.ValidationError) as cm:
  1003. field.clean("")
  1004. self.assertEqual(cm.exception.messages[0], "This field is required.")
  1005. def test_model_field_formfield(self):
  1006. model_field = ArrayField(models.CharField(max_length=27))
  1007. form_field = model_field.formfield()
  1008. self.assertIsInstance(form_field, SimpleArrayField)
  1009. self.assertIsInstance(form_field.base_field, forms.CharField)
  1010. self.assertEqual(form_field.base_field.max_length, 27)
  1011. def test_model_field_formfield_size(self):
  1012. model_field = ArrayField(models.CharField(max_length=27), size=4)
  1013. form_field = model_field.formfield()
  1014. self.assertIsInstance(form_field, SimpleArrayField)
  1015. self.assertEqual(form_field.max_length, 4)
  1016. def test_model_field_choices(self):
  1017. model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
  1018. form_field = model_field.formfield()
  1019. self.assertEqual(form_field.clean("1,2"), [1, 2])
  1020. def test_already_converted_value(self):
  1021. field = SimpleArrayField(forms.CharField())
  1022. vals = ["a", "b", "c"]
  1023. self.assertEqual(field.clean(vals), vals)
  1024. def test_has_changed(self):
  1025. field = SimpleArrayField(forms.IntegerField())
  1026. self.assertIs(field.has_changed([1, 2], [1, 2]), False)
  1027. self.assertIs(field.has_changed([1, 2], "1,2"), False)
  1028. self.assertIs(field.has_changed([1, 2], "1,2,3"), True)
  1029. self.assertIs(field.has_changed([1, 2], "a,b"), True)
  1030. def test_has_changed_empty(self):
  1031. field = SimpleArrayField(forms.CharField())
  1032. self.assertIs(field.has_changed(None, None), False)
  1033. self.assertIs(field.has_changed(None, ""), False)
  1034. self.assertIs(field.has_changed(None, []), False)
  1035. self.assertIs(field.has_changed([], None), False)
  1036. self.assertIs(field.has_changed([], ""), False)
  1037. class TestSplitFormField(PostgreSQLSimpleTestCase):
  1038. def test_valid(self):
  1039. class SplitForm(forms.Form):
  1040. array = SplitArrayField(forms.CharField(), size=3)
  1041. data = {"array_0": "a", "array_1": "b", "array_2": "c"}
  1042. form = SplitForm(data)
  1043. self.assertTrue(form.is_valid())
  1044. self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]})
  1045. def test_required(self):
  1046. class SplitForm(forms.Form):
  1047. array = SplitArrayField(forms.CharField(), required=True, size=3)
  1048. data = {"array_0": "", "array_1": "", "array_2": ""}
  1049. form = SplitForm(data)
  1050. self.assertFalse(form.is_valid())
  1051. self.assertEqual(form.errors, {"array": ["This field is required."]})
  1052. def test_remove_trailing_nulls(self):
  1053. class SplitForm(forms.Form):
  1054. array = SplitArrayField(
  1055. forms.CharField(required=False), size=5, remove_trailing_nulls=True
  1056. )
  1057. data = {
  1058. "array_0": "a",
  1059. "array_1": "",
  1060. "array_2": "b",
  1061. "array_3": "",
  1062. "array_4": "",
  1063. }
  1064. form = SplitForm(data)
  1065. self.assertTrue(form.is_valid(), form.errors)
  1066. self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]})
  1067. def test_remove_trailing_nulls_not_required(self):
  1068. class SplitForm(forms.Form):
  1069. array = SplitArrayField(
  1070. forms.CharField(required=False),
  1071. size=2,
  1072. remove_trailing_nulls=True,
  1073. required=False,
  1074. )
  1075. data = {"array_0": "", "array_1": ""}
  1076. form = SplitForm(data)
  1077. self.assertTrue(form.is_valid())
  1078. self.assertEqual(form.cleaned_data, {"array": []})
  1079. def test_required_field(self):
  1080. class SplitForm(forms.Form):
  1081. array = SplitArrayField(forms.CharField(), size=3)
  1082. data = {"array_0": "a", "array_1": "b", "array_2": ""}
  1083. form = SplitForm(data)
  1084. self.assertFalse(form.is_valid())
  1085. self.assertEqual(
  1086. form.errors,
  1087. {
  1088. "array": [
  1089. "Item 3 in the array did not validate: This field is required."
  1090. ]
  1091. },
  1092. )
  1093. def test_invalid_integer(self):
  1094. msg = (
  1095. "Item 2 in the array did not validate: Ensure this value is less than or "
  1096. "equal to 100."
  1097. )
  1098. with self.assertRaisesMessage(exceptions.ValidationError, msg):
  1099. SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101])
  1100. def test_rendering(self):
  1101. class SplitForm(forms.Form):
  1102. array = SplitArrayField(forms.CharField(), size=3)
  1103. self.assertHTMLEqual(
  1104. str(SplitForm()),
  1105. """
  1106. <div>
  1107. <label for="id_array_0">Array:</label>
  1108. <input id="id_array_0" name="array_0" type="text" required>
  1109. <input id="id_array_1" name="array_1" type="text" required>
  1110. <input id="id_array_2" name="array_2" type="text" required>
  1111. </div>
  1112. """,
  1113. )
  1114. def test_invalid_char_length(self):
  1115. field = SplitArrayField(forms.CharField(max_length=2), size=3)
  1116. with self.assertRaises(exceptions.ValidationError) as cm:
  1117. field.clean(["abc", "c", "defg"])
  1118. self.assertEqual(
  1119. cm.exception.messages,
  1120. [
  1121. "Item 1 in the array did not validate: Ensure this value has at most 2 "
  1122. "characters (it has 3).",
  1123. "Item 3 in the array did not validate: Ensure this value has at most 2 "
  1124. "characters (it has 4).",
  1125. ],
  1126. )
  1127. def test_splitarraywidget_value_omitted_from_data(self):
  1128. class Form(forms.ModelForm):
  1129. field = SplitArrayField(forms.IntegerField(), required=False, size=2)
  1130. class Meta:
  1131. model = IntegerArrayModel
  1132. fields = ("field",)
  1133. form = Form({"field_0": "1", "field_1": "2"})
  1134. self.assertEqual(form.errors, {})
  1135. obj = form.save(commit=False)
  1136. self.assertEqual(obj.field, [1, 2])
  1137. def test_splitarrayfield_has_changed(self):
  1138. class Form(forms.ModelForm):
  1139. field = SplitArrayField(forms.IntegerField(), required=False, size=2)
  1140. class Meta:
  1141. model = IntegerArrayModel
  1142. fields = ("field",)
  1143. tests = [
  1144. ({}, {"field_0": "", "field_1": ""}, True),
  1145. ({"field": None}, {"field_0": "", "field_1": ""}, True),
  1146. ({"field": [1]}, {"field_0": "", "field_1": ""}, True),
  1147. ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True),
  1148. ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False),
  1149. ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True),
  1150. ]
  1151. for initial, data, expected_result in tests:
  1152. with self.subTest(initial=initial, data=data):
  1153. obj = IntegerArrayModel(**initial)
  1154. form = Form(data, instance=obj)
  1155. self.assertIs(form.has_changed(), expected_result)
  1156. def test_splitarrayfield_remove_trailing_nulls_has_changed(self):
  1157. class Form(forms.ModelForm):
  1158. field = SplitArrayField(
  1159. forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True
  1160. )
  1161. class Meta:
  1162. model = IntegerArrayModel
  1163. fields = ("field",)
  1164. tests = [
  1165. ({}, {"field_0": "", "field_1": ""}, False),
  1166. ({"field": None}, {"field_0": "", "field_1": ""}, False),
  1167. ({"field": []}, {"field_0": "", "field_1": ""}, False),
  1168. ({"field": [1]}, {"field_0": "1", "field_1": ""}, False),
  1169. ]
  1170. for initial, data, expected_result in tests:
  1171. with self.subTest(initial=initial, data=data):
  1172. obj = IntegerArrayModel(**initial)
  1173. form = Form(data, instance=obj)
  1174. self.assertIs(form.has_changed(), expected_result)
  1175. class TestSplitFormWidget(PostgreSQLWidgetTestCase):
  1176. def test_get_context(self):
  1177. self.assertEqual(
  1178. SplitArrayWidget(forms.TextInput(), size=2).get_context(
  1179. "name", ["val1", "val2"]
  1180. ),
  1181. {
  1182. "widget": {
  1183. "name": "name",
  1184. "is_hidden": False,
  1185. "required": False,
  1186. "value": "['val1', 'val2']",
  1187. "attrs": {},
  1188. "template_name": "postgres/widgets/split_array.html",
  1189. "subwidgets": [
  1190. {
  1191. "name": "name_0",
  1192. "is_hidden": False,
  1193. "required": False,
  1194. "value": "val1",
  1195. "attrs": {},
  1196. "template_name": "django/forms/widgets/text.html",
  1197. "type": "text",
  1198. },
  1199. {
  1200. "name": "name_1",
  1201. "is_hidden": False,
  1202. "required": False,
  1203. "value": "val2",
  1204. "attrs": {},
  1205. "template_name": "django/forms/widgets/text.html",
  1206. "type": "text",
  1207. },
  1208. ],
  1209. }
  1210. },
  1211. )
  1212. def test_checkbox_get_context_attrs(self):
  1213. context = SplitArrayWidget(
  1214. forms.CheckboxInput(),
  1215. size=2,
  1216. ).get_context("name", [True, False])
  1217. self.assertEqual(context["widget"]["value"], "[True, False]")
  1218. self.assertEqual(
  1219. [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]],
  1220. [{"checked": True}, {}],
  1221. )
  1222. def test_render(self):
  1223. self.check_html(
  1224. SplitArrayWidget(forms.TextInput(), size=2),
  1225. "array",
  1226. None,
  1227. """
  1228. <input name="array_0" type="text">
  1229. <input name="array_1" type="text">
  1230. """,
  1231. )
  1232. def test_render_attrs(self):
  1233. self.check_html(
  1234. SplitArrayWidget(forms.TextInput(), size=2),
  1235. "array",
  1236. ["val1", "val2"],
  1237. attrs={"id": "foo"},
  1238. html=(
  1239. """
  1240. <input id="foo_0" name="array_0" type="text" value="val1">
  1241. <input id="foo_1" name="array_1" type="text" value="val2">
  1242. """
  1243. ),
  1244. )
  1245. def test_value_omitted_from_data(self):
  1246. widget = SplitArrayWidget(forms.TextInput(), size=2)
  1247. self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True)
  1248. self.assertIs(
  1249. widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False
  1250. )
  1251. self.assertIs(
  1252. widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False
  1253. )
  1254. self.assertIs(
  1255. widget.value_omitted_from_data(
  1256. {"field_0": "value", "field_1": "value"}, {}, "field"
  1257. ),
  1258. False,
  1259. )
  1260. class TestAdminUtils(PostgreSQLTestCase):
  1261. empty_value = "-empty-"
  1262. def test_array_display_for_field(self):
  1263. array_field = ArrayField(models.IntegerField())
  1264. display_value = display_for_field(
  1265. [1, 2],
  1266. array_field,
  1267. self.empty_value,
  1268. )
  1269. self.assertEqual(display_value, "1, 2")
  1270. def test_array_with_choices_display_for_field(self):
  1271. array_field = ArrayField(
  1272. models.IntegerField(),
  1273. choices=[
  1274. ([1, 2, 3], "1st choice"),
  1275. ([1, 2], "2nd choice"),
  1276. ],
  1277. )
  1278. display_value = display_for_field(
  1279. [1, 2],
  1280. array_field,
  1281. self.empty_value,
  1282. )
  1283. self.assertEqual(display_value, "2nd choice")
  1284. display_value = display_for_field(
  1285. [99, 99],
  1286. array_field,
  1287. self.empty_value,
  1288. )
  1289. self.assertEqual(display_value, self.empty_value)