test_tuple_lookups.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. import itertools
  2. from django.db import NotSupportedError
  3. from django.db.models import F
  4. from django.db.models.fields.tuple_lookups import (
  5. TupleExact,
  6. TupleGreaterThan,
  7. TupleGreaterThanOrEqual,
  8. TupleIn,
  9. TupleIsNull,
  10. TupleLessThan,
  11. TupleLessThanOrEqual,
  12. )
  13. from django.db.models.lookups import In
  14. from django.test import TestCase, skipUnlessDBFeature
  15. from .models import Contact, Customer
  16. class TupleLookupsTests(TestCase):
  17. @classmethod
  18. def setUpTestData(cls):
  19. super().setUpTestData()
  20. cls.customer_1 = Customer.objects.create(customer_id=1, company="a")
  21. cls.customer_2 = Customer.objects.create(customer_id=1, company="b")
  22. cls.customer_3 = Customer.objects.create(customer_id=2, company="c")
  23. cls.customer_4 = Customer.objects.create(customer_id=3, company="d")
  24. cls.customer_5 = Customer.objects.create(customer_id=1, company="e")
  25. cls.contact_1 = Contact.objects.create(customer=cls.customer_1)
  26. cls.contact_2 = Contact.objects.create(customer=cls.customer_1)
  27. cls.contact_3 = Contact.objects.create(customer=cls.customer_2)
  28. cls.contact_4 = Contact.objects.create(customer=cls.customer_3)
  29. cls.contact_5 = Contact.objects.create(customer=cls.customer_1)
  30. cls.contact_6 = Contact.objects.create(customer=cls.customer_5)
  31. def test_exact(self):
  32. test_cases = (
  33. (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
  34. (self.customer_2, (self.contact_3,)),
  35. (self.customer_3, (self.contact_4,)),
  36. (self.customer_4, ()),
  37. (self.customer_5, (self.contact_6,)),
  38. )
  39. for customer, contacts in test_cases:
  40. with self.subTest(
  41. "filter(customer=customer)",
  42. customer=customer,
  43. contacts=contacts,
  44. ):
  45. self.assertSequenceEqual(
  46. Contact.objects.filter(customer=customer).order_by("id"), contacts
  47. )
  48. with self.subTest(
  49. "filter(TupleExact)",
  50. customer=customer,
  51. contacts=contacts,
  52. ):
  53. lhs = (F("customer_code"), F("company_code"))
  54. rhs = (customer.customer_id, customer.company)
  55. lookup = TupleExact(lhs, rhs)
  56. self.assertSequenceEqual(
  57. Contact.objects.filter(lookup).order_by("id"), contacts
  58. )
  59. def test_exact_subquery(self):
  60. with self.assertRaisesMessage(
  61. NotSupportedError, "'exact' doesn't support multi-column subqueries."
  62. ):
  63. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  64. self.assertSequenceEqual(
  65. Contact.objects.filter(customer=subquery).order_by("id"), ()
  66. )
  67. def test_in(self):
  68. cust_1, cust_2, cust_3, cust_4, cust_5 = (
  69. self.customer_1,
  70. self.customer_2,
  71. self.customer_3,
  72. self.customer_4,
  73. self.customer_5,
  74. )
  75. c1, c2, c3, c4, c5, c6 = (
  76. self.contact_1,
  77. self.contact_2,
  78. self.contact_3,
  79. self.contact_4,
  80. self.contact_5,
  81. self.contact_6,
  82. )
  83. test_cases = (
  84. ((), ()),
  85. ((cust_1,), (c1, c2, c5)),
  86. ((cust_1, cust_2), (c1, c2, c3, c5)),
  87. ((cust_1, cust_2, cust_3), (c1, c2, c3, c4, c5)),
  88. ((cust_1, cust_2, cust_3, cust_4), (c1, c2, c3, c4, c5)),
  89. ((cust_1, cust_2, cust_3, cust_4, cust_5), (c1, c2, c3, c4, c5, c6)),
  90. )
  91. for customers, contacts in test_cases:
  92. with self.subTest(
  93. "filter(customer__in=customers)",
  94. customers=customers,
  95. contacts=contacts,
  96. ):
  97. self.assertSequenceEqual(
  98. Contact.objects.filter(customer__in=customers).order_by("id"),
  99. contacts,
  100. )
  101. with self.subTest(
  102. "filter(TupleIn)",
  103. customers=customers,
  104. contacts=contacts,
  105. ):
  106. lhs = (F("customer_code"), F("company_code"))
  107. rhs = [(c.customer_id, c.company) for c in customers]
  108. lookup = TupleIn(lhs, rhs)
  109. self.assertSequenceEqual(
  110. Contact.objects.filter(lookup).order_by("id"), contacts
  111. )
  112. @skipUnlessDBFeature("allow_sliced_subqueries_with_in")
  113. def test_in_subquery(self):
  114. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  115. self.assertSequenceEqual(
  116. Contact.objects.filter(customer__in=subquery).order_by("id"),
  117. (self.contact_1, self.contact_2, self.contact_5),
  118. )
  119. def test_tuple_in_subquery_must_be_query(self):
  120. lhs = (F("customer_code"), F("company_code"))
  121. # If rhs is any non-Query object with an as_sql() function.
  122. rhs = In(F("customer_code"), [1, 2, 3])
  123. with self.assertRaisesMessage(
  124. ValueError,
  125. "'in' subquery lookup of ('customer_code', 'company_code') "
  126. "must be a Query object (received 'In')",
  127. ):
  128. TupleIn(lhs, rhs)
  129. def test_tuple_in_subquery_must_have_2_fields(self):
  130. lhs = (F("customer_code"), F("company_code"))
  131. rhs = Customer.objects.values_list("customer_id").query
  132. with self.assertRaisesMessage(
  133. ValueError,
  134. "'in' subquery lookup of ('customer_code', 'company_code') "
  135. "must have 2 fields (received 1)",
  136. ):
  137. TupleIn(lhs, rhs)
  138. def test_tuple_in_subquery(self):
  139. customers = Customer.objects.values_list("customer_id", "company")
  140. test_cases = (
  141. (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
  142. (self.customer_2, (self.contact_3,)),
  143. (self.customer_3, (self.contact_4,)),
  144. (self.customer_4, ()),
  145. (self.customer_5, (self.contact_6,)),
  146. )
  147. for customer, contacts in test_cases:
  148. lhs = (F("customer_code"), F("company_code"))
  149. rhs = customers.filter(id=customer.id).query
  150. lookup = TupleIn(lhs, rhs)
  151. qs = Contact.objects.filter(lookup).order_by("id")
  152. with self.subTest(customer=customer.id, query=str(qs.query)):
  153. self.assertSequenceEqual(qs, contacts)
  154. def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
  155. test_cases = (
  156. (1, 2, 3),
  157. ((1, 2), (3, 4), None),
  158. )
  159. for rhs in test_cases:
  160. with self.subTest(rhs=rhs):
  161. with self.assertRaisesMessage(
  162. ValueError,
  163. "'in' lookup of ('customer_code', 'company_code') "
  164. "must be a collection of tuples or lists",
  165. ):
  166. TupleIn((F("customer_code"), F("company_code")), rhs)
  167. def test_tuple_in_rhs_must_have_2_elements_each(self):
  168. test_cases = (
  169. ((),),
  170. ((1,),),
  171. ((1, 2, 3),),
  172. )
  173. for rhs in test_cases:
  174. with self.subTest(rhs=rhs):
  175. with self.assertRaisesMessage(
  176. ValueError,
  177. "'in' lookup of ('customer_code', 'company_code') "
  178. "must have 2 elements each",
  179. ):
  180. TupleIn((F("customer_code"), F("company_code")), rhs)
  181. def test_lt(self):
  182. c1, c2, c3, c4, c5, c6 = (
  183. self.contact_1,
  184. self.contact_2,
  185. self.contact_3,
  186. self.contact_4,
  187. self.contact_5,
  188. self.contact_6,
  189. )
  190. test_cases = (
  191. (self.customer_1, ()),
  192. (self.customer_2, (c1, c2, c5)),
  193. (self.customer_5, (c1, c2, c3, c5)),
  194. (self.customer_3, (c1, c2, c3, c5, c6)),
  195. (self.customer_4, (c1, c2, c3, c4, c5, c6)),
  196. )
  197. for customer, contacts in test_cases:
  198. with self.subTest(
  199. "filter(customer__lt=customer)",
  200. customer=customer,
  201. contacts=contacts,
  202. ):
  203. self.assertSequenceEqual(
  204. Contact.objects.filter(customer__lt=customer).order_by("id"),
  205. contacts,
  206. )
  207. with self.subTest(
  208. "filter(TupleLessThan)",
  209. customer=customer,
  210. contacts=contacts,
  211. ):
  212. lhs = (F("customer_code"), F("company_code"))
  213. rhs = (customer.customer_id, customer.company)
  214. lookup = TupleLessThan(lhs, rhs)
  215. self.assertSequenceEqual(
  216. Contact.objects.filter(lookup).order_by("id"), contacts
  217. )
  218. def test_lt_subquery(self):
  219. with self.assertRaisesMessage(
  220. NotSupportedError, "'lt' doesn't support multi-column subqueries."
  221. ):
  222. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  223. self.assertSequenceEqual(
  224. Contact.objects.filter(customer__lt=subquery).order_by("id"), ()
  225. )
  226. def test_lte(self):
  227. c1, c2, c3, c4, c5, c6 = (
  228. self.contact_1,
  229. self.contact_2,
  230. self.contact_3,
  231. self.contact_4,
  232. self.contact_5,
  233. self.contact_6,
  234. )
  235. test_cases = (
  236. (self.customer_1, (c1, c2, c5)),
  237. (self.customer_2, (c1, c2, c3, c5)),
  238. (self.customer_5, (c1, c2, c3, c5, c6)),
  239. (self.customer_3, (c1, c2, c3, c4, c5, c6)),
  240. (self.customer_4, (c1, c2, c3, c4, c5, c6)),
  241. )
  242. for customer, contacts in test_cases:
  243. with self.subTest(
  244. "filter(customer__lte=customer)",
  245. customer=customer,
  246. contacts=contacts,
  247. ):
  248. self.assertSequenceEqual(
  249. Contact.objects.filter(customer__lte=customer).order_by("id"),
  250. contacts,
  251. )
  252. with self.subTest(
  253. "filter(TupleLessThanOrEqual)",
  254. customer=customer,
  255. contacts=contacts,
  256. ):
  257. lhs = (F("customer_code"), F("company_code"))
  258. rhs = (customer.customer_id, customer.company)
  259. lookup = TupleLessThanOrEqual(lhs, rhs)
  260. self.assertSequenceEqual(
  261. Contact.objects.filter(lookup).order_by("id"), contacts
  262. )
  263. def test_lte_subquery(self):
  264. with self.assertRaisesMessage(
  265. NotSupportedError, "'lte' doesn't support multi-column subqueries."
  266. ):
  267. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  268. self.assertSequenceEqual(
  269. Contact.objects.filter(customer__lte=subquery).order_by("id"), ()
  270. )
  271. def test_gt(self):
  272. test_cases = (
  273. (self.customer_1, (self.contact_3, self.contact_4, self.contact_6)),
  274. (self.customer_2, (self.contact_4, self.contact_6)),
  275. (self.customer_5, (self.contact_4,)),
  276. (self.customer_3, ()),
  277. (self.customer_4, ()),
  278. )
  279. for customer, contacts in test_cases:
  280. with self.subTest(
  281. "filter(customer__gt=customer)",
  282. customer=customer,
  283. contacts=contacts,
  284. ):
  285. self.assertSequenceEqual(
  286. Contact.objects.filter(customer__gt=customer).order_by("id"),
  287. contacts,
  288. )
  289. with self.subTest(
  290. "filter(TupleGreaterThan)",
  291. customer=customer,
  292. contacts=contacts,
  293. ):
  294. lhs = (F("customer_code"), F("company_code"))
  295. rhs = (customer.customer_id, customer.company)
  296. lookup = TupleGreaterThan(lhs, rhs)
  297. self.assertSequenceEqual(
  298. Contact.objects.filter(lookup).order_by("id"), contacts
  299. )
  300. def test_gt_subquery(self):
  301. with self.assertRaisesMessage(
  302. NotSupportedError, "'gt' doesn't support multi-column subqueries."
  303. ):
  304. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  305. self.assertSequenceEqual(
  306. Contact.objects.filter(customer__gt=subquery).order_by("id"), ()
  307. )
  308. def test_gte(self):
  309. c1, c2, c3, c4, c5, c6 = (
  310. self.contact_1,
  311. self.contact_2,
  312. self.contact_3,
  313. self.contact_4,
  314. self.contact_5,
  315. self.contact_6,
  316. )
  317. test_cases = (
  318. (self.customer_1, (c1, c2, c3, c4, c5, c6)),
  319. (self.customer_2, (c3, c4, c6)),
  320. (self.customer_5, (c4, c6)),
  321. (self.customer_3, (c4,)),
  322. (self.customer_4, ()),
  323. )
  324. for customer, contacts in test_cases:
  325. with self.subTest(
  326. "filter(customer__gte=customer)",
  327. customer=customer,
  328. contacts=contacts,
  329. ):
  330. self.assertSequenceEqual(
  331. Contact.objects.filter(customer__gte=customer).order_by("pk"),
  332. contacts,
  333. )
  334. with self.subTest(
  335. "filter(TupleGreaterThanOrEqual)",
  336. customer=customer,
  337. contacts=contacts,
  338. ):
  339. lhs = (F("customer_code"), F("company_code"))
  340. rhs = (customer.customer_id, customer.company)
  341. lookup = TupleGreaterThanOrEqual(lhs, rhs)
  342. self.assertSequenceEqual(
  343. Contact.objects.filter(lookup).order_by("id"), contacts
  344. )
  345. def test_gte_subquery(self):
  346. with self.assertRaisesMessage(
  347. NotSupportedError, "'gte' doesn't support multi-column subqueries."
  348. ):
  349. subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
  350. self.assertSequenceEqual(
  351. Contact.objects.filter(customer__gte=subquery).order_by("id"), ()
  352. )
  353. def test_isnull(self):
  354. contacts = (
  355. self.contact_1,
  356. self.contact_2,
  357. self.contact_3,
  358. self.contact_4,
  359. self.contact_5,
  360. self.contact_6,
  361. )
  362. with self.subTest("filter(customer__isnull=True)"):
  363. self.assertSequenceEqual(
  364. Contact.objects.filter(customer__isnull=True).order_by("id"),
  365. (),
  366. )
  367. with self.subTest("filter(TupleIsNull(True))"):
  368. lhs = (F("customer_code"), F("company_code"))
  369. lookup = TupleIsNull(lhs, True)
  370. self.assertSequenceEqual(
  371. Contact.objects.filter(lookup).order_by("id"),
  372. (),
  373. )
  374. with self.subTest("filter(customer__isnull=False)"):
  375. self.assertSequenceEqual(
  376. Contact.objects.filter(customer__isnull=False).order_by("id"),
  377. contacts,
  378. )
  379. with self.subTest("filter(TupleIsNull(False))"):
  380. lhs = (F("customer_code"), F("company_code"))
  381. lookup = TupleIsNull(lhs, False)
  382. self.assertSequenceEqual(
  383. Contact.objects.filter(lookup).order_by("id"),
  384. contacts,
  385. )
  386. def test_isnull_subquery(self):
  387. with self.assertRaisesMessage(
  388. NotSupportedError, "'isnull' doesn't support multi-column subqueries."
  389. ):
  390. subquery = Customer.objects.filter(id=0)[:1]
  391. self.assertSequenceEqual(
  392. Contact.objects.filter(customer__isnull=subquery).order_by("id"), ()
  393. )
  394. def test_lookup_errors(self):
  395. m_2_elements = "'%s' lookup of 'customer' must have 2 elements"
  396. m_2_elements_each = "'in' lookup of 'customer' must have 2 elements each"
  397. test_cases = (
  398. ({"customer": 1}, m_2_elements % "exact"),
  399. ({"customer": (1, 2, 3)}, m_2_elements % "exact"),
  400. ({"customer__in": (1, 2, 3)}, m_2_elements_each),
  401. ({"customer__in": ("foo", "bar")}, m_2_elements_each),
  402. ({"customer__gt": 1}, m_2_elements % "gt"),
  403. ({"customer__gt": (1, 2, 3)}, m_2_elements % "gt"),
  404. ({"customer__gte": 1}, m_2_elements % "gte"),
  405. ({"customer__gte": (1, 2, 3)}, m_2_elements % "gte"),
  406. ({"customer__lt": 1}, m_2_elements % "lt"),
  407. ({"customer__lt": (1, 2, 3)}, m_2_elements % "lt"),
  408. ({"customer__lte": 1}, m_2_elements % "lte"),
  409. ({"customer__lte": (1, 2, 3)}, m_2_elements % "lte"),
  410. )
  411. for kwargs, message in test_cases:
  412. with (
  413. self.subTest(kwargs=kwargs),
  414. self.assertRaisesMessage(ValueError, message),
  415. ):
  416. Contact.objects.get(**kwargs)
  417. def test_tuple_lookup_names(self):
  418. test_cases = (
  419. (TupleExact, "exact"),
  420. (TupleGreaterThan, "gt"),
  421. (TupleGreaterThanOrEqual, "gte"),
  422. (TupleLessThan, "lt"),
  423. (TupleLessThanOrEqual, "lte"),
  424. (TupleIn, "in"),
  425. (TupleIsNull, "isnull"),
  426. )
  427. for lookup_class, lookup_name in test_cases:
  428. with self.subTest(lookup_name):
  429. self.assertEqual(lookup_class.lookup_name, lookup_name)
  430. def test_tuple_lookup_rhs_must_be_tuple_or_list(self):
  431. test_cases = itertools.product(
  432. (
  433. TupleExact,
  434. TupleGreaterThan,
  435. TupleGreaterThanOrEqual,
  436. TupleLessThan,
  437. TupleLessThanOrEqual,
  438. TupleIn,
  439. ),
  440. (
  441. 0,
  442. 1,
  443. None,
  444. True,
  445. False,
  446. {"foo": "bar"},
  447. ),
  448. )
  449. for lookup_cls, rhs in test_cases:
  450. lookup_name = lookup_cls.lookup_name
  451. with self.subTest(lookup_name=lookup_name, rhs=rhs):
  452. with self.assertRaisesMessage(
  453. ValueError,
  454. f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
  455. "must be a tuple or a list",
  456. ):
  457. lookup_cls((F("customer_code"), F("company_code")), rhs)
  458. def test_tuple_lookup_rhs_must_have_2_elements(self):
  459. test_cases = itertools.product(
  460. (
  461. TupleExact,
  462. TupleGreaterThan,
  463. TupleGreaterThanOrEqual,
  464. TupleLessThan,
  465. TupleLessThanOrEqual,
  466. ),
  467. (
  468. [],
  469. [1],
  470. [1, 2, 3],
  471. (),
  472. (1,),
  473. (1, 2, 3),
  474. ),
  475. )
  476. for lookup_cls, rhs in test_cases:
  477. lookup_name = lookup_cls.lookup_name
  478. with self.subTest(lookup_name=lookup_name, rhs=rhs):
  479. with self.assertRaisesMessage(
  480. ValueError,
  481. f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
  482. "must have 2 elements",
  483. ):
  484. lookup_cls((F("customer_code"), F("company_code")), rhs)