123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- from __future__ import unicode_literals
- from datetime import date
- import unittest
- from django.core.exceptions import FieldError
- from django.db import models
- from django.db import connection
- from django.test import TestCase
- from .models import Author
- class Div3Lookup(models.Lookup):
- lookup_name = 'div3'
- def as_sql(self, qn, connection):
- lhs, params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params.extend(rhs_params)
- return '%s %%%% 3 = %s' % (lhs, rhs), params
- def as_oracle(self, qn, connection):
- lhs, params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- params.extend(rhs_params)
- return 'mod(%s, 3) = %s' % (lhs, rhs), params
- class Div3Transform(models.Transform):
- lookup_name = 'div3'
- def as_sql(self, qn, connection):
- lhs, lhs_params = qn.compile(self.lhs)
- return '%s %%%% 3' % (lhs,), lhs_params
- def as_oracle(self, qn, connection):
- lhs, lhs_params = qn.compile(self.lhs)
- return 'mod(%s, 3)' % lhs, lhs_params
- class YearTransform(models.Transform):
- lookup_name = 'year'
- def as_sql(self, qn, connection):
- lhs_sql, params = qn.compile(self.lhs)
- return connection.ops.date_extract_sql('year', lhs_sql), params
- @property
- def output_field(self):
- return models.IntegerField()
- @YearTransform.register_lookup
- class YearExact(models.lookups.Lookup):
- lookup_name = 'exact'
- def as_sql(self, qn, connection):
- # We will need to skip the extract part, and instead go
- # directly with the originating field, that is self.lhs.lhs
- lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
- rhs_sql, rhs_params = self.process_rhs(qn, connection)
- # Note that we must be careful so that we have params in the
- # same order as we have the parts in the SQL.
- params = lhs_params + rhs_params + lhs_params + rhs_params
- # We use PostgreSQL specific SQL here. Note that we must do the
- # conversions in SQL instead of in Python to support F() references.
- return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
- "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
- @YearTransform.register_lookup
- class YearLte(models.lookups.LessThanOrEqual):
- """
- The purpose of this lookup is to efficiently compare the year of the field.
- """
- def as_sql(self, qn, connection):
- # Skip the YearTransform above us (no possibility for efficient
- # lookup otherwise).
- real_lhs = self.lhs.lhs
- lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
- rhs_sql, rhs_params = self.process_rhs(qn, connection)
- params.extend(rhs_params)
- # Build SQL where the integer year is concatenated with last month
- # and day, then convert that to date. (We try to have SQL like:
- # WHERE somecol <= '2013-12-31')
- # but also make it work if the rhs_sql is field reference.
- return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
- class SQLFunc(models.Lookup):
- def __init__(self, name, *args, **kwargs):
- super(SQLFunc, self).__init__(*args, **kwargs)
- self.name = name
- def as_sql(self, qn, connection):
- return '%s()', [self.name]
- @property
- def output_field(self):
- return CustomField()
- class SQLFuncFactory(object):
- def __init__(self, name):
- self.name = name
- def __call__(self, *args, **kwargs):
- return SQLFunc(self.name, *args, **kwargs)
- class CustomField(models.TextField):
- def get_lookup(self, lookup_name):
- if lookup_name.startswith('lookupfunc_'):
- key, name = lookup_name.split('_', 1)
- return SQLFuncFactory(name)
- return super(CustomField, self).get_lookup(lookup_name)
- def get_transform(self, lookup_name):
- if lookup_name.startswith('transformfunc_'):
- key, name = lookup_name.split('_', 1)
- return SQLFuncFactory(name)
- return super(CustomField, self).get_transform(lookup_name)
- class CustomModel(models.Model):
- field = CustomField()
- # We will register this class temporarily in the test method.
- class InMonth(models.lookups.Lookup):
- """
- InMonth matches if the column's month is the same as value's month.
- """
- lookup_name = 'inmonth'
- def as_sql(self, qn, connection):
- lhs, lhs_params = self.process_lhs(qn, connection)
- rhs, rhs_params = self.process_rhs(qn, connection)
- # We need to be careful so that we get the params in right
- # places.
- params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%s >= date_trunc('month', %s) and "
- "%s < date_trunc('month', %s) + interval '1 months'" %
- (lhs, rhs, lhs, rhs), params)
- class LookupTests(TestCase):
- def test_basic_lookup(self):
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
- models.IntegerField.register_lookup(Div3Lookup)
- try:
- self.assertQuerysetEqual(
- Author.objects.filter(age__div3=0),
- [a3], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(age__div3=1).order_by('age'),
- [a1, a4], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(age__div3=2),
- [a2], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(age__div3=3),
- [], lambda x: x
- )
- finally:
- models.IntegerField._unregister_lookup(Div3Lookup)
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
- def test_birthdate_month(self):
- a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
- a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
- a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
- a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
- models.DateField.register_lookup(InMonth)
- try:
- self.assertQuerysetEqual(
- Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
- [a3], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
- [a2], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
- [a1], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
- [a4], lambda x: x
- )
- self.assertQuerysetEqual(
- Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
- [], lambda x: x
- )
- finally:
- models.DateField._unregister_lookup(InMonth)
- def test_div3_extract(self):
- models.IntegerField.register_lookup(Div3Transform)
- try:
- a1 = Author.objects.create(name='a1', age=1)
- a2 = Author.objects.create(name='a2', age=2)
- a3 = Author.objects.create(name='a3', age=3)
- a4 = Author.objects.create(name='a4', age=4)
- baseqs = Author.objects.order_by('name')
- self.assertQuerysetEqual(
- baseqs.filter(age__div3=2),
- [a2], lambda x: x)
- self.assertQuerysetEqual(
- baseqs.filter(age__div3__lte=3),
- [a1, a2, a3, a4], lambda x: x)
- self.assertQuerysetEqual(
- baseqs.filter(age__div3__in=[0, 2]),
- [a2, a3], lambda x: x)
- finally:
- models.IntegerField._unregister_lookup(Div3Transform)
- class YearLteTests(TestCase):
- def setUp(self):
- models.DateField.register_lookup(YearTransform)
- self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
- self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
- self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
- self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
- def tearDown(self):
- models.DateField._unregister_lookup(YearTransform)
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
- def test_year_lte(self):
- baseqs = Author.objects.order_by('name')
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year__lte=2012),
- [self.a1, self.a2, self.a3, self.a4], lambda x: x)
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year=2012),
- [self.a2, self.a3, self.a4], lambda x: x)
- self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year__lte=2011),
- [self.a1], lambda x: x)
- # The non-optimized version works, too.
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year__lt=2012),
- [self.a1], lambda x: x)
- @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
- def test_year_lte_fexpr(self):
- self.a2.age = 2011
- self.a2.save()
- self.a3.age = 2012
- self.a3.save()
- self.a4.age = 2013
- self.a4.save()
- baseqs = Author.objects.order_by('name')
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year__lte=models.F('age')),
- [self.a3, self.a4], lambda x: x)
- self.assertQuerysetEqual(
- baseqs.filter(birthdate__year__lt=models.F('age')),
- [self.a4], lambda x: x)
- def test_year_lte_sql(self):
- # This test will just check the generated SQL for __lte. This
- # doesn't require running on PostgreSQL and spots the most likely
- # error - not running YearLte SQL at all.
- baseqs = Author.objects.order_by('name')
- self.assertIn(
- '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
- self.assertIn(
- '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
- def test_postgres_year_exact(self):
- baseqs = Author.objects.order_by('name')
- self.assertIn(
- '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query))
- self.assertIn(
- '-12-31', str(baseqs.filter(birthdate__year=2011).query))
- def test_custom_implementation_year_exact(self):
- try:
- # Two ways to add a customized implementation for different backends:
- # First is MonkeyPatch of the class.
- def as_custom_sql(self, qn, connection):
- lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
- rhs_sql, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
- "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
- setattr(YearExact, 'as_' + connection.vendor, as_custom_sql)
- self.assertIn(
- 'concat(',
- str(Author.objects.filter(birthdate__year=2012).query))
- finally:
- delattr(YearExact, 'as_' + connection.vendor)
- try:
- # The other way is to subclass the original lookup and register the subclassed
- # lookup instead of the original.
- class CustomYearExact(YearExact):
- # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
- # and so on, but as we don't know which DB we are running on, we need to use
- # setattr.
- def as_custom_sql(self, qn, connection):
- lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
- rhs_sql, rhs_params = self.process_rhs(qn, connection)
- params = lhs_params + rhs_params + lhs_params + rhs_params
- return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
- "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
- {'lhs': lhs_sql, 'rhs': rhs_sql}, params)
- setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql)
- YearTransform.register_lookup(CustomYearExact)
- self.assertIn(
- 'CONCAT(',
- str(Author.objects.filter(birthdate__year=2012).query))
- finally:
- YearTransform._unregister_lookup(CustomYearExact)
- YearTransform.register_lookup(YearExact)
- class TrackCallsYearTransform(YearTransform):
- lookup_name = 'year'
- call_order = []
- def as_sql(self, qn, connection):
- lhs_sql, params = qn.compile(self.lhs)
- return connection.ops.date_extract_sql('year', lhs_sql), params
- @property
- def output_field(self):
- return models.IntegerField()
- def get_lookup(self, lookup_name):
- self.call_order.append('lookup')
- return super(TrackCallsYearTransform, self).get_lookup(lookup_name)
- def get_transform(self, lookup_name):
- self.call_order.append('transform')
- return super(TrackCallsYearTransform, self).get_transform(lookup_name)
- class LookupTransformCallOrderTests(TestCase):
- def test_call_order(self):
- models.DateField.register_lookup(TrackCallsYearTransform)
- try:
- # junk lookup - tries lookup, then transform, then fails
- with self.assertRaises(FieldError):
- Author.objects.filter(birthdate__year__junk=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup', 'transform'])
- TrackCallsYearTransform.call_order = []
- # junk transform - tries transform only, then fails
- with self.assertRaises(FieldError):
- Author.objects.filter(birthdate__year__junk__more_junk=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['transform'])
- TrackCallsYearTransform.call_order = []
- # Just getting the year (implied __exact) - lookup only
- Author.objects.filter(birthdate__year=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup'])
- TrackCallsYearTransform.call_order = []
- # Just getting the year (explicit __exact) - lookup only
- Author.objects.filter(birthdate__year__exact=2012)
- self.assertEqual(TrackCallsYearTransform.call_order,
- ['lookup'])
- finally:
- models.DateField._unregister_lookup(TrackCallsYearTransform)
- class CustomisedMethodsTests(TestCase):
- def test_overridden_get_lookup(self):
- q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
- self.assertIn('monkeys()', str(q.query))
- def test_overridden_get_transform(self):
- q = CustomModel.objects.filter(field__transformfunc_banana=3)
- self.assertIn('banana()', str(q.query))
- def test_overridden_get_lookup_chain(self):
- q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
- self.assertIn('elephants()', str(q.query))
- def test_overridden_get_transform_chain(self):
- q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
- self.assertIn('pear()', str(q.query))
|