Browse Source

Fixed #34135 -- Added async-compatible interface to related managers.

Jon Janzen 2 years ago
parent
commit
321ecb40f4

+ 20 - 0
django/contrib/contenttypes/fields.py

@@ -689,6 +689,11 @@ def create_generic_related_manager(superclass, rel):
 
         add.alters_data = True
 
+        async def aadd(self, *objs, bulk=True):
+            return await sync_to_async(self.add)(*objs, bulk=bulk)
+
+        aadd.alters_data = True
+
         def remove(self, *objs, bulk=True):
             if not objs:
                 return
@@ -696,11 +701,21 @@ def create_generic_related_manager(superclass, rel):
 
         remove.alters_data = True
 
+        async def aremove(self, *objs, bulk=True):
+            return await sync_to_async(self.remove)(*objs, bulk=bulk)
+
+        aremove.alters_data = True
+
         def clear(self, *, bulk=True):
             self._clear(self, bulk)
 
         clear.alters_data = True
 
+        async def aclear(self, *, bulk=True):
+            return await sync_to_async(self.clear)(bulk=bulk)
+
+        aclear.alters_data = True
+
         def _clear(self, queryset, bulk):
             self._remove_prefetched_objects()
             db = router.db_for_write(self.model, instance=self.instance)
@@ -740,6 +755,11 @@ def create_generic_related_manager(superclass, rel):
 
         set.alters_data = True
 
+        async def aset(self, objs, *, bulk=True, clear=False):
+            return await sync_to_async(self.set)(objs, bulk=bulk, clear=clear)
+
+        aset.alters_data = True
+
         def create(self, **kwargs):
             self._remove_prefetched_objects()
             kwargs[self.content_type_field_name] = self.content_type

+ 44 - 0
django/db/models/fields/related_descriptors.py

@@ -787,6 +787,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
 
         add.alters_data = True
 
+        async def aadd(self, *objs, bulk=True):
+            return await sync_to_async(self.add)(*objs, bulk=bulk)
+
+        aadd.alters_data = True
+
         def create(self, **kwargs):
             self._check_fk_val()
             kwargs[self.field.name] = self.instance
@@ -856,12 +861,22 @@ def create_reverse_many_to_one_manager(superclass, rel):
 
             remove.alters_data = True
 
+            async def aremove(self, *objs, bulk=True):
+                return await sync_to_async(self.remove)(*objs, bulk=bulk)
+
+            aremove.alters_data = True
+
             def clear(self, *, bulk=True):
                 self._check_fk_val()
                 self._clear(self, bulk)
 
             clear.alters_data = True
 
+            async def aclear(self, *, bulk=True):
+                return await sync_to_async(self.clear)(bulk=bulk)
+
+            aclear.alters_data = True
+
             def _clear(self, queryset, bulk):
                 self._remove_prefetched_objects()
                 db = router.db_for_write(self.model, instance=self.instance)
@@ -905,6 +920,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
 
         set.alters_data = True
 
+        async def aset(self, objs, *, bulk=True, clear=False):
+            return await sync_to_async(self.set)(objs=objs, bulk=bulk, clear=clear)
+
+        aset.alters_data = True
+
     return RelatedManager
 
 
@@ -1132,12 +1152,24 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
 
         add.alters_data = True
 
+        async def aadd(self, *objs, through_defaults=None):
+            return await sync_to_async(self.add)(
+                *objs, through_defaults=through_defaults
+            )
+
+        aadd.alters_data = True
+
         def remove(self, *objs):
             self._remove_prefetched_objects()
             self._remove_items(self.source_field_name, self.target_field_name, *objs)
 
         remove.alters_data = True
 
+        async def aremove(self, *objs):
+            return await sync_to_async(self.remove)(*objs)
+
+        aremove.alters_data = True
+
         def clear(self):
             db = router.db_for_write(self.through, instance=self.instance)
             with transaction.atomic(using=db, savepoint=False):
@@ -1166,6 +1198,11 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
 
         clear.alters_data = True
 
+        async def aclear(self):
+            return await sync_to_async(self.clear)()
+
+        aclear.alters_data = True
+
         def set(self, objs, *, clear=False, through_defaults=None):
             # Force evaluation of `objs` in case it's a queryset whose value
             # could be affected by `manager.clear()`. Refs #19816.
@@ -1200,6 +1237,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
 
         set.alters_data = True
 
+        async def aset(self, objs, *, clear=False, through_defaults=None):
+            return await sync_to_async(self.set)(
+                objs=objs, clear=clear, through_defaults=through_defaults
+            )
+
+        aset.alters_data = True
+
         def create(self, *, through_defaults=None, **kwargs):
             db = router.db_for_write(self.instance.__class__, instance=self.instance)
             new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)

+ 36 - 6
docs/ref/models/relations.txt

@@ -37,6 +37,9 @@ Related objects reference
       ``topping.pizza_set`` and on ``pizza.toppings``.
 
     .. method:: add(*objs, bulk=True, through_defaults=None)
+    .. method:: aadd(*objs, bulk=True, through_defaults=None)
+
+        *Asynchronous version*: ``aadd``
 
         Adds the specified model objects to the related object set.
 
