|
@@ -47,7 +47,8 @@ class TupleLookupMixin:
|
|
|
self.check_rhs_is_tuple_or_list()
|
|
|
self.check_rhs_length_equals_lhs_length()
|
|
|
else:
|
|
|
- self.check_rhs_is_outer_ref()
|
|
|
+ self.check_rhs_is_supported_expression()
|
|
|
+ super().get_prep_lookup()
|
|
|
return self.rhs
|
|
|
|
|
|
def check_rhs_is_tuple_or_list(self):
|
|
@@ -65,13 +66,13 @@ class TupleLookupMixin:
|
|
|
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
|
|
|
)
|
|
|
|
|
|
- def check_rhs_is_outer_ref(self):
|
|
|
- if not isinstance(self.rhs, ResolvedOuterRef):
|
|
|
+ def check_rhs_is_supported_expression(self):
|
|
|
+ if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
|
|
|
lhs_str = self.get_lhs_str()
|
|
|
rhs_cls = self.rhs.__class__.__name__
|
|
|
raise ValueError(
|
|
|
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
|
- f"only supports OuterRef objects (received {rhs_cls!r})"
|
|
|
+ f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
|
|
|
)
|
|
|
|
|
|
def get_lhs_str(self):
|
|
@@ -101,11 +102,14 @@ class TupleLookupMixin:
|
|
|
return compiler.compile(Tuple(*args))
|
|
|
else:
|
|
|
sql, params = compiler.compile(self.rhs)
|
|
|
- if not isinstance(self.rhs, ColPairs):
|
|
|
+ if isinstance(self.rhs, ColPairs):
|
|
|
+ return "(%s)" % sql, params
|
|
|
+ elif isinstance(self.rhs, Query):
|
|
|
+ return super().process_rhs(compiler, connection)
|
|
|
+ else:
|
|
|
raise ValueError(
|
|
|
"Composite field lookups only work with composite expressions."
|
|
|
)
|
|
|
- return "(%s)" % sql, params
|
|
|
|
|
|
def get_fallback_sql(self, compiler, connection):
|
|
|
raise NotImplementedError(
|
|
@@ -121,6 +125,8 @@ class TupleLookupMixin:
|
|
|
|
|
|
class TupleExact(TupleLookupMixin, Exact):
|
|
|
def get_fallback_sql(self, compiler, connection):
|
|
|
+ if isinstance(self.rhs, Query):
|
|
|
+ return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
|
|
# Process right-hand-side to trigger sanitization.
|
|
|
self.process_rhs(compiler, connection)
|
|
|
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
|
@@ -273,7 +279,7 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
self.check_rhs_elements_length_equals_lhs_length()
|
|
|
else:
|
|
|
self.check_rhs_is_query()
|
|
|
- self.check_rhs_select_length_equals_lhs_length()
|
|
|
+ super(TupleLookupMixin, self).get_prep_lookup()
|
|
|
|
|
|
return self.rhs # skip checks from mixin
|
|
|
|
|
@@ -303,19 +309,10 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
f"must be a Query object (received {rhs_cls!r})"
|
|
|
)
|
|
|
|
|
|
- def check_rhs_select_length_equals_lhs_length(self):
|
|
|
- len_rhs = len(self.rhs.select)
|
|
|
- if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
|
|
|
- len_rhs = len(self.rhs.select[0])
|
|
|
- len_lhs = len(self.lhs)
|
|
|
- if len_rhs != len_lhs:
|
|
|
- lhs_str = self.get_lhs_str()
|
|
|
- raise ValueError(
|
|
|
- f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
|
- f"must have {len_lhs} fields (received {len_rhs})"
|
|
|
- )
|
|
|
-
|
|
|
def process_rhs(self, compiler, connection):
|
|
|
+ if not self.rhs_is_direct_value():
|
|
|
+ return super(TupleLookupMixin, self).process_rhs(compiler, connection)
|
|
|
+
|
|
|
rhs = self.rhs
|
|
|
if not rhs:
|
|
|
raise EmptyResultSet
|
|
@@ -337,19 +334,12 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
|
|
|
return compiler.compile(Tuple(*result))
|
|
|
|
|
|
- def as_subquery_sql(self, compiler, connection):
|
|
|
- lhs = self.lhs
|
|
|
- rhs = self.rhs
|
|
|
- if isinstance(lhs, ColPairs):
|
|
|
- rhs = rhs.clone()
|
|
|
- rhs.set_values([source.name for source in lhs.sources])
|
|
|
- lhs = Tuple(lhs)
|
|
|
- return compiler.compile(In(lhs, rhs))
|
|
|
-
|
|
|
def get_fallback_sql(self, compiler, connection):
|
|
|
rhs = self.rhs
|
|
|
if not rhs:
|
|
|
raise EmptyResultSet
|
|
|
+ if not self.rhs_is_direct_value():
|
|
|
+ return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
|
|
|
|
|
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
|
|
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
|
@@ -362,11 +352,6 @@ class TupleIn(TupleLookupMixin, In):
|
|
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
- def as_sql(self, compiler, connection):
|
|
|
- if not self.rhs_is_direct_value():
|
|
|
- return self.as_subquery_sql(compiler, connection)
|
|
|
- return super().as_sql(compiler, connection)
|
|
|
-
|
|
|
|
|
|
tuple_lookups = {
|
|
|
"exact": TupleExact,
|