tests.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. import time
  2. import unittest
  3. from datetime import date, datetime
  4. from django.core.exceptions import FieldError
  5. from django.db import connection, models
  6. from django.test import SimpleTestCase, TestCase, override_settings
  7. from django.test.utils import register_lookup
  8. from django.utils import timezone
  9. from .models import Article, Author, MySQLUnixTimestamp
  10. class Div3Lookup(models.Lookup):
  11. lookup_name = 'div3'
  12. def as_sql(self, compiler, connection):
  13. lhs, params = self.process_lhs(compiler, connection)
  14. rhs, rhs_params = self.process_rhs(compiler, connection)
  15. params.extend(rhs_params)
  16. return '(%s) %%%% 3 = %s' % (lhs, rhs), params
  17. def as_oracle(self, compiler, connection):
  18. lhs, params = self.process_lhs(compiler, connection)
  19. rhs, rhs_params = self.process_rhs(compiler, connection)
  20. params.extend(rhs_params)
  21. return 'mod(%s, 3) = %s' % (lhs, rhs), params
  22. class Div3Transform(models.Transform):
  23. lookup_name = 'div3'
  24. def as_sql(self, compiler, connection):
  25. lhs, lhs_params = compiler.compile(self.lhs)
  26. return '(%s) %%%% 3' % lhs, lhs_params
  27. def as_oracle(self, compiler, connection, **extra_context):
  28. lhs, lhs_params = compiler.compile(self.lhs)
  29. return 'mod(%s, 3)' % lhs, lhs_params
  30. class Div3BilateralTransform(Div3Transform):
  31. bilateral = True
  32. class Mult3BilateralTransform(models.Transform):
  33. bilateral = True
  34. lookup_name = 'mult3'
  35. def as_sql(self, compiler, connection):
  36. lhs, lhs_params = compiler.compile(self.lhs)
  37. return '3 * (%s)' % lhs, lhs_params
  38. class LastDigitTransform(models.Transform):
  39. lookup_name = 'lastdigit'
  40. def as_sql(self, compiler, connection):
  41. lhs, lhs_params = compiler.compile(self.lhs)
  42. return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params
  43. class UpperBilateralTransform(models.Transform):
  44. bilateral = True
  45. lookup_name = 'upper'
  46. def as_sql(self, compiler, connection):
  47. lhs, lhs_params = compiler.compile(self.lhs)
  48. return 'UPPER(%s)' % lhs, lhs_params
  49. class YearTransform(models.Transform):
  50. # Use a name that avoids collision with the built-in year lookup.
  51. lookup_name = 'testyear'
  52. def as_sql(self, compiler, connection):
  53. lhs_sql, params = compiler.compile(self.lhs)
  54. return connection.ops.date_extract_sql('year', lhs_sql), params
  55. @property
  56. def output_field(self):
  57. return models.IntegerField()
  58. @YearTransform.register_lookup
  59. class YearExact(models.lookups.Lookup):
  60. lookup_name = 'exact'
  61. def as_sql(self, compiler, connection):
  62. # We will need to skip the extract part, and instead go
  63. # directly with the originating field, that is self.lhs.lhs
  64. lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
  65. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  66. # Note that we must be careful so that we have params in the
  67. # same order as we have the parts in the SQL.
  68. params = lhs_params + rhs_params + lhs_params + rhs_params
  69. # We use PostgreSQL specific SQL here. Note that we must do the
  70. # conversions in SQL instead of in Python to support F() references.
  71. return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
  72. "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
  73. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  74. @YearTransform.register_lookup
  75. class YearLte(models.lookups.LessThanOrEqual):
  76. """
  77. The purpose of this lookup is to efficiently compare the year of the field.
  78. """
  79. def as_sql(self, compiler, connection):
  80. # Skip the YearTransform above us (no possibility for efficient
  81. # lookup otherwise).
  82. real_lhs = self.lhs.lhs
  83. lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
  84. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  85. params.extend(rhs_params)
  86. # Build SQL where the integer year is concatenated with last month
  87. # and day, then convert that to date. (We try to have SQL like:
  88. # WHERE somecol <= '2013-12-31')
  89. # but also make it work if the rhs_sql is field reference.
  90. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
  91. class Exactly(models.lookups.Exact):
  92. """
  93. This lookup is used to test lookup registration.
  94. """
  95. lookup_name = 'exactly'
  96. def get_rhs_op(self, connection, rhs):
  97. return connection.operators['exact'] % rhs
  98. class SQLFuncMixin:
  99. def as_sql(self, compiler, connection):
  100. return '%s()' % self.name, []
  101. @property
  102. def output_field(self):
  103. return CustomField()
  104. class SQLFuncLookup(SQLFuncMixin, models.Lookup):
  105. def __init__(self, name, *args, **kwargs):
  106. super().__init__(*args, **kwargs)
  107. self.name = name
  108. class SQLFuncTransform(SQLFuncMixin, models.Transform):
  109. def __init__(self, name, *args, **kwargs):
  110. super().__init__(*args, **kwargs)
  111. self.name = name
  112. class SQLFuncFactory:
  113. def __init__(self, key, name):
  114. self.key = key
  115. self.name = name
  116. def __call__(self, *args, **kwargs):
  117. if self.key == 'lookupfunc':
  118. return SQLFuncLookup(self.name, *args, **kwargs)
  119. return SQLFuncTransform(self.name, *args, **kwargs)
  120. class CustomField(models.TextField):
  121. def get_lookup(self, lookup_name):
  122. if lookup_name.startswith('lookupfunc_'):
  123. key, name = lookup_name.split('_', 1)
  124. return SQLFuncFactory(key, name)
  125. return super().get_lookup(lookup_name)
  126. def get_transform(self, lookup_name):
  127. if lookup_name.startswith('transformfunc_'):
  128. key, name = lookup_name.split('_', 1)
  129. return SQLFuncFactory(key, name)
  130. return super().get_transform(lookup_name)
  131. class CustomModel(models.Model):
  132. field = CustomField()
  133. # We will register this class temporarily in the test method.
  134. class InMonth(models.lookups.Lookup):
  135. """
  136. InMonth matches if the column's month is the same as value's month.
  137. """
  138. lookup_name = 'inmonth'
  139. def as_sql(self, compiler, connection):
  140. lhs, lhs_params = self.process_lhs(compiler, connection)
  141. rhs, rhs_params = self.process_rhs(compiler, connection)
  142. # We need to be careful so that we get the params in right
  143. # places.
  144. params = lhs_params + rhs_params + lhs_params + rhs_params
  145. return ("%s >= date_trunc('month', %s) and "
  146. "%s < date_trunc('month', %s) + interval '1 months'" %
  147. (lhs, rhs, lhs, rhs), params)
  148. class DateTimeTransform(models.Transform):
  149. lookup_name = 'as_datetime'
  150. @property
  151. def output_field(self):
  152. return models.DateTimeField()
  153. def as_sql(self, compiler, connection):
  154. lhs, params = compiler.compile(self.lhs)
  155. return 'from_unixtime({})'.format(lhs), params
  156. class LookupTests(TestCase):
  157. def test_custom_name_lookup(self):
  158. a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
  159. Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
  160. with register_lookup(models.DateField, YearTransform), \
  161. register_lookup(models.DateField, YearTransform, lookup_name='justtheyear'), \
  162. register_lookup(YearTransform, Exactly), \
  163. register_lookup(YearTransform, Exactly, lookup_name='isactually'):
  164. qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
  165. qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
  166. self.assertSequenceEqual(qs1, [a1])
  167. self.assertSequenceEqual(qs2, [a1])
  168. def test_custom_exact_lookup_none_rhs(self):
  169. """
  170. __exact=None is transformed to __isnull=True if a custom lookup class
  171. with lookup_name != 'exact' is registered as the `exact` lookup.
  172. """
  173. field = Author._meta.get_field('birthdate')
  174. OldExactLookup = field.get_lookup('exact')
  175. author = Author.objects.create(name='author', birthdate=None)
  176. try:
  177. field.register_lookup(Exactly, 'exact')
  178. self.assertEqual(Author.objects.get(birthdate__exact=None), author)
  179. finally:
  180. field.register_lookup(OldExactLookup, 'exact')
  181. def test_basic_lookup(self):
  182. a1 = Author.objects.create(name='a1', age=1)
  183. a2 = Author.objects.create(name='a2', age=2)
  184. a3 = Author.objects.create(name='a3', age=3)
  185. a4 = Author.objects.create(name='a4', age=4)
  186. with register_lookup(models.IntegerField, Div3Lookup):
  187. self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
  188. self.assertSequenceEqual(Author.objects.filter(age__div3=1).order_by('age'), [a1, a4])
  189. self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
  190. self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])
  191. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  192. def test_birthdate_month(self):
  193. a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
  194. a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
  195. a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
  196. a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
  197. with register_lookup(models.DateField, InMonth):
  198. self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3])
  199. self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2])
  200. self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1])
  201. self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4])
  202. self.assertSequenceEqual(Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), [])
  203. def test_div3_extract(self):
  204. with register_lookup(models.IntegerField, Div3Transform):
  205. a1 = Author.objects.create(name='a1', age=1)
  206. a2 = Author.objects.create(name='a2', age=2)
  207. a3 = Author.objects.create(name='a3', age=3)
  208. a4 = Author.objects.create(name='a4', age=4)
  209. baseqs = Author.objects.order_by('name')
  210. self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
  211. self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
  212. self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
  213. self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
  214. self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
  215. self.assertSequenceEqual(baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4])
  216. def test_foreignobject_lookup_registration(self):
  217. field = Article._meta.get_field('author')
  218. with register_lookup(models.ForeignObject, Exactly):
  219. self.assertIs(field.get_lookup('exactly'), Exactly)
  220. # ForeignObject should ignore regular Field lookups
  221. with register_lookup(models.Field, Exactly):
  222. self.assertIsNone(field.get_lookup('exactly'))
  223. def test_lookups_caching(self):
  224. field = Article._meta.get_field('author')
  225. # clear and re-cache
  226. field.get_lookups.cache_clear()
  227. self.assertNotIn('exactly', field.get_lookups())
  228. # registration should bust the cache
  229. with register_lookup(models.ForeignObject, Exactly):
  230. # getting the lookups again should re-cache
  231. self.assertIn('exactly', field.get_lookups())
  232. class BilateralTransformTests(TestCase):
  233. def test_bilateral_upper(self):
  234. with register_lookup(models.CharField, UpperBilateralTransform):
  235. author1 = Author.objects.create(name='Doe')
  236. author2 = Author.objects.create(name='doe')
  237. author3 = Author.objects.create(name='Foo')
  238. self.assertCountEqual(
  239. Author.objects.filter(name__upper='doe'),
  240. [author1, author2],
  241. )
  242. self.assertSequenceEqual(
  243. Author.objects.filter(name__upper__contains='f'),
  244. [author3],
  245. )
  246. def test_bilateral_inner_qs(self):
  247. with register_lookup(models.CharField, UpperBilateralTransform):
  248. msg = 'Bilateral transformations on nested querysets are not implemented.'
  249. with self.assertRaisesMessage(NotImplementedError, msg):
  250. Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
  251. def test_bilateral_multi_value(self):
  252. with register_lookup(models.CharField, UpperBilateralTransform):
  253. Author.objects.bulk_create([
  254. Author(name='Foo'),
  255. Author(name='Bar'),
  256. Author(name='Ray'),
  257. ])
  258. self.assertQuerysetEqual(
  259. Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'),
  260. ['Bar', 'Foo'],
  261. lambda a: a.name
  262. )
  263. def test_div3_bilateral_extract(self):
  264. with register_lookup(models.IntegerField, Div3BilateralTransform):
  265. a1 = Author.objects.create(name='a1', age=1)
  266. a2 = Author.objects.create(name='a2', age=2)
  267. a3 = Author.objects.create(name='a3', age=3)
  268. a4 = Author.objects.create(name='a4', age=4)
  269. baseqs = Author.objects.order_by('name')
  270. self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
  271. self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
  272. self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
  273. self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
  274. self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
  275. self.assertSequenceEqual(baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4])
  276. def test_bilateral_order(self):
  277. with register_lookup(models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform):
  278. a1 = Author.objects.create(name='a1', age=1)
  279. a2 = Author.objects.create(name='a2', age=2)
  280. a3 = Author.objects.create(name='a3', age=3)
  281. a4 = Author.objects.create(name='a4', age=4)
  282. baseqs = Author.objects.order_by('name')
  283. # mult3__div3 always leads to 0
  284. self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4])
  285. self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
  286. def test_transform_order_by(self):
  287. with register_lookup(models.IntegerField, LastDigitTransform):
  288. a1 = Author.objects.create(name='a1', age=11)
  289. a2 = Author.objects.create(name='a2', age=23)
  290. a3 = Author.objects.create(name='a3', age=32)
  291. a4 = Author.objects.create(name='a4', age=40)
  292. qs = Author.objects.order_by('age__lastdigit')
  293. self.assertSequenceEqual(qs, [a4, a1, a3, a2])
  294. def test_bilateral_fexpr(self):
  295. with register_lookup(models.IntegerField, Mult3BilateralTransform):
  296. a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
  297. a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
  298. a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
  299. a4 = Author.objects.create(name='a4', age=4)
  300. baseqs = Author.objects.order_by('name')
  301. self.assertSequenceEqual(baseqs.filter(age__mult3=models.F('age')), [a1, a2, a3, a4])
  302. # Same as age >= average_rating
  303. self.assertSequenceEqual(baseqs.filter(age__mult3__gte=models.F('average_rating')), [a2, a3])
  304. @override_settings(USE_TZ=True)
  305. class DateTimeLookupTests(TestCase):
  306. @unittest.skipUnless(connection.vendor == 'mysql', "MySQL specific SQL used")
  307. def test_datetime_output_field(self):
  308. with register_lookup(models.PositiveIntegerField, DateTimeTransform):
  309. ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
  310. y2k = timezone.make_aware(datetime(2000, 1, 1))
  311. self.assertSequenceEqual(MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut])
  312. class YearLteTests(TestCase):
  313. @classmethod
  314. def setUpTestData(cls):
  315. cls.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
  316. cls.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
  317. cls.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
  318. cls.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
  319. def setUp(self):
  320. models.DateField.register_lookup(YearTransform)
  321. def tearDown(self):
  322. models.DateField._unregister_lookup(YearTransform)
  323. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  324. def test_year_lte(self):
  325. baseqs = Author.objects.order_by('name')
  326. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=2012), [self.a1, self.a2, self.a3, self.a4])
  327. self.assertSequenceEqual(baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4])
  328. self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__testyear=2012).query))
  329. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=2011), [self.a1])
  330. # The non-optimized version works, too.
  331. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])
  332. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  333. def test_year_lte_fexpr(self):
  334. self.a2.age = 2011
  335. self.a2.save()
  336. self.a3.age = 2012
  337. self.a3.save()
  338. self.a4.age = 2013
  339. self.a4.save()
  340. baseqs = Author.objects.order_by('name')
  341. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lte=models.F('age')), [self.a3, self.a4])
  342. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=models.F('age')), [self.a4])
  343. def test_year_lte_sql(self):
  344. # This test will just check the generated SQL for __lte. This
  345. # doesn't require running on PostgreSQL and spots the most likely
  346. # error - not running YearLte SQL at all.
  347. baseqs = Author.objects.order_by('name')
  348. self.assertIn(
  349. '<= (2011 || ', str(baseqs.filter(birthdate__testyear__lte=2011).query))
  350. self.assertIn(
  351. '-12-31', str(baseqs.filter(birthdate__testyear__lte=2011).query))
  352. def test_postgres_year_exact(self):
  353. baseqs = Author.objects.order_by('name')
  354. self.assertIn(
  355. '= (2011 || ', str(baseqs.filter(birthdate__testyear=2011).query))
  356. self.assertIn(
  357. '-12-31', str(baseqs.filter(birthdate__testyear=2011).query))
  358. def test_custom_implementation_year_exact(self):
  359. try:
  360. # Two ways to add a customized implementation for different backends:
  361. # First is MonkeyPatch of the class.
  362. def as_custom_sql(self, compiler, connection):
  363. lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
  364. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  365. params = lhs_params + rhs_params + lhs_params + rhs_params
  366. return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  367. "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
  368. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  369. setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
  370. self.assertIn(
  371. 'concat(',
  372. str(Author.objects.filter(birthdate__testyear=2012).query))
  373. finally:
  374. delattr(YearExact, 'as_' + connection.vendor)
  375. try:
  376. # The other way is to subclass the original lookup and register the subclassed
  377. # lookup instead of the original.
  378. class CustomYearExact(YearExact):
  379. # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
  380. # and so on, but as we don't know which DB we are running on, we need to use
  381. # setattr.
  382. def as_custom_sql(self, compiler, connection):
  383. lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
  384. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  385. params = lhs_params + rhs_params + lhs_params + rhs_params
  386. return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  387. "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
  388. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  389. setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
  390. YearTransform.register_lookup(CustomYearExact)
  391. self.assertIn(
  392. 'CONCAT(',
  393. str(Author.objects.filter(birthdate__testyear=2012).query))
  394. finally:
  395. YearTransform._unregister_lookup(CustomYearExact)
  396. YearTransform.register_lookup(YearExact)
  397. class TrackCallsYearTransform(YearTransform):
  398. # Use a name that avoids collision with the built-in year lookup.
  399. lookup_name = 'testyear'
  400. call_order = []
  401. def as_sql(self, compiler, connection):
  402. lhs_sql, params = compiler.compile(self.lhs)
  403. return connection.ops.date_extract_sql('year', lhs_sql), params
  404. @property
  405. def output_field(self):
  406. return models.IntegerField()
  407. def get_lookup(self, lookup_name):
  408. self.call_order.append('lookup')
  409. return super().get_lookup(lookup_name)
  410. def get_transform(self, lookup_name):
  411. self.call_order.append('transform')
  412. return super().get_transform(lookup_name)
  413. class LookupTransformCallOrderTests(SimpleTestCase):
  414. def test_call_order(self):
  415. with register_lookup(models.DateField, TrackCallsYearTransform):
  416. # junk lookup - tries lookup, then transform, then fails
  417. msg = "Unsupported lookup 'junk' for IntegerField or join on the field not permitted."
  418. with self.assertRaisesMessage(FieldError, msg):
  419. Author.objects.filter(birthdate__testyear__junk=2012)
  420. self.assertEqual(TrackCallsYearTransform.call_order,
  421. ['lookup', 'transform'])
  422. TrackCallsYearTransform.call_order = []
  423. # junk transform - tries transform only, then fails
  424. with self.assertRaisesMessage(FieldError, msg):
  425. Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
  426. self.assertEqual(TrackCallsYearTransform.call_order,
  427. ['transform'])
  428. TrackCallsYearTransform.call_order = []
  429. # Just getting the year (implied __exact) - lookup only
  430. Author.objects.filter(birthdate__testyear=2012)
  431. self.assertEqual(TrackCallsYearTransform.call_order,
  432. ['lookup'])
  433. TrackCallsYearTransform.call_order = []
  434. # Just getting the year (explicit __exact) - lookup only
  435. Author.objects.filter(birthdate__testyear__exact=2012)
  436. self.assertEqual(TrackCallsYearTransform.call_order,
  437. ['lookup'])
  438. class CustomisedMethodsTests(SimpleTestCase):
  439. def test_overridden_get_lookup(self):
  440. q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
  441. self.assertIn('monkeys()', str(q.query))
  442. def test_overridden_get_transform(self):
  443. q = CustomModel.objects.filter(field__transformfunc_banana=3)
  444. self.assertIn('banana()', str(q.query))
  445. def test_overridden_get_lookup_chain(self):
  446. q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
  447. self.assertIn('elephants()', str(q.query))
  448. def test_overridden_get_transform_chain(self):
  449. q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
  450. self.assertIn('pear()', str(q.query))
  451. class SubqueryTransformTests(TestCase):
  452. def test_subquery_usage(self):
  453. with register_lookup(models.IntegerField, Div3Transform):
  454. Author.objects.create(name='a1', age=1)
  455. a2 = Author.objects.create(name='a2', age=2)
  456. Author.objects.create(name='a3', age=3)
  457. Author.objects.create(name='a4', age=4)
  458. qs = Author.objects.order_by('name').filter(id__in=Author.objects.filter(age__div3=2))
  459. self.assertSequenceEqual(qs, [a2])