Jelajahi Sumber

Fixed #22288 -- Fixed F() expressions with the __range lookup.

Matthew Wilkes 8 tahun lalu
induk
melakukan
4f138fe5a4

+ 1 - 0
AUTHORS

@@ -498,6 +498,7 @@ answer newbie questions, and generally made Django that much better:
     Matthew Schinckel <matt@schinckel.net>
     Matthew Somerville <matthew-django@dracos.co.uk>
     Matthew Tretter <m@tthewwithanm.com>
+    Matthew Wilkes <matt@matthewwilkes.name>
     Matthias Kestenholz <mk@406.ch>
     Matthias Pronk <django@masida.nl>
     Matt Hoskins <skaffenuk@googlemail.com>

+ 7 - 1
django/contrib/postgres/fields/array.py

@@ -239,7 +239,13 @@ class ArrayInLookup(In):
         values = super(ArrayInLookup, self).get_prep_lookup()
         # In.process_rhs() expects values to be hashable, so convert lists
         # to tuples.
-        return [tuple(value) for value in values]
+        prepared_values = []
+        for value in values:
+            if hasattr(value, 'resolve_expression'):
+                prepared_values.append(value)
+            else:
+                prepared_values.append(tuple(value))
+        return prepared_values
 
 
 class IndexTransform(Transform):

+ 8 - 0
django/db/backends/mysql/operations.py

@@ -155,6 +155,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         # MySQL doesn't support tz-aware datetimes
         if timezone.is_aware(value):
             if settings.USE_TZ:
@@ -171,6 +175,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         # MySQL doesn't support tz-aware times
         if timezone.is_aware(value):
             raise ValueError("MySQL backend does not support timezone-aware times.")

+ 8 - 0
django/db/backends/oracle/operations.py

@@ -408,6 +408,10 @@ WHEN (new.%(col_name)s IS NULL)
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         # cx_Oracle doesn't support tz-aware datetimes
         if timezone.is_aware(value):
             if settings.USE_TZ:
@@ -421,6 +425,10 @@ WHEN (new.%(col_name)s IS NULL)
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         if isinstance(value, six.string_types):
             return datetime.datetime.strptime(value, '%H:%M:%S')
 

+ 8 - 0
django/db/backends/sqlite3/operations.py

@@ -182,6 +182,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         # SQLite doesn't support tz-aware datetimes
         if timezone.is_aware(value):
             if settings.USE_TZ:
@@ -195,6 +199,10 @@ class DatabaseOperations(BaseDatabaseOperations):
         if value is None:
             return None
 
+        # Expression values are adapted by the database.
+        if hasattr(value, 'resolve_expression'):
+            return value
+
         # SQLite doesn't support tz-aware datetimes
         if timezone.is_aware(value):
             raise ValueError("SQLite backend does not support timezone-aware times.")

+ 52 - 19
django/db/models/lookups.py

@@ -1,3 +1,4 @@
+import itertools
 import math
 import warnings
 from copy import copy
@@ -170,6 +171,12 @@ class FieldGetDbPrepValueMixin(object):
     """
     get_db_prep_lookup_value_is_iterable = False
 
+    @classmethod
+    def get_prep_lookup_value(cls, value, output_field):
+        if hasattr(value, '_prepare'):
+            return value._prepare(output_field)
+        return output_field.get_prep_value(value)
+
     def get_db_prep_lookup(self, value, connection):
         # For relational fields, use the output_field of the 'field' attribute.
         field = getattr(self.lhs.output_field, 'field', None)
@@ -191,6 +198,51 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
     """
     get_db_prep_lookup_value_is_iterable = True
 
