aggregates.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. """
  2. Classes to represent the definitions of aggregate functions.
  3. """
  4. from django.core.exceptions import FieldError
  5. from django.db.models.expressions import Func, Value
  6. from django.db.models.fields import IntegerField, FloatField
  7. __all__ = [
  8. 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
  9. ]
  10. class Aggregate(Func):
  11. contains_aggregate = True
  12. name = None
  13. def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
  14. assert len(self.source_expressions) == 1
  15. c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
  16. if c.source_expressions[0].contains_aggregate and not summarize:
  17. name = self.source_expressions[0].name
  18. raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
  19. c.name, name, name))
  20. c._patch_aggregate(query) # backward-compatibility support
  21. return c
  22. def refs_field(self, aggregate_types, field_types):
  23. try:
  24. return (isinstance(self, aggregate_types) and
  25. isinstance(self.input_field._output_field_or_none, field_types))
  26. except FieldError:
  27. # Sometimes we don't know the input_field's output type (for example,
  28. # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
  29. # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
  30. # have any output field.
  31. return False
  32. @property
  33. def input_field(self):
  34. return self.source_expressions[0]
  35. @property
  36. def default_alias(self):
  37. if hasattr(self.source_expressions[0], 'name'):
  38. return '%s__%s' % (self.source_expressions[0].name, self.name.lower())
  39. raise TypeError("Complex expressions require an alias")
  40. def get_group_by_cols(self):
  41. return []
  42. def _patch_aggregate(self, query):
  43. """
  44. Helper method for patching 3rd party aggregates that do not yet support
  45. the new way of subclassing. This method should be removed in 2.0
  46. add_to_query(query, alias, col, source, is_summary) will be defined on
  47. legacy aggregates which, in turn, instantiates the SQL implementation of
  48. the aggregate. In all the cases found, the general implementation of
  49. add_to_query looks like:
  50. def add_to_query(self, query, alias, col, source, is_summary):
  51. klass = SQLImplementationAggregate
  52. aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
  53. query.aggregates[alias] = aggregate
  54. By supplying a known alias, we can get the SQLAggregate out of the
  55. aggregates dict, and use the sql_function and sql_template attributes
  56. to patch *this* aggregate.
  57. """
  58. if not hasattr(self, 'add_to_query') or self.function is not None:
  59. return
  60. placeholder_alias = "_XXXXXXXX_"
  61. self.add_to_query(query, placeholder_alias, None, None, None)
  62. sql_aggregate = query.aggregates.pop(placeholder_alias)
  63. if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'):
  64. self.extra['function'] = sql_aggregate.sql_function
  65. if hasattr(sql_aggregate, 'sql_template'):
  66. self.extra['template'] = sql_aggregate.sql_template
  67. class Avg(Aggregate):
  68. function = 'AVG'
  69. name = 'Avg'
  70. def __init__(self, expression, **extra):
  71. super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
  72. def convert_value(self, value, connection):
  73. if value is None:
  74. return value
  75. return float(value)
  76. class Count(Aggregate):
  77. function = 'COUNT'
  78. name = 'Count'
  79. template = '%(function)s(%(distinct)s%(expressions)s)'
  80. def __init__(self, expression, distinct=False, **extra):
  81. if expression == '*':
  82. expression = Value(expression)
  83. expression._output_field = IntegerField()
  84. super(Count, self).__init__(
  85. expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
  86. def convert_value(self, value, connection):
  87. if value is None:
  88. return 0
  89. return int(value)
  90. class Max(Aggregate):
  91. function = 'MAX'
  92. name = 'Max'
  93. class Min(Aggregate):
  94. function = 'MIN'
  95. name = 'Min'
  96. class StdDev(Aggregate):
  97. name = 'StdDev'
  98. def __init__(self, expression, sample=False, **extra):
  99. self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  100. super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
  101. def convert_value(self, value, connection):
  102. if value is None:
  103. return value
  104. return float(value)
  105. class Sum(Aggregate):
  106. function = 'SUM'
  107. name = 'Sum'
  108. class Variance(Aggregate):
  109. name = 'Variance'
  110. def __init__(self, expression, sample=False, **extra):
  111. self.function = 'VAR_SAMP' if sample else 'VAR_POP'
  112. super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
  113. def convert_value(self, value, connection):
  114. if value is None:
  115. return value
  116. return float(value)