Browse Source

Fixed #27021 -- Allowed lookup expressions in annotations, aggregations, and QuerySet.filter().

Thanks Hannes Ljungberg and Simon Charette for reviews.

Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
Ian Foote 4 years ago
parent
commit
f42ccdd835

+ 3 - 3
django/db/backends/oracle/operations.py

@@ -6,7 +6,7 @@ from django.conf import settings
 from django.db import DatabaseError, NotSupportedError
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.backends.utils import strip_quotes, truncate_name
-from django.db.models import AutoField, Exists, ExpressionWrapper
+from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
 from django.db.models.expressions import RawSQL
 from django.db.models.sql.where import WhereNode
 from django.utils import timezone
@@ -202,7 +202,7 @@ END;
         # Oracle stores empty strings as null. If the field accepts the empty
         # string, undo this to adhere to the Django convention of using
         # the empty string instead of null.
-        if expression.field.empty_strings_allowed:
+        if expression.output_field.empty_strings_allowed:
             converters.append(
                 self.convert_empty_bytes
                 if internal_type == 'BinaryField' else
@@ -639,7 +639,7 @@ END;
         Oracle supports only EXISTS(...) or filters in the WHERE clause, others
         must be compared with True.
         """
-        if isinstance(expression, (Exists, WhereNode)):
+        if isinstance(expression, (Exists, Lookup, WhereNode)):
             return True
         if isinstance(expression, ExpressionWrapper) and expression.conditional:
             return self.conditional_expression_supported_in_where_clause(expression.expression)

+ 3 - 3
django/db/models/expressions.py

@@ -1248,9 +1248,9 @@ class OrderBy(Expression):
         return (template % placeholders).rstrip(), params
 
     def as_oracle(self, compiler, connection):
-        # Oracle doesn't allow ORDER BY EXISTS() unless it's wrapped in
-        # a CASE WHEN.
-        if isinstance(self.expression, Exists):
+        # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
+        # in a CASE WHEN.
+        if connection.ops.conditional_expression_supported_in_where_clause(self.expression):
             copy = self.copy()
             copy.expression = Case(
                 When(self.expression, then=True),

+ 3 - 0
django/db/models/fields/related_lookups.py

@@ -22,6 +22,9 @@ class MultiColSource:
     def get_lookup(self, lookup):
         return self.output_field.get_lookup(lookup)
 
+    def resolve_expression(self, *args, **kwargs):
+        return self
+
 
 def get_normalized_value(value, lhs):
     from django.db.models import Model

+ 40 - 31
django/db/models/lookups.py

@@ -1,11 +1,10 @@
 import itertools
 import math
-from copy import copy
 
 from django.core.exceptions import EmptyResultSet
-from django.db.models.expressions import Case, Func, Value, When
+from django.db.models.expressions import Case, Expression, Func, Value, When
 from django.db.models.fields import (
-    CharField, DateTimeField, Field, IntegerField, UUIDField,
+    BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
 )
 from django.db.models.query_utils import RegisterLookupMixin
 from django.utils.datastructures import OrderedSet
@@ -13,7 +12,7 @@ from django.utils.functional import cached_property
 from django.utils.hashable import make_hashable
 
 
-class Lookup:
+class Lookup(Expression):
     lookup_name = None
     prepare_rhs = True
     can_use_none_as_rhs = False
@@ -21,6 +20,7 @@ class Lookup:
     def __init__(self, lhs, rhs):
         self.lhs, self.rhs = lhs, rhs
         self.rhs = self.get_prep_lookup()
+        self.lhs = self.get_prep_lhs()
         if hasattr(self.lhs, 'get_bilateral_transforms'):
             bilateral_transforms = self.lhs.get_bilateral_transforms()
         else:
@@ -72,12 +72,20 @@ class Lookup:
             self.lhs, self.rhs = new_exprs
 
     def get_prep_lookup(self):
-        if hasattr(self.rhs, 'resolve_expression'):
+        if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'):
             return self.rhs
-        if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
-            return self.lhs.output_field.get_prep_value(self.rhs)
+        if hasattr(self.lhs, 'output_field'):
+            if hasattr(self.lhs.output_field, 'get_prep_value'):
+                return self.lhs.output_field.get_prep_value(self.rhs)
+        elif self.rhs_is_direct_value():
+            return Value(self.rhs)
         return self.rhs
 
+    def get_prep_lhs(self):
+        if hasattr(self.lhs, 'resolve_expression'):
+            return self.lhs
+        return Value(self.lhs)
+
     def get_db_prep_lookup(self, value, connection):
         return ('%s', [value])
 
@@ -85,7 +93,11 @@ class Lookup:
         lhs = lhs or self.lhs
         if hasattr(lhs, 'resolve_expression'):
             lhs = lhs.resolve_expression(compiler.query)
-        return compiler.compile(lhs)
+        sql, params = compiler.compile(lhs)
+        if isinstance(lhs, Lookup):
+            # Wrapped in parentheses to respect operator precedence.
+            sql = f'({sql})'
+        return sql, params
 
     def process_rhs(self, compiler, connection):
         value = self.rhs
@@ -110,22 +122,12 @@ class Lookup:
     def rhs_is_direct_value(self):
         return not hasattr(self.rhs, 'as_sql')
 
-    def relabeled_clone(self, relabels):
-        new = copy(self)
-        new.lhs = new.lhs.relabeled_clone(relabels)
-        if hasattr(new.rhs, 'relabeled_clone'):
-            new.rhs = new.rhs.relabeled_clone(relabels)
-        return new
-
     def get_group_by_cols(self, alias=None):
-        cols = self.lhs.get_group_by_cols()
-        if hasattr(self.rhs, 'get_group_by_cols'):
-            cols.extend(self.rhs.get_group_by_cols())
+        cols = []
+        for source in self.get_source_expressions():
+            cols.extend(source.get_group_by_cols())
         return cols
 
-    def as_sql(self, compiler, connection):
-        raise NotImplementedError
-
     def as_oracle(self, compiler, connection):
         # Oracle doesn't allow EXISTS() and filters to be compared to another
         # expression unless they're wrapped in a CASE WHEN.
@@ -140,16 +142,8 @@ class Lookup:
         return lookup.as_sql(compiler, connection)
 
     @cached_property
-    def contains_aggregate(self):
-        return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
-
-    @cached_property
-    def contains_over_clause(self):
-        return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
-
-    @property
-    def is_summary(self):
-        return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
+    def output_field(self):
+        return BooleanField()
 
     @property
     def identity(self):
@@ -163,6 +157,21 @@ class Lookup:
     def __hash__(self):
         return hash(make_hashable(self.identity))
 
+    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
+        c = self.copy()
+        c.is_summary = summarize
+        c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+        c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
+        return c
+
+    def select_format(self, compiler, sql, params):
+        # Wrap filters with a CASE WHEN expression if a database backend
+        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
+        # BY list.
+        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
+            sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
+        return sql, params
+
 
 class Transform(RegisterLookupMixin, Func):
     """

+ 3 - 3
django/db/models/sql/query.py

@@ -1262,9 +1262,9 @@ class Query(BaseExpression):
         if hasattr(filter_expr, 'resolve_expression'):
             if not getattr(filter_expr, 'conditional', False):
                 raise TypeError('Cannot filter against a non-conditional expression.')
-            condition = self.build_lookup(
-                ['exact'], filter_expr.resolve_expression(self, allow_joins=allow_joins), True
-            )
+            condition = filter_expr.resolve_expression(self, allow_joins=allow_joins)
+            if not isinstance(condition, Lookup):
+                condition = self.build_lookup(['exact'], condition, True)
             clause = self.where_class()
             clause.add(condition, AND)
             return clause, []

+ 19 - 0
django/db/models/sql/where.py

@@ -208,6 +208,25 @@ class WhereNode(tree.Node):
         clone.resolved = True
         return clone
 
+    @cached_property
+    def output_field(self):
+        from django.db.models import BooleanField
+        return BooleanField()
+
+    def select_format(self, compiler, sql, params):
+        # Wrap filters with a CASE WHEN expression if a database backend
+        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
+        # BY list.
+        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
+            sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
+        return sql, params
+
+    def get_db_converters(self, connection):
+        return self.output_field.get_db_converters(connection)
+
+    def get_lookup(self, lookup):
+        return self.output_field.get_lookup(lookup)
+
 
 class NothingNode:
     """A node that matches nothing."""

+ 11 - 0
docs/ref/models/conditional-expressions.txt

@@ -48,6 +48,10 @@ objects that have an ``output_field`` that is a
 :class:`~django.db.models.BooleanField`. The result is provided using the
 ``then`` keyword.
 
+.. versionchanged:: 4.0
+
+    Support for lookup expressions was added.
+
 Some examples::
 
     >>> from django.db.models import F, Q, When
@@ -68,6 +72,13 @@ Some examples::
     ...     account_type=OuterRef('account_type'),
     ... ).exclude(pk=OuterRef('pk')).values('pk')
     >>> When(Exists(non_unique_account_type), then=Value('non unique'))
