|
@@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler):
|
|
|
self.return_id = False
|
|
|
super(SQLInsertCompiler, self).__init__(*args, **kwargs)
|
|
|
|
|
|
- def placeholder(self, field, val):
|
|
|
+ def field_as_sql(self, field, val):
|
|
|
+ """
|
|
|
+ Take a field and a value intended to be saved on that field, and
|
|
|
+ return placeholder SQL and accompanying params. Checks for raw values,
|
|
|
+ expressions and fields with get_placeholder() defined in that order.
|
|
|
+
|
|
|
+ When field is None, the value is considered raw and is used as the
|
|
|
+ placeholder, with no corresponding parameters returned.
|
|
|
+ """
|
|
|
if field is None:
|
|
|
|
|
|
- return val
|
|
|
+ sql, params = val, []
|
|
|
+ elif hasattr(val, 'as_sql'):
|
|
|
+
|
|
|
+ sql, params = self.compile(val)
|
|
|
elif hasattr(field, 'get_placeholder'):
|
|
|
|
|
|
|
|
|
- return field.get_placeholder(val, self, self.connection)
|
|
|
+ sql, params = field.get_placeholder(val, self, self.connection), [val]
|
|
|
else:
|
|
|
|
|
|
- return '%s'
|
|
|
+ sql, params = '%s', [val]
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ params = self.connection.ops.modify_insert_params(sql, params)
|
|
|
+
|
|
|
+ return sql, params
|
|
|
+
|
|
|
+ def prepare_value(self, field, value):
|
|
|
+ """
|
|
|
+ Prepare a value to be used in a query by resolving it if it is an
|
|
|
+ expression and otherwise calling the field's get_db_prep_save().
|
|
|
+ """
|
|
|
+ if hasattr(value, 'resolve_expression'):
|
|
|
+ value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ if value.contains_column_references:
|
|
|
+ raise ValueError(
|
|
|
+ 'Failed to insert expression "%s" on %s. F() expressions '
|
|
|
+ 'can only be used to update, not to insert.' % (value, field)
|
|
|
+ )
|
|
|
+ if value.contains_aggregate:
|
|
|
+ raise FieldError("Aggregate functions are not allowed in this query")
|
|
|
+ else:
|
|
|
+ value = field.get_db_prep_save(value, connection=self.connection)
|
|
|
+ return value
|
|
|
+
|
|
|
+ def pre_save_val(self, field, obj):
|
|
|
+ """
|
|
|
+ Get the given field's value off the given obj. pre_save() is used for
|
|
|
+ things like auto_now on DateTimeField. Skip it if this is a raw query.
|
|
|
+ """
|
|
|
+ if self.query.raw:
|
|
|
+ return getattr(obj, field.attname)
|
|
|
+ return field.pre_save(obj, add=True)
|
|
|
+
|
|
|
+ def assemble_as_sql(self, fields, value_rows):
|
|
|
+ """
|
|
|
+ Take a sequence of N fields and a sequence of M rows of values,
|
|
|
+ generate placeholder SQL and parameters for each field and value, and
|
|
|
+ return a pair containing:
|
|
|
+ * a sequence of M rows of N SQL placeholder strings, and
|
|
|
+ * a sequence of M rows of corresponding parameter values.
|
|
|
+
|
|
|
+ Each placeholder string may contain any number of '%s' interpolation
|
|
|
+ strings, and each parameter row will contain exactly as many params
|
|
|
+ as the total number of '%s's in the corresponding placeholder row.
|
|
|
+ """
|
|
|
+ if not value_rows:
|
|
|
+ return [], []
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ rows_of_fields_as_sql = (
|
|
|
+ (self.field_as_sql(field, v) for field, v in zip(fields, row))
|
|
|
+ for row in value_rows
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
|
|
|
+
|
|
|
+
|
|
|
+ param_rows = [[p for ps in row for p in ps] for row in param_rows]
|
|
|
+
|
|
|
+ return placeholder_rows, param_rows
|
|
|
|
|
|
def as_sql(self):
|
|
|
|
|
@@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler):
|
|
|
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
|
|
|
|
|
|
if has_fields:
|
|
|
- params = values = [
|
|
|
- [
|
|
|
- f.get_db_prep_save(
|
|
|
- getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True),
|
|
|
- connection=self.connection
|
|
|
- ) for f in fields
|
|
|
- ]
|
|
|
+ value_rows = [
|
|
|
+ [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
|
|
|
for obj in self.query.objs
|
|
|
]
|
|
|
else:
|
|
|
- values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
|
|
|
- params = [[]]
|
|
|
+
|
|
|
+ value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
|
|
|
fields = [None]
|
|
|
- can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
|
|
|
- not self.return_id and self.connection.features.has_bulk_insert)
|
|
|
|
|
|
- if can_bulk:
|
|
|
- placeholders = [["%s"] * len(fields)]
|
|
|
- else:
|
|
|
- placeholders = [
|
|
|
- [self.placeholder(field, v) for field, v in zip(fields, val)]
|
|
|
- for val in values
|
|
|
- ]
|
|
|
-
|
|
|
- params = self.connection.ops.modify_insert_params(placeholders, params)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)
|
|
|
+
|
|
|
+ placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
|
|
|
+
|
|
|
if self.return_id and self.connection.features.can_return_id_from_insert:
|
|
|
- params = params[0]
|
|
|
+ params = param_rows[0]
|
|
|
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
|
|
- result.append("VALUES (%s)" % ", ".join(placeholders[0]))
|
|
|
+ result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
|
|
|
r_fmt, r_params = self.connection.ops.return_insert_id()
|
|
|
|
|
|
|
|
@@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler):
|
|
|
result.append(r_fmt % col)
|
|
|
params += r_params
|
|
|
return [(" ".join(result), tuple(params))]
|
|
|
+
|
|
|
if can_bulk:
|
|
|
- result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
|
|
|
- return [(" ".join(result), tuple(v for val in values for v in val))]
|
|
|
+ result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
|
|
|
+ return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
|
|
|
else:
|
|
|
return [
|
|
|
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
|
|
|
- for p, vals in zip(placeholders, params)
|
|
|
+ for p, vals in zip(placeholder_rows, param_rows)
|
|
|
]
|
|
|
|
|
|
def execute_sql(self, return_id=False):
|
|
@@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler):
|
|
|
connection=self.connection,
|
|
|
)
|
|
|
else:
|
|
|
- raise TypeError("Database is trying to update a relational field "
|
|
|
- "of type %s with a value of type %s. Make sure "
|
|
|
- "you are setting the correct relations" %
|
|
|
- (field.__class__.__name__, val.__class__.__name__))
|
|
|
+ raise TypeError(
|
|
|
+ "Tried to update field %s with a model instance, %r. "
|
|
|
+ "Use a value compatible with %s."
|
|
|
+ % (field, val, field.__class__.__name__)
|
|
|
+ )
|
|
|
else:
|
|
|
val = field.get_db_prep_save(val, connection=self.connection)
|
|
|
|