|
@@ -12,12 +12,13 @@ from copy import deepcopy
|
|
|
from django.utils.tree import Node
|
|
|
from django.utils.datastructures import SortedDict
|
|
|
from django.utils.encoding import force_unicode
|
|
|
+from django.db.backends.util import truncate_name
|
|
|
from django.db import connection
|
|
|
from django.db.models import signals
|
|
|
from django.db.models.fields import FieldDoesNotExist
|
|
|
from django.db.models.query_utils import select_related_descend
|
|
|
+from django.db.models.sql import aggregates as base_aggregates_module
|
|
|
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
|
|
|
-from django.db.models.sql.datastructures import Count
|
|
|
from django.core.exceptions import FieldError
|
|
|
from datastructures import EmptyResultSet, Empty, MultiJoin
|
|
|
from constants import *
|
|
@@ -40,6 +41,7 @@ class BaseQuery(object):
|
|
|
|
|
|
alias_prefix = 'T'
|
|
|
query_terms = QUERY_TERMS
|
|
|
+ aggregates_module = base_aggregates_module
|
|
|
|
|
|
def __init__(self, model, connection, where=WhereNode):
|
|
|
self.model = model
|
|
@@ -73,6 +75,9 @@ class BaseQuery(object):
|
|
|
self.select_related = False
|
|
|
self.related_select_cols = []
|
|
|
|
|
|
+ # SQL aggregate-related attributes
|
|
|
+ self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
|
|
|
+
|
|
|
# Arbitrary maximum limit for select_related. Prevents infinite
|
|
|
# recursion. Can be changed by the depth parameter to select_related().
|
|
|
self.max_depth = 5
|
|
@@ -178,6 +183,7 @@ class BaseQuery(object):
|
|
|
obj.distinct = self.distinct
|
|
|
obj.select_related = self.select_related
|
|
|
obj.related_select_cols = []
|
|
|
+ obj.aggregate_select = self.aggregate_select.copy()
|
|
|
obj.max_depth = self.max_depth
|
|
|
obj.extra_select = self.extra_select.copy()
|
|
|
obj.extra_tables = self.extra_tables
|
|
@@ -194,6 +200,35 @@ class BaseQuery(object):
|
|
|
obj._setup_query()
|
|
|
return obj
|
|
|
|
|
|
+ def convert_values(self, value, field):
|
|
|
+ """Convert the database-returned value into a type that is consistent
|
|
|
+ across database backends.
|
|
|
+
|
|
|
+ By default, this defers to the underlying backend operations, but
|
|
|
+ it can be overridden by Query classes for specific backends.
|
|
|
+ """
|
|
|
+ return self.connection.ops.convert_values(value, field)
|
|
|
+
|
|
|
+ def resolve_aggregate(self, value, aggregate):
|
|
|
+ """Resolve the value of aggregates returned by the database to
|
|
|
+ consistent (and reasonable) types.
|
|
|
+
|
|
|
+ This is required because of the predisposition of certain backends
|
|
|
+ to return Decimal and long types when they are not needed.
|
|
|
+ """
|
|
|
+ if value is None:
|
|
|
+ # Return None as-is
|
|
|
+ return value
|
|
|
+ elif aggregate.is_ordinal:
|
|
|
+ # Any ordinal aggregate (e.g., count) returns an int
|
|
|
+ return int(value)
|
|
|
+ elif aggregate.is_computed:
|
|
|
+ # Any computed aggregate (e.g., avg) returns a float
|
|
|
+ return float(value)
|
|
|
+ else:
|
|
|
+ # Return value depends on the type of the field being processed.
|
|
|
+ return self.convert_values(value, aggregate.field)
|
|
|
+
|
|
|
def results_iter(self):
|
|
|
"""
|
|
|
Returns an iterator over the results from executing this query.
|
|
@@ -212,29 +247,78 @@ class BaseQuery(object):
|
|
|
else:
|
|
|
fields = self.model._meta.fields
|
|
|
row = self.resolve_columns(row, fields)
|
|
|
+
|
|
|
+ if self.aggregate_select:
|
|
|
+ aggregate_start = len(self.extra_select.keys()) + len(self.select)
|
|
|
+ row = tuple(row[:aggregate_start]) + tuple([
|
|
|
+ self.resolve_aggregate(value, aggregate)
|
|
|
+ for (alias, aggregate), value
|
|
|
+ in zip(self.aggregate_select.items(), row[aggregate_start:])
|
|
|
+ ])
|
|
|
+
|
|
|
yield row
|
|
|
|
|
|
+ def get_aggregation(self):
|
|
|
+ """
|
|
|
+ Returns the dictionary with the values of the existing aggregations.
|
|
|
+ """
|
|
|
+ if not self.aggregate_select:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ # If there is a group by clause, aggregating does not add useful
|
|
|
+ # information but retrieves only the first row. Aggregate
|
|
|
+ # over the subquery instead.
|
|
|
+ if self.group_by:
|
|
|
+ from subqueries import AggregateQuery
|
|
|
+ query = AggregateQuery(self.model, self.connection)
|
|
|
+
|
|
|
+ obj = self.clone()
|
|
|
+
|
|
|
+ # Remove any aggregates marked for reduction from the subquery
|
|
|
+ # and move them to the outer AggregateQuery.
|
|
|
+ for alias, aggregate in self.aggregate_select.items():
|
|
|
+ if aggregate.is_summary:
|
|
|
+ query.aggregate_select[alias] = aggregate
|
|
|
+ del obj.aggregate_select[alias]
|
|
|
+
|
|
|
+ query.add_subquery(obj)
|
|
|
+ else:
|
|
|
+ query = self
|
|
|
+ self.select = []
|
|
|
+ self.default_cols = False
|
|
|
+ self.extra_select = {}
|
|
|
+
|
|
|
+ query.clear_ordering(True)
|
|
|
+ query.clear_limits()
|
|
|
+ query.select_related = False
|
|
|
+ query.related_select_cols = []
|
|
|
+ query.related_select_fields = []
|
|
|
+
|
|
|
+ return dict([
|
|
|
+ (alias, self.resolve_aggregate(val, aggregate))
|
|
|
+ for (alias, aggregate), val
|
|
|
+ in zip(query.aggregate_select.items(), query.execute_sql(SINGLE))
|
|
|
+ ])
|
|
|
+
|
|
|
def get_count(self):
|
|
|
"""
|
|
|
Performs a COUNT() query using the current filter constraints.
|
|
|
"""
|
|
|
- from subqueries import CountQuery
|
|
|
obj = self.clone()
|
|
|
- obj.clear_ordering(True)
|
|
|
- obj.clear_limits()
|
|
|
- obj.select_related = False
|
|
|
- obj.related_select_cols = []
|
|
|
- obj.related_select_fields = []
|
|
|
- if len(obj.select) > 1:
|
|
|
- obj = self.clone(CountQuery, _query=obj, where=self.where_class(),
|
|
|
- distinct=False)
|
|
|
- obj.select = []
|
|
|
- obj.extra_select = SortedDict()
|
|
|
+ if len(self.select) > 1:
|
|
|
+ # If a select clause exists, then the query has already started to
|
|
|
+ # specify the columns that are to be returned.
|
|
|
+ # In this case, we need to use a subquery to evaluate the count.
|
|
|
+ from subqueries import AggregateQuery
|
|
|
+ subquery = obj
|
|
|
+ subquery.clear_ordering(True)
|
|
|
+ subquery.clear_limits()
|
|
|
+
|
|
|
+ obj = AggregateQuery(obj.model, obj.connection)
|
|
|
+ obj.add_subquery(subquery)
|
|
|
+
|
|
|
obj.add_count_column()
|
|
|
- data = obj.execute_sql(SINGLE)
|
|
|
- if not data:
|
|
|
- return 0
|
|
|
- number = data[0]
|
|
|
+ number = obj.get_aggregation()[None]
|
|
|
|
|
|
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
|
|
|
# in SQL (in variants that provide them) doesn't change the COUNT
|
|
@@ -450,25 +534,41 @@ class BaseQuery(object):
|
|
|
for col in self.select:
|
|
|
if isinstance(col, (list, tuple)):
|
|
|
r = '%s.%s' % (qn(col[0]), qn(col[1]))
|
|
|
- if with_aliases and col[1] in col_aliases:
|
|
|
- c_alias = 'Col%d' % len(col_aliases)
|
|
|
- result.append('%s AS %s' % (r, c_alias))
|
|
|
- aliases.add(c_alias)
|
|
|
- col_aliases.add(c_alias)
|
|
|
+ if with_aliases:
|
|
|
+ if col[1] in col_aliases:
|
|
|
+ c_alias = 'Col%d' % len(col_aliases)
|
|
|
+ result.append('%s AS %s' % (r, c_alias))
|
|
|
+ aliases.add(c_alias)
|
|
|
+ col_aliases.add(c_alias)
|
|
|
+ else:
|
|
|
+ result.append('%s AS %s' % (r, col[1]))
|
|
|
+ aliases.add(r)
|
|
|
+ col_aliases.add(col[1])
|
|
|
else:
|
|
|
result.append(r)
|
|
|
aliases.add(r)
|
|
|
col_aliases.add(col[1])
|
|
|
else:
|
|
|
result.append(col.as_sql(quote_func=qn))
|
|
|
+
|
|
|
if hasattr(col, 'alias'):
|
|
|
aliases.add(col.alias)
|
|
|
col_aliases.add(col.alias)
|
|
|
+
|
|
|
elif self.default_cols:
|
|
|
cols, new_aliases = self.get_default_columns(with_aliases,
|
|
|
col_aliases)
|
|
|
result.extend(cols)
|
|
|
aliases.update(new_aliases)
|
|
|
+
|
|
|
+ result.extend([
|
|
|
+ '%s%s' % (
|
|
|
+ aggregate.as_sql(quote_func=qn),
|
|
|
+ alias is not None and ' AS %s' % qn(alias) or ''
|
|
|
+ )
|
|
|
+ for alias, aggregate in self.aggregate_select.items()
|
|
|
+ ])
|
|
|
+
|
|
|
for table, col in self.related_select_cols:
|
|
|
r = '%s.%s' % (qn(table), qn(col))
|
|
|
if with_aliases and col in col_aliases:
|
|
@@ -538,7 +638,7 @@ class BaseQuery(object):
|
|
|
Returns a list of strings that are joined together to go after the
|
|
|
"FROM" part of the query, as well as a list any extra parameters that
|
|
|
need to be included. Sub-classes, can override this to create a
|
|
|
- from-clause via a "select", for example (e.g. CountQuery).
|
|
|
+ from-clause via a "select".
|
|
|
|
|
|
This should only be called after any SQL construction methods that
|
|
|
might change the tables we need. This means the select columns and
|
|
@@ -635,10 +735,13 @@ class BaseQuery(object):
|
|
|
order = asc
|
|
|
result.append('%s %s' % (field, order))
|
|
|
continue
|
|
|
+ col, order = get_order_dir(field, asc)
|
|
|
+ if col in self.aggregate_select:
|
|
|
+ result.append('%s %s' % (col, order))
|
|
|
+ continue
|
|
|
if '.' in field:
|
|
|
# This came in through an extra(order_by=...) addition. Pass it
|
|
|
# on verbatim.
|
|
|
- col, order = get_order_dir(field, asc)
|
|
|
table, col = col.split('.', 1)
|
|
|
if (table, col) not in processed_pairs:
|
|
|
elt = '%s.%s' % (qn(table), col)
|
|
@@ -657,7 +760,6 @@ class BaseQuery(object):
|
|
|
ordering_aliases.append(elt)
|
|
|
result.append('%s %s' % (elt, order))
|
|
|
else:
|
|
|
- col, order = get_order_dir(field, asc)
|
|
|
elt = qn2(col)
|
|
|
if distinct and col not in select_aliases:
|
|
|
ordering_aliases.append(elt)
|
|
@@ -1068,6 +1170,48 @@ class BaseQuery(object):
|
|
|
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
|
|
|
used, next, restricted, new_nullable, dupe_set, avoid)
|
|
|
|
|
|
+ def add_aggregate(self, aggregate, model, alias, is_summary):
|
|
|
+ """
|
|
|
+ Adds a single aggregate expression to the Query
|
|
|
+ """
|
|
|
+ opts = model._meta
|
|
|
+ field_list = aggregate.lookup.split(LOOKUP_SEP)
|
|
|
+ if (len(field_list) == 1 and
|
|
|
+ aggregate.lookup in self.aggregate_select.keys()):
|
|
|
+ # Aggregate is over an annotation
|
|
|
+ field_name = field_list[0]
|
|
|
+ col = field_name
|
|
|
+ source = self.aggregate_select[field_name]
|
|
|
+ elif (len(field_list) > 1 or
|
|
|
+ field_list[0] not in [i.name for i in opts.fields]):
|
|
|
+ field, source, opts, join_list, last, _ = self.setup_joins(
|
|
|
+ field_list, opts, self.get_initial_alias(), False)
|
|
|
+
|
|
|
+ # Process the join chain to see if it can be trimmed
|
|
|
+ _, _, col, _, join_list = self.trim_joins(source, join_list, last, False)
|
|
|
+
|
|
|
+ # If the aggregate references a model or field that requires a join,
|
|
|
+ # those joins must be LEFT OUTER - empty join rows must be returned
|
|
|
+ # in order for zeros to be returned for those aggregates.
|
|
|
+ for column_alias in join_list:
|
|
|
+ self.promote_alias(column_alias, unconditional=True)
|
|
|
+
|
|
|
+ col = (join_list[-1], col)
|
|
|
+ else:
|
|
|
+ # Aggregate references a normal field
|
|
|
+ field_name = field_list[0]
|
|
|
+ source = opts.get_field(field_name)
|
|
|
+ if not (self.group_by and is_summary):
|
|
|
+ # Only use a column alias if this is a
|
|
|
+ # standalone aggregate, or an annotation
|
|
|
+ col = (opts.db_table, source.column)
|
|
|
+ else:
|
|
|
+ col = field_name
|
|
|
+
|
|
|
+ # Add the aggregate to the query
|
|
|
+ alias = truncate_name(alias, self.connection.ops.max_name_length())
|
|
|
+ aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
|
|
|
+
|
|
|
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
|
|
|
can_reuse=None, process_extras=True):
|
|
|
"""
|
|
@@ -1119,6 +1263,11 @@ class BaseQuery(object):
|
|
|
elif callable(value):
|
|
|
value = value()
|
|
|
|
|
|
+ for alias, aggregate in self.aggregate_select.items():
|
|
|
+ if alias == parts[0]:
|
|
|
+ self.having.add((aggregate, lookup_type, value), AND)
|
|
|
+ return
|
|
|
+
|
|
|
opts = self.get_meta()
|
|
|
alias = self.get_initial_alias()
|
|
|
allow_many = trim or not negate
|
|
@@ -1131,38 +1280,9 @@ class BaseQuery(object):
|
|
|
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
|
|
|
can_reuse)
|
|
|
return
|
|
|
- final = len(join_list)
|
|
|
- penultimate = last.pop()
|
|
|
- if penultimate == final:
|
|
|
- penultimate = last.pop()
|
|
|
- if trim and len(join_list) > 1:
|
|
|
- extra = join_list[penultimate:]
|
|
|
- join_list = join_list[:penultimate]
|
|
|
- final = penultimate
|
|
|
- penultimate = last.pop()
|
|
|
- col = self.alias_map[extra[0]][LHS_JOIN_COL]
|
|
|
- for alias in extra:
|
|
|
- self.unref_alias(alias)
|
|
|
- else:
|
|
|
- col = target.column
|
|
|
- alias = join_list[-1]
|
|
|
|
|
|
- while final > 1:
|
|
|
- # An optimization: if the final join is against the same column as
|
|
|
- # we are comparing against, we can go back one step in the join
|
|
|
- # chain and compare against the lhs of the join instead (and then
|
|
|
- # repeat the optimization). The result, potentially, involves less
|
|
|
- # table joins.
|
|
|
- join = self.alias_map[alias]
|
|
|
- if col != join[RHS_JOIN_COL]:
|
|
|
- break
|
|
|
- self.unref_alias(alias)
|
|
|
- alias = join[LHS_ALIAS]
|
|
|
- col = join[LHS_JOIN_COL]
|
|
|
- join_list = join_list[:-1]
|
|
|
- final -= 1
|
|
|
- if final == penultimate:
|
|
|
- penultimate = last.pop()
|
|
|
+ # Process the join chain to see if it can be trimmed
|
|
|
+ final, penultimate, col, alias, join_list = self.trim_joins(target, join_list, last, trim)
|
|
|
|
|
|
if (lookup_type == 'isnull' and value is True and not negate and
|
|
|
final > 1):
|
|
@@ -1313,7 +1433,7 @@ class BaseQuery(object):
|
|
|
field, model, direct, m2m = opts.get_field_by_name(f.name)
|
|
|
break
|
|
|
else:
|
|
|
- names = opts.get_all_field_names()
|
|
|
+ names = opts.get_all_field_names() + self.aggregate_select.keys()
|
|
|
raise FieldError("Cannot resolve keyword %r into field. "
|
|
|
"Choices are: %s" % (name, ", ".join(names)))
|
|
|
|
|
@@ -1462,6 +1582,43 @@ class BaseQuery(object):
|
|
|
|
|
|
return field, target, opts, joins, last, extra_filters
|
|
|
|
|
|
+ def trim_joins(self, target, join_list, last, trim):
|
|
|
+ """An optimization: if the final join is against the same column as
|
|
|
+ we are comparing against, we can go back one step in a join
|
|
|
+ chain and compare against the LHS of the join instead (and then
|
|
|
+ repeat the optimization). The result, potentially, involves less
|
|
|
+ table joins.
|
|
|
+
|
|
|
+ Returns a tuple
|
|
|
+ """
|
|
|
+ final = len(join_list)
|
|
|
+ penultimate = last.pop()
|
|
|
+ if penultimate == final:
|
|
|
+ penultimate = last.pop()
|
|
|
+ if trim and len(join_list) > 1:
|
|
|
+ extra = join_list[penultimate:]
|
|
|
+ join_list = join_list[:penultimate]
|
|
|
+ final = penultimate
|
|
|
+ penultimate = last.pop()
|
|
|
+ col = self.alias_map[extra[0]][LHS_JOIN_COL]
|
|
|
+ for alias in extra:
|
|
|
+ self.unref_alias(alias)
|
|
|
+ else:
|
|
|
+ col = target.column
|
|
|
+ alias = join_list[-1]
|
|
|
+ while final > 1:
|
|
|
+ join = self.alias_map[alias]
|
|
|
+ if col != join[RHS_JOIN_COL]:
|
|
|
+ break
|
|
|
+ self.unref_alias(alias)
|
|
|
+ alias = join[LHS_ALIAS]
|
|
|
+ col = join[LHS_JOIN_COL]
|
|
|
+ join_list = join_list[:-1]
|
|
|
+ final -= 1
|
|
|
+ if final == penultimate:
|
|
|
+ penultimate = last.pop()
|
|
|
+ return final, penultimate, col, alias, join_list
|
|
|
+
|
|
|
def update_dupe_avoidance(self, opts, col, alias):
|
|
|
"""
|
|
|
For a column that is one of multiple pointing to the same table, update
|
|
@@ -1554,6 +1711,7 @@ class BaseQuery(object):
|
|
|
"""
|
|
|
alias = self.get_initial_alias()
|
|
|
opts = self.get_meta()
|
|
|
+
|
|
|
try:
|
|
|
for name in field_names:
|
|
|
field, target, u2, joins, u3, u4 = self.setup_joins(
|
|
@@ -1574,7 +1732,7 @@ class BaseQuery(object):
|
|
|
except MultiJoin:
|
|
|
raise FieldError("Invalid field name: '%s'" % name)
|
|
|
except FieldError:
|
|
|
- names = opts.get_all_field_names() + self.extra_select.keys()
|
|
|
+ names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
|
|
|
names.sort()
|
|
|
raise FieldError("Cannot resolve keyword %r into field. "
|
|
|
"Choices are: %s" % (name, ", ".join(names)))
|
|
@@ -1609,38 +1767,52 @@ class BaseQuery(object):
|
|
|
if force_empty:
|
|
|
self.default_ordering = False
|
|
|
|
|
|
+ def set_group_by(self):
|
|
|
+ """
|
|
|
+ Expands the GROUP BY clause required by the query.
|
|
|
+
|
|
|
+ This will usually be the set of all non-aggregate fields in the
|
|
|
+ return data. If the database backend supports grouping by the
|
|
|
+ primary key, and the query would be equivalent, the optimization
|
|
|
+ will be made automatically.
|
|
|
+ """
|
|
|
+ if self.connection.features.allows_group_by_pk:
|
|
|
+ if len(self.select) == len(self.model._meta.fields):
|
|
|
+ self.group_by.append('.'.join([self.model._meta.db_table,
|
|
|
+ self.model._meta.pk.column]))
|
|
|
+ return
|
|
|
+
|
|
|
+ for sel in self.select:
|
|
|
+ self.group_by.append(sel)
|
|
|
+
|
|
|
def add_count_column(self):
|
|
|
"""
|
|
|
Converts the query to do count(...) or count(distinct(pk)) in order to
|
|
|
get its size.
|
|
|
"""
|
|
|
- # TODO: When group_by support is added, this needs to be adjusted so
|
|
|
- # that it doesn't totally overwrite the select list.
|
|
|
if not self.distinct:
|
|
|
if not self.select:
|
|
|
- select = Count()
|
|
|
+ count = self.aggregates_module.Count('*', is_summary=True)
|
|
|
else:
|
|
|
assert len(self.select) == 1, \
|
|
|
"Cannot add count col with multiple cols in 'select': %r" % self.select
|
|
|
- select = Count(self.select[0])
|
|
|
+ count = self.aggregates_module.Count(self.select[0])
|
|
|
else:
|
|
|
opts = self.model._meta
|
|
|
if not self.select:
|
|
|
- select = Count((self.join((None, opts.db_table, None, None)),
|
|
|
- opts.pk.column), True)
|
|
|
+ count = self.aggregates_module.Count((self.join((None, opts.db_table, None, None)), opts.pk.column),
|
|
|
+ is_summary=True, distinct=True)
|
|
|
else:
|
|
|
# Because of SQL portability issues, multi-column, distinct
|
|
|
# counts need a sub-query -- see get_count() for details.
|
|
|
assert len(self.select) == 1, \
|
|
|
"Cannot add count col with multiple cols in 'select'."
|
|
|
- select = Count(self.select[0], True)
|
|
|
|
|
|
+ count = self.aggregates_module.Count(self.select[0], distinct=True)
|
|
|
# Distinct handling is done in Count(), so don't do it at this
|
|
|
# level.
|
|
|
self.distinct = False
|
|
|
- self.select = [select]
|
|
|
- self.select_fields = [None]
|
|
|
- self.extra_select = {}
|
|
|
+ self.aggregate_select = {None: count}
|
|
|
|
|
|
def add_select_related(self, fields):
|
|
|
"""
|
|
@@ -1758,7 +1930,6 @@ class BaseQuery(object):
|
|
|
return empty_iter()
|
|
|
else:
|
|
|
return
|
|
|
-
|
|
|
cursor = self.connection.cursor()
|
|
|
cursor.execute(sql, params)
|
|
|
|