Browse Source

Fixed #34331 -- Added QuerySet.aiterator() support for prefetch_related().

John Parton 1 year ago
parent
commit
fff14736f1

+ 31 - 7
django/db/models/query.py

@@ -544,19 +544,36 @@ class QuerySet(AltersData):
         An asynchronous iterator over the results from applying this QuerySet
         An asynchronous iterator over the results from applying this QuerySet
         to the database.
         to the database.
         """
         """
-        if self._prefetch_related_lookups:
-            raise NotSupportedError(
-                "Using QuerySet.aiterator() after prefetch_related() is not supported."
-            )
         if chunk_size <= 0:
         if chunk_size <= 0:
             raise ValueError("Chunk size must be strictly positive.")
             raise ValueError("Chunk size must be strictly positive.")
         use_chunked_fetch = not connections[self.db].settings_dict.get(
         use_chunked_fetch = not connections[self.db].settings_dict.get(
             "DISABLE_SERVER_SIDE_CURSORS"
             "DISABLE_SERVER_SIDE_CURSORS"
         )
         )
-        async for item in self._iterable_class(
+        iterable = self._iterable_class(
             self, chunked_fetch=use_chunked_fetch, chunk_size=chunk_size
             self, chunked_fetch=use_chunked_fetch, chunk_size=chunk_size
-        ):
+        )
-            yield item
+        if self._prefetch_related_lookups:
+            results = []
+
+            async for item in iterable:
+                results.append(item)
+                if len(results) >= chunk_size:
+                    await aprefetch_related_objects(
+                        results, *self._prefetch_related_lookups
+                    )
+                    for result in results:
+                        yield result
+                    results.clear()
+
+            if results:
+                await aprefetch_related_objects(
+                    results, *self._prefetch_related_lookups
+                )
+                for result in results:
+                    yield result
+        else:
+            async for item in iterable:
+                yield item
 
 
     def aggregate(self, *args, **kwargs):
     def aggregate(self, *args, **kwargs):
         """
         """
@@ -2387,6 +2404,13 @@ def prefetch_related_objects(model_instances, *related_lookups):
                 obj_list = new_obj_list
                 obj_list = new_obj_list
 
 
 
 
+async def aprefetch_related_objects(model_instances, *related_lookups):
+    """See prefetch_related_objects()."""
+    return await sync_to_async(prefetch_related_objects)(
+        model_instances, *related_lookups
+    )
+
+
 def get_prefetcher(instance, through_attr, to_attr):
 def get_prefetcher(instance, through_attr, to_attr):
     """
     """
     For the attribute 'through_attr' on the given instance, find
     For the attribute 'through_attr' on the given instance, find

+ 10 - 3
docs/ref/models/querysets.txt

@@ -2579,10 +2579,10 @@ evaluated will force it to evaluate again, repeating the query.
 long as ``chunk_size`` is given. Larger values will necessitate fewer queries
 long as ``chunk_size`` is given. Larger values will necessitate fewer queries
 to accomplish the prefetching at the cost of greater memory usage.
 to accomplish the prefetching at the cost of greater memory usage.
 
 
