tests.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. import time
  2. import unittest
  3. from datetime import date, datetime
  4. from django.core.exceptions import FieldError
  5. from django.db import connection, models
  6. from django.db.models.fields.related_lookups import RelatedGreaterThan
  7. from django.db.models.lookups import EndsWith, StartsWith
  8. from django.test import SimpleTestCase, TestCase, override_settings
  9. from django.test.utils import register_lookup
  10. from django.utils import timezone
  11. from .models import Article, Author, MySQLUnixTimestamp
  12. class Div3Lookup(models.Lookup):
  13. lookup_name = "div3"
  14. def as_sql(self, compiler, connection):
  15. lhs, params = self.process_lhs(compiler, connection)
  16. rhs, rhs_params = self.process_rhs(compiler, connection)
  17. params.extend(rhs_params)
  18. return "(%s) %%%% 3 = %s" % (lhs, rhs), params
  19. def as_oracle(self, compiler, connection):
  20. lhs, params = self.process_lhs(compiler, connection)
  21. rhs, rhs_params = self.process_rhs(compiler, connection)
  22. params.extend(rhs_params)
  23. return "mod(%s, 3) = %s" % (lhs, rhs), params
  24. class Div3Transform(models.Transform):
  25. lookup_name = "div3"
  26. def as_sql(self, compiler, connection):
  27. lhs, lhs_params = compiler.compile(self.lhs)
  28. return "(%s) %%%% 3" % lhs, lhs_params
  29. def as_oracle(self, compiler, connection, **extra_context):
  30. lhs, lhs_params = compiler.compile(self.lhs)
  31. return "mod(%s, 3)" % lhs, lhs_params
  32. class Div3BilateralTransform(Div3Transform):
  33. bilateral = True
  34. class Mult3BilateralTransform(models.Transform):
  35. bilateral = True
  36. lookup_name = "mult3"
  37. def as_sql(self, compiler, connection):
  38. lhs, lhs_params = compiler.compile(self.lhs)
  39. return "3 * (%s)" % lhs, lhs_params
  40. class LastDigitTransform(models.Transform):
  41. lookup_name = "lastdigit"
  42. def as_sql(self, compiler, connection):
  43. lhs, lhs_params = compiler.compile(self.lhs)
  44. return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params
  45. class UpperBilateralTransform(models.Transform):
  46. bilateral = True
  47. lookup_name = "upper"
  48. def as_sql(self, compiler, connection):
  49. lhs, lhs_params = compiler.compile(self.lhs)
  50. return "UPPER(%s)" % lhs, lhs_params
  51. class YearTransform(models.Transform):
  52. # Use a name that avoids collision with the built-in year lookup.
  53. lookup_name = "testyear"
  54. def as_sql(self, compiler, connection):
  55. lhs_sql, params = compiler.compile(self.lhs)
  56. return connection.ops.date_extract_sql("year", lhs_sql, params)
  57. @property
  58. def output_field(self):
  59. return models.IntegerField()
  60. @YearTransform.register_lookup
  61. class YearExact(models.lookups.Lookup):
  62. lookup_name = "exact"
  63. def as_sql(self, compiler, connection):
  64. # We will need to skip the extract part, and instead go
  65. # directly with the originating field, that is self.lhs.lhs
  66. lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
  67. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  68. # Note that we must be careful so that we have params in the
  69. # same order as we have the parts in the SQL.
  70. params = lhs_params + rhs_params + lhs_params + rhs_params
  71. # We use PostgreSQL specific SQL here. Note that we must do the
  72. # conversions in SQL instead of in Python to support F() references.
  73. return (
  74. "%(lhs)s >= (%(rhs)s || '-01-01')::date "
  75. "AND %(lhs)s <= (%(rhs)s || '-12-31')::date"
  76. % {"lhs": lhs_sql, "rhs": rhs_sql},
  77. params,
  78. )
  79. @YearTransform.register_lookup
  80. class YearLte(models.lookups.LessThanOrEqual):
  81. """
  82. The purpose of this lookup is to efficiently compare the year of the field.
  83. """
  84. def as_sql(self, compiler, connection):
  85. # Skip the YearTransform above us (no possibility for efficient
  86. # lookup otherwise).
  87. real_lhs = self.lhs.lhs
  88. lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
  89. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  90. params.extend(rhs_params)
  91. # Build SQL where the integer year is concatenated with last month
  92. # and day, then convert that to date. (We try to have SQL like:
  93. # WHERE somecol <= '2013-12-31')
  94. # but also make it work if the rhs_sql is field reference.
  95. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
  96. class Exactly(models.lookups.Exact):
  97. """
  98. This lookup is used to test lookup registration.
  99. """
  100. lookup_name = "exactly"
  101. def get_rhs_op(self, connection, rhs):
  102. return connection.operators["exact"] % rhs
  103. class SQLFuncMixin:
  104. def as_sql(self, compiler, connection):
  105. return "%s()" % self.name, []
  106. @property
  107. def output_field(self):
  108. return CustomField()
  109. class SQLFuncLookup(SQLFuncMixin, models.Lookup):
  110. def __init__(self, name, *args, **kwargs):
  111. super().__init__(*args, **kwargs)
  112. self.name = name
  113. class SQLFuncTransform(SQLFuncMixin, models.Transform):
  114. def __init__(self, name, *args, **kwargs):
  115. super().__init__(*args, **kwargs)
  116. self.name = name
  117. class SQLFuncFactory:
  118. def __init__(self, key, name):
  119. self.key = key
  120. self.name = name
  121. def __call__(self, *args, **kwargs):
  122. if self.key == "lookupfunc":
  123. return SQLFuncLookup(self.name, *args, **kwargs)
  124. return SQLFuncTransform(self.name, *args, **kwargs)
  125. class CustomField(models.TextField):
  126. def get_lookup(self, lookup_name):
  127. if lookup_name.startswith("lookupfunc_"):
  128. key, name = lookup_name.split("_", 1)
  129. return SQLFuncFactory(key, name)
  130. return super().get_lookup(lookup_name)
  131. def get_transform(self, lookup_name):
  132. if lookup_name.startswith("transformfunc_"):
  133. key, name = lookup_name.split("_", 1)
  134. return SQLFuncFactory(key, name)
  135. return super().get_transform(lookup_name)
  136. class CustomModel(models.Model):
  137. field = CustomField()
  138. # We will register this class temporarily in the test method.
  139. class InMonth(models.lookups.Lookup):
  140. """
  141. InMonth matches if the column's month is the same as value's month.
  142. """
  143. lookup_name = "inmonth"
  144. def as_sql(self, compiler, connection):
  145. lhs, lhs_params = self.process_lhs(compiler, connection)
  146. rhs, rhs_params = self.process_rhs(compiler, connection)
  147. # We need to be careful so that we get the params in right
  148. # places.
  149. params = lhs_params + rhs_params + lhs_params + rhs_params
  150. return (
  151. "%s >= date_trunc('month', %s) and "
  152. "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs),
  153. params,
  154. )
  155. class DateTimeTransform(models.Transform):
  156. lookup_name = "as_datetime"
  157. @property
  158. def output_field(self):
  159. return models.DateTimeField()
  160. def as_sql(self, compiler, connection):
  161. lhs, params = compiler.compile(self.lhs)
  162. return "from_unixtime({})".format(lhs), params
  163. class CustomStartsWith(StartsWith):
  164. lookup_name = "sw"
  165. class CustomEndsWith(EndsWith):
  166. lookup_name = "ew"
  167. class RelatedMoreThan(RelatedGreaterThan):
  168. lookup_name = "rmt"
  169. class LookupTests(TestCase):
  170. def test_custom_name_lookup(self):
  171. a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
  172. Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
  173. with register_lookup(models.DateField, YearTransform), register_lookup(
  174. models.DateField, YearTransform, lookup_name="justtheyear"
  175. ), register_lookup(YearTransform, Exactly), register_lookup(
  176. YearTransform, Exactly, lookup_name="isactually"
  177. ):
  178. qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
  179. qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
  180. self.assertSequenceEqual(qs1, [a1])
  181. self.assertSequenceEqual(qs2, [a1])
  182. def test_custom_exact_lookup_none_rhs(self):
  183. """
  184. __exact=None is transformed to __isnull=True if a custom lookup class
  185. with lookup_name != 'exact' is registered as the `exact` lookup.
  186. """
  187. field = Author._meta.get_field("birthdate")
  188. OldExactLookup = field.get_lookup("exact")
  189. author = Author.objects.create(name="author", birthdate=None)
  190. try:
  191. field.register_lookup(Exactly, "exact")
  192. self.assertEqual(Author.objects.get(birthdate__exact=None), author)
  193. finally:
  194. field.register_lookup(OldExactLookup, "exact")
  195. def test_basic_lookup(self):
  196. a1 = Author.objects.create(name="a1", age=1)
  197. a2 = Author.objects.create(name="a2", age=2)
  198. a3 = Author.objects.create(name="a3", age=3)
  199. a4 = Author.objects.create(name="a4", age=4)
  200. with register_lookup(models.IntegerField, Div3Lookup):
  201. self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
  202. self.assertSequenceEqual(
  203. Author.objects.filter(age__div3=1).order_by("age"), [a1, a4]
  204. )
  205. self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
  206. self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])
  207. @unittest.skipUnless(
  208. connection.vendor == "postgresql", "PostgreSQL specific SQL used"
  209. )
  210. def test_birthdate_month(self):
  211. a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
  212. a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
  213. a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
  214. a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
  215. with register_lookup(models.DateField, InMonth):
  216. self.assertSequenceEqual(
  217. Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3]
  218. )
  219. self.assertSequenceEqual(
  220. Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2]
  221. )
  222. self.assertSequenceEqual(
  223. Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1]
  224. )
  225. self.assertSequenceEqual(
  226. Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4]
  227. )
  228. self.assertSequenceEqual(
  229. Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), []
  230. )
  231. def test_div3_extract(self):
  232. with register_lookup(models.IntegerField, Div3Transform):
  233. a1 = Author.objects.create(name="a1", age=1)
  234. a2 = Author.objects.create(name="a2", age=2)
  235. a3 = Author.objects.create(name="a3", age=3)
  236. a4 = Author.objects.create(name="a4", age=4)
  237. baseqs = Author.objects.order_by("name")
  238. self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
  239. self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
  240. self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
  241. self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
  242. self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
  243. self.assertSequenceEqual(
  244. baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
  245. )
  246. def test_foreignobject_lookup_registration(self):
  247. field = Article._meta.get_field("author")
  248. with register_lookup(models.ForeignObject, Exactly):
  249. self.assertIs(field.get_lookup("exactly"), Exactly)
  250. # ForeignObject should ignore regular Field lookups
  251. with register_lookup(models.Field, Exactly):
  252. self.assertIsNone(field.get_lookup("exactly"))
  253. def test_lookups_caching(self):
  254. field = Article._meta.get_field("author")
  255. # clear and re-cache
  256. field.get_class_lookups.cache_clear()
  257. self.assertNotIn("exactly", field.get_lookups())
  258. # registration should bust the cache
  259. with register_lookup(models.ForeignObject, Exactly):
  260. # getting the lookups again should re-cache
  261. self.assertIn("exactly", field.get_lookups())
  262. # Unregistration should bust the cache.
  263. self.assertNotIn("exactly", field.get_lookups())
  264. class BilateralTransformTests(TestCase):
  265. def test_bilateral_upper(self):
  266. with register_lookup(models.CharField, UpperBilateralTransform):
  267. author1 = Author.objects.create(name="Doe")
  268. author2 = Author.objects.create(name="doe")
  269. author3 = Author.objects.create(name="Foo")
  270. self.assertCountEqual(
  271. Author.objects.filter(name__upper="doe"),
  272. [author1, author2],
  273. )
  274. self.assertSequenceEqual(
  275. Author.objects.filter(name__upper__contains="f"),
  276. [author3],
  277. )
  278. def test_bilateral_inner_qs(self):
  279. with register_lookup(models.CharField, UpperBilateralTransform):
  280. msg = "Bilateral transformations on nested querysets are not implemented."
  281. with self.assertRaisesMessage(NotImplementedError, msg):
  282. Author.objects.filter(
  283. name__upper__in=Author.objects.values_list("name")
  284. )
  285. def test_bilateral_multi_value(self):
  286. with register_lookup(models.CharField, UpperBilateralTransform):
  287. Author.objects.bulk_create(
  288. [
  289. Author(name="Foo"),
  290. Author(name="Bar"),
  291. Author(name="Ray"),
  292. ]
  293. )
  294. self.assertQuerySetEqual(
  295. Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by(
  296. "name"
  297. ),
  298. ["Bar", "Foo"],
  299. lambda a: a.name,
  300. )
  301. def test_div3_bilateral_extract(self):
  302. with register_lookup(models.IntegerField, Div3BilateralTransform):
  303. a1 = Author.objects.create(name="a1", age=1)
  304. a2 = Author.objects.create(name="a2", age=2)
  305. a3 = Author.objects.create(name="a3", age=3)
  306. a4 = Author.objects.create(name="a4", age=4)
  307. baseqs = Author.objects.order_by("name")
  308. self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
  309. self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
  310. self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
  311. self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
  312. self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
  313. self.assertSequenceEqual(
  314. baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
  315. )
  316. def test_bilateral_order(self):
  317. with register_lookup(
  318. models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform
  319. ):
  320. a1 = Author.objects.create(name="a1", age=1)
  321. a2 = Author.objects.create(name="a2", age=2)
  322. a3 = Author.objects.create(name="a3", age=3)
  323. a4 = Author.objects.create(name="a4", age=4)
  324. baseqs = Author.objects.order_by("name")
  325. # mult3__div3 always leads to 0
  326. self.assertSequenceEqual(
  327. baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]
  328. )
  329. self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
  330. def test_transform_order_by(self):
  331. with register_lookup(models.IntegerField, LastDigitTransform):
  332. a1 = Author.objects.create(name="a1", age=11)
  333. a2 = Author.objects.create(name="a2", age=23)
  334. a3 = Author.objects.create(name="a3", age=32)
  335. a4 = Author.objects.create(name="a4", age=40)
  336. qs = Author.objects.order_by("age__lastdigit")
  337. self.assertSequenceEqual(qs, [a4, a1, a3, a2])
  338. def test_bilateral_fexpr(self):
  339. with register_lookup(models.IntegerField, Mult3BilateralTransform):
  340. a1 = Author.objects.create(name="a1", age=1, average_rating=3.2)
  341. a2 = Author.objects.create(name="a2", age=2, average_rating=0.5)
  342. a3 = Author.objects.create(name="a3", age=3, average_rating=1.5)
  343. a4 = Author.objects.create(name="a4", age=4)
  344. baseqs = Author.objects.order_by("name")
  345. self.assertSequenceEqual(
  346. baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4]
  347. )
  348. # Same as age >= average_rating
  349. self.assertSequenceEqual(
  350. baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3]
  351. )
  352. @override_settings(USE_TZ=True)
  353. class DateTimeLookupTests(TestCase):
  354. @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used")
  355. def test_datetime_output_field(self):
  356. with register_lookup(models.PositiveIntegerField, DateTimeTransform):
  357. ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
  358. y2k = timezone.make_aware(datetime(2000, 1, 1))
  359. self.assertSequenceEqual(
  360. MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut]
  361. )
  362. class YearLteTests(TestCase):
  363. @classmethod
  364. def setUpTestData(cls):
  365. cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
  366. cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
  367. cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
  368. cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
  369. def setUp(self):
  370. models.DateField.register_lookup(YearTransform)
  371. def tearDown(self):
  372. models.DateField._unregister_lookup(YearTransform)
  373. @unittest.skipUnless(
  374. connection.vendor == "postgresql", "PostgreSQL specific SQL used"
  375. )
  376. def test_year_lte(self):
  377. baseqs = Author.objects.order_by("name")
  378. self.assertSequenceEqual(
  379. baseqs.filter(birthdate__testyear__lte=2012),
  380. [self.a1, self.a2, self.a3, self.a4],
  381. )
  382. self.assertSequenceEqual(
  383. baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4]
  384. )
  385. self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query))
  386. self.assertSequenceEqual(
  387. baseqs.filter(birthdate__testyear__lte=2011), [self.a1]
  388. )
  389. # The non-optimized version works, too.
  390. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])
  391. @unittest.skipUnless(
  392. connection.vendor == "postgresql", "PostgreSQL specific SQL used"
  393. )
  394. def test_year_lte_fexpr(self):
  395. self.a2.age = 2011
  396. self.a2.save()
  397. self.a3.age = 2012
  398. self.a3.save()
  399. self.a4.age = 2013
  400. self.a4.save()
  401. baseqs = Author.objects.order_by("name")
  402. self.assertSequenceEqual(
  403. baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4]
  404. )
  405. self.assertSequenceEqual(
  406. baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4]
  407. )
  408. def test_year_lte_sql(self):
  409. # This test will just check the generated SQL for __lte. This
  410. # doesn't require running on PostgreSQL and spots the most likely
  411. # error - not running YearLte SQL at all.
  412. baseqs = Author.objects.order_by("name")
  413. self.assertIn(
  414. "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query)
  415. )
  416. self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query))
  417. def test_postgres_year_exact(self):
  418. baseqs = Author.objects.order_by("name")
  419. self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query))
  420. self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query))
  421. def test_custom_implementation_year_exact(self):
  422. try:
  423. # Two ways to add a customized implementation for different backends:
  424. # First is MonkeyPatch of the class.
  425. def as_custom_sql(self, compiler, connection):
  426. lhs_sql, lhs_params = self.process_lhs(
  427. compiler, connection, self.lhs.lhs
  428. )
  429. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  430. params = lhs_params + rhs_params + lhs_params + rhs_params
  431. return (
  432. "%(lhs)s >= "
  433. "str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  434. "AND %(lhs)s <= "
  435. "str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
  436. % {"lhs": lhs_sql, "rhs": rhs_sql},
  437. params,
  438. )
  439. setattr(YearExact, "as_" + connection.vendor, as_custom_sql)
  440. self.assertIn(
  441. "concat(", str(Author.objects.filter(birthdate__testyear=2012).query)
  442. )
  443. finally:
  444. delattr(YearExact, "as_" + connection.vendor)
  445. try:
  446. # The other way is to subclass the original lookup and register the
  447. # subclassed lookup instead of the original.
  448. class CustomYearExact(YearExact):
  449. # This method should be named "as_mysql" for MySQL,
  450. # "as_postgresql" for postgres and so on, but as we don't know
  451. # which DB we are running on, we need to use setattr.
  452. def as_custom_sql(self, compiler, connection):
  453. lhs_sql, lhs_params = self.process_lhs(
  454. compiler, connection, self.lhs.lhs
  455. )
  456. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  457. params = lhs_params + rhs_params + lhs_params + rhs_params
  458. return (
  459. "%(lhs)s >= "
  460. "str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  461. "AND %(lhs)s <= "
  462. "str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
  463. % {"lhs": lhs_sql, "rhs": rhs_sql},
  464. params,
  465. )
  466. setattr(
  467. CustomYearExact,
  468. "as_" + connection.vendor,
  469. CustomYearExact.as_custom_sql,
  470. )
  471. YearTransform.register_lookup(CustomYearExact)
  472. self.assertIn(
  473. "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query)
  474. )
  475. finally:
  476. YearTransform._unregister_lookup(CustomYearExact)
  477. YearTransform.register_lookup(YearExact)
  478. class TrackCallsYearTransform(YearTransform):
  479. # Use a name that avoids collision with the built-in year lookup.
  480. lookup_name = "testyear"
  481. call_order = []
  482. def as_sql(self, compiler, connection):
  483. lhs_sql, params = compiler.compile(self.lhs)
  484. return connection.ops.date_extract_sql("year", lhs_sql), params
  485. @property
  486. def output_field(self):
  487. return models.IntegerField()
  488. def get_lookup(self, lookup_name):
  489. self.call_order.append("lookup")
  490. return super().get_lookup(lookup_name)
  491. def get_transform(self, lookup_name):
  492. self.call_order.append("transform")
  493. return super().get_transform(lookup_name)
  494. class LookupTransformCallOrderTests(SimpleTestCase):
  495. def test_call_order(self):
  496. with register_lookup(models.DateField, TrackCallsYearTransform):
  497. # junk lookup - tries lookup, then transform, then fails
  498. msg = (
  499. "Unsupported lookup 'junk' for IntegerField or join on the field not "
  500. "permitted."
  501. )
  502. with self.assertRaisesMessage(FieldError, msg):
  503. Author.objects.filter(birthdate__testyear__junk=2012)
  504. self.assertEqual(
  505. TrackCallsYearTransform.call_order, ["lookup", "transform"]
  506. )
  507. TrackCallsYearTransform.call_order = []
  508. # junk transform - tries transform only, then fails
  509. with self.assertRaisesMessage(FieldError, msg):
  510. Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
  511. self.assertEqual(TrackCallsYearTransform.call_order, ["transform"])
  512. TrackCallsYearTransform.call_order = []
  513. # Just getting the year (implied __exact) - lookup only
  514. Author.objects.filter(birthdate__testyear=2012)
  515. self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
  516. TrackCallsYearTransform.call_order = []
  517. # Just getting the year (explicit __exact) - lookup only
  518. Author.objects.filter(birthdate__testyear__exact=2012)
  519. self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
  520. class CustomisedMethodsTests(SimpleTestCase):
  521. def test_overridden_get_lookup(self):
  522. q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
  523. self.assertIn("monkeys()", str(q.query))
  524. def test_overridden_get_transform(self):
  525. q = CustomModel.objects.filter(field__transformfunc_banana=3)
  526. self.assertIn("banana()", str(q.query))
  527. def test_overridden_get_lookup_chain(self):
  528. q = CustomModel.objects.filter(
  529. field__transformfunc_banana__lookupfunc_elephants=3
  530. )
  531. self.assertIn("elephants()", str(q.query))
  532. def test_overridden_get_transform_chain(self):
  533. q = CustomModel.objects.filter(
  534. field__transformfunc_banana__transformfunc_pear=3
  535. )
  536. self.assertIn("pear()", str(q.query))
  537. class SubqueryTransformTests(TestCase):
  538. def test_subquery_usage(self):
  539. with register_lookup(models.IntegerField, Div3Transform):
  540. Author.objects.create(name="a1", age=1)
  541. a2 = Author.objects.create(name="a2", age=2)
  542. Author.objects.create(name="a3", age=3)
  543. Author.objects.create(name="a4", age=4)
  544. qs = Author.objects.order_by("name").filter(
  545. id__in=Author.objects.filter(age__div3=2)
  546. )
  547. self.assertSequenceEqual(qs, [a2])
  548. class RegisterLookupTests(SimpleTestCase):
  549. def test_class_lookup(self):
  550. author_name = Author._meta.get_field("name")
  551. with register_lookup(models.CharField, CustomStartsWith):
  552. self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
  553. self.assertIsNone(author_name.get_lookup("sw"))
  554. def test_instance_lookup(self):
  555. author_name = Author._meta.get_field("name")
  556. author_alias = Author._meta.get_field("alias")
  557. with register_lookup(author_name, CustomStartsWith):
  558. self.assertEqual(author_name.instance_lookups, {"sw": CustomStartsWith})
  559. self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
  560. self.assertIsNone(author_alias.get_lookup("sw"))
  561. self.assertIsNone(author_name.get_lookup("sw"))
  562. self.assertEqual(author_name.instance_lookups, {})
  563. self.assertIsNone(author_alias.get_lookup("sw"))
  564. def test_instance_lookup_override_class_lookups(self):
  565. author_name = Author._meta.get_field("name")
  566. author_alias = Author._meta.get_field("alias")
  567. with register_lookup(models.CharField, CustomStartsWith, lookup_name="st_end"):
  568. with register_lookup(author_alias, CustomEndsWith, lookup_name="st_end"):
  569. self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
  570. self.assertEqual(author_alias.get_lookup("st_end"), CustomEndsWith)
  571. self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
  572. self.assertEqual(author_alias.get_lookup("st_end"), CustomStartsWith)
  573. self.assertIsNone(author_name.get_lookup("st_end"))
  574. self.assertIsNone(author_alias.get_lookup("st_end"))
  575. def test_instance_lookup_override(self):
  576. author_name = Author._meta.get_field("name")
  577. with register_lookup(author_name, CustomStartsWith, lookup_name="st_end"):
  578. self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
  579. author_name.register_lookup(CustomEndsWith, lookup_name="st_end")
  580. self.assertEqual(author_name.get_lookup("st_end"), CustomEndsWith)
  581. self.assertIsNone(author_name.get_lookup("st_end"))
  582. def test_lookup_on_transform(self):
  583. transform = Div3Transform
  584. with register_lookup(Div3Transform, CustomStartsWith):
  585. with register_lookup(Div3Transform, CustomEndsWith):
  586. self.assertEqual(
  587. transform.get_lookups(),
  588. {"sw": CustomStartsWith, "ew": CustomEndsWith},
  589. )
  590. self.assertEqual(transform.get_lookups(), {"sw": CustomStartsWith})
  591. self.assertEqual(transform.get_lookups(), {})
  592. def test_transform_on_field(self):
  593. author_name = Author._meta.get_field("name")
  594. author_alias = Author._meta.get_field("alias")
  595. with register_lookup(models.CharField, Div3Transform):
  596. self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
  597. self.assertEqual(author_name.get_transform("div3"), Div3Transform)
  598. with register_lookup(author_alias, Div3Transform):
  599. self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
  600. self.assertIsNone(author_name.get_transform("div3"))
  601. self.assertIsNone(author_alias.get_transform("div3"))
  602. self.assertIsNone(author_name.get_transform("div3"))
  603. def test_related_lookup(self):
  604. article_author = Article._meta.get_field("author")
  605. with register_lookup(models.Field, CustomStartsWith):
  606. self.assertIsNone(article_author.get_lookup("sw"))
  607. with register_lookup(models.ForeignKey, RelatedMoreThan):
  608. self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)
  609. def test_instance_related_lookup(self):
  610. article_author = Article._meta.get_field("author")
  611. with register_lookup(article_author, RelatedMoreThan):
  612. self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)
  613. self.assertIsNone(article_author.get_lookup("rmt"))