123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- import json
- import xml.etree.ElementTree
- from datetime import datetime
- from asgiref.sync import async_to_sync, sync_to_async
- from django.db import NotSupportedError, connection
- from django.db.models import Prefetch, Sum
- from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
- from .models import RelatedModel, SimpleModel
- class AsyncQuerySetTest(TestCase):
- @classmethod
- def setUpTestData(cls):
- cls.s1 = SimpleModel.objects.create(
- field=1,
- created=datetime(2022, 1, 1, 0, 0, 0),
- )
- cls.s2 = SimpleModel.objects.create(
- field=2,
- created=datetime(2022, 1, 1, 0, 0, 1),
- )
- cls.s3 = SimpleModel.objects.create(
- field=3,
- created=datetime(2022, 1, 1, 0, 0, 2),
- )
- cls.r1 = RelatedModel.objects.create(simple=cls.s1)
- cls.r2 = RelatedModel.objects.create(simple=cls.s2)
- cls.r3 = RelatedModel.objects.create(simple=cls.s3)
- @staticmethod
- def _get_db_feature(connection_, feature_name):
- # Wrapper to avoid accessing connection attributes until inside
- # coroutine function. Connection access is thread sensitive and cannot
- # be passed across sync/async boundaries.
- return getattr(connection_.features, feature_name)
- async def test_async_iteration(self):
- results = []
- async for m in SimpleModel.objects.order_by("pk"):
- results.append(m)
- self.assertEqual(results, [self.s1, self.s2, self.s3])
- async def test_aiterator(self):
- qs = SimpleModel.objects.aiterator()
- results = []
- async for m in qs:
- results.append(m)
- self.assertCountEqual(results, [self.s1, self.s2, self.s3])
- async def test_aiterator_prefetch_related(self):
- results = []
- async for s in SimpleModel.objects.prefetch_related(
- Prefetch("relatedmodel_set", to_attr="prefetched_relatedmodel")
- ).aiterator():
- results.append(s.prefetched_relatedmodel)
- self.assertCountEqual(results, [[self.r1], [self.r2], [self.r3]])
- async def test_aiterator_invalid_chunk_size(self):
- msg = "Chunk size must be strictly positive."
- for size in [0, -1]:
- qs = SimpleModel.objects.aiterator(chunk_size=size)
- with self.subTest(size=size), self.assertRaisesMessage(ValueError, msg):
- async for m in qs:
- pass
- async def test_acount(self):
- count = await SimpleModel.objects.acount()
- self.assertEqual(count, 3)
- async def test_acount_cached_result(self):
- qs = SimpleModel.objects.all()
- # Evaluate the queryset to populate the query cache.
- [x async for x in qs]
- count = await qs.acount()
- self.assertEqual(count, 3)
- await sync_to_async(SimpleModel.objects.create)(
- field=4,
- created=datetime(2022, 1, 1, 0, 0, 0),
- )
- # The query cache is used.
- count = await qs.acount()
- self.assertEqual(count, 3)
- async def test_aget(self):
- instance = await SimpleModel.objects.aget(field=1)
- self.assertEqual(instance, self.s1)
- async def test_acreate(self):
- await SimpleModel.objects.acreate(field=4)
- self.assertEqual(await SimpleModel.objects.acount(), 4)
- async def test_aget_or_create(self):
- instance, created = await SimpleModel.objects.aget_or_create(field=4)
- self.assertEqual(await SimpleModel.objects.acount(), 4)
- self.assertIs(created, True)
- async def test_aupdate_or_create(self):
- instance, created = await SimpleModel.objects.aupdate_or_create(
- id=self.s1.id, defaults={"field": 2}
- )
- self.assertEqual(instance, self.s1)
- self.assertEqual(instance.field, 2)
- self.assertIs(created, False)
- instance, created = await SimpleModel.objects.aupdate_or_create(field=4)
- self.assertEqual(await SimpleModel.objects.acount(), 4)
- self.assertIs(created, True)
- instance, created = await SimpleModel.objects.aupdate_or_create(
- field=5, defaults={"field": 7}, create_defaults={"field": 6}
- )
- self.assertEqual(await SimpleModel.objects.acount(), 5)
- self.assertIs(created, True)
- self.assertEqual(instance.field, 6)
- @skipUnlessDBFeature("has_bulk_insert")
- @async_to_sync
- async def test_abulk_create(self):
- instances = [SimpleModel(field=i) for i in range(10)]
- qs = await SimpleModel.objects.abulk_create(instances)
- self.assertEqual(len(qs), 10)
- @skipUnlessDBFeature("has_bulk_insert", "supports_update_conflicts")
- @skipIfDBFeature("supports_update_conflicts_with_target")
- @async_to_sync
- async def test_update_conflicts_unique_field_unsupported(self):
- msg = (
- "This database backend does not support updating conflicts with specifying "
- "unique fields that can trigger the upsert."
- )
- with self.assertRaisesMessage(NotSupportedError, msg):
- await SimpleModel.objects.abulk_create(
- [SimpleModel(field=1), SimpleModel(field=2)],
- update_conflicts=True,
- update_fields=["field"],
- unique_fields=["created"],
- )
- async def test_abulk_update(self):
- instances = SimpleModel.objects.all()
- async for instance in instances:
- instance.field = instance.field * 10
- await SimpleModel.objects.abulk_update(instances, ["field"])
- qs = [(o.pk, o.field) async for o in SimpleModel.objects.all()]
- self.assertCountEqual(
- qs,
- [(self.s1.pk, 10), (self.s2.pk, 20), (self.s3.pk, 30)],
- )
- async def test_ain_bulk(self):
- res = await SimpleModel.objects.ain_bulk()
- self.assertEqual(
- res,
- {self.s1.pk: self.s1, self.s2.pk: self.s2, self.s3.pk: self.s3},
- )
- res = await SimpleModel.objects.ain_bulk([self.s2.pk])
- self.assertEqual(res, {self.s2.pk: self.s2})
- res = await SimpleModel.objects.ain_bulk([self.s2.pk], field_name="id")
- self.assertEqual(res, {self.s2.pk: self.s2})
- async def test_alatest(self):
- instance = await SimpleModel.objects.alatest("created")
- self.assertEqual(instance, self.s3)
- instance = await SimpleModel.objects.alatest("-created")
- self.assertEqual(instance, self.s1)
- async def test_aearliest(self):
- instance = await SimpleModel.objects.aearliest("created")
- self.assertEqual(instance, self.s1)
- instance = await SimpleModel.objects.aearliest("-created")
- self.assertEqual(instance, self.s3)
- async def test_afirst(self):
- instance = await SimpleModel.objects.afirst()
- self.assertEqual(instance, self.s1)
- instance = await SimpleModel.objects.filter(field=4).afirst()
- self.assertIsNone(instance)
- async def test_alast(self):
- instance = await SimpleModel.objects.alast()
- self.assertEqual(instance, self.s3)
- instance = await SimpleModel.objects.filter(field=4).alast()
- self.assertIsNone(instance)
- async def test_aaggregate(self):
- total = await SimpleModel.objects.aaggregate(total=Sum("field"))
- self.assertEqual(total, {"total": 6})
- async def test_aexists(self):
- check = await SimpleModel.objects.filter(field=1).aexists()
- self.assertIs(check, True)
- check = await SimpleModel.objects.filter(field=4).aexists()
- self.assertIs(check, False)
- async def test_acontains(self):
- check = await SimpleModel.objects.acontains(self.s1)
- self.assertIs(check, True)
- # Unsaved instances are not allowed, so use an ID known not to exist.
- check = await SimpleModel.objects.acontains(
- SimpleModel(id=self.s3.id + 1, field=4)
- )
- self.assertIs(check, False)
- async def test_aupdate(self):
- await SimpleModel.objects.aupdate(field=99)
- qs = [o async for o in SimpleModel.objects.all()]
- values = [instance.field for instance in qs]
- self.assertEqual(set(values), {99})
- async def test_adelete(self):
- await SimpleModel.objects.filter(field=2).adelete()
- qs = [o async for o in SimpleModel.objects.all()]
- self.assertCountEqual(qs, [self.s1, self.s3])
- @skipUnlessDBFeature("supports_explaining_query_execution")
- @async_to_sync
- async def test_aexplain(self):
- supported_formats = await sync_to_async(self._get_db_feature)(
- connection, "supported_explain_formats"
- )
- all_formats = (None, *supported_formats)
- for format_ in all_formats:
- with self.subTest(format=format_):
- # TODO: Check the captured query when async versions of
- # self.assertNumQueries/CaptureQueriesContext context
- # processors are available.
- result = await SimpleModel.objects.filter(field=1).aexplain(
- format=format_
- )
- self.assertIsInstance(result, str)
- self.assertTrue(result)
- if not format_:
- continue
- if format_.lower() == "xml":
- try:
- xml.etree.ElementTree.fromstring(result)
- except xml.etree.ElementTree.ParseError as e:
- self.fail(f"QuerySet.aexplain() result is not valid XML: {e}")
- elif format_.lower() == "json":
- try:
- json.loads(result)
- except json.JSONDecodeError as e:
- self.fail(f"QuerySet.aexplain() result is not valid JSON: {e}")
- async def test_raw(self):
- sql = "SELECT id, field FROM async_simplemodel WHERE created=%s"
- qs = SimpleModel.objects.raw(sql, [self.s1.created])
- self.assertEqual([o async for o in qs], [self.s1])
|