123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- """
- Classes to represent the definitions of aggregate functions.
- """
- from django.core.exceptions import FieldError
- from django.db.models.expressions import Func, Star
- from django.db.models.fields import DecimalField, FloatField, IntegerField
- __all__ = [
- 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
- ]
- class Aggregate(Func):
- contains_aggregate = True
- name = None
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- # Aggregates are not allowed in UPDATE queries, so ignore for_save
- c = super().resolve_expression(query, allow_joins, reuse, summarize)
- if not summarize:
- expressions = c.get_source_expressions()
- for index, expr in enumerate(expressions):
- if expr.contains_aggregate:
- before_resolved = self.get_source_expressions()[index]
- name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
- raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
- return c
- @property
- def default_alias(self):
- expressions = self.get_source_expressions()
- if len(expressions) == 1 and hasattr(expressions[0], 'name'):
- return '%s__%s' % (expressions[0].name, self.name.lower())
- raise TypeError("Complex expressions require an alias")
- def get_group_by_cols(self):
- return []
- class Avg(Aggregate):
- function = 'AVG'
- name = 'Avg'
- def _resolve_output_field(self):
- source_field = self.get_source_fields()[0]
- if isinstance(source_field, (IntegerField, DecimalField)):
- return FloatField()
- return super()._resolve_output_field()
- def as_oracle(self, compiler, connection):
- if self.output_field.get_internal_type() == 'DurationField':
- expression = self.get_source_expressions()[0]
- from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
- return compiler.compile(
- SecondsToInterval(Avg(IntervalToSeconds(expression)))
- )
- return super().as_sql(compiler, connection)
- class Count(Aggregate):
- function = 'COUNT'
- name = 'Count'
- template = '%(function)s(%(distinct)s%(expressions)s)'
- def __init__(self, expression, distinct=False, **extra):
- if expression == '*':
- expression = Star()
- super().__init__(
- expression, distinct='DISTINCT ' if distinct else '',
- output_field=IntegerField(), **extra
- )
- def _get_repr_options(self):
- return {'distinct': self.extra['distinct'] != ''}
- def convert_value(self, value, expression, connection):
- if value is None:
- return 0
- return int(value)
- class Max(Aggregate):
- function = 'MAX'
- name = 'Max'
- class Min(Aggregate):
- function = 'MIN'
- name = 'Min'
- class StdDev(Aggregate):
- name = 'StdDev'
- def __init__(self, expression, sample=False, **extra):
- self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
- super().__init__(expression, output_field=FloatField(), **extra)
- def _get_repr_options(self):
- return {'sample': self.function == 'STDDEV_SAMP'}
- def convert_value(self, value, expression, connection):
- if value is None:
- return value
- return float(value)
- class Sum(Aggregate):
- function = 'SUM'
- name = 'Sum'
- def as_oracle(self, compiler, connection):
- if self.output_field.get_internal_type() == 'DurationField':
- expression = self.get_source_expressions()[0]
- from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
- return compiler.compile(
- SecondsToInterval(Sum(IntervalToSeconds(expression)))
- )
- return super().as_sql(compiler, connection)
- class Variance(Aggregate):
- name = 'Variance'
- def __init__(self, expression, sample=False, **extra):
- self.function = 'VAR_SAMP' if sample else 'VAR_POP'
- super().__init__(expression, output_field=FloatField(), **extra)
- def _get_repr_options(self):
- return {'sample': self.function == 'VAR_SAMP'}
- def convert_value(self, value, expression, connection):
- if value is None:
- return value
- return float(value)
|