Browse Source

Fixed #24629 -- Unified Transform and Expression APIs

Josh Smeaton 9 years ago
parent
commit
534aaf56f4

+ 2 - 2
django/contrib/postgres/fields/hstore.py

@@ -81,14 +81,14 @@ class KeyTransformFactory(object):
 
 
 @HStoreField.register_lookup
-class KeysTransform(lookups.FunctionTransform):
+class KeysTransform(Transform):
     lookup_name = 'keys'
     function = 'akeys'
     output_field = ArrayField(TextField())
 
 
 @HStoreField.register_lookup
-class ValuesTransform(lookups.FunctionTransform):
+class ValuesTransform(Transform):
     lookup_name = 'values'
     function = 'avals'
     output_field = ArrayField(TextField())

+ 3 - 3
django/contrib/postgres/fields/ranges.py

@@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup):
 
 
 @RangeField.register_lookup
-class RangeStartsWith(lookups.FunctionTransform):
+class RangeStartsWith(models.Transform):
     lookup_name = 'startswith'
     function = 'lower'
 
@@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform):
 
 
 @RangeField.register_lookup
-class RangeEndsWith(lookups.FunctionTransform):
+class RangeEndsWith(models.Transform):
     lookup_name = 'endswith'
     function = 'upper'
 
@@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform):
 
 
 @RangeField.register_lookup
-class IsEmpty(lookups.FunctionTransform):
+class IsEmpty(models.Transform):
     lookup_name = 'isempty'
     function = 'isempty'
     output_field = models.BooleanField()

+ 1 - 7
django/contrib/postgres/lookups.py

@@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup):
         return '%s %s %s' % (lhs, self.operator, rhs), params
 
 
-class FunctionTransform(Transform):
-    def as_sql(self, qn, connection):
-        lhs, params = qn.compile(self.lhs)
-        return "%s(%s)" % (self.function, lhs), params
-
-
 class DataContains(PostgresSimpleLookup):
     lookup_name = 'contains'
     operator = '@>'
@@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup):
     operator = '?|'
 
 
-class Unaccent(FunctionTransform):
+class Unaccent(Transform):
     bilateral = True
     lookup_name = 'unaccent'
     function = 'UNACCENT'

+ 1 - 164
django/db/models/fields/__init__.py

@@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators
 # purposes.
 from django.core.exceptions import FieldDoesNotExist  # NOQA
 from django.db import connection, connections, router
-from django.db.models.lookups import (
-    Lookup, RegisterLookupMixin, Transform, default_lookups,
-)
-from django.db.models.query_utils import QueryWrapper
+from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin
 from django.utils import six, timezone
 from django.utils.datastructures import DictWrapper
 from django.utils.dateparse import (
@@ -120,7 +117,6 @@ class Field(RegisterLookupMixin):
         'unique_for_date': _("%(field_label)s must be unique for "
                              "%(date_field_label)s %(lookup_type)s."),
     }
-    class_lookups = default_lookups.copy()
     system_check_deprecated_details = None
     system_check_removed_details = None
 
@@ -1492,22 +1488,6 @@ class DateTimeField(DateField):
         return super(DateTimeField, self).formfield(**defaults)
 
 
-@DateTimeField.register_lookup
-class DateTimeDateTransform(Transform):
-    lookup_name = 'date'
-
-    @cached_property
-    def output_field(self):
-        return DateField()
-
-    def as_sql(self, compiler, connection):
-        lhs, lhs_params = compiler.compile(self.lhs)
-        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
-        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
-        lhs_params.extend(tz_params)
-        return sql, lhs_params
-
-
 class DecimalField(Field):
     empty_strings_allowed = False
     default_error_messages = {
@@ -2450,146 +2430,3 @@ class UUIDField(Field):
         }
         defaults.update(kwargs)
         return super(UUIDField, self).formfield(**defaults)
