|
@@ -12,6 +12,7 @@ from django.db.models.lookups import (
|
|
|
LessThan,
|
|
|
LessThanOrEqual,
|
|
|
)
|
|
|
+from django.db.models.sql import Query
|
|
|
from django.db.models.sql.where import AND, OR, WhereNode
|
|
|
|
|
|
|
|
@@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
|
|
|
|
|
class TupleIn(TupleLookupMixin, In):
|
|
|
def get_prep_lookup(self):
|
|
|
- self.check_rhs_is_tuple_or_list()
|
|
|
- self.check_rhs_is_collection_of_tuples_or_lists()
|
|
|
- self.check_rhs_elements_length_equals_lhs_length()
|
|
|
+ if self.rhs_is_direct_value():
|
|
|
+ self.check_rhs_is_tuple_or_list()
|
|
|
+ self.check_rhs_is_collection_of_tuples_or_lists()
|
|
|
+ self.check_rhs_elements_length_equals_lhs_length()
|
|
|
+ else:
|
|
|
+ self.check_rhs_is_query()
|
|
|
+ self.check_rhs_select_length_equals_lhs_length()
|
|
|
+
|
|
|
return self.rhs # skip checks from mixin
|
|
|
|
|
|
def check_rhs_is_collection_of_tuples_or_lists(self):
|
|
@@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
f"must have {len_lhs} elements each"
|
|
|
)
|
|
|
|
|
|
+ def check_rhs_is_query(self):
|
|
|
+ if not isinstance(self.rhs, Query):
|
|
|
+ lhs_str = self.get_lhs_str()
|
|
|
+ rhs_cls = self.rhs.__class__.__name__
|
|
|
+ raise ValueError(
|
|
|
+ f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
|
+ f"must be a Query object (received {rhs_cls!r})"
|
|
|
+ )
|
|
|
+
|
|
|
+ def check_rhs_select_length_equals_lhs_length(self):
|
|
|
+ len_rhs = len(self.rhs.select)
|
|
|
+ len_lhs = len(self.lhs)
|
|
|
+ if len_rhs != len_lhs:
|
|
|
+ lhs_str = self.get_lhs_str()
|
|
|
+ raise ValueError(
|
|
|
+ f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
|
+ f"must have {len_lhs} fields (received {len_rhs})"
|
|
|
+ )
|
|
|
+
|
|
|
def process_rhs(self, compiler, connection):
|
|
|
rhs = self.rhs
|
|
|
if not rhs:
|
|
@@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
|
|
|
return Tuple(*result).as_sql(compiler, connection)
|
|
|
|
|
|
+ def as_sql(self, compiler, connection):
|
|
|
+ if not self.rhs_is_direct_value():
|
|
|
+ return self.as_subquery(compiler, connection)
|
|
|
+ return super().as_sql(compiler, connection)
|
|
|
+
|
|
|
def as_sqlite(self, compiler, connection):
|
|
|
rhs = self.rhs
|
|
|
if not rhs:
|
|
|
raise EmptyResultSet
|
|
|
+ if not self.rhs_is_direct_value():
|
|
|
+ return self.as_subquery(compiler, connection)
|
|
|
|
|
|
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
|
|
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
|
@@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
+ def as_subquery(self, compiler, connection):
|
|
|
+ return compiler.compile(In(self.lhs, self.rhs))
|
|
|
+
|
|
|
|
|
|
tuple_lookups = {
|
|
|
"exact": TupleExact,
|