Browse Source

Fixed #24747 -- Allowed transforms in QuerySet.order_by() and distinct(*fields).

Matthew Wilkes 7 years ago
parent
commit
2162f0983d

+ 2 - 2
django/db/backends/base/operations.py

@@ -158,7 +158,7 @@ class BaseDatabaseOperations:
         """
         return ''
 
-    def distinct_sql(self, fields):
+    def distinct_sql(self, fields, params):
         """
         Return an SQL DISTINCT clause which removes duplicate rows from the
         result set. If any fields are given, only check the given fields for
@@ -167,7 +167,7 @@ class BaseDatabaseOperations:
         if fields:
             raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
         else:
-            return 'DISTINCT'
+            return ['DISTINCT'], []
 
     def fetch_returned_insert_id(self, cursor):
         """

+ 4 - 3
django/db/backends/postgresql/operations.py

@@ -207,11 +207,12 @@ class DatabaseOperations(BaseDatabaseOperations):
         """
         return 63
 
-    def distinct_sql(self, fields):
+    def distinct_sql(self, fields, params):
         if fields:
-            return 'DISTINCT ON (%s)' % ', '.join(fields)
+            params = [param for param_list in params for param in param_list]
+            return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
         else:
-            return 'DISTINCT'
+            return ['DISTINCT'], []
 
     def last_executed_query(self, cursor, sql, params):
         # http://initd.org/psycopg/docs/cursor.html#cursor.query

+ 19 - 14
django/db/models/sql/compiler.py

@@ -451,7 +451,7 @@ class SQLCompiler:
                     raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
                 result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
             else:
-                distinct_fields = self.get_distinct()
+                distinct_fields, distinct_params = self.get_distinct()
                 # This must come after 'select', 'ordering', and 'distinct'
                 # (see docstring of get_from_clause() for details).
                 from_, f_params = self.get_from_clause()
@@ -461,7 +461,12 @@ class SQLCompiler:
                 params = []
 
                 if self.query.distinct:
-                    result.append(self.connection.ops.distinct_sql(distinct_fields))
+                    distinct_result, distinct_params = self.connection.ops.distinct_sql(
+                        distinct_fields,
+                        distinct_params,
+                    )
+                    result += distinct_result
+                    params += distinct_params
 
                 out_cols = []
                 col_idx = 1
@@ -621,21 +626,22 @@ class SQLCompiler:
         This method can alter the tables in the query, and thus it must be
         called before get_from_clause().
         """
-        qn = self.quote_name_unless_alias
-        qn2 = self.connection.ops.quote_name
         result = []
+        params = []
         opts = self.query.get_meta()
 
         for name in self.query.distinct_fields:
             parts = name.split(LOOKUP_SEP)
-            _, targets, alias, joins, path, _ = self._setup_joins(parts, opts, None)
+            _, targets, alias, joins, path, _, transform_function = self._setup_joins(parts, opts, None)
             targets, alias, _ = self.query.trim_joins(targets, joins, path)
             for target in targets:
                 if name in self.query.annotation_select:
                     result.append(name)
                 else:
-                    result.append("%s.%s" % (qn(alias), qn2(target.column)))
-        return result
+                    r, p = self.compile(transform_function(target, alias))
+                    result.append(r)
+                    params.append(p)
+        return result, params
 
     def find_ordering_name(self, name, opts, alias=None, default_order='ASC',
                            already_seen=None):
@@ -647,7 +653,7 @@ class SQLCompiler:
         name, order = get_order_dir(name, default_order)
         descending = order == 'DESC'
         pieces = name.split(LOOKUP_SEP)
-        field, targets, alias, joins, path, opts = self._setup_joins(pieces, opts, alias)
+        field, targets, alias, joins, path, opts, transform_function = self._setup_joins(pieces, opts, alias)
 
         # If we get to this point and the field is a relation to another model,
         # append the default ordering for that model unless the attribute name
@@ -666,7 +672,7 @@ class SQLCompiler:
                                                        order, already_seen))
             return results
         targets, alias, _ = self.query.trim_joins(targets, joins, path)
-        return [(OrderBy(t.get_col(alias), descending=descending), False) for t in targets]
+        return [(OrderBy(transform_function(t, alias), descending=descending), False) for t in targets]
 
     def _setup_joins(self, pieces, opts, alias):
         """