-
-
-class DateTransform(Transform):
-    def as_sql(self, compiler, connection):
-        sql, params = compiler.compile(self.lhs)
-        lhs_output_field = self.lhs.output_field
-        if isinstance(lhs_output_field, DateTimeField):
-            tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
-            sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
-            params.extend(tz_params)
-        elif isinstance(lhs_output_field, DateField):
-            sql = connection.ops.date_extract_sql(self.lookup_name, sql)
-        elif isinstance(lhs_output_field, TimeField):
-            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
-        else:
-            raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
-        return sql, params
-
-    @cached_property
-    def output_field(self):
-        return IntegerField()
-
-
-class YearTransform(DateTransform):
-    lookup_name = 'year'
-
-
-class YearLookup(Lookup):
-    def year_lookup_bounds(self, connection, year):
-        output_field = self.lhs.lhs.output_field
-        if isinstance(output_field, DateTimeField):
-            bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
-        else:
-            bounds = connection.ops.year_lookup_bounds_for_date_field(year)
-        return bounds
-
-
-@YearTransform.register_lookup
-class YearExact(YearLookup):
-    lookup_name = 'exact'
-
-    def as_sql(self, compiler, connection):
-        # We will need to skip the extract part and instead go
-        # directly with the originating field, that is self.lhs.lhs.
-        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
-        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
-        bounds = self.year_lookup_bounds(connection, rhs_params[0])
-        params.extend(bounds)
-        return '%s BETWEEN %%s AND %%s' % lhs_sql, params
-
-
-class YearComparisonLookup(YearLookup):
-    def as_sql(self, compiler, connection):
-        # We will need to skip the extract part and instead go
-        # directly with the originating field, that is self.lhs.lhs.
-        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
-        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
-        rhs_sql = self.get_rhs_op(connection, rhs_sql)
-        start, finish = self.year_lookup_bounds(connection, rhs_params[0])
-        params.append(self.get_bound(start, finish))
-        return '%s %s' % (lhs_sql, rhs_sql), params
-
-    def get_rhs_op(self, connection, rhs):
-        return connection.operators[self.lookup_name] % rhs
-
-    def get_bound(self):
-        raise NotImplementedError(
-            'subclasses of YearComparisonLookup must provide a get_bound() method'
-        )
-
-
-@YearTransform.register_lookup
-class YearGt(YearComparisonLookup):
-    lookup_name = 'gt'
-
-    def get_bound(self, start, finish):
-        return finish
-
-
-@YearTransform.register_lookup
-class YearGte(YearComparisonLookup):
-    lookup_name = 'gte'
-
-    def get_bound(self, start, finish):
-        return start
-
-
-@YearTransform.register_lookup
-class YearLt(YearComparisonLookup):
-    lookup_name = 'lt'
-
-    def get_bound(self, start, finish):
-        return start
-
-
-@YearTransform.register_lookup
-class YearLte(YearComparisonLookup):
-    lookup_name = 'lte'
-
-    def get_bound(self, start, finish):
-        return finish
-
-
-class MonthTransform(DateTransform):
-    lookup_name = 'month'
-
-
-class DayTransform(DateTransform):
-    lookup_name = 'day'
-
-
-class WeekDayTransform(DateTransform):
-    lookup_name = 'week_day'
-
-
-class HourTransform(DateTransform):
-    lookup_name = 'hour'
-
-
-class MinuteTransform(DateTransform):
-    lookup_name = 'minute'
-
-
-class SecondTransform(DateTransform):
-    lookup_name = 'second'
-
-
-DateField.register_lookup(YearTransform)
-DateField.register_lookup(MonthTransform)
-DateField.register_lookup(DayTransform)
-DateField.register_lookup(WeekDayTransform)
-
-TimeField.register_lookup(HourTransform)
-TimeField.register_lookup(MinuteTransform)
-TimeField.register_lookup(SecondTransform)
-
-DateTimeField.register_lookup(YearTransform)
-DateTimeField.register_lookup(MonthTransform)
-DateTimeField.register_lookup(DayTransform)
-DateTimeField.register_lookup(WeekDayTransform)
-DateTimeField.register_lookup(HourTransform)
-DateTimeField.register_lookup(MinuteTransform)
-DateTimeField.register_lookup(SecondTransform)

+ 9 - 5
django/db/models/functions.py

@@ -1,8 +1,9 @@
 """
 Classes that represent database functions.
 """
-from django.db.models import DateTimeField, IntegerField
-from django.db.models.expressions import Func, Value
+from django.db.models import (
+    DateTimeField, Func, IntegerField, Transform, Value,
+)
 
 
 class Coalesce(Func):
@@ -123,9 +124,10 @@ class Least(Func):
         return super(Least, self).as_sql(compiler, connection, function='MIN')
 
 
-class Length(Func):
+class Length(Transform):
     """Returns the number of characters in the expression"""
     function = 'LENGTH'
+    lookup_name = 'length'
 
     def __init__(self, expression, **extra):
         output_field = extra.pop('output_field', IntegerField())
@@ -136,8 +138,9 @@ class Length(Func):
         return super(Length, self).as_sql(compiler, connection)
 
 
-class Lower(Func):
+class Lower(Transform):
     function = 'LOWER'
+    lookup_name = 'lower'
 
     def __init__(self, expression, **extra):
         super(Lower, self).__init__(expression, **extra)
@@ -188,8 +191,9 @@ class Substr(Func):
         return super(Substr, self).as_sql(compiler, connection)
 
 
-class Upper(Func):
+class Upper(Transform):
     function = 'UPPER'
+    lookup_name = 'upper'
 
     def __init__(self, expression, **extra):
         super(Upper, self).__init__(expression, **extra)

+ 240 - 157
django/db/models/lookups.py

@@ -1,101 +1,17 @@
-import inspect
 from copy import copy
 
+from django.conf import settings
+from django.db.models.expressions import Func, Value
+from django.db.models.fields import (
+    DateField, DateTimeField, Field, IntegerField, TimeField,
+)
+from django.db.models.query_utils import RegisterLookupMixin
+from django.utils import timezone
 from django.utils.functional import cached_property
 from django.utils.six.moves import range
 
-from .query_utils import QueryWrapper
-
-
-class RegisterLookupMixin(object):
-    def _get_lookup(self, lookup_name):
-        try:
-            return self.class_lookups[lookup_name]
-        except KeyError:
-            # To allow for inheritance, check parent class' class_lookups.
-            for parent in inspect.getmro(self.__class__):
-                if 'class_lookups' not in parent.__dict__:
-                    continue
-                if lookup_name in parent.class_lookups:
-                    return parent.class_lookups[lookup_name]
-        except AttributeError:
-            # This class didn't have any class_lookups
-            pass
-        return None
-
-    def get_lookup(self, lookup_name):
-        found = self._get_lookup(lookup_name)
-        if found is None and hasattr(self, 'output_field'):
-            return self.output_field.get_lookup(lookup_name)
-        if found is not None and not issubclass(found, Lookup):
-            return None
-        return found
-
-    def get_transform(self, lookup_name):
-        found = self._get_lookup(lookup_name)
-        if found is None and hasattr(self, 'output_field'):
-            return self.output_field.get_transform(lookup_name)
-        if found is not None and not issubclass(found, Transform):
-            return None
-        return found
-
-    @classmethod
-    def register_lookup(cls, lookup):
-        if 'class_lookups' not in cls.__dict__:
-            cls.class_lookups = {}
-        cls.class_lookups[lookup.lookup_name] = lookup
-        return lookup
-
-    @classmethod
-    def _unregister_lookup(cls, lookup):
-        """
-        Removes given lookup from cls lookups. Meant to be used in
-        tests only.
-        """
-        del cls.class_lookups[lookup.lookup_name]
-
-
-class Transform(RegisterLookupMixin):
 
