Browse Source

Fixed #23493 -- Added bilateral attribute to Transform

Thomas Chaumeny 10 years ago
parent
commit
00aa562884

+ 83 - 20
django/db/models/lookups.py

@@ -1,5 +1,4 @@
 from copy import copy
-from itertools import repeat
 import inspect
 
 from django.conf import settings
@@ -7,6 +6,8 @@ from django.utils import timezone
 from django.utils.functional import cached_property
 from django.utils.six.moves import xrange
 
+from .query_utils import QueryWrapper
+
 
 class RegisterLookupMixin(object):
     def _get_lookup(self, lookup_name):
@@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
 
 
 class Transform(RegisterLookupMixin):
+
+    bilateral = False
+
     def __init__(self, lhs, lookups):
         self.lhs = lhs
         self.init_lookups = lookups[:]
@@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
 class Lookup(RegisterLookupMixin):
     lookup_name = None
 
-    def __init__(self, lhs, rhs):
+    def __init__(self, lhs, rhs, bilateral_transforms=None):
         self.lhs, self.rhs = lhs, rhs
         self.rhs = self.get_prep_lookup()
+        if bilateral_transforms is None:
+            bilateral_transforms = []
+        if bilateral_transforms:
+            # We should warn the user as soon as possible if he is trying to apply
+            # a bilateral transformation on a nested QuerySet: that won't work.
+            # We need to import QuerySet here so as to avoid circular
+            from django.db.models.query import QuerySet
+            if isinstance(rhs, QuerySet):
+                raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
+        self.bilateral_transforms = bilateral_transforms
+
+    def apply_bilateral_transforms(self, value):
+        for transform, lookups in self.bilateral_transforms:
+            value = transform(value, lookups)
+        return value
+
+    def batch_process_rhs(self, qn, connection, rhs=None):
+        if rhs is None:
+            rhs = self.rhs
+        if self.bilateral_transforms:
+            sqls, sqls_params = [], []
+            for p in rhs:
+                value = QueryWrapper('%s',
+                    [self.lhs.output_field.get_db_prep_value(p, connection)])
+                value = self.apply_bilateral_transforms(value)
+                sql, sql_params = qn.compile(value)
+                sqls.append(sql)
+                sqls_params.extend(sql_params)
+        else:
+            params = self.lhs.output_field.get_db_prep_lookup(
+                self.lookup_name, rhs, connection, prepared=True)
+            sqls, sqls_params = ['%s'] * len(params), params
+        return sqls, sqls_params
 
     def get_prep_lookup(self):
         return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
@@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
 
     def process_rhs(self, qn, connection):
         value = self.rhs
+        if self.bilateral_transforms:
+            if self.rhs_is_direct_value():
+                # Do not call get_db_prep_lookup here as the value will be
+                # transformed before being used for lookup
+                value = QueryWrapper("%s",
+                    [self.lhs.output_field.get_db_prep_value(value, connection)])
+            value = self.apply_bilateral_transforms(value)
         # Due to historical reasons there are a couple of different
         # ways to produce sql here. get_compiler is likely a Query
         # instance, _as_sql QuerySet and as_sql just something with
@@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
 class In(BuiltinLookup):
     lookup_name = 'in'
 
-    def get_db_prep_lookup(self, value, connection):
-        params = self.lhs.output_field.get_db_prep_lookup(
-            self.lookup_name, value, connection, prepared=True)
-        if not params:
-            # TODO: check why this leads to circular import
-            from django.db.models.sql.datastructures import EmptyResultSet
-            raise EmptyResultSet
-        placeholder = '(' + ', '.join('%s' for p in params) + ')'
-        return (placeholder, params)
+    def process_rhs(self, qn, connection):
+        if self.rhs_is_direct_value():
+            # rhs should be an iterable, we use batch_process_rhs
+            # to prepare/transform those values
+            rhs = list(self.rhs)
+            if not rhs:
+                from django.db.models.sql.datastructures import EmptyResultSet
+                raise EmptyResultSet
+            sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
+            placeholder = '(' + ', '.join(sqls) + ')'
+            return (placeholder, sqls_params)
+        else:
+            return super(In, self).process_rhs(qn, connection)
 
     def get_rhs_op(self, connection, rhs):
         return 'IN %s' % rhs
