test_async_queryset.py 9.8 KB


  1. import json
  2. import xml.etree.ElementTree
  3. from datetime import datetime
  4. from asgiref.sync import async_to_sync, sync_to_async
  5. from django.db import NotSupportedError, connection
  6. from django.db.models import Prefetch, Sum
  7. from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
  8. from .models import RelatedModel, SimpleModel
  9. class AsyncQuerySetTest(TestCase):
  10. @classmethod
  11. def setUpTestData(cls):
  12. cls.s1 = SimpleModel.objects.create(
  13. field=1,
  14. created=datetime(2022, 1, 1, 0, 0, 0),
  15. )
  16. cls.s2 = SimpleModel.objects.create(
  17. field=2,
  18. created=datetime(2022, 1, 1, 0, 0, 1),
  19. )
  20. cls.s3 = SimpleModel.objects.create(
  21. field=3,
  22. created=datetime(2022, 1, 1, 0, 0, 2),
  23. )
  24. cls.r1 = RelatedModel.objects.create(simple=cls.s1)
  25. cls.r2 = RelatedModel.objects.create(simple=cls.s2)
  26. cls.r3 = RelatedModel.objects.create(simple=cls.s3)
  27. @staticmethod
  28. def _get_db_feature(connection_, feature_name):
  29. # Wrapper to avoid accessing connection attributes until inside
  30. # coroutine function. Connection access is thread sensitive and cannot
  31. # be passed across sync/async boundaries.
  32. return getattr(connection_.features, feature_name)
  33. async def test_async_iteration(self):
  34. results = []
  35. async for m in SimpleModel.objects.order_by("pk"):
  36. results.append(m)
  37. self.assertEqual(results, [self.s1, self.s2, self.s3])
  38. async def test_aiterator(self):
  39. qs = SimpleModel.objects.aiterator()
  40. results = []
  41. async for m in qs:
  42. results.append(m)
  43. self.assertCountEqual(results, [self.s1, self.s2, self.s3])
  44. async def test_aiterator_prefetch_related(self):
  45. results = []
  46. async for s in SimpleModel.objects.prefetch_related(
  47. Prefetch("relatedmodel_set", to_attr="prefetched_relatedmodel")
  48. ).aiterator():
  49. results.append(s.prefetched_relatedmodel)
  50. self.assertCountEqual(results, [[self.r1], [self.r2], [self.r3]])
  51. async def test_aiterator_invalid_chunk_size(self):
  52. msg = "Chunk size must be strictly positive."
  53. for size in [0, -1]:
  54. qs = SimpleModel.objects.aiterator(chunk_size=size)
  55. with self.subTest(size=size), self.assertRaisesMessage(ValueError, msg):
  56. async for m in qs:
  57. pass
  58. async def test_acount(self):
  59. count = await SimpleModel.objects.acount()
  60. self.assertEqual(count, 3)
  61. async def test_acount_cached_result(self):
  62. qs = SimpleModel.objects.all()
  63. # Evaluate the queryset to populate the query cache.
  64. [x async for x in qs]
  65. count = await qs.acount()
  66. self.assertEqual(count, 3)
  67. await sync_to_async(SimpleModel.objects.create)(
  68. field=4,
  69. created=datetime(2022, 1, 1, 0, 0, 0),
  70. )
  71. # The query cache is used.
  72. count = await qs.acount()
  73. self.assertEqual(count, 3)
  74. async def test_aget(self):
  75. instance = await SimpleModel.objects.aget(field=1)
  76. self.assertEqual(instance, self.s1)
  77. async def test_acreate(self):
  78. await SimpleModel.objects.acreate(field=4)
  79. self.assertEqual(await SimpleModel.objects.acount(), 4)
  80. async def test_aget_or_create(self):
  81. instance, created = await SimpleModel.objects.aget_or_create(field=4)
  82. self.assertEqual(await SimpleModel.objects.acount(), 4)
  83. self.assertIs(created, True)
  84. async def test_aupdate_or_create(self):
  85. instance, created = await SimpleModel.objects.aupdate_or_create(
  86. id=self.s1.id, defaults={"field": 2}
  87. )
  88. self.assertEqual(instance, self.s1)
  89. self.assertEqual(instance.field, 2)
  90. self.assertIs(created, False)
  91. instance, created = await SimpleModel.objects.aupdate_or_create(field=4)
  92. self.assertEqual(await SimpleModel.objects.acount(), 4)
  93. self.assertIs(created, True)
  94. instance, created = await SimpleModel.objects.aupdate_or_create(
  95. field=5, defaults={"field": 7}, create_defaults={"field": 6}
  96. )
  97. self.assertEqual(await SimpleModel.objects.acount(), 5)
  98. self.assertIs(created, True)
  99. self.assertEqual(instance.field, 6)
  100. @skipUnlessDBFeature("has_bulk_insert")
  101. @async_to_sync
  102. async def test_abulk_create(self):
  103. instances = [SimpleModel(field=i) for i in range(10)]
  104. qs = await SimpleModel.objects.abulk_create(instances)
  105. self.assertEqual(len(qs), 10)
  106. @skipUnlessDBFeature("has_bulk_insert", "supports_update_conflicts")
  107. @skipIfDBFeature("supports_update_conflicts_with_target")
  108. @async_to_sync
  109. async def test_update_conflicts_unique_field_unsupported(self):
  110. msg = (
  111. "This database backend does not support updating conflicts with specifying "
  112. "unique fields that can trigger the upsert."
  113. )
  114. with self.assertRaisesMessage(NotSupportedError, msg):
  115. await SimpleModel.objects.abulk_create(
  116. [SimpleModel(field=1), SimpleModel(field=2)],
  117. update_conflicts=True,
  118. update_fields=["field"],
  119. unique_fields=["created"],
  120. )
  121. async def test_abulk_update(self):
  122. instances = SimpleModel.objects.all()
  123. async for instance in instances:
  124. instance.field = instance.field * 10
  125. await SimpleModel.objects.abulk_update(instances, ["field"])
  126. qs = [(o.pk, o.field) async for o in SimpleModel.objects.all()]
  127. self.assertCountEqual(
  128. qs,
  129. [(self.s1.pk, 10), (self.s2.pk, 20), (self.s3.pk, 30)],
  130. )
  131. async def test_ain_bulk(self):
  132. res = await SimpleModel.objects.ain_bulk()
  133. self.assertEqual(
  134. res,
  135. {self.s1.pk: self.s1, self.s2.pk: self.s2, self.s3.pk: self.s3},
  136. )
  137. res = await SimpleModel.objects.ain_bulk([self.s2.pk])
  138. self.assertEqual(res, {self.s2.pk: self.s2})
  139. res = await SimpleModel.objects.ain_bulk([self.s2.pk], field_name="id")
  140. self.assertEqual(res, {self.s2.pk: self.s2})
  141. async def test_alatest(self):
  142. instance = await SimpleModel.objects.alatest("created")
  143. self.assertEqual(instance, self.s3)
  144. instance = await SimpleModel.objects.alatest("-created")
  145. self.assertEqual(instance, self.s1)
  146. async def test_aearliest(self):
  147. instance = await SimpleModel.objects.aearliest("created")
  148. self.assertEqual(instance, self.s1)
  149. instance = await SimpleModel.objects.aearliest("-created")
  150. self.assertEqual(instance, self.s3)
  151. async def test_afirst(self):
  152. instance = await SimpleModel.objects.afirst()
  153. self.assertEqual(instance, self.s1)
  154. instance = await SimpleModel.objects.filter(field=4).afirst()
  155. self.assertIsNone(instance)
  156. async def test_alast(self):
  157. instance = await SimpleModel.objects.alast()
  158. self.assertEqual(instance, self.s3)
  159. instance = await SimpleModel.objects.filter(field=4).alast()
  160. self.assertIsNone(instance)
  161. async def test_aaggregate(self):
  162. total = await SimpleModel.objects.aaggregate(total=Sum("field"))
  163. self.assertEqual(total, {"total": 6})
  164. async def test_aexists(self):
  165. check = await SimpleModel.objects.filter(field=1).aexists()
  166. self.assertIs(check, True)
  167. check = await SimpleModel.objects.filter(field=4).aexists()
  168. self.assertIs(check, False)
  169. async def test_acontains(self):
  170. check = await SimpleModel.objects.acontains(self.s1)
  171. self.assertIs(check, True)
  172. # Unsaved instances are not allowed, so use an ID known not to exist.
  173. check = await SimpleModel.objects.acontains(
  174. SimpleModel(id=self.s3.id + 1, field=4)
  175. )
  176. self.assertIs(check, False)
  177. async def test_aupdate(self):
  178. await SimpleModel.objects.aupdate(field=99)
  179. qs = [o async for o in SimpleModel.objects.all()]
  180. values = [instance.field for instance in qs]
  181. self.assertEqual(set(values), {99})
  182. async def test_adelete(self):
  183. await SimpleModel.objects.filter(field=2).adelete()
  184. qs = [o async for o in SimpleModel.objects.all()]
  185. self.assertCountEqual(qs, [self.s1, self.s3])
  186. @skipUnlessDBFeature("supports_explaining_query_execution")
  187. @async_to_sync
  188. async def test_aexplain(self):
  189. supported_formats = await sync_to_async(self._get_db_feature)(
  190. connection, "supported_explain_formats"
  191. )
  192. all_formats = (None, *supported_formats)
  193. for format_ in all_formats:
  194. with self.subTest(format=format_):
  195. # TODO: Check the captured query when async versions of
  196. # self.assertNumQueries/CaptureQueriesContext context
  197. # processors are available.
  198. result = await SimpleModel.objects.filter(field=1).aexplain(
  199. format=format_
  200. )
  201. self.assertIsInstance(result, str)
  202. self.assertTrue(result)
  203. if not format_:
  204. continue
  205. if format_.lower() == "xml":
  206. try:
  207. xml.etree.ElementTree.fromstring(result)
  208. except xml.etree.ElementTree.ParseError as e:
  209. self.fail(f"QuerySet.aexplain() result is not valid XML: {e}")
  210. elif format_.lower() == "json":
  211. try:
  212. json.loads(result)
  213. except json.JSONDecodeError as e:
  214. self.fail(f"QuerySet.aexplain() result is not valid JSON: {e}")
  215. async def test_raw(self):
  216. sql = "SELECT id, field FROM async_simplemodel WHERE created=%s"
  217. qs = SimpleModel.objects.raw(sql, [self.s1.created])
  218. self.assertEqual([o async for o in qs], [self.s1])