tests.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. from math import ceil
  2. from operator import attrgetter
  3. from django.core.exceptions import FieldDoesNotExist
  4. from django.db import (
  5. IntegrityError,
  6. NotSupportedError,
  7. OperationalError,
  8. ProgrammingError,
  9. connection,
  10. )
  11. from django.db.models import FileField, Value
  12. from django.db.models.functions import Lower, Now
  13. from django.test import (
  14. TestCase,
  15. override_settings,
  16. skipIfDBFeature,
  17. skipUnlessDBFeature,
  18. )
  19. from django.utils import timezone
  20. from .models import (
  21. BigAutoFieldModel,
  22. Country,
  23. DbDefaultModel,
  24. FieldsWithDbColumns,
  25. NoFields,
  26. NullableFields,
  27. Pizzeria,
  28. ProxyCountry,
  29. ProxyMultiCountry,
  30. ProxyMultiProxyCountry,
  31. ProxyProxyCountry,
  32. RelatedModel,
  33. Restaurant,
  34. SmallAutoFieldModel,
  35. State,
  36. TwoFields,
  37. UpsertConflict,
  38. )
  39. class BulkCreateTests(TestCase):
  40. def setUp(self):
  41. self.data = [
  42. Country(name="United States of America", iso_two_letter="US"),
  43. Country(name="The Netherlands", iso_two_letter="NL"),
  44. Country(name="Germany", iso_two_letter="DE"),
  45. Country(name="Czech Republic", iso_two_letter="CZ"),
  46. ]
  47. def test_simple(self):
  48. created = Country.objects.bulk_create(self.data)
  49. self.assertEqual(created, self.data)
  50. self.assertQuerySetEqual(
  51. Country.objects.order_by("-name"),
  52. [
  53. "United States of America",
  54. "The Netherlands",
  55. "Germany",
  56. "Czech Republic",
  57. ],
  58. attrgetter("name"),
  59. )
  60. created = Country.objects.bulk_create([])
  61. self.assertEqual(created, [])
  62. self.assertEqual(Country.objects.count(), 4)
  63. @skipUnlessDBFeature("has_bulk_insert")
  64. def test_efficiency(self):
  65. with self.assertNumQueries(1):
  66. Country.objects.bulk_create(self.data)
  67. @skipUnlessDBFeature("has_bulk_insert")
  68. def test_long_non_ascii_text(self):
  69. """
  70. Inserting non-ASCII values with a length in the range 2001 to 4000
  71. characters, i.e. 4002 to 8000 bytes, must be set as a CLOB on Oracle
  72. (#22144).
  73. """
  74. Country.objects.bulk_create([Country(description="Ж" * 3000)])
  75. self.assertEqual(Country.objects.count(), 1)
  76. @skipUnlessDBFeature("has_bulk_insert")
  77. def test_long_and_short_text(self):
  78. Country.objects.bulk_create(
  79. [
  80. Country(description="a" * 4001, iso_two_letter="A"),
  81. Country(description="a", iso_two_letter="B"),
  82. Country(description="Ж" * 2001, iso_two_letter="C"),
  83. Country(description="Ж", iso_two_letter="D"),
  84. ]
  85. )
  86. self.assertEqual(Country.objects.count(), 4)
  87. def test_multi_table_inheritance_unsupported(self):
  88. expected_message = "Can't bulk create a multi-table inherited model"
  89. with self.assertRaisesMessage(ValueError, expected_message):
  90. Pizzeria.objects.bulk_create(
  91. [
  92. Pizzeria(name="The Art of Pizza"),
  93. ]
  94. )
  95. with self.assertRaisesMessage(ValueError, expected_message):
  96. ProxyMultiCountry.objects.bulk_create(
  97. [
  98. ProxyMultiCountry(name="Fillory", iso_two_letter="FL"),
  99. ]
  100. )
  101. with self.assertRaisesMessage(ValueError, expected_message):
  102. ProxyMultiProxyCountry.objects.bulk_create(
  103. [
  104. ProxyMultiProxyCountry(name="Fillory", iso_two_letter="FL"),
  105. ]
  106. )
  107. def test_proxy_inheritance_supported(self):
  108. ProxyCountry.objects.bulk_create(
  109. [
  110. ProxyCountry(name="Qwghlm", iso_two_letter="QW"),
  111. Country(name="Tortall", iso_two_letter="TA"),
  112. ]
  113. )
  114. self.assertQuerySetEqual(
  115. ProxyCountry.objects.all(),
  116. {"Qwghlm", "Tortall"},
  117. attrgetter("name"),
  118. ordered=False,
  119. )
  120. ProxyProxyCountry.objects.bulk_create(
  121. [
  122. ProxyProxyCountry(name="Netherlands", iso_two_letter="NT"),
  123. ]
  124. )
  125. self.assertQuerySetEqual(
  126. ProxyProxyCountry.objects.all(),
  127. {
  128. "Qwghlm",
  129. "Tortall",
  130. "Netherlands",
  131. },
  132. attrgetter("name"),
  133. ordered=False,
  134. )
  135. def test_non_auto_increment_pk(self):
  136. State.objects.bulk_create(
  137. [State(two_letter_code=s) for s in ["IL", "NY", "CA", "ME"]]
  138. )
  139. self.assertQuerySetEqual(
  140. State.objects.order_by("two_letter_code"),
  141. [
  142. "CA",
  143. "IL",
  144. "ME",
  145. "NY",
  146. ],
  147. attrgetter("two_letter_code"),
  148. )
  149. @skipUnlessDBFeature("has_bulk_insert")
  150. def test_non_auto_increment_pk_efficiency(self):
  151. with self.assertNumQueries(1):
  152. State.objects.bulk_create(
  153. [State(two_letter_code=s) for s in ["IL", "NY", "CA", "ME"]]
  154. )
  155. self.assertQuerySetEqual(
  156. State.objects.order_by("two_letter_code"),
  157. [
  158. "CA",
  159. "IL",
  160. "ME",
  161. "NY",
  162. ],
  163. attrgetter("two_letter_code"),
  164. )
  165. @skipIfDBFeature("allows_auto_pk_0")
  166. def test_zero_as_autoval(self):
  167. """
  168. Zero as id for AutoField should raise exception in MySQL, because MySQL
  169. does not allow zero for automatic primary key if the
  170. NO_AUTO_VALUE_ON_ZERO SQL mode is not enabled.
  171. """
  172. valid_country = Country(name="Germany", iso_two_letter="DE")
  173. invalid_country = Country(id=0, name="Poland", iso_two_letter="PL")
  174. msg = "The database backend does not accept 0 as a value for AutoField."
  175. with self.assertRaisesMessage(ValueError, msg):
  176. Country.objects.bulk_create([valid_country, invalid_country])
  177. def test_batch_same_vals(self):
  178. # SQLite had a problem where all the same-valued models were
  179. # collapsed to one insert.
  180. Restaurant.objects.bulk_create([Restaurant(name="foo") for i in range(0, 2)])
  181. self.assertEqual(Restaurant.objects.count(), 2)
  182. def test_large_batch(self):
  183. TwoFields.objects.bulk_create(
  184. [TwoFields(f1=i, f2=i + 1) for i in range(0, 1001)]
  185. )
  186. self.assertEqual(TwoFields.objects.count(), 1001)
  187. self.assertEqual(
  188. TwoFields.objects.filter(f1__gte=450, f1__lte=550).count(), 101
  189. )
  190. self.assertEqual(TwoFields.objects.filter(f2__gte=901).count(), 101)
  191. @skipUnlessDBFeature("has_bulk_insert")
  192. def test_large_single_field_batch(self):
  193. # SQLite had a problem with more than 500 UNIONed selects in single
  194. # query.
  195. Restaurant.objects.bulk_create([Restaurant() for i in range(0, 501)])
  196. @skipUnlessDBFeature("has_bulk_insert")
  197. def test_large_batch_efficiency(self):
  198. with override_settings(DEBUG=True):
  199. connection.queries_log.clear()
  200. TwoFields.objects.bulk_create(
  201. [TwoFields(f1=i, f2=i + 1) for i in range(0, 1001)]
  202. )
  203. self.assertLess(len(connection.queries), 10)
  204. def test_large_batch_mixed(self):
  205. """
  206. Test inserting a large batch with objects having primary key set
  207. mixed together with objects without PK set.
  208. """
  209. TwoFields.objects.bulk_create(
  210. [
  211. TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1)
  212. for i in range(100000, 101000)
  213. ]
  214. )
  215. self.assertEqual(TwoFields.objects.count(), 1000)
  216. # We can't assume much about the ID's created, except that the above
  217. # created IDs must exist.
  218. id_range = range(100000, 101000, 2)
  219. self.assertEqual(TwoFields.objects.filter(id__in=id_range).count(), 500)
  220. self.assertEqual(TwoFields.objects.exclude(id__in=id_range).count(), 500)
  221. @skipUnlessDBFeature("has_bulk_insert")
  222. def test_large_batch_mixed_efficiency(self):
  223. """
  224. Test inserting a large batch with objects having primary key set
  225. mixed together with objects without PK set.
  226. """
  227. with override_settings(DEBUG=True):
  228. connection.queries_log.clear()
  229. TwoFields.objects.bulk_create(
  230. [
  231. TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1)
  232. for i in range(100000, 101000)
  233. ]
  234. )
  235. self.assertLess(len(connection.queries), 10)
  236. def test_explicit_batch_size(self):
  237. objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)]
  238. num_objs = len(objs)
  239. TwoFields.objects.bulk_create(objs, batch_size=1)
  240. self.assertEqual(TwoFields.objects.count(), num_objs)
  241. TwoFields.objects.all().delete()
  242. TwoFields.objects.bulk_create(objs, batch_size=2)
  243. self.assertEqual(TwoFields.objects.count(), num_objs)
  244. TwoFields.objects.all().delete()
  245. TwoFields.objects.bulk_create(objs, batch_size=3)
  246. self.assertEqual(TwoFields.objects.count(), num_objs)
  247. TwoFields.objects.all().delete()
  248. TwoFields.objects.bulk_create(objs, batch_size=num_objs)
  249. self.assertEqual(TwoFields.objects.count(), num_objs)
  250. def test_empty_model(self):
  251. NoFields.objects.bulk_create([NoFields() for i in range(2)])
  252. self.assertEqual(NoFields.objects.count(), 2)
  253. @skipUnlessDBFeature("has_bulk_insert")
  254. def test_explicit_batch_size_efficiency(self):
  255. objs = [TwoFields(f1=i, f2=i) for i in range(0, 100)]
  256. with self.assertNumQueries(2):
  257. TwoFields.objects.bulk_create(objs, 50)
  258. TwoFields.objects.all().delete()
  259. with self.assertNumQueries(1):
  260. TwoFields.objects.bulk_create(objs, len(objs))
  261. @skipUnlessDBFeature("has_bulk_insert")
  262. def test_explicit_batch_size_respects_max_batch_size(self):
  263. objs = [Country(name=f"Country {i}") for i in range(1000)]
  264. fields = ["name", "iso_two_letter", "description"]
  265. max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1)
  266. with self.assertNumQueries(ceil(len(objs) / max_batch_size)):
  267. Country.objects.bulk_create(objs, batch_size=max_batch_size + 1)
  268. @skipUnlessDBFeature("has_bulk_insert")
  269. def test_bulk_insert_expressions(self):
  270. Restaurant.objects.bulk_create(
  271. [
  272. Restaurant(name="Sam's Shake Shack"),
  273. Restaurant(name=Lower(Value("Betty's Beetroot Bar"))),
  274. ]
  275. )
  276. bbb = Restaurant.objects.filter(name="betty's beetroot bar")
  277. self.assertEqual(bbb.count(), 1)
  278. @skipUnlessDBFeature("has_bulk_insert")
  279. def test_bulk_insert_now(self):
  280. NullableFields.objects.bulk_create(
  281. [
  282. NullableFields(datetime_field=Now()),
  283. NullableFields(datetime_field=Now()),
  284. ]
  285. )
  286. self.assertEqual(
  287. NullableFields.objects.filter(datetime_field__isnull=False).count(),
  288. 2,
  289. )
  290. @skipUnlessDBFeature("has_bulk_insert")
  291. def test_bulk_insert_nullable_fields(self):
  292. fk_to_auto_fields = {
  293. "auto_field": NoFields.objects.create(),
  294. "small_auto_field": SmallAutoFieldModel.objects.create(),
  295. "big_auto_field": BigAutoFieldModel.objects.create(),
  296. }
  297. # NULL can be mixed with other values in nullable fields
  298. nullable_fields = [
  299. field for field in NullableFields._meta.get_fields() if field.name != "id"
  300. ]
  301. NullableFields.objects.bulk_create(
  302. [
  303. NullableFields(**{**fk_to_auto_fields, field.name: None})
  304. for field in nullable_fields
  305. ]
  306. )
  307. self.assertEqual(NullableFields.objects.count(), len(nullable_fields))
  308. for field in nullable_fields:
  309. with self.subTest(field=field):
  310. field_value = "" if isinstance(field, FileField) else None
  311. self.assertEqual(
  312. NullableFields.objects.filter(**{field.name: field_value}).count(),
  313. 1,
  314. )
  315. @skipUnlessDBFeature("can_return_rows_from_bulk_insert")
  316. def test_set_pk_and_insert_single_item(self):
  317. with self.assertNumQueries(1):
  318. countries = Country.objects.bulk_create([self.data[0]])
  319. self.assertEqual(len(countries), 1)
  320. self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
  321. @skipUnlessDBFeature("can_return_rows_from_bulk_insert")
  322. def test_set_pk_and_query_efficiency(self):
  323. with self.assertNumQueries(1):
  324. countries = Country.objects.bulk_create(self.data)
  325. self.assertEqual(len(countries), 4)
  326. self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
  327. self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1])
  328. self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2])
  329. self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3])
  330. @skipUnlessDBFeature("can_return_rows_from_bulk_insert")
  331. def test_set_state(self):
  332. country_nl = Country(name="Netherlands", iso_two_letter="NL")
  333. country_be = Country(name="Belgium", iso_two_letter="BE")
  334. Country.objects.bulk_create([country_nl])
  335. country_be.save()
  336. # Objects save via bulk_create() and save() should have equal state.
  337. self.assertEqual(country_nl._state.adding, country_be._state.adding)
  338. self.assertEqual(country_nl._state.db, country_be._state.db)
  339. def test_set_state_with_pk_specified(self):
  340. state_ca = State(two_letter_code="CA")
  341. state_ny = State(two_letter_code="NY")
  342. State.objects.bulk_create([state_ca])
  343. state_ny.save()
  344. # Objects save via bulk_create() and save() should have equal state.
  345. self.assertEqual(state_ca._state.adding, state_ny._state.adding)
  346. self.assertEqual(state_ca._state.db, state_ny._state.db)
  347. @skipIfDBFeature("supports_ignore_conflicts")
  348. def test_ignore_conflicts_value_error(self):
  349. message = "This database backend does not support ignoring conflicts."
  350. with self.assertRaisesMessage(NotSupportedError, message):
  351. TwoFields.objects.bulk_create(self.data, ignore_conflicts=True)
  352. @skipUnlessDBFeature("supports_ignore_conflicts")
  353. def test_ignore_conflicts_ignore(self):
  354. data = [
  355. TwoFields(f1=1, f2=1),
  356. TwoFields(f1=2, f2=2),
  357. TwoFields(f1=3, f2=3),
  358. ]
  359. TwoFields.objects.bulk_create(data)
  360. self.assertEqual(TwoFields.objects.count(), 3)
  361. # With ignore_conflicts=True, conflicts are ignored.
  362. conflicting_objects = [
  363. TwoFields(f1=2, f2=2),
  364. TwoFields(f1=3, f2=3),
  365. ]
  366. TwoFields.objects.bulk_create([conflicting_objects[0]], ignore_conflicts=True)
  367. TwoFields.objects.bulk_create(conflicting_objects, ignore_conflicts=True)
  368. self.assertEqual(TwoFields.objects.count(), 3)
  369. self.assertIsNone(conflicting_objects[0].pk)
  370. self.assertIsNone(conflicting_objects[1].pk)
  371. # New objects are created and conflicts are ignored.
  372. new_object = TwoFields(f1=4, f2=4)
  373. TwoFields.objects.bulk_create(
  374. conflicting_objects + [new_object], ignore_conflicts=True
  375. )
  376. self.assertEqual(TwoFields.objects.count(), 4)
  377. self.assertIsNone(new_object.pk)
  378. # Without ignore_conflicts=True, there's a problem.
  379. with self.assertRaises(IntegrityError):
  380. TwoFields.objects.bulk_create(conflicting_objects)
  381. def test_nullable_fk_after_parent(self):
  382. parent = NoFields()
  383. child = NullableFields(auto_field=parent, integer_field=88)
  384. parent.save()
  385. NullableFields.objects.bulk_create([child])
  386. child = NullableFields.objects.get(integer_field=88)
  387. self.assertEqual(child.auto_field, parent)
  388. @skipUnlessDBFeature("can_return_rows_from_bulk_insert")
  389. def test_nullable_fk_after_parent_bulk_create(self):
  390. parent = NoFields()
  391. child = NullableFields(auto_field=parent, integer_field=88)
  392. NoFields.objects.bulk_create([parent])
  393. NullableFields.objects.bulk_create([child])
  394. child = NullableFields.objects.get(integer_field=88)
  395. self.assertEqual(child.auto_field, parent)
  396. def test_unsaved_parent(self):
  397. parent = NoFields()
  398. msg = (
  399. "bulk_create() prohibited to prevent data loss due to unsaved "
  400. "related object 'auto_field'."
  401. )
  402. with self.assertRaisesMessage(ValueError, msg):
  403. NullableFields.objects.bulk_create([NullableFields(auto_field=parent)])
  404. def test_invalid_batch_size_exception(self):
  405. msg = "Batch size must be a positive integer."
  406. with self.assertRaisesMessage(ValueError, msg):
  407. Country.objects.bulk_create([], batch_size=-1)
  408. @skipIfDBFeature("supports_update_conflicts")
  409. def test_update_conflicts_unsupported(self):
  410. msg = "This database backend does not support updating conflicts."
  411. with self.assertRaisesMessage(NotSupportedError, msg):
  412. Country.objects.bulk_create(self.data, update_conflicts=True)
  413. @skipUnlessDBFeature("supports_ignore_conflicts", "supports_update_conflicts")
  414. def test_ignore_update_conflicts_exclusive(self):
  415. msg = "ignore_conflicts and update_conflicts are mutually exclusive"
  416. with self.assertRaisesMessage(ValueError, msg):
  417. Country.objects.bulk_create(
  418. self.data,
  419. ignore_conflicts=True,
  420. update_conflicts=True,
  421. )
  422. @skipUnlessDBFeature("supports_update_conflicts")
  423. def test_update_conflicts_no_update_fields(self):
  424. msg = (
  425. "Fields that will be updated when a row insertion fails on "
  426. "conflicts must be provided."
  427. )
  428. with self.assertRaisesMessage(ValueError, msg):
  429. Country.objects.bulk_create(self.data, update_conflicts=True)
  430. @skipUnlessDBFeature("supports_update_conflicts")
  431. @skipIfDBFeature("supports_update_conflicts_with_target")
  432. def test_update_conflicts_unique_field_unsupported(self):
  433. msg = (
  434. "This database backend does not support updating conflicts with "
  435. "specifying unique fields that can trigger the upsert."
  436. )
  437. with self.assertRaisesMessage(NotSupportedError, msg):
  438. TwoFields.objects.bulk_create(
  439. [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
  440. update_conflicts=True,
  441. update_fields=["f2"],
  442. unique_fields=["f1"],
  443. )
  444. @skipUnlessDBFeature("supports_update_conflicts")
  445. def test_update_conflicts_nonexistent_update_fields(self):
  446. unique_fields = None
  447. if connection.features.supports_update_conflicts_with_target:
  448. unique_fields = ["f1"]
  449. msg = "TwoFields has no field named 'nonexistent'"
  450. with self.assertRaisesMessage(FieldDoesNotExist, msg):
  451. TwoFields.objects.bulk_create(
  452. [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
  453. update_conflicts=True,
  454. update_fields=["nonexistent"],
  455. unique_fields=unique_fields,
  456. )
  457. @skipUnlessDBFeature(
  458. "supports_update_conflicts",
  459. "supports_update_conflicts_with_target",
  460. )
  461. def test_update_conflicts_unique_fields_required(self):
  462. msg = "Unique fields that can trigger the upsert must be provided."
  463. with self.assertRaisesMessage(ValueError, msg):
  464. TwoFields.objects.bulk_create(
  465. [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
  466. update_conflicts=True,
  467. update_fields=["f1"],
  468. )
  469. @skipUnlessDBFeature(
  470. "supports_update_conflicts",
  471. "supports_update_conflicts_with_target",
  472. )
  473. def test_update_conflicts_invalid_update_fields(self):
  474. msg = "bulk_create() can only be used with concrete fields in update_fields."
  475. # Reverse one-to-one relationship.
  476. with self.assertRaisesMessage(ValueError, msg):
  477. Country.objects.bulk_create(
  478. self.data,
  479. update_conflicts=True,
  480. update_fields=["relatedmodel"],
  481. unique_fields=["pk"],
  482. )
  483. # Many-to-many relationship.
  484. with self.assertRaisesMessage(ValueError, msg):
  485. RelatedModel.objects.bulk_create(
  486. [RelatedModel(country=self.data[0])],
  487. update_conflicts=True,
  488. update_fields=["big_auto_fields"],
  489. unique_fields=["country"],
  490. )
  491. @skipUnlessDBFeature(
  492. "supports_update_conflicts",
  493. "supports_update_conflicts_with_target",
  494. )
  495. def test_update_conflicts_pk_in_update_fields(self):
  496. msg = "bulk_create() cannot be used with primary keys in update_fields."
  497. with self.assertRaisesMessage(ValueError, msg):
  498. BigAutoFieldModel.objects.bulk_create(
  499. [BigAutoFieldModel()],
  500. update_conflicts=True,
  501. update_fields=["id"],
  502. unique_fields=["id"],
  503. )
  504. @skipUnlessDBFeature(
  505. "supports_update_conflicts",
  506. "supports_update_conflicts_with_target",
  507. )
  508. def test_update_conflicts_invalid_unique_fields(self):
  509. msg = "bulk_create() can only be used with concrete fields in unique_fields."
  510. # Reverse one-to-one relationship.
  511. with self.assertRaisesMessage(ValueError, msg):
  512. Country.objects.bulk_create(
  513. self.data,
  514. update_conflicts=True,
  515. update_fields=["name"],
  516. unique_fields=["relatedmodel"],
  517. )
  518. # Many-to-many relationship.
  519. with self.assertRaisesMessage(ValueError, msg):
  520. RelatedModel.objects.bulk_create(
  521. [RelatedModel(country=self.data[0])],
  522. update_conflicts=True,
  523. update_fields=["name"],
  524. unique_fields=["big_auto_fields"],
  525. )
  526. def _test_update_conflicts_two_fields(self, unique_fields):
  527. TwoFields.objects.bulk_create(
  528. [
  529. TwoFields(f1=1, f2=1, name="a"),
  530. TwoFields(f1=2, f2=2, name="b"),
  531. ]
  532. )
  533. self.assertEqual(TwoFields.objects.count(), 2)
  534. conflicting_objects = [
  535. TwoFields(f1=1, f2=1, name="c"),
  536. TwoFields(f1=2, f2=2, name="d"),
  537. ]
  538. results = TwoFields.objects.bulk_create(
  539. conflicting_objects,
  540. update_conflicts=True,
  541. unique_fields=unique_fields,
  542. update_fields=["name"],
  543. )
  544. self.assertEqual(len(results), len(conflicting_objects))
  545. if connection.features.can_return_rows_from_bulk_insert:
  546. for instance in results:
  547. self.assertIsNotNone(instance.pk)
  548. self.assertEqual(TwoFields.objects.count(), 2)
  549. self.assertCountEqual(
  550. TwoFields.objects.values("f1", "f2", "name"),
  551. [
  552. {"f1": 1, "f2": 1, "name": "c"},
  553. {"f1": 2, "f2": 2, "name": "d"},
  554. ],
  555. )
  556. @skipUnlessDBFeature(
  557. "supports_update_conflicts", "supports_update_conflicts_with_target"
  558. )
  559. def test_update_conflicts_two_fields_unique_fields_first(self):
  560. self._test_update_conflicts_two_fields(["f1"])
  561. @skipUnlessDBFeature(
  562. "supports_update_conflicts", "supports_update_conflicts_with_target"
  563. )
  564. def test_update_conflicts_two_fields_unique_fields_second(self):
  565. self._test_update_conflicts_two_fields(["f2"])
  566. @skipUnlessDBFeature(
  567. "supports_update_conflicts", "supports_update_conflicts_with_target"
  568. )
  569. def test_update_conflicts_unique_fields_pk(self):
  570. TwoFields.objects.bulk_create(
  571. [
  572. TwoFields(f1=1, f2=1, name="a"),
  573. TwoFields(f1=2, f2=2, name="b"),
  574. ]
  575. )
  576. obj1 = TwoFields.objects.get(f1=1)
  577. obj2 = TwoFields.objects.get(f1=2)
  578. conflicting_objects = [
  579. TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"),
  580. TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"),
  581. ]
  582. results = TwoFields.objects.bulk_create(
  583. conflicting_objects,
  584. update_conflicts=True,
  585. unique_fields=["pk"],
  586. update_fields=["name"],
  587. )
  588. self.assertEqual(len(results), len(conflicting_objects))
  589. if connection.features.can_return_rows_from_bulk_insert:
  590. for instance in results:
  591. self.assertIsNotNone(instance.pk)
  592. self.assertEqual(TwoFields.objects.count(), 2)
  593. self.assertCountEqual(
  594. TwoFields.objects.values("f1", "f2", "name"),
  595. [
  596. {"f1": 1, "f2": 1, "name": "c"},
  597. {"f1": 2, "f2": 2, "name": "d"},
  598. ],
  599. )
  600. @skipUnlessDBFeature(
  601. "supports_update_conflicts", "supports_update_conflicts_with_target"
  602. )
  603. def test_update_conflicts_two_fields_unique_fields_both(self):
  604. with self.assertRaises((OperationalError, ProgrammingError)):
  605. self._test_update_conflicts_two_fields(["f1", "f2"])
  606. @skipUnlessDBFeature("supports_update_conflicts")
  607. @skipIfDBFeature("supports_update_conflicts_with_target")
  608. def test_update_conflicts_two_fields_no_unique_fields(self):
  609. self._test_update_conflicts_two_fields([])
  610. def _test_update_conflicts_unique_two_fields(self, unique_fields):
  611. Country.objects.bulk_create(self.data)
  612. self.assertEqual(Country.objects.count(), 4)
  613. new_data = [
  614. # Conflicting countries.
  615. Country(
  616. name="Germany",
  617. iso_two_letter="DE",
  618. description=("Germany is a country in Central Europe."),
  619. ),
  620. Country(
  621. name="Czech Republic",
  622. iso_two_letter="CZ",
  623. description=(
  624. "The Czech Republic is a landlocked country in Central Europe."
  625. ),
  626. ),
  627. # New countries.
  628. Country(name="Australia", iso_two_letter="AU"),
  629. Country(
  630. name="Japan",
  631. iso_two_letter="JP",
  632. description=("Japan is an island country in East Asia."),
  633. ),
  634. ]
  635. results = Country.objects.bulk_create(
  636. new_data,
  637. update_conflicts=True,
  638. update_fields=["description"],
  639. unique_fields=unique_fields,
  640. )
  641. self.assertEqual(len(results), len(new_data))
  642. if connection.features.can_return_rows_from_bulk_insert:
  643. for instance in results:
  644. self.assertIsNotNone(instance.pk)
  645. self.assertEqual(Country.objects.count(), 6)
  646. self.assertCountEqual(
  647. Country.objects.values("iso_two_letter", "description"),
  648. [
  649. {"iso_two_letter": "US", "description": ""},
  650. {"iso_two_letter": "NL", "description": ""},
  651. {
  652. "iso_two_letter": "DE",
  653. "description": ("Germany is a country in Central Europe."),
  654. },
  655. {
  656. "iso_two_letter": "CZ",
  657. "description": (
  658. "The Czech Republic is a landlocked country in Central Europe."
  659. ),
  660. },
  661. {"iso_two_letter": "AU", "description": ""},
  662. {
  663. "iso_two_letter": "JP",
  664. "description": ("Japan is an island country in East Asia."),
  665. },
  666. ],
  667. )
  668. @skipUnlessDBFeature(
  669. "supports_update_conflicts", "supports_update_conflicts_with_target"
  670. )
  671. def test_update_conflicts_unique_two_fields_unique_fields_both(self):
  672. self._test_update_conflicts_unique_two_fields(["iso_two_letter", "name"])
  673. @skipUnlessDBFeature(
  674. "supports_update_conflicts", "supports_update_conflicts_with_target"
  675. )
  676. def test_update_conflicts_unique_two_fields_unique_fields_one(self):
  677. with self.assertRaises((OperationalError, ProgrammingError)):
  678. self._test_update_conflicts_unique_two_fields(["iso_two_letter"])
  679. @skipUnlessDBFeature("supports_update_conflicts")
  680. @skipIfDBFeature("supports_update_conflicts_with_target")
  681. def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self):
  682. self._test_update_conflicts_unique_two_fields([])
  683. def _test_update_conflicts(self, unique_fields):
  684. UpsertConflict.objects.bulk_create(
  685. [
  686. UpsertConflict(number=1, rank=1, name="John"),
  687. UpsertConflict(number=2, rank=2, name="Mary"),
  688. UpsertConflict(number=3, rank=3, name="Hannah"),
  689. ]
  690. )
  691. self.assertEqual(UpsertConflict.objects.count(), 3)
  692. conflicting_objects = [
  693. UpsertConflict(number=1, rank=4, name="Steve"),
  694. UpsertConflict(number=2, rank=2, name="Olivia"),
  695. UpsertConflict(number=3, rank=1, name="Hannah"),
  696. ]
  697. results = UpsertConflict.objects.bulk_create(
  698. conflicting_objects,
  699. update_conflicts=True,
  700. update_fields=["name", "rank"],
  701. unique_fields=unique_fields,
  702. )
  703. self.assertEqual(len(results), len(conflicting_objects))
  704. if connection.features.can_return_rows_from_bulk_insert:
  705. for instance in results:
  706. self.assertIsNotNone(instance.pk)
  707. self.assertEqual(UpsertConflict.objects.count(), 3)
  708. self.assertCountEqual(
  709. UpsertConflict.objects.values("number", "rank", "name"),
  710. [
  711. {"number": 1, "rank": 4, "name": "Steve"},
  712. {"number": 2, "rank": 2, "name": "Olivia"},
  713. {"number": 3, "rank": 1, "name": "Hannah"},
  714. ],
  715. )
  716. results = UpsertConflict.objects.bulk_create(
  717. conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")],
  718. update_conflicts=True,
  719. update_fields=["name", "rank"],
  720. unique_fields=unique_fields,
  721. )
  722. self.assertEqual(len(results), 4)
  723. if connection.features.can_return_rows_from_bulk_insert:
  724. for instance in results:
  725. self.assertIsNotNone(instance.pk)
  726. self.assertEqual(UpsertConflict.objects.count(), 4)
  727. self.assertCountEqual(
  728. UpsertConflict.objects.values("number", "rank", "name"),
  729. [
  730. {"number": 1, "rank": 4, "name": "Steve"},
  731. {"number": 2, "rank": 2, "name": "Olivia"},
  732. {"number": 3, "rank": 1, "name": "Hannah"},
  733. {"number": 4, "rank": 4, "name": "Mark"},
  734. ],
  735. )
  736. @skipUnlessDBFeature(
  737. "supports_update_conflicts", "supports_update_conflicts_with_target"
  738. )
  739. def test_update_conflicts_unique_fields(self):
  740. self._test_update_conflicts(unique_fields=["number"])
  741. @skipUnlessDBFeature("supports_update_conflicts")
  742. @skipIfDBFeature("supports_update_conflicts_with_target")
  743. def test_update_conflicts_no_unique_fields(self):
  744. self._test_update_conflicts([])
  745. @skipUnlessDBFeature(
  746. "supports_update_conflicts", "supports_update_conflicts_with_target"
  747. )
  748. def test_update_conflicts_unique_fields_update_fields_db_column(self):
  749. FieldsWithDbColumns.objects.bulk_create(
  750. [
  751. FieldsWithDbColumns(rank=1, name="a"),
  752. FieldsWithDbColumns(rank=2, name="b"),
  753. ]
  754. )
  755. self.assertEqual(FieldsWithDbColumns.objects.count(), 2)
  756. conflicting_objects = [
  757. FieldsWithDbColumns(rank=1, name="c"),
  758. FieldsWithDbColumns(rank=2, name="d"),
  759. ]
  760. results = FieldsWithDbColumns.objects.bulk_create(
  761. conflicting_objects,
  762. update_conflicts=True,
  763. unique_fields=["rank"],
  764. update_fields=["name"],
  765. )
  766. self.assertEqual(len(results), len(conflicting_objects))
  767. if connection.features.can_return_rows_from_bulk_insert:
  768. for instance in results:
  769. self.assertIsNotNone(instance.pk)
  770. self.assertEqual(FieldsWithDbColumns.objects.count(), 2)
  771. self.assertCountEqual(
  772. FieldsWithDbColumns.objects.values("rank", "name"),
  773. [
  774. {"rank": 1, "name": "c"},
  775. {"rank": 2, "name": "d"},
  776. ],
  777. )
  778. def test_db_default_field_excluded(self):
  779. # created_at is excluded when no db_default override is provided.
  780. with self.assertNumQueries(1) as ctx:
  781. DbDefaultModel.objects.bulk_create(
  782. [DbDefaultModel(name="foo"), DbDefaultModel(name="bar")]
  783. )
  784. created_at_quoted_name = connection.ops.quote_name("created_at")
  785. self.assertEqual(
  786. ctx[0]["sql"].count(created_at_quoted_name),
  787. 1 if connection.features.can_return_rows_from_bulk_insert else 0,
  788. )
  789. # created_at is included when a db_default override is provided.
  790. with self.assertNumQueries(1) as ctx:
  791. DbDefaultModel.objects.bulk_create(
  792. [
  793. DbDefaultModel(name="foo", created_at=timezone.now()),
  794. DbDefaultModel(name="bar"),
  795. ]
  796. )
  797. self.assertEqual(
  798. ctx[0]["sql"].count(created_at_quoted_name),
  799. 2 if connection.features.can_return_rows_from_bulk_insert else 1,
  800. )