@@ -677,10 +683,9 @@ class SQLCompiler:
         match. Executing SQL where this is not true is an error.
         """
         alias = alias or self.query.get_initial_alias()
-        field, targets, opts, joins, path = self.query.setup_joins(
-            pieces, opts, alias)
+        field, targets, opts, joins, path, transform_function = self.query.setup_joins(pieces, opts, alias)
         alias = joins[-1]
-        return field, targets, alias, joins, path, opts
+        return field, targets, alias, joins, path, opts, transform_function
 
     def get_from_clause(self):
         """
@@ -786,7 +791,7 @@ class SQLCompiler:
             }
             related_klass_infos.append(klass_info)
             select_fields = []
-            _, _, _, joins, _ = self.query.setup_joins(
+            _, _, _, joins, _, _ = self.query.setup_joins(
                 [f.name], opts, root_alias)
             alias = joins[-1]
             columns = self.get_default_columns(start_alias=alias, opts=f.remote_field.model._meta)
@@ -843,7 +848,7 @@ class SQLCompiler:
                     break
                 if name in self.query._filtered_relations:
                     fields_found.add(name)
-                    f, _, join_opts, joins, _ = self.query.setup_joins([name], opts, root_alias)
+                    f, _, join_opts, joins, _, _ = self.query.setup_joins([name], opts, root_alias)
                     model = join_opts.model
                     alias = joins[-1]
                     from_parent = issubclass(model, opts.model) and model is not opts.model

+ 50 - 9
django/db/models/sql/query.py

@@ -6,6 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
 databases). The abstraction barrier only works one way: this module has to know
 all about the internals of models in order to get the information it needs.
 """
+import functools
 from collections import Counter, OrderedDict, namedtuple
 from collections.abc import Iterator, Mapping
 from itertools import chain, count, product
@@ -18,6 +19,7 @@ from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
 from django.db.models.aggregates import Count
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.expressions import Col, Ref
+from django.db.models.fields import Field
 from django.db.models.fields.related_lookups import MultiColSource
 from django.db.models.lookups import Lookup
 from django.db.models.query_utils import (
@@ -56,7 +58,7 @@ def get_children_from_q(q):
 
 JoinInfo = namedtuple(
     'JoinInfo',
-    ('final_field', 'targets', 'opts', 'joins', 'path')
+    ('final_field', 'targets', 'opts', 'joins', 'path', 'transform_function')
 )
 
 
@@ -1429,8 +1431,11 @@ class Query:
         generate a MultiJoin exception.
 
         Return the final field involved in the joins, the target field (used
-        for any 'where' constraint), the final 'opts' value, the joins and the
-        field path travelled to generate the joins.
+        for any 'where' constraint), the final 'opts' value, the joins, the
+        field path traveled to generate the joins, and a transform function
+        that takes a field and alias and is equivalent to `field.get_col(alias)`
+        in the simple case but wraps field transforms if they were included in
+        names.
 
         The target field is the field containing the concrete value. Final
         field can be something different, for example foreign key pointing to
@@ -1439,10 +1444,46 @@ class Query:
         key field for example).
         """
         joins = [alias]
