|
@@ -1,5 +1,4 @@
|
|
|
from copy import copy
|
|
|
-from itertools import repeat
|
|
|
import inspect
|
|
|
|
|
|
from django.conf import settings
|
|
@@ -7,6 +6,8 @@ from django.utils import timezone
|
|
|
from django.utils.functional import cached_property
|
|
|
from django.utils.six.moves import xrange
|
|
|
|
|
|
+from .query_utils import QueryWrapper
|
|
|
+
|
|
|
|
|
|
class RegisterLookupMixin(object):
|
|
|
def _get_lookup(self, lookup_name):
|
|
@@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
|
|
|
|
|
|
|
|
|
class Transform(RegisterLookupMixin):
|
|
|
+
|
|
|
+ bilateral = False
|
|
|
+
|
|
|
def __init__(self, lhs, lookups):
|
|
|
self.lhs = lhs
|
|
|
self.init_lookups = lookups[:]
|
|
@@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
|
|
|
class Lookup(RegisterLookupMixin):
|
|
|
lookup_name = None
|
|
|
|
|
|
- def __init__(self, lhs, rhs):
|
|
|
+ def __init__(self, lhs, rhs, bilateral_transforms=None):
|
|
|
self.lhs, self.rhs = lhs, rhs
|
|
|
self.rhs = self.get_prep_lookup()
|
|
|
+ if bilateral_transforms is None:
|
|
|
+ bilateral_transforms = []
|
|
|
+ if bilateral_transforms:
|
|
|
+ # We should warn the user as soon as possible if he is trying to apply
|
|
|
+ # a bilateral transformation on a nested QuerySet: that won't work.
|
|
|
+ # We need to import QuerySet here so as to avoid circular
|
|
|
+ from django.db.models.query import QuerySet
|
|
|
+ if isinstance(rhs, QuerySet):
|
|
|
+ raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
|
|
|
+ self.bilateral_transforms = bilateral_transforms
|
|
|
+
|
|
|
+ def apply_bilateral_transforms(self, value):
|
|
|
+ for transform, lookups in self.bilateral_transforms:
|
|
|
+ value = transform(value, lookups)
|
|
|
+ return value
|
|
|
+
|
|
|
+ def batch_process_rhs(self, qn, connection, rhs=None):
|
|
|
+ if rhs is None:
|
|
|
+ rhs = self.rhs
|
|
|
+ if self.bilateral_transforms:
|
|
|
+ sqls, sqls_params = [], []
|
|
|
+ for p in rhs:
|
|
|
+ value = QueryWrapper('%s',
|
|
|
+ [self.lhs.output_field.get_db_prep_value(p, connection)])
|
|
|
+ value = self.apply_bilateral_transforms(value)
|
|
|
+ sql, sql_params = qn.compile(value)
|
|
|
+ sqls.append(sql)
|
|
|
+ sqls_params.extend(sql_params)
|
|
|
+ else:
|
|
|
+ params = self.lhs.output_field.get_db_prep_lookup(
|
|
|
+ self.lookup_name, rhs, connection, prepared=True)
|
|
|
+ sqls, sqls_params = ['%s'] * len(params), params
|
|
|
+ return sqls, sqls_params
|
|
|
|
|
|
def get_prep_lookup(self):
|
|
|
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
|
|
@@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
|
|
|
|
|
|
def process_rhs(self, qn, connection):
|
|
|
value = self.rhs
|
|
|
+ if self.bilateral_transforms:
|
|
|
+ if self.rhs_is_direct_value():
|
|
|
+ # Do not call get_db_prep_lookup here as the value will be
|
|
|
+ # transformed before being used for lookup
|
|
|
+ value = QueryWrapper("%s",
|
|
|
+ [self.lhs.output_field.get_db_prep_value(value, connection)])
|
|
|
+ value = self.apply_bilateral_transforms(value)
|
|
|
# Due to historical reasons there are a couple of different
|
|
|
# ways to produce sql here. get_compiler is likely a Query
|
|
|
# instance, _as_sql QuerySet and as_sql just something with
|
|
@@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
|
|
|
class In(BuiltinLookup):
|
|
|
lookup_name = 'in'
|
|
|
|
|
|
- def get_db_prep_lookup(self, value, connection):
|
|
|
- params = self.lhs.output_field.get_db_prep_lookup(
|
|
|
- self.lookup_name, value, connection, prepared=True)
|
|
|
- if not params:
|
|
|
- # TODO: check why this leads to circular import
|
|
|
- from django.db.models.sql.datastructures import EmptyResultSet
|
|
|
- raise EmptyResultSet
|
|
|
- placeholder = '(' + ', '.join('%s' for p in params) + ')'
|
|
|
- return (placeholder, params)
|
|
|
+ def process_rhs(self, qn, connection):
|
|
|
+ if self.rhs_is_direct_value():
|
|
|
+ # rhs should be an iterable, we use batch_process_rhs
|
|
|
+ # to prepare/transform those values
|
|
|
+ rhs = list(self.rhs)
|
|
|
+ if not rhs:
|
|
|
+ from django.db.models.sql.datastructures import EmptyResultSet
|
|
|
+ raise EmptyResultSet
|
|
|
+ sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
|
|
|
+ placeholder = '(' + ', '.join(sqls) + ')'
|
|
|
+ return (placeholder, sqls_params)
|
|
|
+ else:
|
|
|
+ return super(In, self).process_rhs(qn, connection)
|
|
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
|
return 'IN %s' % rhs
|
|
@@ -220,8 +268,10 @@ class In(BuiltinLookup):
|
|
|
max_in_list_size = connection.ops.max_in_list_size()
|
|
|
if self.rhs_is_direct_value() and (max_in_list_size and
|
|
|
len(self.rhs) > max_in_list_size):
|
|
|
- rhs, rhs_params = self.process_rhs(qn, connection)
|
|
|
+ # This is a special case for Oracle which limits the number of elements
|
|
|
+ # which can appear in an 'IN' clause.
|
|
|
lhs, lhs_params = self.process_lhs(qn, connection)
|
|
|
+ rhs, rhs_params = self.batch_process_rhs(qn, connection)
|
|
|
in_clause_elements = ['(']
|
|
|
params = []
|
|
|
for offset in xrange(0, len(rhs_params), max_in_list_size):
|
|
@@ -229,11 +279,12 @@ class In(BuiltinLookup):
|
|
|
in_clause_elements.append(' OR ')
|
|
|
in_clause_elements.append('%s IN (' % lhs)
|
|
|
params.extend(lhs_params)
|
|
|
- group_size = min(len(rhs_params) - offset, max_in_list_size)
|
|
|
- param_group = ', '.join(repeat('%s', group_size))
|
|
|
+ sqls = rhs[offset: offset + max_in_list_size]
|
|
|
+ sqls_params = rhs_params[offset: offset + max_in_list_size]
|
|
|
+ param_group = ', '.join(sqls)
|
|
|
in_clause_elements.append(param_group)
|
|
|
in_clause_elements.append(')')
|
|
|
- params.extend(rhs_params[offset: offset + max_in_list_size])
|
|
|
+ params.extend(sqls_params)
|
|
|
in_clause_elements.append(')')
|
|
|
return ''.join(in_clause_elements), params
|
|
|
else:
|
|
@@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
|
|
|
# we need to add the % pattern match to the lookup by something like
|
|
|
# col LIKE othercol || '%%'
|
|
|
# So, for Python values we don't need any special pattern, but for
|
|
|
- # SQL reference values we need the correct pattern added.
|
|
|
- value = self.rhs
|
|
|
- if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
|
|
|
- or hasattr(value, '_as_sql')):
|
|
|
+ # SQL reference values or SQL transformations we need the correct
|
|
|
+ # pattern added.
|
|
|
+ if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
|
|
|
+ or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
|
|
|
return connection.pattern_ops[self.lookup_name] % rhs
|
|
|
else:
|
|
|
return super(PatternLookup, self).get_rhs_op(connection, rhs)
|
|
@@ -291,8 +342,20 @@ class Year(Between):
|
|
|
default_lookups['year'] = Year
|
|
|
|
|
|
|
|
|
-class Range(Between):
|
|
|
+class Range(BuiltinLookup):
|
|
|
lookup_name = 'range'
|
|
|
+
|
|
|
+ def get_rhs_op(self, connection, rhs):
|
|
|
+ return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
|
|
+
|
|
|
+ def process_rhs(self, qn, connection):
|
|
|
+ if self.rhs_is_direct_value():
|
|
|
+ # rhs should be an iterable of 2 values, we use batch_process_rhs
|
|
|
+ # to prepare/transform those values
|
|
|
+ return self.batch_process_rhs(qn, connection)
|
|
|
+ else:
|
|
|
+ return super(Range, self).process_rhs(qn, connection)
|
|
|
+
|
|
|
default_lookups['range'] = Range
|
|
|
|
|
|
|