Browse Source

Fixed #26458 -- Based Avg's default output_field resolution on its source field type.

Thanks Tim for the review and Josh for the input.
Simon Charette 9 years ago
parent
commit
a6074e8908
2 changed files with 12 additions and 4 deletions
  1. 6 4
      django/db/models/aggregates.py
  2. 6 0
      tests/aggregation/tests.py

+ 6 - 4
django/db/models/aggregates.py

@@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions.
 """
 from django.core.exceptions import FieldError
 from django.db.models.expressions import Func, Star
-from django.db.models.fields import FloatField, IntegerField
+from django.db.models.fields import DecimalField, FloatField, IntegerField
 
 __all__ = [
     'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@@ -41,9 +41,11 @@ class Avg(Aggregate):
     function = 'AVG'
     name = 'Avg'
 
-    def __init__(self, expression, **extra):
-        output_field = extra.pop('output_field', FloatField())
-        super(Avg, self).__init__(expression, output_field=output_field, **extra)
+    def _resolve_output_field(self):
+        source_field = self.get_source_fields()[0]
+        if isinstance(source_field, (IntegerField, DecimalField)):
+            self._output_field = FloatField()
+        super(Avg, self)._resolve_output_field()
 
     def as_oracle(self, compiler, connection):
         if self.output_field.get_internal_type() == 'DurationField':

+ 6 - 0
tests/aggregation/tests.py

@@ -496,10 +496,16 @@ class AggregateTestCase(TestCase):
         self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)})
 
     def test_avg_duration_field(self):
+        # Explicit `output_field`.
         self.assertEqual(
             Publisher.objects.aggregate(Avg('duration', output_field=DurationField())),
             {'duration__avg': datetime.timedelta(days=1, hours=12)}
         )
+        # Implicit `output_field`.
+        self.assertEqual(
+            Publisher.objects.aggregate(Avg('duration')),
+            {'duration__avg': datetime.timedelta(days=1, hours=12)}
+        )
 
     def test_sum_duration_field(self):
         self.assertEqual(