-        # First, generate the path for the names
-        path, final_field, targets, rest = self.names_to_path(
-            names, opts, allow_many, fail_on_missing=True)
-
+        # The transform can't be applied yet, as joins must be trimmed later.
+        # To avoid making every caller of this method look up transforms
+        # directly, compute transforms here and and create a partial that
+        # converts fields to the appropriate wrapped version.
+
+        def final_transformer(field, alias):
+            return field.get_col(alias)
+
+        # Try resolving all the names as fields first. If there's an error,
+        # treat trailing names as lookups until a field can be resolved.
+        last_field_exception = None
+        for pivot in range(len(names), 0, -1):
+            try:
+                path, final_field, targets, rest = self.names_to_path(
+                    names[:pivot], opts, allow_many, fail_on_missing=True,
+                )
+            except FieldError as exc:
+                if pivot == 1:
+                    # The first item cannot be a lookup, so it's safe
+                    # to raise the field error here.
+                    raise
+                else:
+                    last_field_exception = exc
+            else:
+                # The transforms are the remaining items that couldn't be
+                # resolved into fields.
+                transforms = names[pivot:]
+                break
+        for name in transforms:
+            def transform(field, alias, *, name, previous):
+                try:
+                    wrapped = previous(field, alias)
+                    return self.try_transform(wrapped, name)
+                except FieldError:
+                    # FieldError is raised if the transform doesn't exist.
+                    if isinstance(final_field, Field) and last_field_exception:
+                        raise last_field_exception
+                    else:
+                        raise
+            final_transformer = functools.partial(transform, name=name, previous=final_transformer)
         # Then, add the path to the query's joins. Note that we can't trim
         # joins at this stage - we will need the information about join type
         # of the trimmed joins.
@@ -1470,7 +1511,7 @@ class Query:
             joins.append(alias)
             if filtered_relation:
                 filtered_relation.path = joins[:]
