Browse Source

Fixed #36116 -- Optimized multi-column ForwardManyToOne prefetching.

Rely on ColPairs and TupleIn which support a single column to be specified
to avoid special casing ForwardManyToOne.get_prefetch_querysets().

Thanks Jacob Walls for the report.
Simon Charette 2 months ago
parent
commit
626d77e52a

+ 14 - 16
django/db/models/fields/related_descriptors.py

@@ -74,6 +74,8 @@ from django.db import (
     transaction,
 )
 from django.db.models import Manager, Q, Window, signals
+from django.db.models.expressions import ColPairs
+from django.db.models.fields.tuple_lookups import TupleIn
 from django.db.models.functions import RowNumber
 from django.db.models.lookups import GreaterThan, LessThanOrEqual
 from django.db.models.query import QuerySet
@@ -164,23 +166,19 @@ class ForwardManyToOneDescriptor:
         rel_obj_attr = self.field.get_foreign_related_value
         instance_attr = self.field.get_local_related_value
         instances_dict = {instance_attr(inst): inst for inst in instances}
-        related_field = self.field.foreign_related_fields[0]
+        related_fields = self.field.foreign_related_fields
         remote_field = self.field.remote_field
-
-        # FIXME: This will need to be revisited when we introduce support for
-        # composite fields. In the meantime we take this practical approach to
-        # solve a regression on 1.6 when the reverse manager is hidden
-        # (related_name ends with a '+'). Refs #21410.
-        # The check for len(...) == 1 is a special case that allows the query
-        # to be join-less and smaller. Refs #21760.
-        if remote_field.hidden or len(self.field.foreign_related_fields) == 1:
-            query = {
-                "%s__in"
-                % related_field.name: {instance_attr(inst)[0] for inst in instances}
-            }
-        else:
-            query = {"%s__in" % self.field.related_query_name(): instances}
-        queryset = queryset.filter(**query)
+        queryset = queryset.filter(
+            TupleIn(
+                ColPairs(
+                    queryset.model._meta.db_table,
+                    related_fields,
+                    related_fields,
+                    self.field,
+                ),
+                list(instances_dict),
+            )
+        )
         # There can be only one object prefetched for each instance so clear
         # ordering if the query allows it without side effects.
         queryset.query.clear_ordering()

+ 1 - 1
tests/foreign_object/models/person.py

@@ -107,6 +107,6 @@ class Friendship(models.Model):
         Person,
         from_fields=["to_friend_country_id", "to_friend_id"],
         to_fields=["person_country_id", "id"],
-        related_name="to_friend",
+        related_name="to_friend+",
         on_delete=models.CASCADE,
     )

+ 36 - 3
tests/foreign_object/tests.py

@@ -4,7 +4,7 @@ import pickle
 from operator import attrgetter
 
 from django.core.exceptions import FieldError
-from django.db import models
+from django.db import connection, models
 from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
 from django.test.utils import isolate_apps
 from django.utils import translation
@@ -246,7 +246,7 @@ class MultiColumnFKTests(TestCase):
         normal_people = [m.person for m in Membership.objects.order_by("pk")]
         self.assertEqual(people, normal_people)
 
-    def test_prefetch_foreignkey_forward_works(self):
+    def test_prefetch_foreignobject_forward(self):
         Membership.objects.create(
             membership_country=self.usa, person=self.bob, group=self.cia
         )
@@ -263,7 +263,40 @@ class MultiColumnFKTests(TestCase):
         normal_people = [m.person for m in Membership.objects.order_by("pk")]
         self.assertEqual(people, normal_people)
 
-    def test_prefetch_foreignkey_reverse_works(self):
+    def test_prefetch_foreignobject_hidden_forward(self):
+        Friendship.objects.create(
+            from_friend_country=self.usa,
+            from_friend_id=self.bob.id,
+            to_friend_country_id=self.usa.id,
+            to_friend_id=self.george.id,
+        )
+        Friendship.objects.create(
+            from_friend_country=self.usa,
+            from_friend_id=self.bob.id,
+            to_friend_country_id=self.soviet_union.id,
+            to_friend_id=self.sam.id,
+        )
+        with self.assertNumQueries(2) as ctx:
+            friendships = list(
+                Friendship.objects.prefetch_related("to_friend").order_by("pk")
+            )
+        prefetch_sql = ctx[-1]["sql"]
+        # Prefetch queryset should be filtered by all foreign related fields
+        # to prevent extra rows from being eagerly fetched.
+        prefetch_where_sql = prefetch_sql.split("WHERE")[-1]
+        for to_field_name in Friendship.to_friend.field.to_fields:
+            to_field = Person._meta.get_field(to_field_name)
+            with self.subTest(to_field=to_field):
+                self.assertIn(
+                    connection.ops.quote_name(to_field.column),
+                    prefetch_where_sql,
+                )
+        self.assertNotIn(" JOIN ", prefetch_sql)
+        with self.assertNumQueries(0):
+            self.assertEqual(friendships[0].to_friend, self.george)
+            self.assertEqual(friendships[1].to_friend, self.sam)
+
+    def test_prefetch_foreignobject_reverse(self):
         Membership.objects.create(
             membership_country=self.usa, person=self.bob, group=self.cia
         )