+    >>> # Condition can be created using lookup expressions.
+    >>> from django.db.models.lookups import GreaterThan, LessThan
+    >>> When(
+    ...     GreaterThan(F('registered_on'), date(2014, 1, 1)) &
+    ...     LessThan(F('registered_on'), date(2015, 1, 1)),
+    ...     then='account_type',
+    ... )
 
 Keep in mind that each of these values can be an expression.
 

+ 8 - 0
docs/ref/models/expressions.txt

@@ -25,6 +25,7 @@ Some examples
 
     from django.db.models import Count, F, Value
     from django.db.models.functions import Length, Upper
+    from django.db.models.lookups import GreaterThan
 
     # Find companies that have more employees than chairs.
     Company.objects.filter(num_employees__gt=F('num_chairs'))
@@ -76,6 +77,13 @@ Some examples
         Exists(Employee.objects.filter(company=OuterRef('pk'), salary__gt=10))
     )
 
+    # Lookup expressions can also be used directly in filters
+    Company.objects.filter(GreaterThan(F('num_employees'), F('num_chairs')))
+    # or annotations.
+    Company.objects.annotate(
+        need_chairs=GreaterThan(F('num_employees'), F('num_chairs')),
+    )
+
 Built-in Expressions
 ====================
 

+ 16 - 6
docs/ref/models/lookups.txt