-        return JoinInfo(final_field, targets, opts, joins, path)
+        return JoinInfo(final_field, targets, opts, joins, path, final_transformer)
 
     def trim_joins(self, targets, joins, path):
         """
@@ -1683,7 +1724,7 @@ class Query:
                     join_info.path,
                 )
                 for target in targets:
-                    cols.append(target.get_col(final_alias))
+                    cols.append(join_info.transform_function(target, final_alias))
             if cols:
                 self.set_select(cols)
         except MultiJoin:

+ 15 - 0
docs/howto/custom-lookups.txt

@@ -138,6 +138,21 @@ SQL::
 Note that in case there is no other lookup specified, Django interprets
 ``change__abs=27`` as ``change__abs__exact=27``.
 
+This also allows the result to be used in ``ORDER BY`` and ``DISTINCT ON``
+clauses. For example ``Experiment.objects.order_by('change__abs')`` generates::
+
+    SELECT ... ORDER BY ABS("experiments"."change") ASC
+
+And on databases that support distinct on fields (such as PostgreSQL),
+``Experiment.objects.distinct('change__abs')`` generates::
+
+    SELECT ... DISTINCT ON ABS("experiments"."change")
+
+.. versionchanged:: 2.1
+
+    Ordering and distinct support as described in the last two paragraphs was
+    added.
+
 When looking for which lookups are allowable after the ``Transform`` has been
 applied, Django uses the ``output_field`` attribute. We didn't need to specify
 this here as it didn't change, but supposing we were applying ``AbsoluteValue``

+ 6 - 2
docs/ref/models/expressions.txt

@@ -64,10 +64,14 @@ Some examples
     # Aggregates can contain complex computations also
     Company.objects.annotate(num_offerings=Count(F('products') + F('services')))
 
-    # Expressions can also be used in order_by()
+    # Expressions can also be used in order_by(), either directly
     Company.objects.order_by(Length('name').asc())
     Company.objects.order_by(Length('name').desc())
-
+    # or using the double underscore lookup syntax.
+    from django.db.models import CharField
+    from django.db.models.functions import Length
+    CharField.register_lookup(Length)
+    Company.objects.order_by('name__length')
 
 Built-in Expressions
 ====================

+ 32 - 0
docs/ref/models/querysets.txt

@@ -535,6 +535,19 @@ The ``values()`` method also takes optional keyword arguments,
     >>> Blog.objects.values(lower_name=Lower('name'))
     <QuerySet [{'lower_name': 'beatles blog'}]>
 
+You can use built-in and :doc:`custom lookups </howto/custom-lookups>` in
+ordering. For example::
+
+    >>> from django.db.models import CharField
+    >>> from django.db.models.functions import Lower
+    >>> CharField.register_lookup(Lower, 'lower')
+    >>> Blog.objects.values('name__lower')
+    <QuerySet [{'name__lower': 'beatles blog'}]>
+
+.. versionchanged:: 2.1
+
+    Support for lookups was added.
+
 An aggregate within a ``values()`` clause is applied before other arguments
 within the same ``values()`` clause. If you need to group by another value,
 add it to an earlier ``values()`` clause instead. For example::
@@ -580,6 +593,25 @@ A few subtleties that are worth mentioning:
 * Calling :meth:`only()` and :meth:`defer()` after ``values()`` doesn't make
   sense, so doing so will raise a ``NotImplementedError``.
 
+* Combining transforms and aggregates requires the use of two :meth:`annotate`
+  calls, either explicitly or as keyword arguments to :meth:`values`. As above,
+  if the transform has been registered on the relevant field type the first
+  :meth:`annotate` can be omitted, thus the following examples are equivalent::
+
+    >>> from django.db.models import CharField, Count
+    >>> from django.db.models.functions import Lower
+    >>> CharField.register_lookup(Lower, 'lower')
+    >>> Blog.objects.values('entry__authors__name__lower').annotate(entries=Count('entry'))
+    <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
+    >>> Blog.objects.values(
+    ...     entry__authors__name__lower=Lower('entry__authors__name')
+    ... ).annotate(entries=Count('entry'))
+    <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
+    >>> Blog.objects.annotate(
+    ...     entry__authors__name__lower=Lower('entry__authors__name')
+    ... ).values('entry__authors__name__lower').annotate(entries=Count('entry'))
+    <QuerySet [{'entry__authors__name__lower': 'test author', 'entries': 33}]>
+
 It is useful when you know you're only going to need values from a small number
 of the available fields and you won't need the functionality of a model
 instance object. It's more efficient to select only the fields you need to use.

+ 6 - 0
docs/releases/2.1.txt

@@ -187,6 +187,9 @@ Models
 
 * Query expressions can now be negated using a minus sign.
 
+* :meth:`.QuerySet.order_by` and :meth:`distinct(*fields) <.QuerySet.distinct>`
+  now support using field transforms.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
@@ -242,6 +245,9 @@ Database backend API
 * Renamed the ``allow_sliced_subqueries`` database feature flag to
   ``allow_sliced_subqueries_with_in``.
 
+* ``DatabaseOperations.distinct_sql()`` now requires an additional ``params``
+  argument and returns a tuple of SQL and parameters instead of a SQL string.
+
 :mod:`django.contrib.gis`
 -------------------------
 

+ 1 - 1
tests/backends/base/test_operations.py

@@ -17,7 +17,7 @@ class DatabaseOperationTests(SimpleTestCase):
     def test_distinct_on_fields(self):
         msg = 'DISTINCT ON fields is not supported by this database backend'
         with self.assertRaisesMessage(NotSupportedError, msg):
-            self.ops.distinct_sql(['a', 'b'])
+            self.ops.distinct_sql(['a', 'b'], None)
 
     def test_deferrable_sql(self):
         self.assertEqual(self.ops.deferrable_sql(), '')

+ 17 - 0
tests/custom_lookups/tests.py

@@ -63,6 +63,14 @@ class Mult3BilateralTransform(models.Transform):
         return '3 * (%s)' % lhs, lhs_params
 
 
+class LastDigitTransform(models.Transform):
+    lookup_name = 'lastdigit'
+
+    def as_sql(self, compiler, connection):
+        lhs, lhs_params = compiler.compile(self.lhs)
+        return 'SUBSTR(CAST(%s AS CHAR(2)), 2, 1)' % lhs, lhs_params
+
+
 class UpperBilateralTransform(models.Transform):
     bilateral = True
     lookup_name = 'upper'
@@ -379,6 +387,15 @@ class BilateralTransformTests(TestCase):
             self.assertSequenceEqual(baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4])
             self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
 
+    def test_transform_order_by(self):
+        with register_lookup(models.IntegerField, LastDigitTransform):
+            a1 = Author.objects.create(name='a1', age=11)
+            a2 = Author.objects.create(name='a2', age=23)
+            a3 = Author.objects.create(name='a3', age=32)
+            a4 = Author.objects.create(name='a4', age=40)
+            qs = Author.objects.order_by('age__lastdigit')
+            self.assertSequenceEqual(qs, [a4, a1, a3, a2])
+
     def test_bilateral_fexpr(self):
         with register_lookup(models.IntegerField, Mult3BilateralTransform):
             a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)

+ 22 - 8
tests/distinct_on_fields/tests.py

@@ -1,4 +1,5 @@
-from django.db.models import Max
+from django.db.models import CharField, Max
+from django.db.models.functions import Lower
 from django.test import TestCase, skipUnlessDBFeature
 
 from .models import Celebrity, Fan, Staff, StaffTag, Tag
@@ -8,19 +9,19 @@ from .models import Celebrity, Fan, Staff, StaffTag, Tag
 @skipUnlessDBFeature('supports_nullable_unique_constraints')
 class DistinctOnTests(TestCase):
     def setUp(self):
-        t1 = Tag.objects.create(name='t1')
-        Tag.objects.create(name='t2', parent=t1)
-        t3 = Tag.objects.create(name='t3', parent=t1)
-        Tag.objects.create(name='t4', parent=t3)
-        Tag.objects.create(name='t5', parent=t3)
+        self.t1 = Tag.objects.create(name='t1')
+        self.t2 = Tag.objects.create(name='t2', parent=self.t1)
+        self.t3 = Tag.objects.create(name='t3', parent=self.t1)
+        self.t4 = Tag.objects.create(name='t4', parent=self.t3)
+        self.t5 = Tag.objects.create(name='t5', parent=self.t3)
 
         self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1")
         self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1")
         self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1")
         self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2")
         self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1)
-        StaffTag.objects.create(staff=self.p1_o1, tag=t1)
-        StaffTag.objects.create(staff=self.p1_o1, tag=t1)
+        StaffTag.objects.create(staff=self.p1_o1, tag=self.t1)
+        StaffTag.objects.create(staff=self.p1_o1, tag=self.t1)
 
         celeb1 = Celebrity.objects.create(name="c1")
         celeb2 = Celebrity.objects.create(name="c2")
@@ -95,6 +96,19 @@ class DistinctOnTests(TestCase):
         c2 = c1.distinct('pk')
         self.assertNotIn('OUTER JOIN', str(c2.query))
 
+    def test_transform(self):
+        new_name = self.t1.name.upper()
+        self.assertNotEqual(self.t1.name, new_name)
+        Tag.objects.create(name=new_name)
+        CharField.register_lookup(Lower)
+        try:
+            self.assertCountEqual(
+                Tag.objects.order_by().distinct('name__lower'),
+                [self.t1, self.t2, self.t3, self.t4, self.t5],
+            )
+        finally:
+            CharField._unregister_lookup(Lower)
+
     def test_distinct_not_implemented_checks(self):
         # distinct + annotate not allowed
         msg = 'annotate() + distinct(fields) is not implemented.'

+ 34 - 0
tests/expressions/tests.py

@@ -1363,6 +1363,40 @@ class ValueTests(TestCase):
             ExpressionList()
 
 
+class FieldTransformTests(TestCase):
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.sday = sday = datetime.date(2010, 6, 25)
+        cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000)
+        cls.ex1 = Experiment.objects.create(
+            name='Experiment 1',
+            assigned=sday,
+            completed=sday + datetime.timedelta(2),
+            estimated_time=datetime.timedelta(2),
+            start=stime,
+            end=stime + datetime.timedelta(2),
+        )
+
+    def test_month_aggregation(self):
+        self.assertEqual(
+            Experiment.objects.aggregate(month_count=Count('assigned__month')),
+            {'month_count': 1}
+        )
+
+    def test_transform_in_values(self):
+        self.assertQuerysetEqual(
+            Experiment.objects.values('assigned__month'),
+            ["{'assigned__month': 6}"]
+        )
+
+    def test_multiple_transforms_in_values(self):
+        self.assertQuerysetEqual(
+            Experiment.objects.values('end__date__month'),
+            ["{'end__date__month': 6}"]
+        )
+
+
 class ReprTests(TestCase):
 
     def test_expressions(self):

+ 16 - 0
tests/postgres_tests/test_array.py

@@ -309,6 +309,22 @@ class TestQuerying(PostgreSQLTestCase):
             self.objs[2:3]
         )
 
+    def test_order_by_slice(self):
+        more_objs = (
+            NullableIntegerArrayModel.objects.create(field=[1, 637]),
+            NullableIntegerArrayModel.objects.create(field=[2, 1]),
+            NullableIntegerArrayModel.objects.create(field=[3, -98123]),
+            NullableIntegerArrayModel.objects.create(field=[4, 2]),
+        )
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.order_by('field__1'),
+            [
+                more_objs[2], more_objs[1], more_objs[3], self.objs[2],
+                self.objs[3], more_objs[0], self.objs[4], self.objs[1],
+                self.objs[0],
+            ]
+        )
+
     @unittest.expectedFailure
     def test_slice_nested(self):
         instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])

+ 12 - 0
tests/postgres_tests/test_hstore.py

@@ -148,6 +148,18 @@ class TestQuerying(HStoreTestCase):
             self.objs[:2]
         )
 
+    def test_order_by_field(self):
+        more_objs = (
+            HStoreModel.objects.create(field={'g': '637'}),
+            HStoreModel.objects.create(field={'g': '002'}),
+            HStoreModel.objects.create(field={'g': '042'}),
+            HStoreModel.objects.create(field={'g': '981'}),
+        )
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__has_key='g').order_by('field__g'),
+            [more_objs[1], more_objs[2], more_objs[0], more_objs[3]]
+        )
+
     def test_keys_contains(self):
         self.assertSequenceEqual(
             HStoreModel.objects.filter(field__keys__contains=['a']),

+ 25 - 0
tests/postgres_tests/test_json.py

@@ -141,6 +141,31 @@ class TestQuerying(PostgreSQLTestCase):
             [self.objs[0]]
         )
 
+    def test_ordering_by_transform(self):
+        objs = [
+            JSONModel.objects.create(field={'ord': 93, 'name': 'bar'}),
+            JSONModel.objects.create(field={'ord': 22.1, 'name': 'foo'}),
+            JSONModel.objects.create(field={'ord': -1, 'name': 'baz'}),
+            JSONModel.objects.create(field={'ord': 21.931902, 'name': 'spam'}),
+            JSONModel.objects.create(field={'ord': -100291029, 'name': 'eggs'}),
+        ]
+        query = JSONModel.objects.filter(field__name__isnull=False).order_by('field__ord')
+        self.assertSequenceEqual(query, [objs[4], objs[2], objs[3], objs[1], objs[0]])
+
+    def test_deep_values(self):
+        query = JSONModel.objects.values_list('field__k__l')
+        self.assertSequenceEqual(
+            query,
+            [
+                (None,), (None,), (None,), (None,), (None,), (None,),
+                (None,), (None,), ('m',), (None,), (None,), (None,),
+            ]
+        )
+
+    def test_deep_distinct(self):
+        query = JSONModel.objects.distinct('field__k__l').values_list('field__k__l')
+        self.assertSequenceEqual(query, [('m',), (None,)])
+
     def test_isnull_key(self):
         # key__isnull works the same as has_key='key'.
         self.assertSequenceEqual(