|
@@ -3,7 +3,7 @@ Classes to represent the definitions of aggregate functions.
|
|
|
"""
|
|
|
from django.core.exceptions import FieldError
|
|
|
from django.db.models.expressions import Func, Star
|
|
|
-from django.db.models.fields import FloatField, IntegerField
|
|
|
+from django.db.models.fields import DecimalField, FloatField, IntegerField
|
|
|
|
|
|
__all__ = [
|
|
|
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
|
@@ -41,9 +41,11 @@ class Avg(Aggregate):
|
|
|
function = 'AVG'
|
|
|
name = 'Avg'
|
|
|
|
|
|
- def __init__(self, expression, **extra):
|
|
|
- output_field = extra.pop('output_field', FloatField())
|
|
|
- super(Avg, self).__init__(expression, output_field=output_field, **extra)
|
|
|
+ def _resolve_output_field(self):
|
|
|
+ source_field = self.get_source_fields()[0]
|
|
|
+ if isinstance(source_field, (IntegerField, DecimalField)):
|
|
|
+ self._output_field = FloatField()
|
|
|
+ super(Avg, self)._resolve_output_field()
|
|
|
|
|
|
def as_oracle(self, compiler, connection):
|
|
|
if self.output_field.get_internal_type() == 'DurationField':
|