2
0
Эх сурвалжийг харах

Fixed #29725 -- Removed unnecessary join in QuerySet.count() and exists() on a many to many relation.

Co-Authored-By: Shiwei Chen <april.chen.0615@gmail.com>
ontowhee 1 жил өмнө
parent
commit
66e47ac69a

+ 49 - 4
django/db/models/fields/related_descriptors.py

@@ -75,7 +75,7 @@ from django.db import (
     router,
     transaction,
 )
-from django.db.models import Q, Window, signals
+from django.db.models import Manager, Q, Window, signals
 from django.db.models.functions import RowNumber
 from django.db.models.lookups import GreaterThan, LessThanOrEqual
 from django.db.models.query import QuerySet
@@ -1121,6 +1121,12 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
             queryset._defer_next_filter = True
             return queryset._next_is_sticky().filter(**self.core_filters)
 
+        def get_prefetch_cache(self):
+            try:
+                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
+            except (AttributeError, KeyError):
+                return None
+
         def _remove_prefetched_objects(self):
             try:
                 self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
@@ -1128,9 +1134,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
                 pass  # nothing to clear from cache
 
         def get_queryset(self):
-            try:
-                return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
-            except (AttributeError, KeyError):
+            if (cache := self.get_prefetch_cache()) is not None:
+                return cache
+            else:
                 queryset = super().get_queryset()
                 return self._apply_rel_filters(queryset)
 
@@ -1195,6 +1201,45 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
                 False,
             )
 
+        @property
+        def constrained_target(self):
+            # If the through relation's target field's foreign integrity is
+            # enforced, the query can be performed solely against the through
+            # table as the INNER JOIN'ing against target table is unnecessary.
+            if not self.target_field.db_constraint:
+                return None
+            db = router.db_for_read(self.through, instance=self.instance)
+            if not connections[db].features.supports_foreign_keys:
+                return None
+            hints = {"instance": self.instance}
+            manager = self.through._base_manager.db_manager(db, hints=hints)
+            filters = {self.source_field_name: self.instance.pk}
+            # Nullable target rows must be excluded as well as they would have
+            # been filtered out from an INNER JOIN.
+            if self.target_field.null:
+                filters["%s__isnull" % self.target_field_name] = False
+            return manager.filter(**filters)
+
+        def exists(self):
+            if (
+                superclass is Manager
+                and self.get_prefetch_cache() is None
+                and (constrained_target := self.constrained_target) is not None
+            ):
+                return constrained_target.exists()
+            else:
+                return super().exists()
+
+        def count(self):
+            if (
+                superclass is Manager
+                and self.get_prefetch_cache() is None
+                and (constrained_target := self.constrained_target) is not None
+            ):
+                return constrained_target.count()
+            else:
+                return super().count()
+
         def add(self, *objs, through_defaults=None):
             self._remove_prefetched_objects()
             db = router.db_for_write(self.through, instance=self.instance)

+ 12 - 0
tests/many_to_many/models.py

@@ -78,3 +78,15 @@ class InheritedArticleA(AbstractArticle):
 
 class InheritedArticleB(AbstractArticle):
     pass
+
+
+class NullableTargetArticle(models.Model):
+    headline = models.CharField(max_length=100)
+    publications = models.ManyToManyField(
+        Publication, through="NullablePublicationThrough"
+    )
+
+
+class NullablePublicationThrough(models.Model):
+    article = models.ForeignKey(NullableTargetArticle, models.CASCADE)
+    publication = models.ForeignKey(Publication, models.CASCADE, null=True)

+ 90 - 6
tests/many_to_many/tests.py

@@ -1,10 +1,18 @@
 from unittest import mock
 
-from django.db import transaction
+from django.db import connection, transaction
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.utils.deprecation import RemovedInDjango60Warning
 
-from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User
+from .models import (
+    Article,
+    InheritedArticleA,
+    InheritedArticleB,
+    NullablePublicationThrough,
+    NullableTargetArticle,
+    Publication,
+    User,
+)
 
 
 class ManyToManyTests(TestCase):
@@ -558,10 +566,16 @@ class ManyToManyTests(TestCase):
     def test_custom_default_manager_exists_count(self):
         a5 = Article.objects.create(headline="deleted")
         a5.publications.add(self.p2)
-        self.assertEqual(self.p2.article_set.count(), self.p2.article_set.all().count())
-        self.assertEqual(
-            self.p3.article_set.exists(), self.p3.article_set.all().exists()
-        )
+        with self.assertNumQueries(2) as ctx:
+            self.assertEqual(
+                self.p2.article_set.count(), self.p2.article_set.all().count()
+            )
+        self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+        with self.assertNumQueries(2) as ctx:
+            self.assertEqual(
+                self.p3.article_set.exists(), self.p3.article_set.all().exists()
+            )
+        self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
 
     def test_get_prefetch_queryset_warning(self):
         articles = Article.objects.all()
@@ -582,3 +596,73 @@ class ManyToManyTests(TestCase):
                 instances=articles,
                 querysets=[Publication.objects.all(), Publication.objects.all()],
             )
+
+
+class ManyToManyQueryTests(TestCase):
+    """
+    SQL is optimized to reference the through table without joining against the
+    related table when using count() and exists() functions on a queryset for
+    many to many relations. The optimization applies to the case where there
+    are no filters.
+    """
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.article = Article.objects.create(
+            headline="Django lets you build Web apps easily"
+        )
+        cls.nullable_target_article = NullableTargetArticle.objects.create(
+            headline="The python is good"
+        )
+        NullablePublicationThrough.objects.create(
+            article=cls.nullable_target_article, publication=None
+        )
+
+    @skipUnlessDBFeature("supports_foreign_keys")
+    def test_count_join_optimization(self):
+        with self.assertNumQueries(1) as ctx:
+            self.article.publications.count()
+        self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+
+        with self.assertNumQueries(1) as ctx:
+            self.article.publications.count()
+        self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+        self.assertEqual(self.nullable_target_article.publications.count(), 0)
+
+    def test_count_join_optimization_disabled(self):
+        with (
+            mock.patch.object(connection.features, "supports_foreign_keys", False),
+            self.assertNumQueries(1) as ctx,
+        ):
+            self.article.publications.count()
+
+        self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+
+    @skipUnlessDBFeature("supports_foreign_keys")
+    def test_exists_join_optimization(self):
+        with self.assertNumQueries(1) as ctx:
+            self.article.publications.exists()
+        self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+
+        self.article.publications.prefetch_related()
+        with self.assertNumQueries(1) as ctx:
+            self.article.publications.exists()
+        self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
+        self.assertIs(self.nullable_target_article.publications.exists(), False)
+
+    def test_exists_join_optimization_disabled(self):
+        with (
+            mock.patch.object(connection.features, "supports_foreign_keys", False),
+            self.assertNumQueries(1) as ctx,
+        ):
+            self.article.publications.exists()
+
+        self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
+
+    def test_prefetch_related_no_queries_optimization_disabled(self):
+        qs = Article.objects.prefetch_related("publications")
+        article = qs.get()
+        with self.assertNumQueries(0):
+            article.publications.count()
+        with self.assertNumQueries(0):
+            article.publications.exists()