-    bilateral = False
-
-    def __init__(self, lhs, lookups):
-        self.lhs = lhs
-        self.init_lookups = lookups[:]
-
-    def as_sql(self, compiler, connection):
-        raise NotImplementedError
-
-    @cached_property
-    def output_field(self):
-        return self.lhs.output_field
-
-    def copy(self):
-        return copy(self)
-
-    def relabeled_clone(self, relabels):
-        copy = self.copy()
-        copy.lhs = self.lhs.relabeled_clone(relabels)
-        return copy
-
-    def get_group_by_cols(self):
-        return self.lhs.get_group_by_cols()
-
-    def get_bilateral_transforms(self):
-        if hasattr(self.lhs, 'get_bilateral_transforms'):
-            bilateral_transforms = self.lhs.get_bilateral_transforms()
-        else:
-            bilateral_transforms = []
-        if self.bilateral:
-            bilateral_transforms.append((self.__class__, self.init_lookups))
-        return bilateral_transforms
-
-    @cached_property
-    def contains_aggregate(self):
-        return self.lhs.contains_aggregate
-
-
-class Lookup(RegisterLookupMixin):
+class Lookup(object):
     lookup_name = None
 
     def __init__(self, lhs, rhs):
@@ -115,8 +31,8 @@ class Lookup(RegisterLookupMixin):
         self.bilateral_transforms = bilateral_transforms
 
     def apply_bilateral_transforms(self, value):
-        for transform, lookups in self.bilateral_transforms:
-            value = transform(value, lookups)
+        for transform in self.bilateral_transforms:
+            value = transform(value)
         return value
 
     def batch_process_rhs(self, compiler, connection, rhs=None):
@@ -125,9 +41,9 @@ class Lookup(RegisterLookupMixin):
         if self.bilateral_transforms:
             sqls, sqls_params = [], []
             for p in rhs:
-                value = QueryWrapper('%s',
-                    [self.lhs.output_field.get_db_prep_value(p, connection)])
+                value = Value(p, output_field=self.lhs.output_field)
                 value = self.apply_bilateral_transforms(value)
+                value = value.resolve_expression(compiler.query)
                 sql, sql_params = compiler.compile(value)
                 sqls.append(sql)
                 sqls_params.extend(sql_params)
@@ -155,9 +71,9 @@ class Lookup(RegisterLookupMixin):
             if self.rhs_is_direct_value():
                 # Do not call get_db_prep_lookup here as the value will be
                 # transformed before being used for lookup
-                value = QueryWrapper("%s",
-                    [self.lhs.output_field.get_db_prep_value(value, connection)])
+                value = Value(value, output_field=self.lhs.output_field)
             value = self.apply_bilateral_transforms(value)
+            value = value.resolve_expression(compiler.query)
         # Due to historical reasons there are a couple of different
         # ways to produce sql here. get_compiler is likely a Query
         # instance, _as_sql QuerySet and as_sql just something with
@@ -201,6 +117,31 @@ class Lookup(RegisterLookupMixin):
         return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
 
 