@@ -220,8 +268,10 @@ class In(BuiltinLookup):
         max_in_list_size = connection.ops.max_in_list_size()
         if self.rhs_is_direct_value() and (max_in_list_size and
                                            len(self.rhs) > max_in_list_size):
-            rhs, rhs_params = self.process_rhs(qn, connection)
+            # This is a special case for Oracle which limits the number of elements
+            # which can appear in an 'IN' clause.
             lhs, lhs_params = self.process_lhs(qn, connection)
+            rhs, rhs_params = self.batch_process_rhs(qn, connection)
             in_clause_elements = ['(']
             params = []
             for offset in xrange(0, len(rhs_params), max_in_list_size):
@@ -229,11 +279,12 @@ class In(BuiltinLookup):
                     in_clause_elements.append(' OR ')
                 in_clause_elements.append('%s IN (' % lhs)
                 params.extend(lhs_params)
-                group_size = min(len(rhs_params) - offset, max_in_list_size)
-                param_group = ', '.join(repeat('%s', group_size))
+                sqls = rhs[offset: offset + max_in_list_size]
+                sqls_params = rhs_params[offset: offset + max_in_list_size]
+                param_group = ', '.join(sqls)
                 in_clause_elements.append(param_group)
                 in_clause_elements.append(')')
-                params.extend(rhs_params[offset: offset + max_in_list_size])
+                params.extend(sqls_params)
             in_clause_elements.append(')')
             return ''.join(in_clause_elements), params
         else:
@@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
         # we need to add the % pattern match to the lookup by something like
         #     col LIKE othercol || '%%'
         # So, for Python values we don't need any special pattern, but for
-        # SQL reference values we need the correct pattern added.
-        value = self.rhs
-        if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
-                or hasattr(value, '_as_sql')):
+        # SQL reference values or SQL transformations we need the correct
+        # pattern added.
+        if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
+                or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
             return connection.pattern_ops[self.lookup_name] % rhs
         else:
             return super(PatternLookup, self).get_rhs_op(connection, rhs)
@@ -291,8 +342,20 @@ class Year(Between):
 default_lookups['year'] = Year
 
 
-class Range(Between):
+class Range(BuiltinLookup):
     lookup_name = 'range'
+
+    def get_rhs_op(self, connection, rhs):
+        return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
+
+    def process_rhs(self, qn, connection):
+        if self.rhs_is_direct_value():
+            # rhs should be an iterable of 2 values, we use batch_process_rhs
+            # to prepare/transform those values
+            return self.batch_process_rhs(qn, connection)
+        else:
+            return super(Range, self).process_rhs(qn, connection)
+
 default_lookups['range'] = Range
 
 

+ 4 - 1
django/db/models/sql/query.py

@@ -1111,18 +1111,21 @@ class Query(object):
 
     def build_lookup(self, lookups, lhs, rhs):
         lookups = lookups[:]
+        bilaterals = []
         while lookups:
             lookup = lookups[0]
             if len(lookups) == 1:
                 final_lookup = lhs.get_lookup(lookup)
                 if final_lookup:
-                    return final_lookup(lhs, rhs)
+                    return final_lookup(lhs, rhs, bilaterals)
                 # We didn't find a lookup, so we are going to try get_transform
                 # + get_lookup('exact').
                 lookups.append('exact')
             next = lhs.get_transform(lookup)
             if next:
                 lhs = next(lhs, lookups)
