|
@@ -77,7 +77,9 @@ class BaseQuery(object):
|
|
|
self.related_select_cols = []
|
|
|
|
|
|
# SQL aggregate-related attributes
|
|
|
- self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
|
|
|
+ self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
|
|
|
+ self.aggregate_select_mask = None
|
|
|
+ self._aggregate_select_cache = None
|
|
|
|
|
|
# Arbitrary maximum limit for select_related. Prevents infinite
|
|
|
# recursion. Can be changed by the depth parameter to select_related().
|
|
@@ -187,7 +189,15 @@ class BaseQuery(object):
|
|
|
obj.distinct = self.distinct
|
|
|
obj.select_related = self.select_related
|
|
|
obj.related_select_cols = []
|
|
|
- obj.aggregate_select = self.aggregate_select.copy()
|
|
|
+ obj.aggregates = self.aggregates.copy()
|
|
|
+ if self.aggregate_select_mask is None:
|
|
|
+ obj.aggregate_select_mask = None
|
|
|
+ else:
|
|
|
+ obj.aggregate_select_mask = self.aggregate_select_mask[:]
|
|
|
+ if self._aggregate_select_cache is None:
|
|
|
+ obj._aggregate_select_cache = None
|
|
|
+ else:
|
|
|
+ obj._aggregate_select_cache = self._aggregate_select_cache.copy()
|
|
|
obj.max_depth = self.max_depth
|
|
|
obj.extra_select = self.extra_select.copy()
|
|
|
obj.extra_tables = self.extra_tables
|
|
@@ -940,14 +950,17 @@ class BaseQuery(object):
|
|
|
"""
|
|
|
assert set(change_map.keys()).intersection(set(change_map.values())) == set()
|
|
|
|
|
|
- # 1. Update references in "select" and "where".
|
|
|
+ # 1. Update references in "select" (normal columns plus aliases),
|
|
|
+ # "group by", "where" and "having".
|
|
|
self.where.relabel_aliases(change_map)
|
|
|
- for pos, col in enumerate(self.select):
|
|
|
- if isinstance(col, (list, tuple)):
|
|
|
- old_alias = col[0]
|
|
|
- self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
|
|
|
- else:
|
|
|
- col.relabel_aliases(change_map)
|
|
|
+ self.having.relabel_aliases(change_map)
|
|
|
+ for columns in (self.select, self.aggregates.values(), self.group_by or []):
|
|
|
+ for pos, col in enumerate(columns):
|
|
|
+ if isinstance(col, (list, tuple)):
|
|
|
+ old_alias = col[0]
|
|
|
+ columns[pos] = (change_map.get(old_alias, old_alias), col[1])
|
|
|
+ else:
|
|
|
+ col.relabel_aliases(change_map)
|
|
|
|
|
|
# 2. Rename the alias in the internal table/alias datastructures.
|
|
|
for old_alias, new_alias in change_map.iteritems():
|
|
@@ -1205,11 +1218,11 @@ class BaseQuery(object):
|
|
|
opts = model._meta
|
|
|
field_list = aggregate.lookup.split(LOOKUP_SEP)
|
|
|
if (len(field_list) == 1 and
|
|
|
- aggregate.lookup in self.aggregate_select.keys()):
|
|
|
+ aggregate.lookup in self.aggregates.keys()):
|
|
|
# Aggregate is over an annotation
|
|
|
field_name = field_list[0]
|
|
|
col = field_name
|
|
|
- source = self.aggregate_select[field_name]
|
|
|
+ source = self.aggregates[field_name]
|
|
|
elif (len(field_list) > 1 or
|
|
|
field_list[0] not in [i.name for i in opts.fields]):
|
|
|
field, source, opts, join_list, last, _ = self.setup_joins(
|
|
@@ -1299,7 +1312,7 @@ class BaseQuery(object):
|
|
|
value = SQLEvaluator(value, self)
|
|
|
having_clause = value.contains_aggregate
|
|
|
|
|
|
- for alias, aggregate in self.aggregate_select.items():
|
|
|
+ for alias, aggregate in self.aggregates.items():
|
|
|
if alias == parts[0]:
|
|
|
entry = self.where_class()
|
|
|
entry.add((aggregate, lookup_type, value), AND)
|
|
@@ -1824,8 +1837,8 @@ class BaseQuery(object):
|
|
|
self.group_by = []
|
|
|
if self.connection.features.allows_group_by_pk:
|
|
|
if len(self.select) == len(self.model._meta.fields):
|
|
|
- self.group_by.append('.'.join([self.model._meta.db_table,
|
|
|
- self.model._meta.pk.column]))
|
|
|
+ self.group_by.append((self.model._meta.db_table,
|
|
|
+ self.model._meta.pk.column))
|
|
|
return
|
|
|
|
|
|
for sel in self.select:
|
|
@@ -1858,7 +1871,11 @@ class BaseQuery(object):
|
|
|
# Distinct handling is done in Count(), so don't do it at this
|
|
|
# level.
|
|
|
self.distinct = False
|
|
|
- self.aggregate_select = {None: count}
|
|
|
+
|
|
|
+ # Set only aggregate to be the count column.
|
|
|
+ # Clear out the select cache to reflect the new unmasked aggregates.
|
|
|
+ self.aggregates = {None: count}
|
|
|
+ self.set_aggregate_mask(None)
|
|
|
|
|
|
def add_select_related(self, fields):
|
|
|
"""
|
|
@@ -1920,6 +1937,29 @@ class BaseQuery(object):
|
|
|
for key in set(self.extra_select).difference(set(names)):
|
|
|
del self.extra_select[key]
|
|
|
|
|
|
+ def set_aggregate_mask(self, names):
|
|
|
+ "Set the mask of aggregates that will actually be returned by the SELECT"
|
|
|
+ self.aggregate_select_mask = names
|
|
|
+ self._aggregate_select_cache = None
|
|
|
+
|
|
|
+ def _aggregate_select(self):
|
|
|
+ """The SortedDict of aggregate columns that are not masked, and should
|
|
|
+ be used in the SELECT clause.
|
|
|
+
|
|
|
+ This result is cached for optimization purposes.
|
|
|
+ """
|
|
|
+ if self._aggregate_select_cache is not None:
|
|
|
+ return self._aggregate_select_cache
|
|
|
+ elif self.aggregate_select_mask is not None:
|
|
|
+ self._aggregate_select_cache = SortedDict([
|
|
|
+ (k,v) for k,v in self.aggregates.items()
|
|
|
+ if k in self.aggregate_select_mask
|
|
|
+ ])
|
|
|
+ return self._aggregate_select_cache
|
|
|
+ else:
|
|
|
+ return self.aggregates
|
|
|
+ aggregate_select = property(_aggregate_select)
|
|
|
+
|
|
|
def set_start(self, start):
|
|
|
"""
|
|
|
Sets the table from which to start joining. The start position is
|