@@ -75,6 +78,10 @@ Related objects reference
         dictionary and they will be evaluated once before creating any
         intermediate instance(s).
 
+        .. versionchanged:: 4.2
+
+            ``aadd()`` method was added.
+
     .. method:: create(through_defaults=None, **kwargs)
     .. method:: acreate(through_defaults=None, **kwargs)
 
@@ -118,6 +125,9 @@ Related objects reference
             ``acreate()`` method was added.
 
     .. method:: remove(*objs, bulk=True)
+    .. method:: aremove(*objs, bulk=True)
+
+        *Asynchronous version*: ``aremove``
 
         Removes the specified model objects from the related object set::
 
@@ -157,7 +167,14 @@ Related objects reference
         For many-to-many relationships, the ``bulk`` keyword argument doesn't
         exist.
 
+        .. versionchanged:: 4.2
+
+            ``aremove()`` method was added.
+
     .. method:: clear(bulk=True)
+    .. method:: aclear(bulk=True)
+
+        *Asynchronous version*: ``aclear``
 
         Removes all objects from the related object set::
 
@@ -174,7 +191,14 @@ Related objects reference
         For many-to-many relationships, the ``bulk`` keyword argument doesn't
         exist.
 
+        .. versionchanged:: 4.2
+
+            ``aclear()`` method was added.
+
     .. method:: set(objs, bulk=True, clear=False, through_defaults=None)
+    .. method:: aset(objs, bulk=True, clear=False, through_defaults=None)
+
+        *Asynchronous version*: ``aset``
 
         Replace the set of related objects::
 
@@ -207,13 +231,19 @@ Related objects reference
         dictionary and they will be evaluated once before creating any
         intermediate instance(s).
 
+        .. versionchanged:: 4.2
+
+            ``aset()`` method was added.
+
     .. note::
 
-       Note that ``add()``, ``create()``, ``remove()``, ``clear()``, and
-       ``set()`` all apply database changes immediately for all types of
-       related fields. In other words, there is no need to call ``save()``
-       on either end of the relationship.
+       Note that ``add()``, ``aadd()``, ``create()``, ``acreate()``,
+       ``remove()``, ``aremove()``, ``clear()``, ``aclear()``, ``set()``, and
+       ``aset()`` all apply database changes immediately for all types of
+       related fields. In other words, there is no need to call
+       ``save()``/``asave()`` on either end of the relationship.
 
        If you use :meth:`~django.db.models.query.QuerySet.prefetch_related`,
-       the ``add()``, ``remove()``, ``clear()``, and ``set()`` methods clear
-       the prefetched cache.
+       the ``add()``, ``aadd()``, ``remove()``, ``aremove()``, ``clear()``,
+       ``aclear()``, ``set()``, and ``aset()`` methods clear the prefetched
+       cache.

+ 5 - 0
docs/releases/4.2.txt

@@ -243,6 +243,11 @@ Models
   database, using an ``a`` prefix: :meth:`~.Model.adelete`,
   :meth:`~.Model.arefresh_from_db`, and :meth:`~.Model.asave`.
 
+* Related managers now provide asynchronous versions of methods that change a
+  set of related objects, using an ``a`` prefix: :meth:`~.RelatedManager.aadd`,
+  :meth:`~.RelatedManager.aclear`, :meth:`~.RelatedManager.aremove`, and
+  :meth:`~.RelatedManager.aset`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 5 - 1
docs/topics/async.txt

@@ -97,13 +97,17 @@ Django also supports some asynchronous model methods that use the database::
         book = Book(...)
         await book.asave(using="secondary")
 
+    async def make_book_with_tags(tags, ...):
+        book = await Book.objects.acreate(...)
+        await book.tags.aset(tags)
+
 Transactions do not yet work in async mode. If you have a piece of code that
 needs transactions behavior, we recommend you write that piece as a single
 synchronous function and call it using :func:`sync_to_async`.
 
 .. versionchanged:: 4.2
 
-    Asynchronous model interface was added.
+    Asynchronous model and related manager interfaces were added.
 
 Performance
 -----------

+ 51 - 1
tests/async/test_async_related_managers.py

@@ -1,6 +1,6 @@
 from django.test import TestCase
 
-from .models import ManyToManyModel, SimpleModel
+from .models import ManyToManyModel, RelatedModel, SimpleModel
 
 
 class AsyncRelatedManagersOperationTest(TestCase):
@@ -8,6 +8,8 @@ class AsyncRelatedManagersOperationTest(TestCase):
     def setUpTestData(cls):
         cls.mtm1 = ManyToManyModel.objects.create()
         cls.s1 = SimpleModel.objects.create(field=0)
+        cls.mtm2 = ManyToManyModel.objects.create()
+        cls.mtm2.simples.set([cls.s1])
 
     async def test_acreate(self):
         await self.mtm1.simples.acreate(field=2)
@@ -54,3 +56,51 @@ class AsyncRelatedManagersOperationTest(TestCase):
         self.assertIs(created, True)
         self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
         self.assertEqual(new_relatedmodel.simple, self.s1)
