tuple_lookups.py 11 KB


  1. import itertools
  2. from django.core.exceptions import EmptyResultSet
  3. from django.db.models import Field
  4. from django.db.models.expressions import ColPairs, Func, Value
  5. from django.db.models.lookups import (
  6. Exact,
  7. GreaterThan,
  8. GreaterThanOrEqual,
  9. In,
  10. IsNull,
  11. LessThan,
  12. LessThanOrEqual,
  13. )
  14. from django.db.models.sql import Query
  15. from django.db.models.sql.where import AND, OR, WhereNode
  16. class Tuple(Func):
  17. allows_composite_expressions = True
  18. function = ""
  19. output_field = Field()
  20. def __len__(self):
  21. return len(self.source_expressions)
  22. def __iter__(self):
  23. return iter(self.source_expressions)
  24. class TupleLookupMixin:
  25. allows_composite_expressions = True
  26. def get_prep_lookup(self):
  27. self.check_rhs_is_tuple_or_list()
  28. self.check_rhs_length_equals_lhs_length()
  29. return self.rhs
  30. def check_rhs_is_tuple_or_list(self):
  31. if not isinstance(self.rhs, (tuple, list)):
  32. lhs_str = self.get_lhs_str()
  33. raise ValueError(
  34. f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
  35. )
  36. def check_rhs_length_equals_lhs_length(self):
  37. len_lhs = len(self.lhs)
  38. if len_lhs != len(self.rhs):
  39. lhs_str = self.get_lhs_str()
  40. raise ValueError(
  41. f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
  42. )
  43. def get_lhs_str(self):
  44. if isinstance(self.lhs, ColPairs):
  45. return repr(self.lhs.field.name)
  46. else:
  47. names = ", ".join(repr(f.name) for f in self.lhs)
  48. return f"({names})"
  49. def get_prep_lhs(self):
  50. if isinstance(self.lhs, (tuple, list)):
  51. return Tuple(*self.lhs)
  52. return super().get_prep_lhs()
  53. def process_lhs(self, compiler, connection, lhs=None):
  54. sql, params = super().process_lhs(compiler, connection, lhs)
  55. if not isinstance(self.lhs, Tuple):
  56. sql = f"({sql})"
  57. return sql, params
  58. def process_rhs(self, compiler, connection):
  59. values = [
  60. Value(val, output_field=col.output_field)
  61. for col, val in zip(self.lhs, self.rhs)
  62. ]
  63. return Tuple(*values).as_sql(compiler, connection)
  64. class TupleExact(TupleLookupMixin, Exact):
  65. def as_oracle(self, compiler, connection):
  66. # e.g.: (a, b, c) == (x, y, z) as SQL:
  67. # WHERE a = x AND b = y AND c = z
  68. lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
  69. root = WhereNode(lookups, connector=AND)
  70. return root.as_sql(compiler, connection)
  71. class TupleIsNull(TupleLookupMixin, IsNull):
  72. def get_prep_lookup(self):
  73. rhs = self.rhs
  74. if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
  75. rhs = rhs[0]
  76. if isinstance(rhs, bool):
  77. return rhs
  78. raise ValueError(
  79. "The QuerySet value for an isnull lookup must be True or False."
  80. )
  81. def as_sql(self, compiler, connection):
  82. # e.g.: (a, b, c) is None as SQL:
  83. # WHERE a IS NULL OR b IS NULL OR c IS NULL
  84. # e.g.: (a, b, c) is not None as SQL:
  85. # WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
  86. rhs = self.rhs
  87. lookups = [IsNull(col, rhs) for col in self.lhs]
  88. root = WhereNode(lookups, connector=OR if rhs else AND)
  89. return root.as_sql(compiler, connection)
  90. class TupleGreaterThan(TupleLookupMixin, GreaterThan):
  91. def as_oracle(self, compiler, connection):
  92. # e.g.: (a, b, c) > (x, y, z) as SQL:
  93. # WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
  94. lookups = itertools.cycle([GreaterThan, Exact])
  95. connectors = itertools.cycle([OR, AND])
  96. cols_list = [col for col in self.lhs for _ in range(2)]
  97. vals_list = [val for val in self.rhs for _ in range(2)]
  98. cols_iter = iter(cols_list[:-1])
  99. vals_iter = iter(vals_list[:-1])
  100. col = next(cols_iter)
  101. val = next(vals_iter)
  102. lookup = next(lookups)
  103. connector = next(connectors)
  104. root = node = WhereNode([lookup(col, val)], connector=connector)
  105. for col, val in zip(cols_iter, vals_iter):
  106. lookup = next(lookups)
  107. connector = next(connectors)
  108. child = WhereNode([lookup(col, val)], connector=connector)
  109. node.children.append(child)
  110. node = child
  111. return root.as_sql(compiler, connection)
  112. class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
  113. def as_oracle(self, compiler, connection):
  114. # e.g.: (a, b, c) >= (x, y, z) as SQL:
  115. # WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
  116. lookups = itertools.cycle([GreaterThan, Exact])
  117. connectors = itertools.cycle([OR, AND])
  118. cols_list = [col for col in self.lhs for _ in range(2)]
  119. vals_list = [val for val in self.rhs for _ in range(2)]
  120. cols_iter = iter(cols_list)
  121. vals_iter = iter(vals_list)
  122. col = next(cols_iter)
  123. val = next(vals_iter)
  124. lookup = next(lookups)
  125. connector = next(connectors)
  126. root = node = WhereNode([lookup(col, val)], connector=connector)
  127. for col, val in zip(cols_iter, vals_iter):
  128. lookup = next(lookups)
  129. connector = next(connectors)
  130. child = WhereNode([lookup(col, val)], connector=connector)
  131. node.children.append(child)
  132. node = child
  133. return root.as_sql(compiler, connection)
  134. class TupleLessThan(TupleLookupMixin, LessThan):
  135. def as_oracle(self, compiler, connection):
  136. # e.g.: (a, b, c) < (x, y, z) as SQL:
  137. # WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
  138. lookups = itertools.cycle([LessThan, Exact])
  139. connectors = itertools.cycle([OR, AND])
  140. cols_list = [col for col in self.lhs for _ in range(2)]
  141. vals_list = [val for val in self.rhs for _ in range(2)]
  142. cols_iter = iter(cols_list[:-1])
  143. vals_iter = iter(vals_list[:-1])
  144. col = next(cols_iter)
  145. val = next(vals_iter)
  146. lookup = next(lookups)
  147. connector = next(connectors)
  148. root = node = WhereNode([lookup(col, val)], connector=connector)
  149. for col, val in zip(cols_iter, vals_iter):
  150. lookup = next(lookups)
  151. connector = next(connectors)
  152. child = WhereNode([lookup(col, val)], connector=connector)
  153. node.children.append(child)
  154. node = child
  155. return root.as_sql(compiler, connection)
  156. class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
  157. def as_oracle(self, compiler, connection):
  158. # e.g.: (a, b, c) <= (x, y, z) as SQL:
  159. # WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
  160. lookups = itertools.cycle([LessThan, Exact])
  161. connectors = itertools.cycle([OR, AND])
  162. cols_list = [col for col in self.lhs for _ in range(2)]
  163. vals_list = [val for val in self.rhs for _ in range(2)]
  164. cols_iter = iter(cols_list)
  165. vals_iter = iter(vals_list)
  166. col = next(cols_iter)
  167. val = next(vals_iter)
  168. lookup = next(lookups)
  169. connector = next(connectors)
  170. root = node = WhereNode([lookup(col, val)], connector=connector)
  171. for col, val in zip(cols_iter, vals_iter):
  172. lookup = next(lookups)
  173. connector = next(connectors)
  174. child = WhereNode([lookup(col, val)], connector=connector)
  175. node.children.append(child)
  176. node = child
  177. return root.as_sql(compiler, connection)
  178. class TupleIn(TupleLookupMixin, In):
  179. def get_prep_lookup(self):
  180. if self.rhs_is_direct_value():
  181. self.check_rhs_is_tuple_or_list()
  182. self.check_rhs_is_collection_of_tuples_or_lists()
  183. self.check_rhs_elements_length_equals_lhs_length()
  184. else:
  185. self.check_rhs_is_query()
  186. self.check_rhs_select_length_equals_lhs_length()
  187. return self.rhs # skip checks from mixin
  188. def check_rhs_is_collection_of_tuples_or_lists(self):
  189. if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
  190. lhs_str = self.get_lhs_str()
  191. raise ValueError(
  192. f"{self.lookup_name!r} lookup of {lhs_str} "
  193. "must be a collection of tuples or lists"
  194. )
  195. def check_rhs_elements_length_equals_lhs_length(self):
  196. len_lhs = len(self.lhs)
  197. if not all(len_lhs == len(vals) for vals in self.rhs):
  198. lhs_str = self.get_lhs_str()
  199. raise ValueError(
  200. f"{self.lookup_name!r} lookup of {lhs_str} "
  201. f"must have {len_lhs} elements each"
  202. )
  203. def check_rhs_is_query(self):
  204. if not isinstance(self.rhs, Query):
  205. lhs_str = self.get_lhs_str()
  206. rhs_cls = self.rhs.__class__.__name__
  207. raise ValueError(
  208. f"{self.lookup_name!r} subquery lookup of {lhs_str} "
  209. f"must be a Query object (received {rhs_cls!r})"
  210. )
  211. def check_rhs_select_length_equals_lhs_length(self):
  212. len_rhs = len(self.rhs.select)
  213. if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
  214. len_rhs = len(self.rhs.select[0])
  215. len_lhs = len(self.lhs)
  216. if len_rhs != len_lhs:
  217. lhs_str = self.get_lhs_str()
  218. raise ValueError(
  219. f"{self.lookup_name!r} subquery lookup of {lhs_str} "
  220. f"must have {len_lhs} fields (received {len_rhs})"
  221. )
  222. def process_rhs(self, compiler, connection):
  223. rhs = self.rhs
  224. if not rhs:
  225. raise EmptyResultSet
  226. # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
  227. # WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
  228. result = []
  229. lhs = self.lhs
  230. for vals in rhs:
  231. result.append(
  232. Tuple(
  233. *[
  234. Value(val, output_field=col.output_field)
  235. for col, val in zip(lhs, vals)
  236. ]
  237. )
  238. )
  239. return Tuple(*result).as_sql(compiler, connection)
  240. def as_sql(self, compiler, connection):
  241. if not self.rhs_is_direct_value():
  242. return self.as_subquery(compiler, connection)
  243. return super().as_sql(compiler, connection)
  244. def as_sqlite(self, compiler, connection):
  245. rhs = self.rhs
  246. if not rhs:
  247. raise EmptyResultSet
  248. if not self.rhs_is_direct_value():
  249. return self.as_subquery(compiler, connection)
  250. # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
  251. # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
  252. root = WhereNode([], connector=OR)
  253. lhs = self.lhs
  254. for vals in rhs:
  255. lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
  256. root.children.append(WhereNode(lookups, connector=AND))
  257. return root.as_sql(compiler, connection)
  258. def as_subquery(self, compiler, connection):
  259. lhs = self.lhs
  260. rhs = self.rhs
  261. if isinstance(lhs, ColPairs):
  262. rhs = rhs.clone()
  263. rhs.set_values([source.name for source in lhs.sources])
  264. lhs = Tuple(lhs)
  265. return compiler.compile(In(lhs, rhs))
  266. tuple_lookups = {
  267. "exact": TupleExact,
  268. "gt": TupleGreaterThan,
  269. "gte": TupleGreaterThanOrEqual,
  270. "lt": TupleLessThan,
  271. "lte": TupleLessThanOrEqual,
  272. "in": TupleIn,
  273. "isnull": TupleIsNull,
  274. }