test_bulk_update.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import datetime
  2. from django.core.exceptions import FieldDoesNotExist
  3. from django.db.models import F
  4. from django.db.models.functions import Lower
  5. from django.test import TestCase
  6. from .models import (
  7. Article, CustomDbColumn, CustomPk, Detail, Individual, Member, Note,
  8. Number, Paragraph, SpecialCategory, Tag, Valid,
  9. )
  10. class BulkUpdateNoteTests(TestCase):
  11. def setUp(self):
  12. self.notes = [
  13. Note.objects.create(note=str(i), misc=str(i))
  14. for i in range(10)
  15. ]
  16. def create_tags(self):
  17. self.tags = [
  18. Tag.objects.create(name=str(i))
  19. for i in range(10)
  20. ]
  21. def test_simple(self):
  22. for note in self.notes:
  23. note.note = 'test-%s' % note.id
  24. with self.assertNumQueries(1):
  25. Note.objects.bulk_update(self.notes, ['note'])
  26. self.assertCountEqual(
  27. Note.objects.values_list('note', flat=True),
  28. [cat.note for cat in self.notes]
  29. )
  30. def test_multiple_fields(self):
  31. for note in self.notes:
  32. note.note = 'test-%s' % note.id
  33. note.misc = 'misc-%s' % note.id
  34. with self.assertNumQueries(1):
  35. Note.objects.bulk_update(self.notes, ['note', 'misc'])
  36. self.assertCountEqual(
  37. Note.objects.values_list('note', flat=True),
  38. [cat.note for cat in self.notes]
  39. )
  40. self.assertCountEqual(
  41. Note.objects.values_list('misc', flat=True),
  42. [cat.misc for cat in self.notes]
  43. )
  44. def test_batch_size(self):
  45. with self.assertNumQueries(len(self.notes)):
  46. Note.objects.bulk_update(self.notes, fields=['note'], batch_size=1)
  47. def test_unsaved_models(self):
  48. objs = self.notes + [Note(note='test', misc='test')]
  49. msg = 'All bulk_update() objects must have a primary key set.'
  50. with self.assertRaisesMessage(ValueError, msg):
  51. Note.objects.bulk_update(objs, fields=['note'])
  52. def test_foreign_keys_do_not_lookup(self):
  53. self.create_tags()
  54. for note, tag in zip(self.notes, self.tags):
  55. note.tag = tag
  56. with self.assertNumQueries(1):
  57. Note.objects.bulk_update(self.notes, ['tag'])
  58. self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes)
  59. def test_set_field_to_null(self):
  60. self.create_tags()
  61. Note.objects.update(tag=self.tags[0])
  62. for note in self.notes:
  63. note.tag = None
  64. Note.objects.bulk_update(self.notes, ['tag'])
  65. self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes)
  66. def test_set_mixed_fields_to_null(self):
  67. self.create_tags()
  68. midpoint = len(self.notes) // 2
  69. top, bottom = self.notes[:midpoint], self.notes[midpoint:]
  70. for note in top:
  71. note.tag = None
  72. for note in bottom:
  73. note.tag = self.tags[0]
  74. Note.objects.bulk_update(self.notes, ['tag'])
  75. self.assertCountEqual(Note.objects.filter(tag__isnull=True), top)
  76. self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom)
  77. def test_functions(self):
  78. Note.objects.update(note='TEST')
  79. for note in self.notes:
  80. note.note = Lower('note')
  81. Note.objects.bulk_update(self.notes, ['note'])
  82. self.assertEqual(set(Note.objects.values_list('note', flat=True)), {'test'})
  83. # Tests that use self.notes go here, otherwise put them in another class.
  84. class BulkUpdateTests(TestCase):
  85. def test_no_fields(self):
  86. msg = 'Field names must be given to bulk_update().'
  87. with self.assertRaisesMessage(ValueError, msg):
  88. Note.objects.bulk_update([], fields=[])
  89. def test_invalid_batch_size(self):
  90. msg = 'Batch size must be a positive integer.'
  91. with self.assertRaisesMessage(ValueError, msg):
  92. Note.objects.bulk_update([], fields=['note'], batch_size=-1)
  93. def test_nonexistent_field(self):
  94. with self.assertRaisesMessage(FieldDoesNotExist, "Note has no field named 'nonexistent'"):
  95. Note.objects.bulk_update([], ['nonexistent'])
  96. pk_fields_error = 'bulk_update() cannot be used with primary key fields.'
  97. def test_update_primary_key(self):
  98. with self.assertRaisesMessage(ValueError, self.pk_fields_error):
  99. Note.objects.bulk_update([], ['id'])
  100. def test_update_custom_primary_key(self):
  101. with self.assertRaisesMessage(ValueError, self.pk_fields_error):
  102. CustomPk.objects.bulk_update([], ['name'])
  103. def test_empty_objects(self):
  104. with self.assertNumQueries(0):
  105. Note.objects.bulk_update([], ['note'])
  106. def test_large_batch(self):
  107. Note.objects.bulk_create([
  108. Note(note=str(i), misc=str(i))
  109. for i in range(0, 2000)
  110. ])
  111. notes = list(Note.objects.all())
  112. Note.objects.bulk_update(notes, ['note'])
  113. def test_only_concrete_fields_allowed(self):
  114. obj = Valid.objects.create(valid='test')
  115. detail = Detail.objects.create(data='test')
  116. paragraph = Paragraph.objects.create(text='test')
  117. Member.objects.create(name='test', details=detail)
  118. msg = 'bulk_update() can only be used with concrete fields.'
  119. with self.assertRaisesMessage(ValueError, msg):
  120. Detail.objects.bulk_update([detail], fields=['member'])
  121. with self.assertRaisesMessage(ValueError, msg):
  122. Paragraph.objects.bulk_update([paragraph], fields=['page'])
  123. with self.assertRaisesMessage(ValueError, msg):
  124. Valid.objects.bulk_update([obj], fields=['parent'])
  125. def test_custom_db_columns(self):
  126. model = CustomDbColumn.objects.create(custom_column=1)
  127. model.custom_column = 2
  128. CustomDbColumn.objects.bulk_update([model], fields=['custom_column'])
  129. model.refresh_from_db()
  130. self.assertEqual(model.custom_column, 2)
  131. def test_custom_pk(self):
  132. custom_pks = [
  133. CustomPk.objects.create(name='pk-%s' % i, extra='')
  134. for i in range(10)
  135. ]
  136. for model in custom_pks:
  137. model.extra = 'extra-%s' % model.pk
  138. CustomPk.objects.bulk_update(custom_pks, ['extra'])
  139. self.assertCountEqual(
  140. CustomPk.objects.values_list('extra', flat=True),
  141. [cat.extra for cat in custom_pks]
  142. )
  143. def test_inherited_fields(self):
  144. special_categories = [
  145. SpecialCategory.objects.create(name=str(i), special_name=str(i))
  146. for i in range(10)
  147. ]
  148. for category in special_categories:
  149. category.name = 'test-%s' % category.id
  150. category.special_name = 'special-test-%s' % category.special_name
  151. SpecialCategory.objects.bulk_update(special_categories, ['name', 'special_name'])
  152. self.assertCountEqual(
  153. SpecialCategory.objects.values_list('name', flat=True),
  154. [cat.name for cat in special_categories]
  155. )
  156. self.assertCountEqual(
  157. SpecialCategory.objects.values_list('special_name', flat=True),
  158. [cat.special_name for cat in special_categories]
  159. )
  160. def test_field_references(self):
  161. numbers = [Number.objects.create(num=0) for _ in range(10)]
  162. for number in numbers:
  163. number.num = F('num') + 1
  164. Number.objects.bulk_update(numbers, ['num'])
  165. self.assertCountEqual(Number.objects.filter(num=1), numbers)
  166. def test_booleanfield(self):
  167. individuals = [Individual.objects.create(alive=False) for _ in range(10)]
  168. for individual in individuals:
  169. individual.alive = True
  170. Individual.objects.bulk_update(individuals, ['alive'])
  171. self.assertCountEqual(Individual.objects.filter(alive=True), individuals)
  172. def test_ipaddressfield(self):
  173. for ip in ('2001::1', '1.2.3.4'):
  174. with self.subTest(ip=ip):
  175. models = [
  176. CustomDbColumn.objects.create(ip_address='0.0.0.0')
  177. for _ in range(10)
  178. ]
  179. for model in models:
  180. model.ip_address = ip
  181. CustomDbColumn.objects.bulk_update(models, ['ip_address'])
  182. self.assertCountEqual(CustomDbColumn.objects.filter(ip_address=ip), models)
  183. def test_datetime_field(self):
  184. articles = [
  185. Article.objects.create(name=str(i), created=datetime.datetime.today())
  186. for i in range(10)
  187. ]
  188. point_in_time = datetime.datetime(1991, 10, 31)
  189. for article in articles:
  190. article.created = point_in_time
  191. Article.objects.bulk_update(articles, ['created'])
  192. self.assertCountEqual(Article.objects.filter(created=point_in_time), articles)