+
+    async def test_aadd(self):
+        await self.mtm1.simples.aadd(self.s1)
+        self.assertEqual(await self.mtm1.simples.aget(), self.s1)
+
+    async def test_aadd_reverse(self):
+        r1 = await RelatedModel.objects.acreate()
+        await self.s1.relatedmodel_set.aadd(r1, bulk=False)
+        self.assertEqual(await self.s1.relatedmodel_set.aget(), r1)
+
+    async def test_aremove(self):
+        self.assertEqual(await self.mtm2.simples.acount(), 1)
+        await self.mtm2.simples.aremove(self.s1)
+        self.assertEqual(await self.mtm2.simples.acount(), 0)
+
+    async def test_aremove_reverse(self):
+        r1 = await RelatedModel.objects.acreate(simple=self.s1)
+        self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
+        await self.s1.relatedmodel_set.aremove(r1)
+        self.assertEqual(await self.s1.relatedmodel_set.acount(), 0)
+
+    async def test_aset(self):
+        await self.mtm1.simples.aset([self.s1])
+        self.assertEqual(await self.mtm1.simples.aget(), self.s1)
+        await self.mtm1.simples.aset([])
+        self.assertEqual(await self.mtm1.simples.acount(), 0)
+        await self.mtm1.simples.aset([self.s1], clear=True)
+        self.assertEqual(await self.mtm1.simples.aget(), self.s1)
+
+    async def test_aset_reverse(self):
+        r1 = await RelatedModel.objects.acreate()
+        await self.s1.relatedmodel_set.aset([r1])
+        self.assertEqual(await self.s1.relatedmodel_set.aget(), r1)
+        await self.s1.relatedmodel_set.aset([])
+        self.assertEqual(await self.s1.relatedmodel_set.acount(), 0)
+        await self.s1.relatedmodel_set.aset([r1], bulk=False, clear=True)
+        self.assertEqual(await self.s1.relatedmodel_set.aget(), r1)
+
+    async def test_aclear(self):
+        self.assertEqual(await self.mtm2.simples.acount(), 1)
+        await self.mtm2.simples.aclear()
+        self.assertEqual(await self.mtm2.simples.acount(), 0)
+
+    async def test_aclear_reverse(self):
+        await RelatedModel.objects.acreate(simple=self.s1)
+        self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
+        await self.s1.relatedmodel_set.aclear(bulk=False)
+        self.assertEqual(await self.s1.relatedmodel_set.acount(), 0)

+ 27 - 0
tests/generic_relations/tests.py

@@ -324,6 +324,13 @@ class GenericRelationsTests(TestCase):
         with self.assertRaisesMessage(TypeError, msg):
             self.bacon.tags.add(self.lion)
 
+    async def test_aadd(self):
+        bacon = await Vegetable.objects.acreate(name="Bacon", is_yucky=False)
+        t1 = await TaggedItem.objects.acreate(content_object=self.quartz, tag="shiny")
+        t2 = await TaggedItem.objects.acreate(content_object=self.quartz, tag="fatty")
+        await bacon.tags.aadd(t1, t2, bulk=False)
+        self.assertEqual(await bacon.tags.acount(), 2)
+
     def test_set(self):
         bacon = Vegetable.objects.create(name="Bacon", is_yucky=False)
         fatty = bacon.tags.create(tag="fatty")
@@ -347,6 +354,16 @@ class GenericRelationsTests(TestCase):
         bacon.tags.set([], clear=True)
         self.assertSequenceEqual(bacon.tags.all(), [])
 
+    async def test_aset(self):
+        bacon = await Vegetable.objects.acreate(name="Bacon", is_yucky=False)
+        fatty = await bacon.tags.acreate(tag="fatty")
+        await bacon.tags.aset([fatty])
+        self.assertEqual(await bacon.tags.acount(), 1)
+        await bacon.tags.aset([])
+        self.assertEqual(await bacon.tags.acount(), 0)
+        await bacon.tags.aset([fatty], bulk=False, clear=True)
+        self.assertEqual(await bacon.tags.acount(), 1)
+
     def test_assign(self):
         bacon = Vegetable.objects.create(name="Bacon", is_yucky=False)
         fatty = bacon.tags.create(tag="fatty")
@@ -388,6 +405,10 @@ class GenericRelationsTests(TestCase):
             [self.hairy, self.yellow],
         )
 
+    async def test_aclear(self):
+        await self.bacon.tags.aclear()
+        self.assertEqual(await self.bacon.tags.acount(), 0)
+
     def test_remove(self):
         self.assertSequenceEqual(
             TaggedItem.objects.order_by("tag"),
@@ -400,6 +421,12 @@ class GenericRelationsTests(TestCase):
             [self.hairy, self.salty, self.yellow],
         )
 
+    async def test_aremove(self):
+        await self.bacon.tags.aremove(self.fatty)
+        self.assertEqual(await self.bacon.tags.acount(), 1)
+        await self.bacon.tags.aremove(self.salty)
+        self.assertEqual(await self.bacon.tags.acount(), 0)
+
     def test_generic_relation_related_name_default(self):
         # GenericRelation isn't usable from the reverse side by default.
         msg = (