tests.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import json
  2. import xml.etree.ElementTree
  3. from datetime import datetime
  4. from asgiref.sync import async_to_sync, sync_to_async
  5. from django.db import NotSupportedError, connection
  6. from django.db.models import Sum
  7. from django.test import TestCase, skipUnlessDBFeature
  8. from .models import SimpleModel
  9. class AsyncQuerySetTest(TestCase):
  10. @classmethod
  11. def setUpTestData(cls):
  12. cls.s1 = SimpleModel.objects.create(
  13. field=1,
  14. created=datetime(2022, 1, 1, 0, 0, 0),
  15. )
  16. cls.s2 = SimpleModel.objects.create(
  17. field=2,
  18. created=datetime(2022, 1, 1, 0, 0, 1),
  19. )
  20. cls.s3 = SimpleModel.objects.create(
  21. field=3,
  22. created=datetime(2022, 1, 1, 0, 0, 2),
  23. )
  24. @staticmethod
  25. def _get_db_feature(connection_, feature_name):
  26. # Wrapper to avoid accessing connection attributes until inside
  27. # coroutine function. Connection access is thread sensitive and cannot
  28. # be passed across sync/async boundaries.
  29. return getattr(connection_.features, feature_name)
  30. async def test_async_iteration(self):
  31. results = []
  32. async for m in SimpleModel.objects.order_by("pk"):
  33. results.append(m)
  34. self.assertEqual(results, [self.s1, self.s2, self.s3])
  35. async def test_aiterator(self):
  36. qs = SimpleModel.objects.aiterator()
  37. results = []
  38. async for m in qs:
  39. results.append(m)
  40. self.assertCountEqual(results, [self.s1, self.s2, self.s3])
  41. async def test_aiterator_prefetch_related(self):
  42. qs = SimpleModel.objects.prefetch_related("relatedmodels").aiterator()
  43. msg = "Using QuerySet.aiterator() after prefetch_related() is not supported."
  44. with self.assertRaisesMessage(NotSupportedError, msg):
  45. async for m in qs:
  46. pass
  47. async def test_aiterator_invalid_chunk_size(self):
  48. msg = "Chunk size must be strictly positive."
  49. for size in [0, -1]:
  50. qs = SimpleModel.objects.aiterator(chunk_size=size)
  51. with self.subTest(size=size), self.assertRaisesMessage(ValueError, msg):
  52. async for m in qs:
  53. pass
  54. async def test_acount(self):
  55. count = await SimpleModel.objects.acount()
  56. self.assertEqual(count, 3)
  57. async def test_acount_cached_result(self):
  58. qs = SimpleModel.objects.all()
  59. # Evaluate the queryset to populate the query cache.
  60. [x async for x in qs]
  61. count = await qs.acount()
  62. self.assertEqual(count, 3)
  63. await sync_to_async(SimpleModel.objects.create)(
  64. field=4,
  65. created=datetime(2022, 1, 1, 0, 0, 0),
  66. )
  67. # The query cache is used.
  68. count = await qs.acount()
  69. self.assertEqual(count, 3)
  70. async def test_aget(self):
  71. instance = await SimpleModel.objects.aget(field=1)
  72. self.assertEqual(instance, self.s1)
  73. async def test_acreate(self):
  74. await SimpleModel.objects.acreate(field=4)
  75. self.assertEqual(await SimpleModel.objects.acount(), 4)
  76. async def test_aget_or_create(self):
  77. instance, created = await SimpleModel.objects.aget_or_create(field=4)
  78. self.assertEqual(await SimpleModel.objects.acount(), 4)
  79. self.assertIs(created, True)
  80. async def test_aupdate_or_create(self):
  81. instance, created = await SimpleModel.objects.aupdate_or_create(
  82. id=self.s1.id, defaults={"field": 2}
  83. )
  84. self.assertEqual(instance, self.s1)
  85. self.assertIs(created, False)
  86. instance, created = await SimpleModel.objects.aupdate_or_create(field=4)
  87. self.assertEqual(await SimpleModel.objects.acount(), 4)
  88. self.assertIs(created, True)
  89. @skipUnlessDBFeature("has_bulk_insert")
  90. @async_to_sync
  91. async def test_abulk_create(self):
  92. instances = [SimpleModel(field=i) for i in range(10)]
  93. qs = await SimpleModel.objects.abulk_create(instances)
  94. self.assertEqual(len(qs), 10)
  95. async def test_abulk_update(self):
  96. instances = SimpleModel.objects.all()
  97. async for instance in instances:
  98. instance.field = instance.field * 10
  99. await SimpleModel.objects.abulk_update(instances, ["field"])
  100. qs = [(o.pk, o.field) async for o in SimpleModel.objects.all()]
  101. self.assertCountEqual(
  102. qs,
  103. [(self.s1.pk, 10), (self.s2.pk, 20), (self.s3.pk, 30)],
  104. )
  105. async def test_ain_bulk(self):
  106. res = await SimpleModel.objects.ain_bulk()
  107. self.assertEqual(
  108. res,
  109. {self.s1.pk: self.s1, self.s2.pk: self.s2, self.s3.pk: self.s3},
  110. )
  111. res = await SimpleModel.objects.ain_bulk([self.s2.pk])
  112. self.assertEqual(res, {self.s2.pk: self.s2})
  113. res = await SimpleModel.objects.ain_bulk([self.s2.pk], field_name="id")
  114. self.assertEqual(res, {self.s2.pk: self.s2})
  115. async def test_alatest(self):
  116. instance = await SimpleModel.objects.alatest("created")
  117. self.assertEqual(instance, self.s3)
  118. instance = await SimpleModel.objects.alatest("-created")
  119. self.assertEqual(instance, self.s1)
  120. async def test_aearliest(self):
  121. instance = await SimpleModel.objects.aearliest("created")
  122. self.assertEqual(instance, self.s1)
  123. instance = await SimpleModel.objects.aearliest("-created")
  124. self.assertEqual(instance, self.s3)
  125. async def test_afirst(self):
  126. instance = await SimpleModel.objects.afirst()
  127. self.assertEqual(instance, self.s1)
  128. instance = await SimpleModel.objects.filter(field=4).afirst()
  129. self.assertIsNone(instance)
  130. async def test_alast(self):
  131. instance = await SimpleModel.objects.alast()
  132. self.assertEqual(instance, self.s3)
  133. instance = await SimpleModel.objects.filter(field=4).alast()
  134. self.assertIsNone(instance)
  135. async def test_aaggregate(self):
  136. total = await SimpleModel.objects.aaggregate(total=Sum("field"))
  137. self.assertEqual(total, {"total": 6})
  138. async def test_aexists(self):
  139. check = await SimpleModel.objects.filter(field=1).aexists()
  140. self.assertIs(check, True)
  141. check = await SimpleModel.objects.filter(field=4).aexists()
  142. self.assertIs(check, False)
  143. async def test_acontains(self):
  144. check = await SimpleModel.objects.acontains(self.s1)
  145. self.assertIs(check, True)
  146. # Unsaved instances are not allowed, so use an ID known not to exist.
  147. check = await SimpleModel.objects.acontains(
  148. SimpleModel(id=self.s3.id + 1, field=4)
  149. )
  150. self.assertIs(check, False)
  151. async def test_aupdate(self):
  152. await SimpleModel.objects.aupdate(field=99)
  153. qs = [o async for o in SimpleModel.objects.all()]
  154. values = [instance.field for instance in qs]
  155. self.assertEqual(set(values), {99})
  156. async def test_adelete(self):
  157. await SimpleModel.objects.filter(field=2).adelete()
  158. qs = [o async for o in SimpleModel.objects.all()]
  159. self.assertCountEqual(qs, [self.s1, self.s3])
  160. @skipUnlessDBFeature("supports_explaining_query_execution")
  161. @async_to_sync
  162. async def test_aexplain(self):
  163. supported_formats = await sync_to_async(self._get_db_feature)(
  164. connection, "supported_explain_formats"
  165. )
  166. all_formats = (None, *supported_formats)
  167. for format_ in all_formats:
  168. with self.subTest(format=format_):
  169. # TODO: Check the captured query when async versions of
  170. # self.assertNumQueries/CaptureQueriesContext context
  171. # processors are available.
  172. result = await SimpleModel.objects.filter(field=1).aexplain(
  173. format=format_
  174. )
  175. self.assertIsInstance(result, str)
  176. self.assertTrue(result)
  177. if not format_:
  178. continue
  179. if format_.lower() == "xml":
  180. try:
  181. xml.etree.ElementTree.fromstring(result)
  182. except xml.etree.ElementTree.ParseError as e:
  183. self.fail(f"QuerySet.aexplain() result is not valid XML: {e}")
  184. elif format_.lower() == "json":
  185. try:
  186. json.loads(result)
  187. except json.JSONDecodeError as e:
  188. self.fail(f"QuerySet.aexplain() result is not valid JSON: {e}")