@@ -177,16 +177,21 @@ following methods:
     comparison between ``lhs`` and ``rhs`` such as ``lhs in rhs`` or
     ``lhs > rhs``.
 
-    The notation to use a lookup in an expression is
-    ``<lhs>__<lookup_name>=<rhs>``.
+    The primary notation to use a lookup in an expression is
+    ``<lhs>__<lookup_name>=<rhs>``. Lookups can also be used directly in
+    ``QuerySet`` filters::
 
-    This class acts as a query expression, but, since it has ``=<rhs>`` on its
-    construction, lookups must always be the end of a lookup expression.
+         Book.objects.filter(LessThan(F('word_count'), 7500))
+
+    …or annotations::
+
+         Book.objects.annotate(is_short_story=LessThan(F('word_count'), 7500))
 
     .. attribute:: lhs
 
-        The left-hand side - what is being looked up. The object must follow
-        the :ref:`Query Expression API <query-expression>`.
+        The left-hand side - what is being looked up. The object typically
+        follows the :ref:`Query Expression API <query-expression>`. It may also
+        be a plain value.
 
     .. attribute:: rhs
 
@@ -213,3 +218,8 @@ following methods:
     .. method:: process_rhs(compiler, connection)
 
         Behaves the same way as :meth:`process_lhs`, for the right-hand side.
+
+    .. versionchanged:: 4.0
+
+        Support for using lookups in ``QuerySet`` annotations, aggregations,
+        and directly in filters was added.

+ 3 - 0
docs/releases/4.0.txt

@@ -277,6 +277,9 @@ Models
 * The ``skip_locked`` argument of :meth:`.QuerySet.select_for_update()` is now
   allowed on MariaDB 10.6+.
 
+* :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet``
+  annotations, aggregations, and directly in filters.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 159 - 2
tests/lookup/tests.py

@@ -6,9 +6,13 @@ from operator import attrgetter
 from django.core.exceptions import FieldError
 from django.db import connection, models
 from django.db.models import (
-    BooleanField, Exists, ExpressionWrapper, F, Max, OuterRef, Q,
+    BooleanField, Case, Exists, ExpressionWrapper, F, Max, OuterRef, Q,
+    Subquery, Value, When,
+)
+from django.db.models.functions import Cast, Substr
+from django.db.models.lookups import (
+    Exact, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual,
 )
-from django.db.models.functions import Substr
 from django.test import TestCase, skipUnlessDBFeature
 from django.test.utils import isolate_apps
 
@@ -1020,3 +1024,156 @@ class LookupTests(TestCase):
             )),
             [stock_1, stock_2],
         )