+class Transform(RegisterLookupMixin, Func):
+    """
+    RegisterLookupMixin() is first so that get_lookup() and get_transform()
+    first examine self and then check output_field.
+    """
+    bilateral = False
+
+    def __init__(self, expression, **extra):
+        # Restrict Transform to allow only a single expression.
+        super(Transform, self).__init__(expression, **extra)
+
+    @property
+    def lhs(self):
+        return self.get_source_expressions()[0]
+
+    def get_bilateral_transforms(self):
+        if hasattr(self.lhs, 'get_bilateral_transforms'):
+            bilateral_transforms = self.lhs.get_bilateral_transforms()
+        else:
+            bilateral_transforms = []
+        if self.bilateral:
+            bilateral_transforms.append(self.__class__)
+        return bilateral_transforms
+
+
 class BuiltinLookup(Lookup):
     def process_lhs(self, compiler, connection, lhs=None):
         lhs_sql, params = super(BuiltinLookup, self).process_lhs(
@@ -223,12 +164,9 @@ class BuiltinLookup(Lookup):
         return connection.operators[self.lookup_name] % rhs
 
 
-default_lookups = {}
-
-
 class Exact(BuiltinLookup):
     lookup_name = 'exact'
-default_lookups['exact'] = Exact
+Field.register_lookup(Exact)
 
 
 class IExact(BuiltinLookup):
@@ -241,27 +179,27 @@ class IExact(BuiltinLookup):
         return rhs, params
 
 
-default_lookups['iexact'] = IExact
+Field.register_lookup(IExact)
 
 
 class GreaterThan(BuiltinLookup):
     lookup_name = 'gt'
-default_lookups['gt'] = GreaterThan
+Field.register_lookup(GreaterThan)
 
 
 class GreaterThanOrEqual(BuiltinLookup):
     lookup_name = 'gte'
-default_lookups['gte'] = GreaterThanOrEqual
+Field.register_lookup(GreaterThanOrEqual)
 
 
 class LessThan(BuiltinLookup):
     lookup_name = 'lt'
-default_lookups['lt'] = LessThan
+Field.register_lookup(LessThan)
 
 
 class LessThanOrEqual(BuiltinLookup):
     lookup_name = 'lte'
-default_lookups['lte'] = LessThanOrEqual
+Field.register_lookup(LessThanOrEqual)
 
 
 class In(BuiltinLookup):
@@ -286,32 +224,32 @@ class In(BuiltinLookup):
 
     def as_sql(self, compiler, connection):
         max_in_list_size = connection.ops.max_in_list_size()
-        if self.rhs_is_direct_value() and (max_in_list_size and
-                                           len(self.rhs) > max_in_list_size):
-            # This is a special case for Oracle which limits the number of elements
-            # which can appear in an 'IN' clause.
-            lhs, lhs_params = self.process_lhs(compiler, connection)
-            rhs, rhs_params = self.batch_process_rhs(compiler, connection)
-            in_clause_elements = ['(']
-            params = []
-            for offset in range(0, len(rhs_params), max_in_list_size):
-                if offset > 0:
-                    in_clause_elements.append(' OR ')
-                in_clause_elements.append('%s IN (' % lhs)
-                params.extend(lhs_params)
-                sqls = rhs[offset: offset + max_in_list_size]
-                sqls_params = rhs_params[offset: offset + max_in_list_size]
-                param_group = ', '.join(sqls)
-                in_clause_elements.append(param_group)
-                in_clause_elements.append(')')
-                params.extend(sqls_params)
-            in_clause_elements.append(')')
-            return ''.join(in_clause_elements), params
-        else:
-            return super(In, self).as_sql(compiler, connection)
-
+        if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
+            return self.split_parameter_list_as_sql(compiler, connection)
+        return super(In, self).as_sql(compiler, connection)
 
-default_lookups['in'] = In
+    def split_parameter_list_as_sql(self, compiler, connection):
+        # This is a special case for databases which limit the number of
+        # elements which can appear in an 'IN' clause.
+        max_in_list_size = connection.ops.max_in_list_size()
+        lhs, lhs_params = self.process_lhs(compiler, connection)
+        rhs, rhs_params = self.batch_process_rhs(compiler, connection)
+        in_clause_elements = ['(']
+        params = []
+        for offset in range(0, len(rhs_params), max_in_list_size):
+            if offset > 0:
+                in_clause_elements.append(' OR ')
+            in_clause_elements.append('%s IN (' % lhs)
+            params.extend(lhs_params)
+            sqls = rhs[offset: offset + max_in_list_size]
+            sqls_params = rhs_params[offset: offset + max_in_list_size]
+            param_group = ', '.join(sqls)
+            in_clause_elements.append(param_group)
+            in_clause_elements.append(')')
+            params.extend(sqls_params)
+        in_clause_elements.append(')')
+        return ''.join(in_clause_elements), params
+Field.register_lookup(In)
 
 
 class PatternLookup(BuiltinLookup):
@@ -342,16 +280,12 @@ class Contains(PatternLookup):
         if params and not self.bilateral_transforms:
             params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
         return rhs, params
-
-
-default_lookups['contains'] = Contains
+Field.register_lookup(Contains)
 
 
 class IContains(Contains):
     lookup_name = 'icontains'
-
-
-default_lookups['icontains'] = IContains
+Field.register_lookup(IContains)
 
 
 class StartsWith(PatternLookup):
@@ -362,9 +296,7 @@ class StartsWith(PatternLookup):
         if params and not self.bilateral_transforms:
             params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
         return rhs, params
-
-
-default_lookups['startswith'] = StartsWith
+Field.register_lookup(StartsWith)
 
 
 class IStartsWith(PatternLookup):
@@ -375,9 +307,7 @@ class IStartsWith(PatternLookup):
         if params and not self.bilateral_transforms:
             params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
         return rhs, params
-
-
-default_lookups['istartswith'] = IStartsWith
+Field.register_lookup(IStartsWith)
 
 
 class EndsWith(PatternLookup):
@@ -388,9 +318,7 @@ class EndsWith(PatternLookup):
         if params and not self.bilateral_transforms:
             params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
         return rhs, params
-
-
-default_lookups['endswith'] = EndsWith
+Field.register_lookup(EndsWith)
 
 
 class IEndsWith(PatternLookup):
@@ -401,9 +329,7 @@ class IEndsWith(PatternLookup):
         if params and not self.bilateral_transforms:
             params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
         return rhs, params
-
-
-default_lookups['iendswith'] = IEndsWith
+Field.register_lookup(IEndsWith)
 
 
 class Between(BuiltinLookup):
@@ -424,8 +350,7 @@ class Range(BuiltinLookup):
             return self.batch_process_rhs(compiler, connection)
         else:
             return super(Range, self).process_rhs(compiler, connection)
-
-default_lookups['range'] = Range
+Field.register_lookup(Range)
 
 
 class IsNull(BuiltinLookup):
@@ -437,7 +362,7 @@ class IsNull(BuiltinLookup):
             return "%s IS NULL" % sql, params
         else:
             return "%s IS NOT NULL" % sql, params
-default_lookups['isnull'] = IsNull
+Field.register_lookup(IsNull)
 
 
 class Search(BuiltinLookup):
@@ -448,8 +373,7 @@ class Search(BuiltinLookup):
         rhs, rhs_params = self.process_rhs(compiler, connection)
         sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
         return sql_template, lhs_params + rhs_params
-
-default_lookups['search'] = Search
+Field.register_lookup(Search)
 
 
 class Regex(BuiltinLookup):
@@ -463,9 +387,168 @@ class Regex(BuiltinLookup):
             rhs, rhs_params = self.process_rhs(compiler, connection)
             sql_template = connection.ops.regex_lookup(self.lookup_name)
             return sql_template % (lhs, rhs), lhs_params + rhs_params
-default_lookups['regex'] = Regex
+Field.register_lookup(Regex)
 
 
 class IRegex(Regex):
     lookup_name = 'iregex'
-default_lookups['iregex'] = IRegex
+Field.register_lookup(IRegex)
+
+
+class DateTimeDateTransform(Transform):
+    lookup_name = 'date'
+
+    @cached_property
+    def output_field(self):
+        return DateField()
+
+    def as_sql(self, compiler, connection):
+        lhs, lhs_params = compiler.compile(self.lhs)
+        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
+        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
+        lhs_params.extend(tz_params)
+        return sql, lhs_params
+
+
+class DateTransform(Transform):
+    def as_sql(self, compiler, connection):
+        sql, params = compiler.compile(self.lhs)
+        lhs_output_field = self.lhs.output_field
+        if isinstance(lhs_output_field, DateTimeField):
+            tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
+            sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
+            params.extend(tz_params)
+        elif isinstance(lhs_output_field, DateField):
+            sql = connection.ops.date_extract_sql(self.lookup_name, sql)
+        elif isinstance(lhs_output_field, TimeField):
+            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
+        else:
+            raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
+        return sql, params
+
+    @cached_property
+    def output_field(self):
+        return IntegerField()
+
+
+class YearTransform(DateTransform):
+    lookup_name = 'year'
+
+
+class YearLookup(Lookup):
+    def year_lookup_bounds(self, connection, year):
+        output_field = self.lhs.lhs.output_field
+        if isinstance(output_field, DateTimeField):
+            bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
+        else:
+            bounds = connection.ops.year_lookup_bounds_for_date_field(year)
+        return bounds
+
+
+@YearTransform.register_lookup
+class YearExact(YearLookup):
+    lookup_name = 'exact'
+
+    def as_sql(self, compiler, connection):
+        # We will need to skip the extract part and instead go
+        # directly with the originating field, that is self.lhs.lhs.
+        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
+        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
+        bounds = self.year_lookup_bounds(connection, rhs_params[0])
+        params.extend(bounds)
+        return '%s BETWEEN %%s AND %%s' % lhs_sql, params
+
+
+class YearComparisonLookup(YearLookup):
+    def as_sql(self, compiler, connection):
+        # We will need to skip the extract part and instead go
+        # directly with the originating field, that is self.lhs.lhs.
+        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
+        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
+        rhs_sql = self.get_rhs_op(connection, rhs_sql)
+        start, finish = self.year_lookup_bounds(connection, rhs_params[0])
+        params.append(self.get_bound(start, finish))
+        return '%s %s' % (lhs_sql, rhs_sql), params
+
+    def get_rhs_op(self, connection, rhs):
+        return connection.operators[self.lookup_name] % rhs
+
+    def get_bound(self):
+        raise NotImplementedError(
+            'subclasses of YearComparisonLookup must provide a get_bound() method'
+        )
+
+
+@YearTransform.register_lookup
+class YearGt(YearComparisonLookup):
+    lookup_name = 'gt'
+
+    def get_bound(self, start, finish):
+        return finish
+
+
+@YearTransform.register_lookup
+class YearGte(YearComparisonLookup):
+    lookup_name = 'gte'
+
+    def get_bound(self, start, finish):
+        return start
+
+
+@YearTransform.register_lookup
+class YearLt(YearComparisonLookup):
+    lookup_name = 'lt'
+
+    def get_bound(self, start, finish):
+        return start
+
+
+@YearTransform.register_lookup
+class YearLte(YearComparisonLookup):
+    lookup_name = 'lte'
+
+    def get_bound(self, start, finish):
+        return finish
+
+
+class MonthTransform(DateTransform):
+    lookup_name = 'month'
+
+
+class DayTransform(DateTransform):
+    lookup_name = 'day'
+
+
+class WeekDayTransform(DateTransform):
+    lookup_name = 'week_day'
+
+
+class HourTransform(DateTransform):
+    lookup_name = 'hour'
+
+
+class MinuteTransform(DateTransform):
+    lookup_name = 'minute'
+
+
+class SecondTransform(DateTransform):
+    lookup_name = 'second'
+
+
+DateField.register_lookup(YearTransform)
+DateField.register_lookup(MonthTransform)
+DateField.register_lookup(DayTransform)
+DateField.register_lookup(WeekDayTransform)
+
+TimeField.register_lookup(HourTransform)
+TimeField.register_lookup(MinuteTransform)
+TimeField.register_lookup(SecondTransform)
+
+DateTimeField.register_lookup(DateTimeDateTransform)
+DateTimeField.register_lookup(YearTransform)
+DateTimeField.register_lookup(MonthTransform)
+DateTimeField.register_lookup(DayTransform)
+DateTimeField.register_lookup(WeekDayTransform)
+DateTimeField.register_lookup(HourTransform)
+DateTimeField.register_lookup(MinuteTransform)
+DateTimeField.register_lookup(SecondTransform)

+ 55 - 0
django/db/models/query_utils.py

@@ -7,6 +7,7 @@ circular import difficulties.
 """
 from __future__ import unicode_literals
 
+import inspect
 from collections import namedtuple
 
 from django.apps import apps
@@ -169,6 +170,60 @@ class DeferredAttribute(object):
         return None
 
 
+class RegisterLookupMixin(object):
+    def _get_lookup(self, lookup_name):
+        try:
+            return self.class_lookups[lookup_name]
+        except KeyError:
+            # To allow for inheritance, check parent class' class_lookups.
+            for parent in inspect.getmro(self.__class__):
+                if 'class_lookups' not in parent.__dict__:
+                    continue
+                if lookup_name in parent.class_lookups:
+                    return parent.class_lookups[lookup_name]
+        except AttributeError:
+            # This class didn't have any class_lookups
+            pass
+        return None
+
+    def get_lookup(self, lookup_name):
+        from django.db.models.lookups import Lookup
+        found = self._get_lookup(lookup_name)
+        if found is None and hasattr(self, 'output_field'):
+            return self.output_field.get_lookup(lookup_name)
+        if found is not None and not issubclass(found, Lookup):
+            return None
+        return found
+
+    def get_transform(self, lookup_name):
+        from django.db.models.lookups import Transform
+        found = self._get_lookup(lookup_name)
+        if found is None and hasattr(self, 'output_field'):
+            return self.output_field.get_transform(lookup_name)
+        if found is not None and not issubclass(found, Transform):
+            return None
+        return found
+
+    @classmethod
+    def register_lookup(cls, lookup, lookup_name=None):
+        if lookup_name is None:
+            lookup_name = lookup.lookup_name
+        if 'class_lookups' not in cls.__dict__:
+            cls.class_lookups = {}
+        cls.class_lookups[lookup_name] = lookup
+        return lookup
+
+    @classmethod
+    def _unregister_lookup(cls, lookup, lookup_name=None):
+        """
+        Remove given lookup from cls lookups. For use in tests only as it's
+        not thread-safe.
+        """
+        if lookup_name is None:
+            lookup_name = lookup.lookup_name
+        del cls.class_lookups[lookup_name]
+
+
 def select_related_descend(field, restricted, requested, load_fields, reverse=False):
     """
     Returns True if this field should be used to descend deeper for

+ 1 - 1
django/db/models/sql/aggregates.py

@@ -5,7 +5,7 @@ import copy
 import warnings
 
 from django.db.models.fields import FloatField, IntegerField
-from django.db.models.lookups import RegisterLookupMixin
+from django.db.models.query_utils import RegisterLookupMixin
 from django.utils.deprecation import RemovedInDjango110Warning
 from django.utils.functional import cached_property
 

+ 3 - 3
django/db/models/sql/query.py

@@ -1105,9 +1105,9 @@ class Query(object):
         Helper method for build_lookup. Tries to fetch and initialize
         a transform for name parameter from lhs.
         """
-        next = lhs.get_transform(name)
-        if next:
-            return next(lhs, rest_of_lookups)
+        transform_class = lhs.get_transform(name)
+        if transform_class:
+            return transform_class(lhs)
         else:
             raise FieldError(
                 "Unsupported lookup '%s' for %s or join on the field not "

+ 3 - 12
docs/howto/custom-lookups.txt

@@ -120,10 +120,7 @@ function ``ABS()`` to transform the value before comparison::
 
   class AbsoluteValue(Transform):
       lookup_name = 'abs'
-
-      def as_sql(self, compiler, connection):
-          lhs, params = compiler.compile(self.lhs)
-          return "ABS(%s)" % lhs, params
+      function = 'ABS'
 
 Next, let's register it for ``IntegerField``::
 
@@ -157,10 +154,7 @@ be done by adding an ``output_field`` attribute to the transform::
 
     class AbsoluteValue(Transform):
         lookup_name = 'abs'
-
-        def as_sql(self, compiler, connection):
-            lhs, params = compiler.compile(self.lhs)
-            return "ABS(%s)" % lhs, params
+        function = 'ABS'
 
         @property
         def output_field(self):
@@ -243,12 +237,9 @@ this transformation should apply to both ``lhs`` and ``rhs``::
 
   class UpperCase(Transform):
       lookup_name = 'upper'
+      function = 'UPPER'
       bilateral = True
 
-      def as_sql(self, compiler, connection):
-          lhs, params = compiler.compile(self.lhs)
-          return "UPPER(%s)" % lhs, params
-
 Next, let's register it::
 
   from django.db.models import CharField, TextField

+ 24 - 0
docs/ref/models/database-functions.txt

@@ -180,6 +180,18 @@ Usage example::
     >>> print(author.name_length, author.goes_by_length)
     (14, None)
 
+It can also be registered as a transform. For example::
+
+    >>> from django.db.models import CharField
+    >>> from django.db.models.functions import Length
+    >>> CharField.register_lookup(Length, 'length')
+    >>> # Get authors whose name is longer than 7 characters
+    >>> authors = Author.objects.filter(name__length__gt=7)
+
+.. versionchanged:: 1.9
+
+    The ability to register the function as a transform was added.
+
 Lower
 ------
 
@@ -188,6 +200,8 @@ Lower
 Accepts a single text field or expression and returns the lowercase
 representation.
 
+It can also be registered as a transform as described in :class:`Length`.
+
 Usage example::
 
     >>> from django.db.models.functions import Lower
@@ -196,6 +210,10 @@ Usage example::
     >>> print(author.name_lower)
     margaret smith
 
+.. versionchanged:: 1.9
+
+    The ability to register the function as a transform was added.
+
 Now
 ---
 
@@ -246,6 +264,8 @@ Upper
 Accepts a single text field or expression and returns the uppercase
 representation.
 
+It can also be registered as a transform as described in :class:`Length`.
+
 Usage example::
 
     >>> from django.db.models.functions import Upper
@@ -253,3 +273,7 @@ Usage example::
     >>> author = Author.objects.annotate(name_upper=Upper('name')).get()
     >>> print(author.name_upper)
     MARGARET SMITH
+
+.. versionchanged:: 1.9
+
+    The ability to register the function as a transform was added.

+ 15 - 15
docs/ref/models/lookups.txt

@@ -42,12 +42,17 @@ register lookups on itself. The two prominent examples are
 
     A mixin that implements the lookup API on a class.
 
-    .. classmethod:: register_lookup(lookup)
+    .. classmethod:: register_lookup(lookup, lookup_name=None)
 
         Registers a new lookup in the class. For example
         ``DateField.register_lookup(YearExact)`` will register ``YearExact``
         lookup on ``DateField``. It overrides a lookup that already exists with
-        the same name.
+        the same name. ``lookup_name`` will be used for this lookup if
+        provided, otherwise ``lookup.lookup_name`` will be used.
+
+        .. versionchanged:: 1.9
+
+            The ``lookup_name`` parameter was added.
 
     .. method:: get_lookup(lookup_name)
 
@@ -125,7 +130,14 @@ Transform reference
     ``<expression>__<transformation>`` (e.g. ``date__year``).
 
     This class follows the :ref:`Query Expression API <query-expression>`, which
-    implies that you can use ``<expression>__<transform1>__<transform2>``.
+    implies that you can use ``<expression>__<transform1>__<transform2>``. It's
+    a specialized :ref:`Func() expression <func-expressions>` that only accepts
+    one argument.  It can also be used on the right hand side of a filter or
+    directly as an annotation.
+
+    .. versionchanged:: 1.9
+
+        ``Transform`` is now a subclass of ``Func``.
 
     .. attribute:: bilateral
 
@@ -152,18 +164,6 @@ Transform reference
         :class:`~django.db.models.Field` instance. By default is the same as
         its ``lhs.output_field``.
 
-    .. method:: as_sql
-
-        To be overridden; raises :exc:`NotImplementedError`.
-
-    .. method:: get_lookup(lookup_name)
-
-        Same as :meth:`~lookups.RegisterLookupMixin.get_lookup()`.
-
-    .. method:: get_transform(transform_name)
-
-        Same as :meth:`~lookups.RegisterLookupMixin.get_transform()`.
-
 Lookup reference
 ~~~~~~~~~~~~~~~~
 

+ 8 - 0
docs/releases/1.9.txt

@@ -520,6 +520,14 @@ Models
 * Added the :class:`~django.db.models.functions.Now` database function, which
   returns the current date and time.
 
+* :class:`~django.db.models.Transform` is now a subclass of
+  :ref:`Func() <func-expressions>` which allows ``Transform``\s to be used on
+  the right hand side of an expression, just like regular ``Func``\s. This
+  allows registering some database functions like
+  :class:`~django.db.models.functions.Length`,
+  :class:`~django.db.models.functions.Lower`, and
+  :class:`~django.db.models.functions.Upper` as transforms.
+
 * :class:`~django.db.models.SlugField` now accepts an
   :attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode
   characters in slugs.

+ 63 - 8
tests/custom_lookups/tests.py

@@ -126,11 +126,17 @@ class YearLte(models.lookups.LessThanOrEqual):
         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
 
 
-class SQLFunc(models.Lookup):
-    def __init__(self, name, *args, **kwargs):
-        super(SQLFunc, self).__init__(*args, **kwargs)
-        self.name = name
+class Exactly(models.lookups.Exact):
+    """
+    This lookup is used to test lookup registration.
+    """
+    lookup_name = 'exactly'
 
+    def get_rhs_op(self, connection, rhs):
+        return connection.operators['exact'] % rhs
+
+
+class SQLFuncMixin(object):
     def as_sql(self, compiler, connection):
         return '%s()', [self.name]
 
@@ -139,13 +145,28 @@ class SQLFunc(models.Lookup):
         return CustomField()
 
 
+class SQLFuncLookup(SQLFuncMixin, models.Lookup):
+    def __init__(self, name, *args, **kwargs):
+        super(SQLFuncLookup, self).__init__(*args, **kwargs)
+        self.name = name
+
+
+class SQLFuncTransform(SQLFuncMixin, models.Transform):
+    def __init__(self, name, *args, **kwargs):
+        super(SQLFuncTransform, self).__init__(*args, **kwargs)
+        self.name = name
+
+
 class SQLFuncFactory(object):
 
-    def __init__(self, name):
+    def __init__(self, key, name):
+        self.key = key
         self.name = name
 
     def __call__(self, *args, **kwargs):
-        return SQLFunc(self.name, *args, **kwargs)
+        if self.key == 'lookupfunc':
+            return SQLFuncLookup(self.name, *args, **kwargs)
+        return SQLFuncTransform(self.name, *args, **kwargs)
 
 
 class CustomField(models.TextField):
@@ -153,13 +174,13 @@ class CustomField(models.TextField):
     def get_lookup(self, lookup_name):
         if lookup_name.startswith('lookupfunc_'):
             key, name = lookup_name.split('_', 1)
-            return SQLFuncFactory(name)
+            return SQLFuncFactory(key, name)
         return super(CustomField, self).get_lookup(lookup_name)
 
     def get_transform(self, lookup_name):
         if lookup_name.startswith('transformfunc_'):
             key, name = lookup_name.split('_', 1)
-            return SQLFuncFactory(name)
+            return SQLFuncFactory(key, name)
         return super(CustomField, self).get_transform(lookup_name)
 
 
@@ -200,6 +221,27 @@ class DateTimeTransform(models.Transform):
 
 
 class LookupTests(TestCase):
+
+    def test_custom_name_lookup(self):
+        a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
+        Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
+        custom_lookup_name = 'isactually'
+        custom_transform_name = 'justtheyear'
+        try:
+            models.DateField.register_lookup(YearTransform)
+            models.DateField.register_lookup(YearTransform, custom_transform_name)
+            YearTransform.register_lookup(Exactly)
+            YearTransform.register_lookup(Exactly, custom_lookup_name)
+            qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
+            qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
+            self.assertQuerysetEqual(qs1, [a1], lambda x: x)
+            self.assertQuerysetEqual(qs2, [a1], lambda x: x)
+        finally:
+            YearTransform._unregister_lookup(Exactly)
+            YearTransform._unregister_lookup(Exactly, custom_lookup_name)
+            models.DateField._unregister_lookup(YearTransform)
+            models.DateField._unregister_lookup(YearTransform, custom_transform_name)
+
     def test_basic_lookup(self):
         a1 = Author.objects.create(name='a1', age=1)
         a2 = Author.objects.create(name='a2', age=2)
@@ -299,6 +341,19 @@ class BilateralTransformTests(TestCase):
             with self.assertRaises(NotImplementedError):
                 Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
 
+    def test_bilateral_multi_value(self):
+        with register_lookup(models.CharField, UpperBilateralTransform):
+            Author.objects.bulk_create([
+                Author(name='Foo'),
+                Author(name='Bar'),
+                Author(name='Ray'),
+            ])
+            self.assertQuerysetEqual(
+                Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'),
+                ['Bar', 'Foo'],
+                lambda a: a.name
+            )
+
     def test_div3_bilateral_extract(self):
         with register_lookup(models.IntegerField, Div3BilateralTransform):
             a1 = Author.objects.create(name='a1', age=1)

+ 94 - 0
tests/db_functions/tests.py

@@ -547,3 +547,97 @@ class FunctionTests(TestCase):
             ['How to Time Travel'],
             lambda a: a.title
         )
+
+    def test_length_transform(self):
+        try:
+            CharField.register_lookup(Length, 'length')
+            Author.objects.create(name='John Smith', alias='smithj')
+            Author.objects.create(name='Rhonda')
+            authors = Author.objects.filter(name__length__gt=7)
+            self.assertQuerysetEqual(
+                authors.order_by('name'), [
+                    'John Smith',
+                ],
+                lambda a: a.name
+            )
+        finally:
+            CharField._unregister_lookup(Length, 'length')
+
+    def test_lower_transform(self):
+        try:
+            CharField.register_lookup(Lower, 'lower')
+            Author.objects.create(name='John Smith', alias='smithj')
+            Author.objects.create(name='Rhonda')
+            authors = Author.objects.filter(name__lower__exact='john smith')
+            self.assertQuerysetEqual(
+                authors.order_by('name'), [
+                    'John Smith',
+                ],
+                lambda a: a.name
+            )
+        finally:
+            CharField._unregister_lookup(Lower, 'lower')
+
+    def test_upper_transform(self):
+        try:
+            CharField.register_lookup(Upper, 'upper')
+            Author.objects.create(name='John Smith', alias='smithj')
+            Author.objects.create(name='Rhonda')
+            authors = Author.objects.filter(name__upper__exact='JOHN SMITH')
+            self.assertQuerysetEqual(
+                authors.order_by('name'), [
+                    'John Smith',
+                ],
+                lambda a: a.name
+            )
+        finally:
+            CharField._unregister_lookup(Upper, 'upper')
+
+    def test_func_transform_bilateral(self):
+        class UpperBilateral(Upper):
+            bilateral = True
+
+        try:
+            CharField.register_lookup(UpperBilateral, 'upper')
+            Author.objects.create(name='John Smith', alias='smithj')
+            Author.objects.create(name='Rhonda')
+            authors = Author.objects.filter(name__upper__exact='john smith')
+            self.assertQuerysetEqual(
+                authors.order_by('name'), [
+                    'John Smith',
+                ],
+                lambda a: a.name
+            )
+        finally:
+            CharField._unregister_lookup(UpperBilateral, 'upper')
+
+    def test_func_transform_bilateral_multivalue(self):
+        class UpperBilateral(Upper):
+            bilateral = True
+
+        try:
+            CharField.register_lookup(UpperBilateral, 'upper')
+            Author.objects.create(name='John Smith', alias='smithj')
+            Author.objects.create(name='Rhonda')
+            authors = Author.objects.filter(name__upper__in=['john smith', 'rhonda'])
+            self.assertQuerysetEqual(
+                authors.order_by('name'), [
+                    'John Smith',
+                    'Rhonda',
+                ],
+                lambda a: a.name
+            )
+        finally:
+            CharField._unregister_lookup(UpperBilateral, 'upper')
+
+    def test_function_as_filter(self):
+        Author.objects.create(name='John Smith', alias='SMITHJ')
+        Author.objects.create(name='Rhonda')
+        self.assertQuerysetEqual(
+            Author.objects.filter(alias=Upper(V('smithj'))),
+            ['John Smith'], lambda x: x.name
+        )
+        self.assertQuerysetEqual(
+            Author.objects.exclude(alias=Upper(V('smithj'))),
+            ['Rhonda'], lambda x: x.name
+        )