test_bulk_update.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import datetime
  2. from django.core.exceptions import FieldDoesNotExist
  3. from django.db.models import F
  4. from django.db.models.functions import Lower
  5. from django.test import TestCase, skipUnlessDBFeature
  6. from .models import (
  7. Article, CustomDbColumn, CustomPk, Detail, Individual, JSONFieldNullable,
  8. Member, Note, Number, Order, Paragraph, RelatedObject, SingleObject,
  9. SpecialCategory, Tag, Valid,
  10. )
  11. class BulkUpdateNoteTests(TestCase):
  12. @classmethod
  13. def setUpTestData(cls):
  14. cls.notes = [
  15. Note.objects.create(note=str(i), misc=str(i))
  16. for i in range(10)
  17. ]
  18. def create_tags(self):
  19. self.tags = [
  20. Tag.objects.create(name=str(i))
  21. for i in range(10)
  22. ]
  23. def test_simple(self):
  24. for note in self.notes:
  25. note.note = 'test-%s' % note.id
  26. with self.assertNumQueries(1):
  27. Note.objects.bulk_update(self.notes, ['note'])
  28. self.assertCountEqual(
  29. Note.objects.values_list('note', flat=True),
  30. [cat.note for cat in self.notes]
  31. )
  32. def test_multiple_fields(self):
  33. for note in self.notes:
  34. note.note = 'test-%s' % note.id
  35. note.misc = 'misc-%s' % note.id
  36. with self.assertNumQueries(1):
  37. Note.objects.bulk_update(self.notes, ['note', 'misc'])
  38. self.assertCountEqual(
  39. Note.objects.values_list('note', flat=True),
  40. [cat.note for cat in self.notes]
  41. )
  42. self.assertCountEqual(
  43. Note.objects.values_list('misc', flat=True),
  44. [cat.misc for cat in self.notes]
  45. )
  46. def test_batch_size(self):
  47. with self.assertNumQueries(len(self.notes)):
  48. Note.objects.bulk_update(self.notes, fields=['note'], batch_size=1)
  49. def test_unsaved_models(self):
  50. objs = self.notes + [Note(note='test', misc='test')]
  51. msg = 'All bulk_update() objects must have a primary key set.'
  52. with self.assertRaisesMessage(ValueError, msg):
  53. Note.objects.bulk_update(objs, fields=['note'])
  54. def test_foreign_keys_do_not_lookup(self):
  55. self.create_tags()
  56. for note, tag in zip(self.notes, self.tags):
  57. note.tag = tag
  58. with self.assertNumQueries(1):
  59. Note.objects.bulk_update(self.notes, ['tag'])
  60. self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes)
  61. def test_set_field_to_null(self):
  62. self.create_tags()
  63. Note.objects.update(tag=self.tags[0])
  64. for note in self.notes:
  65. note.tag = None
  66. Note.objects.bulk_update(self.notes, ['tag'])
  67. self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes)
  68. def test_set_mixed_fields_to_null(self):
  69. self.create_tags()
  70. midpoint = len(self.notes) // 2
  71. top, bottom = self.notes[:midpoint], self.notes[midpoint:]
  72. for note in top:
  73. note.tag = None
  74. for note in bottom:
  75. note.tag = self.tags[0]
  76. Note.objects.bulk_update(self.notes, ['tag'])
  77. self.assertCountEqual(Note.objects.filter(tag__isnull=True), top)
  78. self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom)
  79. def test_functions(self):
  80. Note.objects.update(note='TEST')
  81. for note in self.notes:
  82. note.note = Lower('note')
  83. Note.objects.bulk_update(self.notes, ['note'])
  84. self.assertEqual(set(Note.objects.values_list('note', flat=True)), {'test'})
  85. # Tests that use self.notes go here, otherwise put them in another class.
  86. class BulkUpdateTests(TestCase):
  87. def test_no_fields(self):
  88. msg = 'Field names must be given to bulk_update().'
  89. with self.assertRaisesMessage(ValueError, msg):
  90. Note.objects.bulk_update([], fields=[])
  91. def test_invalid_batch_size(self):
  92. msg = 'Batch size must be a positive integer.'
  93. with self.assertRaisesMessage(ValueError, msg):
  94. Note.objects.bulk_update([], fields=['note'], batch_size=-1)
  95. def test_nonexistent_field(self):
  96. with self.assertRaisesMessage(FieldDoesNotExist, "Note has no field named 'nonexistent'"):
  97. Note.objects.bulk_update([], ['nonexistent'])
  98. pk_fields_error = 'bulk_update() cannot be used with primary key fields.'
  99. def test_update_primary_key(self):
  100. with self.assertRaisesMessage(ValueError, self.pk_fields_error):
  101. Note.objects.bulk_update([], ['id'])
  102. def test_update_custom_primary_key(self):
  103. with self.assertRaisesMessage(ValueError, self.pk_fields_error):
  104. CustomPk.objects.bulk_update([], ['name'])
  105. def test_empty_objects(self):
  106. with self.assertNumQueries(0):
  107. rows_updated = Note.objects.bulk_update([], ['note'])
  108. self.assertEqual(rows_updated, 0)
  109. def test_large_batch(self):
  110. Note.objects.bulk_create([
  111. Note(note=str(i), misc=str(i))
  112. for i in range(0, 2000)
  113. ])
  114. notes = list(Note.objects.all())
  115. rows_updated = Note.objects.bulk_update(notes, ['note'])
  116. self.assertEqual(rows_updated, 2000)
  117. def test_updated_rows_when_passing_duplicates(self):
  118. note = Note.objects.create(note='test-note', misc='test')
  119. rows_updated = Note.objects.bulk_update([note, note], ['note'])
  120. self.assertEqual(rows_updated, 1)
  121. # Duplicates in different batches.
  122. rows_updated = Note.objects.bulk_update([note, note], ['note'], batch_size=1)
  123. self.assertEqual(rows_updated, 2)
  124. def test_only_concrete_fields_allowed(self):
  125. obj = Valid.objects.create(valid='test')
  126. detail = Detail.objects.create(data='test')
  127. paragraph = Paragraph.objects.create(text='test')
  128. Member.objects.create(name='test', details=detail)
  129. msg = 'bulk_update() can only be used with concrete fields.'
  130. with self.assertRaisesMessage(ValueError, msg):
  131. Detail.objects.bulk_update([detail], fields=['member'])
  132. with self.assertRaisesMessage(ValueError, msg):
  133. Paragraph.objects.bulk_update([paragraph], fields=['page'])
  134. with self.assertRaisesMessage(ValueError, msg):
  135. Valid.objects.bulk_update([obj], fields=['parent'])
  136. def test_custom_db_columns(self):
  137. model = CustomDbColumn.objects.create(custom_column=1)
  138. model.custom_column = 2
  139. CustomDbColumn.objects.bulk_update([model], fields=['custom_column'])
  140. model.refresh_from_db()
  141. self.assertEqual(model.custom_column, 2)
  142. def test_custom_pk(self):
  143. custom_pks = [
  144. CustomPk.objects.create(name='pk-%s' % i, extra='')
  145. for i in range(10)
  146. ]
  147. for model in custom_pks:
  148. model.extra = 'extra-%s' % model.pk
  149. CustomPk.objects.bulk_update(custom_pks, ['extra'])
  150. self.assertCountEqual(
  151. CustomPk.objects.values_list('extra', flat=True),
  152. [cat.extra for cat in custom_pks]
  153. )
  154. def test_falsey_pk_value(self):
  155. order = Order.objects.create(pk=0, name='test')
  156. order.name = 'updated'
  157. Order.objects.bulk_update([order], ['name'])
  158. order.refresh_from_db()
  159. self.assertEqual(order.name, 'updated')
  160. def test_inherited_fields(self):
  161. special_categories = [
  162. SpecialCategory.objects.create(name=str(i), special_name=str(i))
  163. for i in range(10)
  164. ]
  165. for category in special_categories:
  166. category.name = 'test-%s' % category.id
  167. category.special_name = 'special-test-%s' % category.special_name
  168. SpecialCategory.objects.bulk_update(special_categories, ['name', 'special_name'])
  169. self.assertCountEqual(
  170. SpecialCategory.objects.values_list('name', flat=True),
  171. [cat.name for cat in special_categories]
  172. )
  173. self.assertCountEqual(
  174. SpecialCategory.objects.values_list('special_name', flat=True),
  175. [cat.special_name for cat in special_categories]
  176. )
  177. def test_field_references(self):
  178. numbers = [Number.objects.create(num=0) for _ in range(10)]
  179. for number in numbers:
  180. number.num = F('num') + 1
  181. Number.objects.bulk_update(numbers, ['num'])
  182. self.assertCountEqual(Number.objects.filter(num=1), numbers)
  183. def test_f_expression(self):
  184. notes = [
  185. Note.objects.create(note='test_note', misc='test_misc')
  186. for _ in range(10)
  187. ]
  188. for note in notes:
  189. note.misc = F('note')
  190. Note.objects.bulk_update(notes, ['misc'])
  191. self.assertCountEqual(Note.objects.filter(misc='test_note'), notes)
  192. def test_booleanfield(self):
  193. individuals = [Individual.objects.create(alive=False) for _ in range(10)]
  194. for individual in individuals:
  195. individual.alive = True
  196. Individual.objects.bulk_update(individuals, ['alive'])
  197. self.assertCountEqual(Individual.objects.filter(alive=True), individuals)
  198. def test_ipaddressfield(self):
  199. for ip in ('2001::1', '1.2.3.4'):
  200. with self.subTest(ip=ip):
  201. models = [
  202. CustomDbColumn.objects.create(ip_address='0.0.0.0')
  203. for _ in range(10)
  204. ]
  205. for model in models:
  206. model.ip_address = ip
  207. CustomDbColumn.objects.bulk_update(models, ['ip_address'])
  208. self.assertCountEqual(CustomDbColumn.objects.filter(ip_address=ip), models)
  209. def test_datetime_field(self):
  210. articles = [
  211. Article.objects.create(name=str(i), created=datetime.datetime.today())
  212. for i in range(10)
  213. ]
  214. point_in_time = datetime.datetime(1991, 10, 31)
  215. for article in articles:
  216. article.created = point_in_time
  217. Article.objects.bulk_update(articles, ['created'])
  218. self.assertCountEqual(Article.objects.filter(created=point_in_time), articles)
  219. @skipUnlessDBFeature('supports_json_field')
  220. def test_json_field(self):
  221. JSONFieldNullable.objects.bulk_create([
  222. JSONFieldNullable(json_field={'a': i}) for i in range(10)
  223. ])
  224. objs = JSONFieldNullable.objects.all()
  225. for obj in objs:
  226. obj.json_field = {'c': obj.json_field['a'] + 1}
  227. JSONFieldNullable.objects.bulk_update(objs, ['json_field'])
  228. self.assertCountEqual(JSONFieldNullable.objects.filter(json_field__has_key='c'), objs)
  229. def test_nullable_fk_after_related_save(self):
  230. parent = RelatedObject.objects.create()
  231. child = SingleObject()
  232. parent.single = child
  233. parent.single.save()
  234. RelatedObject.objects.bulk_update([parent], fields=['single'])
  235. self.assertEqual(parent.single_id, parent.single.pk)
  236. parent.refresh_from_db()
  237. self.assertEqual(parent.single, child)
  238. def test_unsaved_parent(self):
  239. parent = RelatedObject.objects.create()
  240. parent.single = SingleObject()
  241. msg = (
  242. "bulk_update() prohibited to prevent data loss due to unsaved "
  243. "related object 'single'."
  244. )
  245. with self.assertRaisesMessage(ValueError, msg):
  246. RelatedObject.objects.bulk_update([parent], fields=['single'])
  247. def test_unspecified_unsaved_parent(self):
  248. parent = RelatedObject.objects.create()
  249. parent.single = SingleObject()
  250. parent.f = 42
  251. RelatedObject.objects.bulk_update([parent], fields=['f'])
  252. parent.refresh_from_db()
  253. self.assertEqual(parent.f, 42)
  254. self.assertIsNone(parent.single)