+    def get_prep_lookup(self):
+        prepared_values = []
+        if hasattr(self.rhs, '_prepare'):
+            # A subquery is like an iterable but its items shouldn't be
+            # prepared independently.
+            return self.rhs._prepare(self.lhs.output_field)
+        for rhs_value in self.rhs:
+            if hasattr(rhs_value, 'resolve_expression'):
+                # An expression will be handled by the database but can coexist
+                # alongside real values.
+                pass
+            elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
+                rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
+            prepared_values.append(rhs_value)
+        return prepared_values
+
+    def process_rhs(self, compiler, connection):
+        if self.rhs_is_direct_value():
+            # rhs should be an iterable of values. Use batch_process_rhs()
+            # to prepare/transform those values.
+            return self.batch_process_rhs(compiler, connection)
+        else:
+            return super(FieldGetDbPrepValueIterableMixin, self).process_rhs(compiler, connection)
+
+    def resolve_expression_parameter(self, compiler, connection, sql, param):
+        params = [param]
+        if hasattr(param, 'resolve_expression'):
+            param = param.resolve_expression(compiler.query)
+        if hasattr(param, 'as_sql'):
+            sql, params = param.as_sql(compiler, connection)
+        return sql, params
+
+    def batch_process_rhs(self, compiler, connection, rhs=None):
+        pre_processed = super(FieldGetDbPrepValueIterableMixin, self).batch_process_rhs(compiler, connection, rhs)
+        # The params list may contain expressions which compile to a
+        # sql/param pair. Zip them to get sql and param pairs that refer to the
+        # same argument and attempt to replace them with the result of
+        # compiling the param step.
+        sql, params = zip(*(
+            self.resolve_expression_parameter(compiler, connection, sql, param)
+            for sql, param in zip(*pre_processed)
+        ))
+        params = itertools.chain.from_iterable(params)
+        return sql, tuple(params)
+
 
 class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
     lookup_name = 'exact'
@@ -255,13 +307,6 @@ IntegerField.register_lookup(IntegerLessThan)
 class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
     lookup_name = 'in'
 
-    def get_prep_lookup(self):
-        if hasattr(self.rhs, '_prepare'):
-            return self.rhs._prepare(self.lhs.output_field)
-        if hasattr(self.lhs.output_field, 'get_prep_value'):
-            return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
-        return self.rhs
-
     def process_rhs(self, compiler, connection):
         db_rhs = getattr(self.rhs, '_db', None)
         if db_rhs is not None and db_rhs != connection.alias:
@@ -409,21 +454,9 @@ Field.register_lookup(IEndsWith)
 class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
     lookup_name = 'range'
 
-    def get_prep_lookup(self):
-        if hasattr(self.rhs, '_prepare'):
-            return self.rhs._prepare(self.lhs.output_field)
-        return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
-
     def get_rhs_op(self, connection, rhs):
         return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
 
-    def process_rhs(self, compiler, connection):
-        if self.rhs_is_direct_value():
-            # rhs should be an iterable of 2 values, we use batch_process_rhs
-            # to prepare/transform those values
-            return self.batch_process_rhs(compiler, connection)
-        else:
-            return super(Range, self).process_rhs(compiler, connection)
 Field.register_lookup(Range)
 
 

+ 14 - 0
django/db/models/sql/query.py

@@ -990,6 +990,20 @@ class Query(object):
             pre_joins = self.alias_refcount.copy()
             value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
             used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
+        elif isinstance(value, (list, tuple)):
+            # The items of the iterable may be expressions and therefore need
+            # to be resolved independently.
+            processed_values = []
+            used_joins = set()
+            for sub_value in value:
+                if hasattr(sub_value, 'resolve_expression'):
+                    pre_joins = self.alias_refcount.copy()
+                    processed_values.append(
+                        sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
+                    )
+                    # The used_joins for a tuple of expressions is the union of
+                    # the used_joins for the individual expressions.
+                    used_joins |= set(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0))
         # Subqueries need to use a different set of aliases than the
         # outer query. Call bump_prefix to change aliases of the inner
         # query (the value).

+ 3 - 0
docs/releases/1.11.txt

@@ -234,6 +234,9 @@ Models
 * Added support for expressions in :meth:`.QuerySet.values` and
   :meth:`~.QuerySet.values_list`.
 
+* Added support for query expressions on lookups that take multiple arguments,
+  such as ``range``.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 19 - 0
tests/expressions/models.py