+                if getattr(next, 'bilateral', False):
+                    bilaterals.append((next, lookups))
             else:
                 raise FieldError(
                     "Unsupported lookup '%s' for %s or join on the field not "

+ 43 - 5
docs/howto/custom-lookups.txt

@@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
           lhs, params = qn.compile(self.lhs)
           return "ABS(%s)" % lhs, params
 
-Next, lets register it for ``IntegerField``::
+Next, let's register it for ``IntegerField``::
 
   from django.db.models import IntegerField
   IntegerField.register_lookup(AbsoluteValue)
@@ -144,9 +144,7 @@ SQL::
 
     SELECT ... WHERE ABS("experiments"."change") < 27
 
-Subclasses of ``Transform`` usually only operate on the left-hand side of the
-expression. Further lookups will work on the transformed value. Note that in
-this case where there is no other lookup specified, Django interprets
+Note that in case there is no other lookup specified, Django interprets
 ``change__abs=27`` as ``change__abs__exact=27``.
 
 When looking for which lookups are allowable after the ``Transform`` has been
@@ -197,7 +195,7 @@ Notice also that  as both sides are used multiple times in the query the params
 need to contain ``lhs_params`` and ``rhs_params`` multiple times.
 
 The final query does the inversion (``27`` to ``-27``) directly in the
-database. The reason for doing this is that if the self.rhs is something else
+database. The reason for doing this is that if the ``self.rhs`` is something else
 than a plain integer value (for example an ``F()`` reference) we can't do the
 transformations in Python.
 
@@ -208,6 +206,46 @@ transformations in Python.
     want to add an index on ``abs(change)`` which would allow these queries to
     be very efficient.
 
+A bilateral transformer example
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The ``AbsoluteValue`` example we discussed previously is a transformation which
+applies to the left-hand side of the lookup. There may be some cases where you
+want the transformation to be applied to both the left-hand side and the
+right-hand side. For instance, if you want to filter a queryset based on the
+equality of the left and right-hand side insensitively to some SQL function.
+
+Let's examine the simple example of case-insensitive transformation here. This
+transformation isn't very useful in practice as Django already comes with a bunch
+of built-in case-insensitive lookups, but it will be a nice demonstration of
+bilateral transformations in a database-agnostic way.
+
+We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
+transform the values before comparison. We define
+:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
+this transformation should apply to both ``lhs`` and ``rhs``::
+
+  from django.db.models import Transform
+
+  class UpperCase(Transform):
+      lookup_name = 'upper'
+      bilateral = True
+
+      def as_sql(self, qn, connection):
+          lhs, params = qn.compile(self.lhs)
+          return "UPPER(%s)" % lhs, params
+
+Next, let's register it::
+
+  from django.db.models import CharField, TextField
+  CharField.register_lookup(UpperCase)
+  TextField.register_lookup(UpperCase)
+
+Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
+insensitive query like this::
+
+    SELECT ... WHERE UPPER("author"."name") = UPPER('doe')
+
 Writing alternative implementations for existing lookups
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 9 - 0
docs/ref/models/lookups.txt

@@ -129,6 +129,15 @@ Transform reference
     This class follows the :ref:`Query Expression API <query-expression>`, which
     implies that you can use ``<expression>__<transform1>__<transform2>``.
 
+    .. attribute:: bilateral
+
+        .. versionadded:: 1.8
+
+        A boolean indicating whether this transformation should apply to both
+        ``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
+        the same order as they appear in the lookup expression. By default it is set
+        to ``False``. For example usage, see :doc:`/howto/custom-lookups`.
+
     .. attribute:: lhs
 
         The left-hand side - what is being transformed. It must follow the

+ 5 - 0
docs/releases/1.8.txt

@@ -306,6 +306,11 @@ Models
 * :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
   a decorator pattern.
 
+* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
+  attribute allows creating bilateral transformations. These transformations
+  are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
+  providing opportunities for more sophisticated lookups.
+
 Signals
 ^^^^^^^
 

+ 124 - 2
tests/custom_lookups/tests.py

@@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup):
         lhs, params = self.process_lhs(qn, connection)
         rhs, rhs_params = self.process_rhs(qn, connection)
         params.extend(rhs_params)
-        return '%s %%%% 3 = %s' % (lhs, rhs), params
+        return '(%s) %%%% 3 = %s' % (lhs, rhs), params
 
     def as_oracle(self, qn, connection):
         lhs, params = self.process_lhs(qn, connection)
@@ -31,12 +31,32 @@ class Div3Transform(models.Transform):
 
     def as_sql(self, qn, connection):
         lhs, lhs_params = qn.compile(self.lhs)
-        return '%s %%%% 3' % (lhs,), lhs_params
+        return '(%s) %%%% 3' % lhs, lhs_params
 
     def as_oracle(self, qn, connection):
         lhs, lhs_params = qn.compile(self.lhs)
         return 'mod(%s, 3)' % lhs, lhs_params
 
