2
0

expressions.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import copy
  2. from django.core.exceptions import FieldError
  3. from django.db.models.constants import LOOKUP_SEP
  4. from django.db.models.fields import FieldDoesNotExist
  5. class SQLEvaluator(object):
  6. def __init__(self, expression, query, allow_joins=True, reuse=None):
  7. self.expression = expression
  8. self.opts = query.get_meta()
  9. self.reuse = reuse
  10. self.cols = []
  11. self.expression.prepare(self, query, allow_joins)
  12. def relabeled_clone(self, change_map):
  13. clone = copy.copy(self)
  14. clone.cols = []
  15. for node, col in self.cols:
  16. if hasattr(col, 'relabeled_clone'):
  17. clone.cols.append((node, col.relabeled_clone(change_map)))
  18. else:
  19. clone.cols.append((node,
  20. (change_map.get(col[0], col[0]), col[1])))
  21. return clone
  22. def get_group_by_cols(self):
  23. cols = []
  24. for node, col in self.cols:
  25. if hasattr(node, 'get_group_by_cols'):
  26. cols.extend(node.get_group_by_cols())
  27. elif isinstance(col, tuple):
  28. cols.append(col)
  29. return cols
  30. def prepare(self):
  31. return self
  32. def as_sql(self, qn, connection):
  33. return self.expression.evaluate(self, qn, connection)
  34. #####################################################
  35. # Vistor methods for initial expression preparation #
  36. #####################################################
  37. def prepare_node(self, node, query, allow_joins):
  38. for child in node.children:
  39. if hasattr(child, 'prepare'):
  40. child.prepare(self, query, allow_joins)
  41. def prepare_leaf(self, node, query, allow_joins):
  42. if not allow_joins and LOOKUP_SEP in node.name:
  43. raise FieldError("Joined field references are not permitted in this query")
  44. field_list = node.name.split(LOOKUP_SEP)
  45. if node.name in query.aggregates:
  46. self.cols.append((node, query.aggregate_select[node.name]))
  47. else:
  48. try:
  49. field, sources, opts, join_list, path = query.setup_joins(
  50. field_list, query.get_meta(),
  51. query.get_initial_alias(), self.reuse)
  52. self._used_joins = join_list
  53. targets, _, join_list = query.trim_joins(sources, join_list, path)
  54. if self.reuse is not None:
  55. self.reuse.update(join_list)
  56. for t in targets:
  57. self.cols.append((node, (join_list[-1], t.column)))
  58. except FieldDoesNotExist:
  59. raise FieldError("Cannot resolve keyword %r into field. "
  60. "Choices are: %s" % (self.name,
  61. [f.name for f in self.opts.fields]))
  62. ##################################################
  63. # Vistor methods for final expression evaluation #
  64. ##################################################
  65. def evaluate_node(self, node, qn, connection):
  66. expressions = []
  67. expression_params = []
  68. for child in node.children:
  69. if hasattr(child, 'evaluate'):
  70. sql, params = child.evaluate(self, qn, connection)
  71. else:
  72. sql, params = '%s', (child,)
  73. if len(getattr(child, 'children', [])) > 1:
  74. format = '(%s)'
  75. else:
  76. format = '%s'
  77. if sql:
  78. expressions.append(format % sql)
  79. expression_params.extend(params)
  80. return connection.ops.combine_expression(node.connector, expressions), expression_params
  81. def evaluate_leaf(self, node, qn, connection):
  82. col = None
  83. for n, c in self.cols:
  84. if n is node:
  85. col = c
  86. break
  87. if col is None:
  88. raise ValueError("Given node not found")
  89. if hasattr(col, 'as_sql'):
  90. return col.as_sql(qn, connection)
  91. else:
  92. return '%s.%s' % (qn(col[0]), qn(col[1])), []
  93. def evaluate_date_modifier_node(self, node, qn, connection):
  94. timedelta = node.children.pop()
  95. sql, params = self.evaluate_node(node, qn, connection)
  96. node.children.append(timedelta)
  97. if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
  98. return sql, params
  99. return connection.ops.date_interval_sql(sql, node.connector, timedelta), params