瀏覽代碼

Fixed #24485 -- Allowed combined expressions to set output_field

Josh Smeaton 10 年之前
父節點
當前提交
02a2943e4c

+ 3 - 1
django/db/models/__init__.py

@@ -2,7 +2,9 @@ from functools import wraps
 
 from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured  # NOQA
 from django.db.models.query import Q, QuerySet, Prefetch  # NOQA
-from django.db.models.expressions import Expression, F, Value, Func, Case, When  # NOQA
+from django.db.models.expressions import (  # NOQA
+    Expression, ExpressionWrapper, F, Value, Func, Case, When,
+)
 from django.db.models.manager import Manager  # NOQA
 from django.db.models.base import Model  # NOQA
 from django.db.models.aggregates import *  # NOQA

+ 26 - 3
django/db/models/expressions.py

@@ -126,12 +126,12 @@ class BaseExpression(object):
     # aggregate specific fields
     is_summary = False
 
-    def get_db_converters(self, connection):
-        return [self.convert_value] + self.output_field.get_db_converters(connection)
-
     def __init__(self, output_field=None):
         self._output_field = output_field
 
+    def get_db_converters(self, connection):
+        return [self.convert_value] + self.output_field.get_db_converters(connection)
+
     def get_source_expressions(self):
         return []
 
@@ -656,6 +656,29 @@ class Ref(Expression):
         return [self]
 
 
+class ExpressionWrapper(Expression):
+    """
+    An expression that can wrap another expression so that it can provide
+    extra context to the inner expression, such as the output_field.
+    """
+
+    def __init__(self, expression, output_field):
+        super(ExpressionWrapper, self).__init__(output_field=output_field)
+        self.expression = expression
+
+    def set_source_expressions(self, exprs):
+        self.expression = exprs[0]
+
+    def get_source_expressions(self):
+        return [self.expression]
+
+    def as_sql(self, compiler, connection):
+        return self.expression.as_sql(compiler, connection)
+
+    def __repr__(self):
+        return "{}({})".format(self.__class__.__name__, self.expression)
+
+
 class When(Expression):
     template = 'WHEN %(condition)s THEN %(result)s'
 

+ 34 - 11
docs/ref/models/expressions.txt

@@ -161,6 +161,27 @@ values, rather than on Python values.
 This is documented in :ref:`using F() expressions in queries
 <using-f-expressions-in-filters>`.
 
+.. _using-f-with-annotations:
+
+Using ``F()`` with annotations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``F()`` can be used to create dynamic fields on your models by combining
+different fields with arithmetic::
+
+    company = Company.objects.annotate(
+        chairs_needed=F('num_employees') - F('num_chairs'))
+
+If the fields that you're combining are of different types you'll need
+to tell Django what kind of field will be returned. Since ``F()`` does not
+directly support ``output_field`` you will need to wrap the expression with
+:class:`ExpressionWrapper`::
+
+    from django.db.models import DateTimeField, ExpressionWrapper, F
+
+    Ticket.objects.annotate(
+        expires=ExpressionWrapper(
+            F('active_at') + F('duration'), output_field=DateTimeField()))
 
 .. _func-expressions:
 
@@ -274,17 +295,6 @@ should define the desired ``output_field``. For example, adding an
 ``IntegerField()`` and a ``FloatField()`` together should probably have
 ``output_field=FloatField()`` defined.
 
-.. note::
-
-    When you need to define the ``output_field`` for ``F`` expression
-    arithmetic between different types, it's necessary to surround the
-    expression in another expression::
-
-        from django.db.models import DateTimeField, Expression, F
-
-        Race.objects.annotate(finish=Expression(
-            F('start') + F('duration'), output_field=DateTimeField()))
-
 .. versionchanged:: 1.8
 
     ``output_field`` is a new parameter.
@@ -343,6 +353,19 @@ instantiating the model field as any arguments relating to data validation
 (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
 output value.
 
+``ExpressionWrapper()`` expressions
+-----------------------------------
+
+.. class:: ExpressionWrapper(expression, output_field)
+
+.. versionadded:: 1.8
+
+``ExpressionWrapper`` simply surrounds another expression and provides access
+to properties, such as ``output_field``, that may not be available on other
+expressions. ``ExpressionWrapper`` is necessary when using arithmetic on
+``F()`` expressions with different types as described in
+:ref:`using-f-with-annotations`.
+
 Conditional expressions
 -----------------------
 

+ 9 - 0
tests/annotations/models.py

@@ -84,3 +84,12 @@ class Company(models.Model):
         return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
             % (self.name, self.motto, self.ticker_name, self.description)
         )
+
+
+@python_2_unicode_compatible
+class Ticket(models.Model):
+    active_at = models.DateTimeField()
+    duration = models.DurationField()
+
+    def __str__(self):
+        return '{} - {}'.format(self.active_at, self.duration)

+ 21 - 2
tests/annotations/tests.py

@@ -5,13 +5,14 @@ from decimal import Decimal
 
 from django.core.exceptions import FieldDoesNotExist, FieldError
 from django.db.models import (
-    F, BooleanField, CharField, Count, Func, IntegerField, Sum, Value,
+    F, BooleanField, CharField, Count, DateTimeField, ExpressionWrapper, Func,
+    IntegerField, Sum, Value,
 )
 from django.test import TestCase
 from django.utils import six
 
 from .models import (
-    Author, Book, Company, DepartmentStore, Employee, Publisher, Store,
+    Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket,
 )
 
 
@@ -135,6 +136,24 @@ class NonAggregateAnnotationTestCase(TestCase):
         for book in books:
             self.assertEqual(book.num_awards, book.publisher.num_awards)
 
+    def test_mixed_type_annotation_date_interval(self):
+        active = datetime.datetime(2015, 3, 20, 14, 0, 0)
+        duration = datetime.timedelta(hours=1)
+        expires = datetime.datetime(2015, 3, 20, 14, 0, 0) + duration
+        Ticket.objects.create(active_at=active, duration=duration)
+        t = Ticket.objects.annotate(
+            expires=ExpressionWrapper(F('active_at') + F('duration'), output_field=DateTimeField())
+        ).first()
+        self.assertEqual(t.expires, expires)
+
+    def test_mixed_type_annotation_numbers(self):
+        test = self.b1
+        b = Book.objects.annotate(
+            combined=ExpressionWrapper(F('pages') + F('rating'), output_field=IntegerField())
+        ).get(isbn=test.isbn)
+        combined = int(test.pages + test.rating)
+        self.assertEqual(b.combined, combined)
+
     def test_annotate_with_aggregation(self):
         books = Book.objects.annotate(
             is_book=Value(1, output_field=IntegerField()),

+ 6 - 2
tests/expressions/tests.py

@@ -11,8 +11,8 @@ from django.db.models.aggregates import (
     Avg, Count, Max, Min, StdDev, Sum, Variance,
 )
 from django.db.models.expressions import (
-    F, Case, Col, Date, DateTime, Func, OrderBy, Random, RawSQL, Ref, Value,
-    When,
+    F, Case, Col, Date, DateTime, ExpressionWrapper, Func, OrderBy, Random,
+    RawSQL, Ref, Value, When,
 )
 from django.db.models.functions import (
     Coalesce, Concat, Length, Lower, Substr, Upper,
@@ -855,6 +855,10 @@ class ReprTests(TestCase):
         self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc)
         self.assertEqual(repr(F('published')), "F(published)")
         self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>")
+        self.assertEqual(
+            repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())),
+            "ExpressionWrapper(F(cost) + F(tax))"
+        )
         self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
         self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
         self.assertEqual(repr(Random()), "Random()")