-.. note::
+.. versionchanged:: 5.0
 
 
-    ``aiterator()`` is *not* compatible with previous calls to
+    Support for ``aiterator()`` with previous calls to ``prefetch_related()``
-    ``prefetch_related()``.
+    was added.
 
 
 On some databases (e.g. Oracle, `SQLite
 On some databases (e.g. Oracle, `SQLite
 <https://www.sqlite.org/limits.html#max_variable_number>`_), the maximum number
 <https://www.sqlite.org/limits.html#max_variable_number>`_), the maximum number
@@ -4073,6 +4073,9 @@ attribute:
 ------------------------------
 ------------------------------
 
 
 .. function:: prefetch_related_objects(model_instances, *related_lookups)
 .. function:: prefetch_related_objects(model_instances, *related_lookups)
+.. function:: aprefetch_related_objects(model_instances, *related_lookups)
+
+*Asynchronous version*: ``aprefetch_related_objects()``
 
 
 Prefetches the given lookups on an iterable of model instances. This is useful
 Prefetches the given lookups on an iterable of model instances. This is useful
 in code that receives a list of model instances as opposed to a ``QuerySet``;
 in code that receives a list of model instances as opposed to a ``QuerySet``;
@@ -4091,6 +4094,10 @@ When using multiple databases with ``prefetch_related_objects``, the prefetch
 query will use the database associated with the model instance. This can be
 query will use the database associated with the model instance. This can be
 overridden by using a custom queryset in a related lookup.
 overridden by using a custom queryset in a related lookup.
 
 
+.. versionchanged:: 5.0
+
+    ``aprefetch_related_objects()`` function was added.
+
 ``FilteredRelation()`` objects
 ``FilteredRelation()`` objects
 ------------------------------
 ------------------------------
 
 

+ 6 - 0
docs/releases/5.0.txt

@@ -368,6 +368,12 @@ Models
   :func:`~django.shortcuts.aget_list_or_404` asynchronous shortcuts allow
   :func:`~django.shortcuts.aget_list_or_404` asynchronous shortcuts allow
   asynchronous getting objects.
   asynchronous getting objects.
 
 
+* The new :func:`~django.db.models.aprefetch_related_objects` function allows
+  asynchronous prefetching of model instances.
+
+* :meth:`.QuerySet.aiterator` now supports previous calls to
+  ``prefetch_related()``.
+
 Pagination
 Pagination
 ~~~~~~~~~~
 ~~~~~~~~~~
 
 

+ 11 - 7
tests/async/test_async_queryset.py

@@ -5,10 +5,10 @@ from datetime import datetime
 from asgiref.sync import async_to_sync, sync_to_async
 from asgiref.sync import async_to_sync, sync_to_async
 
 
 from django.db import NotSupportedError, connection
 from django.db import NotSupportedError, connection
-from django.db.models import Sum
+from django.db.models import Prefetch, Sum
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 
 
-from .models import SimpleModel
+from .models import RelatedModel, SimpleModel
 
 
 
 
 class AsyncQuerySetTest(TestCase):
 class AsyncQuerySetTest(TestCase):
@@ -26,6 +26,9 @@ class AsyncQuerySetTest(TestCase):
             field=3,
             field=3,
             created=datetime(2022, 1, 1, 0, 0, 2),
             created=datetime(2022, 1, 1, 0, 0, 2),
         )
         )
+        cls.r1 = RelatedModel.objects.create(simple=cls.s1)
+        cls.r2 = RelatedModel.objects.create(simple=cls.s2)
+        cls.r3 = RelatedModel.objects.create(simple=cls.s3)
 
 
     @staticmethod
     @staticmethod
     def _get_db_feature(connection_, feature_name):
     def _get_db_feature(connection_, feature_name):
@@ -48,11 +51,12 @@ class AsyncQuerySetTest(TestCase):
         self.assertCountEqual(results, [self.s1, self.s2, self.s3])
         self.assertCountEqual(results, [self.s1, self.s2, self.s3])
 
 
     async def test_aiterator_prefetch_related(self):
     async def test_aiterator_prefetch_related(self):
-        qs = SimpleModel.objects.prefetch_related("relatedmodels").aiterator()
+        results = []
-        msg = "Using QuerySet.aiterator() after prefetch_related() is not supported."
+        async for s in SimpleModel.objects.prefetch_related(
-        with self.assertRaisesMessage(NotSupportedError, msg):
+            Prefetch("relatedmodel_set", to_attr="prefetched_relatedmodel")
-            async for m in qs:
+        ).aiterator():
-                pass
+            results.append(s.prefetched_relatedmodel)
+        self.assertCountEqual(results, [[self.r1], [self.r2], [self.r3]])
 
 
     async def test_aiterator_invalid_chunk_size(self):
     async def test_aiterator_invalid_chunk_size(self):
         msg = "Chunk size must be strictly positive."
         msg = "Chunk size must be strictly positive."