aggregates.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """
  2. Classes to represent the definitions of aggregate functions.
  3. """
  4. from django.core.exceptions import FieldError
  5. from django.db.models.expressions import Func, Star
  6. from django.db.models.fields import DecimalField, FloatField, IntegerField
  7. __all__ = [
  8. 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
  9. ]
  10. class Aggregate(Func):
  11. contains_aggregate = True
  12. name = None
  13. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
  14. # Aggregates are not allowed in UPDATE queries, so ignore for_save
  15. c = super().resolve_expression(query, allow_joins, reuse, summarize)
  16. if not summarize:
  17. expressions = c.get_source_expressions()
  18. for index, expr in enumerate(expressions):
  19. if expr.contains_aggregate:
  20. before_resolved = self.get_source_expressions()[index]
  21. name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
  22. raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
  23. return c
  24. @property
  25. def default_alias(self):
  26. expressions = self.get_source_expressions()
  27. if len(expressions) == 1 and hasattr(expressions[0], 'name'):
  28. return '%s__%s' % (expressions[0].name, self.name.lower())
  29. raise TypeError("Complex expressions require an alias")
  30. def get_group_by_cols(self):
  31. return []
  32. class Avg(Aggregate):
  33. function = 'AVG'
  34. name = 'Avg'
  35. def _resolve_output_field(self):
  36. source_field = self.get_source_fields()[0]
  37. if isinstance(source_field, (IntegerField, DecimalField)):
  38. return FloatField()
  39. return super()._resolve_output_field()
  40. def as_oracle(self, compiler, connection):
  41. if self.output_field.get_internal_type() == 'DurationField':
  42. expression = self.get_source_expressions()[0]
  43. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  44. return compiler.compile(
  45. SecondsToInterval(Avg(IntervalToSeconds(expression)))
  46. )
  47. return super().as_sql(compiler, connection)
  48. class Count(Aggregate):
  49. function = 'COUNT'
  50. name = 'Count'
  51. template = '%(function)s(%(distinct)s%(expressions)s)'
  52. def __init__(self, expression, distinct=False, **extra):
  53. if expression == '*':
  54. expression = Star()
  55. super().__init__(
  56. expression, distinct='DISTINCT ' if distinct else '',
  57. output_field=IntegerField(), **extra
  58. )
  59. def _get_repr_options(self):
  60. return {'distinct': self.extra['distinct'] != ''}
  61. def convert_value(self, value, expression, connection):
  62. if value is None:
  63. return 0
  64. return int(value)
  65. class Max(Aggregate):
  66. function = 'MAX'
  67. name = 'Max'
  68. class Min(Aggregate):
  69. function = 'MIN'
  70. name = 'Min'
  71. class StdDev(Aggregate):
  72. name = 'StdDev'
  73. def __init__(self, expression, sample=False, **extra):
  74. self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  75. super().__init__(expression, output_field=FloatField(), **extra)
  76. def _get_repr_options(self):
  77. return {'sample': self.function == 'STDDEV_SAMP'}
  78. def convert_value(self, value, expression, connection):
  79. if value is None:
  80. return value
  81. return float(value)
  82. class Sum(Aggregate):
  83. function = 'SUM'
  84. name = 'Sum'
  85. def as_oracle(self, compiler, connection):
  86. if self.output_field.get_internal_type() == 'DurationField':
  87. expression = self.get_source_expressions()[0]
  88. from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
  89. return compiler.compile(
  90. SecondsToInterval(Sum(IntervalToSeconds(expression)))
  91. )
  92. return super().as_sql(compiler, connection)
  93. class Variance(Aggregate):
  94. name = 'Variance'
  95. def __init__(self, expression, sample=False, **extra):
  96. self.function = 'VAR_SAMP' if sample else 'VAR_POP'
  97. super().__init__(expression, output_field=FloatField(), **extra)
  98. def _get_repr_options(self):
  99. return {'sample': self.function == 'VAR_SAMP'}
  100. def convert_value(self, value, expression, connection):
  101. if value is None:
  102. return value
  103. return float(value)