+
+
+class LookupQueryingTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.s1 = Season.objects.create(year=1942, gt=1942)
+        cls.s2 = Season.objects.create(year=1842, gt=1942)
+        cls.s3 = Season.objects.create(year=2042, gt=1942)
+
+    def test_annotate(self):
+        qs = Season.objects.annotate(equal=Exact(F('year'), 1942))
+        self.assertCountEqual(
+            qs.values_list('year', 'equal'),
+            ((1942, True), (1842, False), (2042, False)),
+        )
+
+    def test_alias(self):
+        qs = Season.objects.alias(greater=GreaterThan(F('year'), 1910))
+        self.assertCountEqual(qs.filter(greater=True), [self.s1, self.s3])
+
+    def test_annotate_value_greater_than_value(self):
+        qs = Season.objects.annotate(greater=GreaterThan(Value(40), Value(30)))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, True), (1842, True), (2042, True)),
+        )
+
+    def test_annotate_field_greater_than_field(self):
+        qs = Season.objects.annotate(greater=GreaterThan(F('year'), F('gt')))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, False), (1842, False), (2042, True)),
+        )
+
+    def test_annotate_field_greater_than_value(self):
+        qs = Season.objects.annotate(greater=GreaterThan(F('year'), Value(1930)))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, True), (1842, False), (2042, True)),
+        )
+
+    def test_annotate_field_greater_than_literal(self):
+        qs = Season.objects.annotate(greater=GreaterThan(F('year'), 1930))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, True), (1842, False), (2042, True)),
+        )
+
+    def test_annotate_literal_greater_than_field(self):
+        qs = Season.objects.annotate(greater=GreaterThan(1930, F('year')))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, False), (1842, True), (2042, False)),
+        )
+
+    def test_annotate_less_than_float(self):
+        qs = Season.objects.annotate(lesser=LessThan(F('year'), 1942.1))
+        self.assertCountEqual(
+            qs.values_list('year', 'lesser'),
+            ((1942, True), (1842, True), (2042, False)),
+        )
+
+    def test_annotate_greater_than_or_equal(self):
+        qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, True), (1842, False), (2042, True)),
+        )
+
+    def test_annotate_greater_than_or_equal_float(self):
+        qs = Season.objects.annotate(greater=GreaterThanOrEqual(F('year'), 1942.1))
+        self.assertCountEqual(
+            qs.values_list('year', 'greater'),
+            ((1942, False), (1842, False), (2042, True)),
+        )
+
+    def test_combined_lookups(self):
+        expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
+        qs = Season.objects.annotate(gte=expression)
+        self.assertCountEqual(
+            qs.values_list('year', 'gte'),
+            ((1942, True), (1842, False), (2042, True)),
+        )
+
+    def test_lookup_in_filter(self):
+        qs = Season.objects.filter(GreaterThan(F('year'), 1910))
+        self.assertCountEqual(qs, [self.s1, self.s3])
+
+    def test_filter_lookup_lhs(self):
+        qs = Season.objects.annotate(before_20=LessThan(F('year'), 2000)).filter(
+            before_20=LessThan(F('year'), 1900),
+        )
+        self.assertCountEqual(qs, [self.s2, self.s3])
+
+    def test_filter_wrapped_lookup_lhs(self):
+        qs = Season.objects.annotate(before_20=ExpressionWrapper(
+            Q(year__lt=2000),
+            output_field=BooleanField(),
+        )).filter(before_20=LessThan(F('year'), 1900)).values_list('year', flat=True)
+        self.assertCountEqual(qs, [1842, 2042])
+
+    def test_filter_exists_lhs(self):
+        qs = Season.objects.annotate(before_20=Exists(
+            Season.objects.filter(pk=OuterRef('pk'), year__lt=2000),
+        )).filter(before_20=LessThan(F('year'), 1900))
+        self.assertCountEqual(qs, [self.s2, self.s3])
+
+    def test_filter_subquery_lhs(self):
+        qs = Season.objects.annotate(before_20=Subquery(
+            Season.objects.filter(pk=OuterRef('pk')).values(
+                lesser=LessThan(F('year'), 2000),
+            ),
+        )).filter(before_20=LessThan(F('year'), 1900))
+        self.assertCountEqual(qs, [self.s2, self.s3])
+
+    def test_combined_lookups_in_filter(self):
+        expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
+        qs = Season.objects.filter(expression)
+        self.assertCountEqual(qs, [self.s1, self.s3])
+
+    def test_combined_annotated_lookups_in_filter(self):
+        expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
+        qs = Season.objects.annotate(gte=expression).filter(gte=True)
+        self.assertCountEqual(qs, [self.s1, self.s3])
+
+    def test_combined_annotated_lookups_in_filter_false(self):
+        expression = Exact(F('year'), 1942) | GreaterThan(F('year'), 1942)
+        qs = Season.objects.annotate(gte=expression).filter(gte=False)
+        self.assertSequenceEqual(qs, [self.s2])
+
+    def test_lookup_in_order_by(self):
+        qs = Season.objects.order_by(LessThan(F('year'), 1910), F('year'))
+        self.assertSequenceEqual(qs, [self.s1, self.s3, self.s2])
+
+    @skipUnlessDBFeature('supports_boolean_expr_in_select_clause')
+    def test_aggregate_combined_lookup(self):
+        expression = Cast(GreaterThan(F('year'), 1900), models.IntegerField())
+        qs = Season.objects.aggregate(modern=models.Sum(expression))
+        self.assertEqual(qs['modern'], 2)
+
+    def test_conditional_expression(self):
+        qs = Season.objects.annotate(century=Case(
+            When(
+                GreaterThan(F('year'), 1900) & LessThanOrEqual(F('year'), 2000),
+                then=Value('20th'),
+            ),
+            default=Value('other'),
+        )).values('year', 'century')
+        self.assertCountEqual(qs, [
+            {'year': 1942, 'century': '20th'},
+            {'year': 1842, 'century': 'other'},
+            {'year': 2042, 'century': 'other'},
+        ])