@@ -61,6 +61,15 @@ class Experiment(models.Model):
         return self.end - self.start
 
 
+@python_2_unicode_compatible
+class Result(models.Model):
+    experiment = models.ForeignKey(Experiment, models.CASCADE)
+    result_time = models.DateTimeField()
+
+    def __str__(self):
+        return "Result at %s" % self.result_time
+
+
 @python_2_unicode_compatible
 class Time(models.Model):
     time = models.TimeField(null=True)
@@ -69,6 +78,16 @@ class Time(models.Model):
         return "%s" % self.time
 
 
+@python_2_unicode_compatible
+class SimulationRun(models.Model):
+    start = models.ForeignKey(Time, models.CASCADE, null=True)
+    end = models.ForeignKey(Time, models.CASCADE, null=True)
+    midpoint = models.TimeField()
+
+    def __str__(self):
+        return "%s (%s to %s)" % (self.midpoint, self.start, self.end)
+
+
 @python_2_unicode_compatible
 class UUID(models.Model):
     uuid = models.UUIDField(null=True)

+ 144 - 1
tests/expressions/tests.py

@@ -1,6 +1,7 @@
 from __future__ import unicode_literals
 
 import datetime
+import unittest
 import uuid
 from copy import deepcopy
 
@@ -17,11 +18,15 @@ from django.db.models.expressions import (
 from django.db.models.functions import (
     Coalesce, Concat, Length, Lower, Substr, Upper,
 )
+from django.db.models.sql import constants
+from django.db.models.sql.datastructures import Join
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import Approximate
 from django.utils import six
 
-from .models import UUID, Company, Employee, Experiment, Number, Time
+from .models import (
+    UUID, Company, Employee, Experiment, Number, Result, SimulationRun, Time,
+)
 
 
 class BasicExpressionsTests(TestCase):
@@ -391,6 +396,144 @@ class BasicExpressionsTests(TestCase):
         self.assertEqual(str(qs.query).count('JOIN'), 2)
 
 
+class IterableLookupInnerExpressionsTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30)
+        # MySQL requires that the values calculated for expressions don't pass
+        # outside of the field's range, so it's inconvenient to use the values
+        # in the more general tests.
+        Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo)
+        Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo)
+        Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo)
+        Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo)
+        Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo)
+
+    def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self):
+        # __in lookups can use F() expressions for integers.
+        queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10]))
+        self.assertQuerysetEqual(queryset, ['<Company: 5060 Ltd>'], ordered=False)
+        self.assertQuerysetEqual(
+            Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])),
+            ['<Company: 5040 Ltd>', '<Company: 5060 Ltd>'],
+            ordered=False
+        )
+        self.assertQuerysetEqual(
+            Company.objects.filter(
+                num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10])
+            ),
+            ['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
+            ordered=False
+        )
+
+    def test_expressions_in_lookups_join_choice(self):
+        midpoint = datetime.time(13, 0)
+        t1 = Time.objects.create(time=datetime.time(12, 0))
+        t2 = Time.objects.create(time=datetime.time(14, 0))
+        SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint)
+        SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint)
+        SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint)
+        SimulationRun.objects.create(start=None, end=None, midpoint=midpoint)
+
+        queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')])
+        self.assertQuerysetEqual(
+            queryset,
+            ['<SimulationRun: 13:00:00 (12:00:00 to 14:00:00)>'],
+            ordered=False
+        )
+        for alias in queryset.query.alias_map.values():
+            if isinstance(alias, Join):
+                self.assertEqual(alias.join_type, constants.INNER)
+
+        queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')])
+        self.assertQuerysetEqual(queryset, [], ordered=False)
+        for alias in queryset.query.alias_map.values():
+            if isinstance(alias, Join):
+                self.assertEqual(alias.join_type, constants.LOUTER)
+
+    def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self):
+        # Range lookups can use F() expressions for integers.
+        Company.objects.filter(num_employees__exact=F("num_chairs"))
+        self.assertQuerysetEqual(
+            Company.objects.filter(num_employees__range=(F('num_chairs'), 100)),
+            ['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>'],
+            ordered=False
+        )
+        self.assertQuerysetEqual(
+            Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)),
+            ['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
+            ordered=False
+        )
+        self.assertQuerysetEqual(
+            Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)),
+            ['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
+            ordered=False
+        )
+        self.assertQuerysetEqual(
+            Company.objects.filter(num_employees__range=(1, 100)),
+            [
+                '<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>',
+                '<Company: 5060 Ltd>', '<Company: 99300 Ltd>',
+            ],
+            ordered=False
+        )
+
+    @unittest.skipUnless(connection.vendor == 'sqlite',
+                         "This defensive test only works on databases that don't validate parameter types")
+    def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self):
+        """
+        This tests that SQL injection isn't possible using compilation of
+        expressions in iterable filters, as their compilation happens before
+        the main query compilation. It's limited to SQLite, as PostgreSQL,
+        Oracle and other vendors have defense in depth against this by type
+        checking. Testing against SQLite (the most permissive of the built-in
+        databases) demonstrates that the problem doesn't exist while keeping
+        the test simple.
+        """
+        queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1'])
+        self.assertQuerysetEqual(queryset, [], ordered=False)
+
+    def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self):
+        start = datetime.datetime(2016, 2, 3, 15, 0, 0)
+        end = datetime.datetime(2016, 2, 5, 15, 0, 0)
+        experiment_1 = Experiment.objects.create(
+            name='Integrity testing',
+            assigned=start.date(),
+            start=start,
+            end=end,
+            completed=end.date(),
+            estimated_time=end - start,
+        )
+        experiment_2 = Experiment.objects.create(
+            name='Taste testing',
+            assigned=start.date(),
+            start=start,
+            end=end,
+            completed=end.date(),
+            estimated_time=end - start,
+        )
+        Result.objects.create(
+            experiment=experiment_1,
+            result_time=datetime.datetime(2016, 2, 4, 15, 0, 0),
+        )
+        Result.objects.create(
+            experiment=experiment_1,
+            result_time=datetime.datetime(2016, 3, 10, 2, 0, 0),
+        )
+        Result.objects.create(
+            experiment=experiment_2,
+            result_time=datetime.datetime(2016, 1, 8, 5, 0, 0),
+        )
+
+        within_experiment_time = [F('experiment__start'), F('experiment__end')]
+        queryset = Result.objects.filter(result_time__range=within_experiment_time)
+        self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
+
+        within_experiment_time = [F('experiment__start'), F('experiment__end')]
+        queryset = Result.objects.filter(result_time__range=within_experiment_time)
+        self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
+
+
 class ExpressionsTests(TestCase):
 
     def test_F_object_deepcopy(self):

+ 28 - 0
tests/postgres_tests/test_array.py

@@ -173,12 +173,40 @@ class TestQuerying(PostgreSQLTestCase):
             self.objs[:2]
         )
 
+    @unittest.expectedFailure
+    def test_in_including_F_object(self):
+        # This test asserts that Array objects passed to filters can be
+        # constructed to contain F objects. This currently doesn't work as the
+        # psycopg2 mogrify method that generates the ARRAY() syntax is
+        # expecting literals, not column references (#27095).
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]),
+            self.objs[:2]
+        )
+
+    def test_in_as_F_object(self):
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]),
+            self.objs[:4]
+        )
+
     def test_contained_by(self):
         self.assertSequenceEqual(
             NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
             self.objs[:2]
         )
 
+    @unittest.expectedFailure
+    def test_contained_by_including_F_object(self):
+        # This test asserts that Array objects passed to filters can be
+        # constructed to contain F objects. This currently doesn't work as the
+        # psycopg2 mogrify method that generates the ARRAY() syntax is
+        # expecting literals, not column references (#27095).
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
+            self.objs[:2]
+        )
+
     def test_contains(self):
         self.assertSequenceEqual(
             NullableIntegerArrayModel.objects.filter(field__contains=[2]),