tests.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. from __future__ import unicode_literals
  2. from datetime import date
  3. import unittest
  4. from django.core.exceptions import FieldError
  5. from django.db import models
  6. from django.db import connection
  7. from django.test import TestCase
  8. from .models import Author
  9. class Div3Lookup(models.Lookup):
  10. lookup_name = 'div3'
  11. def as_sql(self, qn, connection):
  12. lhs, params = self.process_lhs(qn, connection)
  13. rhs, rhs_params = self.process_rhs(qn, connection)
  14. params.extend(rhs_params)
  15. return '%s %%%% 3 = %s' % (lhs, rhs), params
  16. def as_oracle(self, qn, connection):
  17. lhs, params = self.process_lhs(qn, connection)
  18. rhs, rhs_params = self.process_rhs(qn, connection)
  19. params.extend(rhs_params)
  20. return 'mod(%s, 3) = %s' % (lhs, rhs), params
  21. class Div3Transform(models.Transform):
  22. lookup_name = 'div3'
  23. def as_sql(self, qn, connection):
  24. lhs, lhs_params = qn.compile(self.lhs)
  25. return '%s %%%% 3' % (lhs,), lhs_params
  26. def as_oracle(self, qn, connection):
  27. lhs, lhs_params = qn.compile(self.lhs)
  28. return 'mod(%s, 3)' % lhs, lhs_params
  29. class YearTransform(models.Transform):
  30. lookup_name = 'year'
  31. def as_sql(self, qn, connection):
  32. lhs_sql, params = qn.compile(self.lhs)
  33. return connection.ops.date_extract_sql('year', lhs_sql), params
  34. @property
  35. def output_field(self):
  36. return models.IntegerField()
  37. @YearTransform.register_lookup
  38. class YearExact(models.lookups.Lookup):
  39. lookup_name = 'exact'
  40. def as_sql(self, qn, connection):
  41. # We will need to skip the extract part, and instead go
  42. # directly with the originating field, that is self.lhs.lhs
  43. lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
  44. rhs_sql, rhs_params = self.process_rhs(qn, connection)
  45. # Note that we must be careful so that we have params in the
  46. # same order as we have the parts in the SQL.
  47. params = lhs_params + rhs_params + lhs_params + rhs_params
  48. # We use PostgreSQL specific SQL here. Note that we must do the
  49. # conversions in SQL instead of in Python to support F() references.
  50. return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
  51. "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
  52. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  53. @YearTransform.register_lookup
  54. class YearLte(models.lookups.LessThanOrEqual):
  55. """
  56. The purpose of this lookup is to efficiently compare the year of the field.
  57. """
  58. def as_sql(self, qn, connection):
  59. # Skip the YearTransform above us (no possibility for efficient
  60. # lookup otherwise).
  61. real_lhs = self.lhs.lhs
  62. lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
  63. rhs_sql, rhs_params = self.process_rhs(qn, connection)
  64. params.extend(rhs_params)
  65. # Build SQL where the integer year is concatenated with last month
  66. # and day, then convert that to date. (We try to have SQL like:
  67. # WHERE somecol <= '2013-12-31')
  68. # but also make it work if the rhs_sql is field reference.
  69. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
  70. class SQLFunc(models.Lookup):
  71. def __init__(self, name, *args, **kwargs):
  72. super(SQLFunc, self).__init__(*args, **kwargs)
  73. self.name = name
  74. def as_sql(self, qn, connection):
  75. return '%s()', [self.name]
  76. @property
  77. def output_field(self):
  78. return CustomField()
  79. class SQLFuncFactory(object):
  80. def __init__(self, name):
  81. self.name = name
  82. def __call__(self, *args, **kwargs):
  83. return SQLFunc(self.name, *args, **kwargs)
  84. class CustomField(models.TextField):
  85. def get_lookup(self, lookup_name):
  86. if lookup_name.startswith('lookupfunc_'):
  87. key, name = lookup_name.split('_', 1)
  88. return SQLFuncFactory(name)
  89. return super(CustomField, self).get_lookup(lookup_name)
  90. def get_transform(self, lookup_name):
  91. if lookup_name.startswith('transformfunc_'):
  92. key, name = lookup_name.split('_', 1)
  93. return SQLFuncFactory(name)
  94. return super(CustomField, self).get_transform(lookup_name)
  95. class CustomModel(models.Model):
  96. field = CustomField()
  97. # We will register this class temporarily in the test method.
  98. class InMonth(models.lookups.Lookup):
  99. """
  100. InMonth matches if the column's month is the same as value's month.
  101. """
  102. lookup_name = 'inmonth'
  103. def as_sql(self, qn, connection):
  104. lhs, lhs_params = self.process_lhs(qn, connection)
  105. rhs, rhs_params = self.process_rhs(qn, connection)
  106. # We need to be careful so that we get the params in right
  107. # places.
  108. params = lhs_params + rhs_params + lhs_params + rhs_params
  109. return ("%s >= date_trunc('month', %s) and "
  110. "%s < date_trunc('month', %s) + interval '1 months'" %
  111. (lhs, rhs, lhs, rhs), params)
  112. class LookupTests(TestCase):
  113. def test_basic_lookup(self):
  114. a1 = Author.objects.create(name='a1', age=1)
  115. a2 = Author.objects.create(name='a2', age=2)
  116. a3 = Author.objects.create(name='a3', age=3)
  117. a4 = Author.objects.create(name='a4', age=4)
  118. models.IntegerField.register_lookup(Div3Lookup)
  119. try:
  120. self.assertQuerysetEqual(
  121. Author.objects.filter(age__div3=0),
  122. [a3], lambda x: x
  123. )
  124. self.assertQuerysetEqual(
  125. Author.objects.filter(age__div3=1).order_by('age'),
  126. [a1, a4], lambda x: x
  127. )
  128. self.assertQuerysetEqual(
  129. Author.objects.filter(age__div3=2),
  130. [a2], lambda x: x
  131. )
  132. self.assertQuerysetEqual(
  133. Author.objects.filter(age__div3=3),
  134. [], lambda x: x
  135. )
  136. finally:
  137. models.IntegerField._unregister_lookup(Div3Lookup)
  138. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  139. def test_birthdate_month(self):
  140. a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
  141. a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
  142. a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
  143. a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
  144. models.DateField.register_lookup(InMonth)
  145. try:
  146. self.assertQuerysetEqual(
  147. Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
  148. [a3], lambda x: x
  149. )
  150. self.assertQuerysetEqual(
  151. Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
  152. [a2], lambda x: x
  153. )
  154. self.assertQuerysetEqual(
  155. Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
  156. [a1], lambda x: x
  157. )
  158. self.assertQuerysetEqual(
  159. Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
  160. [a4], lambda x: x
  161. )
  162. self.assertQuerysetEqual(
  163. Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
  164. [], lambda x: x
  165. )
  166. finally:
  167. models.DateField._unregister_lookup(InMonth)
  168. def test_div3_extract(self):
  169. models.IntegerField.register_lookup(Div3Transform)
  170. try:
  171. a1 = Author.objects.create(name='a1', age=1)
  172. a2 = Author.objects.create(name='a2', age=2)
  173. a3 = Author.objects.create(name='a3', age=3)
  174. a4 = Author.objects.create(name='a4', age=4)
  175. baseqs = Author.objects.order_by('name')
  176. self.assertQuerysetEqual(
  177. baseqs.filter(age__div3=2),
  178. [a2], lambda x: x)
  179. self.assertQuerysetEqual(
  180. baseqs.filter(age__div3__lte=3),
  181. [a1, a2, a3, a4], lambda x: x)
  182. self.assertQuerysetEqual(
  183. baseqs.filter(age__div3__in=[0, 2]),
  184. [a2, a3], lambda x: x)
  185. finally:
  186. models.IntegerField._unregister_lookup(Div3Transform)
  187. class YearLteTests(TestCase):
  188. def setUp(self):
  189. models.DateField.register_lookup(YearTransform)
  190. self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
  191. self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
  192. self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
  193. self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
  194. def tearDown(self):
  195. models.DateField._unregister_lookup(YearTransform)
  196. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  197. def test_year_lte(self):
  198. baseqs = Author.objects.order_by('name')
  199. self.assertQuerysetEqual(
  200. baseqs.filter(birthdate__year__lte=2012),
  201. [self.a1, self.a2, self.a3, self.a4], lambda x: x)
  202. self.assertQuerysetEqual(
  203. baseqs.filter(birthdate__year=2012),
  204. [self.a2, self.a3, self.a4], lambda x: x)
  205. self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
  206. self.assertQuerysetEqual(
  207. baseqs.filter(birthdate__year__lte=2011),
  208. [self.a1], lambda x: x)
  209. # The non-optimized version works, too.
  210. self.assertQuerysetEqual(
  211. baseqs.filter(birthdate__year__lt=2012),
  212. [self.a1], lambda x: x)
  213. @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
  214. def test_year_lte_fexpr(self):
  215. self.a2.age = 2011
  216. self.a2.save()
  217. self.a3.age = 2012
  218. self.a3.save()
  219. self.a4.age = 2013
  220. self.a4.save()
  221. baseqs = Author.objects.order_by('name')
  222. self.assertQuerysetEqual(
  223. baseqs.filter(birthdate__year__lte=models.F('age')),
  224. [self.a3, self.a4], lambda x: x)
  225. self.assertQuerysetEqual(
  226. baseqs.filter(birthdate__year__lt=models.F('age')),
  227. [self.a4], lambda x: x)
  228. def test_year_lte_sql(self):
  229. # This test will just check the generated SQL for __lte. This
  230. # doesn't require running on PostgreSQL and spots the most likely
  231. # error - not running YearLte SQL at all.
  232. baseqs = Author.objects.order_by('name')
  233. self.assertIn(
  234. '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
  235. self.assertIn(
  236. '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
  237. def test_postgres_year_exact(self):
  238. baseqs = Author.objects.order_by('name')
  239. self.assertIn(
  240. '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query))
  241. self.assertIn(
  242. '-12-31', str(baseqs.filter(birthdate__year=2011).query))
  243. def test_custom_implementation_year_exact(self):
  244. try:
  245. # Two ways to add a customized implementation for different backends:
  246. # First is MonkeyPatch of the class.
  247. def as_custom_sql(self, qn, connection):
  248. lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
  249. rhs_sql, rhs_params = self.process_rhs(qn, connection)
  250. params = lhs_params + rhs_params + lhs_params + rhs_params
  251. return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  252. "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
  253. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  254. setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
  255. self.assertIn(
  256. 'concat(',
  257. str(Author.objects.filter(birthdate__year=2012).query))
  258. finally:
  259. delattr(YearExact, 'as_' + connection.vendor)
  260. try:
  261. # The other way is to subclass the original lookup and register the subclassed
  262. # lookup instead of the original.
  263. class CustomYearExact(YearExact):
  264. # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
  265. # and so on, but as we don't know which DB we are running on, we need to use
  266. # setattr.
  267. def as_custom_sql(self, qn, connection):
  268. lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
  269. rhs_sql, rhs_params = self.process_rhs(qn, connection)
  270. params = lhs_params + rhs_params + lhs_params + rhs_params
  271. return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
  272. "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
  273. {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
  274. setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
  275. YearTransform.register_lookup(CustomYearExact)
  276. self.assertIn(
  277. 'CONCAT(',
  278. str(Author.objects.filter(birthdate__year=2012).query))
  279. finally:
  280. YearTransform._unregister_lookup(CustomYearExact)
  281. YearTransform.register_lookup(YearExact)
  282. class TrackCallsYearTransform(YearTransform):
  283. lookup_name = 'year'
  284. call_order = []
  285. def as_sql(self, qn, connection):
  286. lhs_sql, params = qn.compile(self.lhs)
  287. return connection.ops.date_extract_sql('year', lhs_sql), params
  288. @property
  289. def output_field(self):
  290. return models.IntegerField()
  291. def get_lookup(self, lookup_name):
  292. self.call_order.append('lookup')
  293. return super(TrackCallsYearTransform, self).get_lookup(lookup_name)
  294. def get_transform(self, lookup_name):
  295. self.call_order.append('transform')
  296. return super(TrackCallsYearTransform, self).get_transform(lookup_name)
  297. class LookupTransformCallOrderTests(TestCase):
  298. def test_call_order(self):
  299. models.DateField.register_lookup(TrackCallsYearTransform)
  300. try:
  301. # junk lookup - tries lookup, then transform, then fails
  302. with self.assertRaises(FieldError):
  303. Author.objects.filter(birthdate__year__junk=2012)
  304. self.assertEqual(TrackCallsYearTransform.call_order,
  305. ['lookup', 'transform'])
  306. TrackCallsYearTransform.call_order = []
  307. # junk transform - tries transform only, then fails
  308. with self.assertRaises(FieldError):
  309. Author.objects.filter(birthdate__year__junk__more_junk=2012)
  310. self.assertEqual(TrackCallsYearTransform.call_order,
  311. ['transform'])
  312. TrackCallsYearTransform.call_order = []
  313. # Just getting the year (implied __exact) - lookup only
  314. Author.objects.filter(birthdate__year=2012)
  315. self.assertEqual(TrackCallsYearTransform.call_order,
  316. ['lookup'])
  317. TrackCallsYearTransform.call_order = []
  318. # Just getting the year (explicit __exact) - lookup only
  319. Author.objects.filter(birthdate__year__exact=2012)
  320. self.assertEqual(TrackCallsYearTransform.call_order,
  321. ['lookup'])
  322. finally:
  323. models.DateField._unregister_lookup(TrackCallsYearTransform)
  324. class CustomisedMethodsTests(TestCase):
  325. def test_overridden_get_lookup(self):
  326. q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
  327. self.assertIn('monkeys()', str(q.query))
  328. def test_overridden_get_transform(self):
  329. q = CustomModel.objects.filter(field__transformfunc_banana=3)
  330. self.assertIn('banana()', str(q.query))
  331. def test_overridden_get_lookup_chain(self):
  332. q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
  333. self.assertIn('elephants()', str(q.query))
  334. def test_overridden_get_transform_chain(self):
  335. q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
  336. self.assertIn('pear()', str(q.query))