Просмотр исходного кода

Refs #373 -- Added TupleIn subqueries.

Bendeguz Csirmaz 5 месяцев назад
Родитель
Сommit
f7601aed51
2 измененных файлов с 79 добавлено и 3 удалено
  1. 38 3
      django/db/models/fields/tuple_lookups.py
  2. 41 0
      tests/foreign_object/test_tuple_lookups.py

+ 38 - 3
django/db/models/fields/tuple_lookups.py

@@ -12,6 +12,7 @@ from django.db.models.lookups import (
     LessThan,
     LessThanOrEqual,
 )
+from django.db.models.sql import Query
 from django.db.models.sql.where import AND, OR, WhereNode
 
 
@@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
 
 class TupleIn(TupleLookupMixin, In):
     def get_prep_lookup(self):
-        self.check_rhs_is_tuple_or_list()
-        self.check_rhs_is_collection_of_tuples_or_lists()
-        self.check_rhs_elements_length_equals_lhs_length()
+        if self.rhs_is_direct_value():
+            self.check_rhs_is_tuple_or_list()
+            self.check_rhs_is_collection_of_tuples_or_lists()
+            self.check_rhs_elements_length_equals_lhs_length()
+        else:
+            self.check_rhs_is_query()
+            self.check_rhs_select_length_equals_lhs_length()
+
         return self.rhs  # skip checks from mixin
 
     def check_rhs_is_collection_of_tuples_or_lists(self):
@@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In):
                 f"must have {len_lhs} elements each"
             )
 
+    def check_rhs_is_query(self):
+        if not isinstance(self.rhs, Query):
+            lhs_str = self.get_lhs_str()
+            rhs_cls = self.rhs.__class__.__name__
+            raise ValueError(
+                f"{self.lookup_name!r} subquery lookup of {lhs_str} "
+                f"must be a Query object (received {rhs_cls!r})"
+            )
+
+    def check_rhs_select_length_equals_lhs_length(self):
+        len_rhs = len(self.rhs.select)
+        len_lhs = len(self.lhs)
+        if len_rhs != len_lhs:
+            lhs_str = self.get_lhs_str()
+            raise ValueError(
+                f"{self.lookup_name!r} subquery lookup of {lhs_str} "
+                f"must have {len_lhs} fields (received {len_rhs})"
+            )
+
     def process_rhs(self, compiler, connection):
         rhs = self.rhs
         if not rhs:
@@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In):
 
         return Tuple(*result).as_sql(compiler, connection)
 
+    def as_sql(self, compiler, connection):
+        if not self.rhs_is_direct_value():
+            return self.as_subquery(compiler, connection)
+        return super().as_sql(compiler, connection)
+
     def as_sqlite(self, compiler, connection):
         rhs = self.rhs
         if not rhs:
             raise EmptyResultSet
+        if not self.rhs_is_direct_value():
+            return self.as_subquery(compiler, connection)
 
         # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
         # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In):
 
         return root.as_sql(compiler, connection)
 
+    def as_subquery(self, compiler, connection):
+        return compiler.compile(In(self.lhs, self.rhs))
+
 
 tuple_lookups = {
     "exact": TupleExact,

+ 41 - 0
tests/foreign_object/test_tuple_lookups.py

@@ -11,6 +11,7 @@ from django.db.models.fields.tuple_lookups import (
     TupleLessThan,
     TupleLessThanOrEqual,
 )
+from django.db.models.lookups import In
 from django.test import TestCase, skipUnlessDBFeature
 
 from .models import Contact, Customer
@@ -126,6 +127,46 @@ class TupleLookupsTests(TestCase):
             (self.contact_1, self.contact_2, self.contact_5),
         )
 
+    def test_tuple_in_subquery_must_be_query(self):
+        lhs = (F("customer_code"), F("company_code"))
+        # If rhs is any non-Query object with an as_sql() function.
+        rhs = In(F("customer_code"), [1, 2, 3])
+        with self.assertRaisesMessage(
+            ValueError,
+            "'in' subquery lookup of ('customer_code', 'company_code') "
+            "must be a Query object (received 'In')",
+        ):
+            TupleIn(lhs, rhs)
+
+    def test_tuple_in_subquery_must_have_2_fields(self):
+        lhs = (F("customer_code"), F("company_code"))
+        rhs = Customer.objects.values_list("customer_id").query
+        with self.assertRaisesMessage(
+            ValueError,
+            "'in' subquery lookup of ('customer_code', 'company_code') "
+            "must have 2 fields (received 1)",
+        ):
+            TupleIn(lhs, rhs)
+
+    def test_tuple_in_subquery(self):
+        customers = Customer.objects.values_list("customer_id", "company")
+        test_cases = (
+            (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
+            (self.customer_2, (self.contact_3,)),
+            (self.customer_3, (self.contact_4,)),
+            (self.customer_4, ()),
+            (self.customer_5, (self.contact_6,)),
+        )
+
+        for customer, contacts in test_cases:
+            lhs = (F("customer_code"), F("company_code"))
+            rhs = customers.filter(id=customer.id).query
+            lookup = TupleIn(lhs, rhs)
+            qs = Contact.objects.filter(lookup).order_by("id")
+
+            with self.subTest(customer=customer.id, query=str(qs.query)):
+                self.assertSequenceEqual(qs, contacts)
+
     def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
         test_cases = (
             (1, 2, 3),