+class Div3BilateralTransform(Div3Transform):
+    bilateral = True
+
+
+class Mult3BilateralTransform(models.Transform):
+    bilateral = True
+    lookup_name = 'mult3'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = qn.compile(self.lhs)
+        return '3 * (%s)' % lhs, lhs_params
+
+class UpperBilateralTransform(models.Transform):
+    bilateral = True
+    lookup_name = 'upper'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = qn.compile(self.lhs)
+        return 'UPPER(%s)' % lhs, lhs_params
+
 
 class YearTransform(models.Transform):
     lookup_name = 'year'
@@ -225,10 +245,112 @@ class LookupTests(TestCase):
             self.assertQuerysetEqual(
                 baseqs.filter(age__div3__in=[0, 2]),
                 [a2, a3], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__in=[2, 4]),
+                [a2], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__gte=3),
+                [], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__range=(1, 2)),
+                [a1, a2, a4], lambda x: x)
         finally:
             models.IntegerField._unregister_lookup(Div3Transform)
 
 
+class BilateralTransformTests(TestCase):
+
+    def test_bilateral_upper(self):
+        models.CharField.register_lookup(UpperBilateralTransform)
+        try:
+            Author.objects.bulk_create([
+                Author(name='Doe'),
+                Author(name='doe'),
+                Author(name='Foo'),
+            ])
+            self.assertQuerysetEqual(
+                Author.objects.filter(name__upper='doe'),
+                ["<Author: Doe>", "<Author: doe>"], ordered=False)
+        finally:
+            models.CharField._unregister_lookup(UpperBilateralTransform)
+
+    def test_bilateral_inner_qs(self):
+        models.CharField.register_lookup(UpperBilateralTransform)
+        try:
+            with self.assertRaises(NotImplementedError):
+                Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
+        finally:
+            models.CharField._unregister_lookup(UpperBilateralTransform)
+
+    def test_div3_bilateral_extract(self):
+        models.IntegerField.register_lookup(Div3BilateralTransform)
+        try:
+            a1 = Author.objects.create(name='a1', age=1)
+            a2 = Author.objects.create(name='a2', age=2)
+            a3 = Author.objects.create(name='a3', age=3)
+            a4 = Author.objects.create(name='a4', age=4)
+            baseqs = Author.objects.order_by('name')
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3=2),
+                [a2], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__lte=3),
+                [a3], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__in=[0, 2]),
+                [a2, a3], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__in=[2, 4]),
+                [a1, a2, a4], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__gte=3),
+                [a1, a2, a3, a4], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__range=(1, 2)),
+                [a1, a2, a4], lambda x: x)
+        finally:
+            models.IntegerField._unregister_lookup(Div3BilateralTransform)
+
+    def test_bilateral_order(self):
+        models.IntegerField.register_lookup(Mult3BilateralTransform)
+        models.IntegerField.register_lookup(Div3BilateralTransform)
+        try:
+            a1 = Author.objects.create(name='a1', age=1)
+            a2 = Author.objects.create(name='a2', age=2)
+            a3 = Author.objects.create(name='a3', age=3)
+            a4 = Author.objects.create(name='a4', age=4)
+            baseqs = Author.objects.order_by('name')
+
+            self.assertQuerysetEqual(
+                baseqs.filter(age__mult3__div3=42),
+                # mult3__div3 always leads to 0
+                [a1, a2, a3, a4], lambda x: x)
+            self.assertQuerysetEqual(
+                baseqs.filter(age__div3__mult3=42),
+                [a3], lambda x: x)
+        finally:
+            models.IntegerField._unregister_lookup(Mult3BilateralTransform)
+            models.IntegerField._unregister_lookup(Div3BilateralTransform)
+
+    def test_bilateral_fexpr(self):
+        models.IntegerField.register_lookup(Mult3BilateralTransform)
+        try:
+            a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
+            a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
+            a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
+            a4 = Author.objects.create(name='a4', age=4)
+            baseqs = Author.objects.order_by('name')
+            self.assertQuerysetEqual(
+                baseqs.filter(age__mult3=models.F('age')),
+                [a1, a2, a3, a4], lambda x: x)
+            self.assertQuerysetEqual(
+                # Same as age >= average_rating
+                baseqs.filter(age__mult3__gte=models.F('average_rating')),
+                [a2, a3], lambda x: x)
+        finally:
+            models.IntegerField._unregister_lookup(Mult3BilateralTransform)
+
+
 class YearLteTests(TestCase):
     def setUp(self):
         models.DateField.register_lookup(YearTransform)