|
@@ -1,3 +1,5 @@
|
|
|
+from itertools import izip
|
|
|
+
|
|
|
from django.core.exceptions import FieldError
|
|
|
from django.db import connections
|
|
|
from django.db import transaction
|
|
@@ -9,6 +11,7 @@ from django.db.models.sql.query import (get_proxied_model, get_order_dir,
|
|
|
select_related_descend, Query)
|
|
|
from django.db.utils import DatabaseError
|
|
|
|
|
|
+
|
|
|
class SQLCompiler(object):
|
|
|
def __init__(self, query, connection, using):
|
|
|
self.query = query
|
|
@@ -794,20 +797,55 @@ class SQLInsertCompiler(SQLCompiler):
|
|
|
qn = self.connection.ops.quote_name
|
|
|
opts = self.query.model._meta
|
|
|
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
|
|
- result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
|
|
|
- values = [self.placeholder(*v) for v in self.query.values]
|
|
|
- result.append('VALUES (%s)' % ', '.join(values))
|
|
|
- params = self.query.params
|
|
|
+
|
|
|
+ has_fields = bool(self.query.fields)
|
|
|
+ fields = self.query.fields if has_fields else [opts.pk]
|
|
|
+ result.append('(%s)' % ', '.join([qn(f.column) for f in fields]))
|
|
|
+
|
|
|
+ if has_fields:
|
|
|
+ params = values = [
|
|
|
+ [
|
|
|
+ f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection)
|
|
|
+ for f in fields
|
|
|
+ ]
|
|
|
+ for obj in self.query.objs
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
|
|
|
+ params = [[]]
|
|
|
+ fields = [None]
|
|
|
+ can_bulk = not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id
|
|
|
+
|
|
|
+ if can_bulk:
|
|
|
+ placeholders = [["%s"] * len(fields)]
|
|
|
+ else:
|
|
|
+ placeholders = [
|
|
|
+ [self.placeholder(field, v) for field, v in izip(fields, val)]
|
|
|
+ for val in values
|
|
|
+ ]
|
|
|
if self.return_id and self.connection.features.can_return_id_from_insert:
|
|
|
+ params = values[0]
|
|
|
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
|
|
+ result.append("VALUES (%s)" % ", ".join(placeholders[0]))
|
|
|
r_fmt, r_params = self.connection.ops.return_insert_id()
|
|
|
result.append(r_fmt % col)
|
|
|
- params = params + r_params
|
|
|
- return ' '.join(result), params
|
|
|
+ params += r_params
|
|
|
+ return [(" ".join(result), tuple(params))]
|
|
|
+ if can_bulk and self.connection.features.has_bulk_insert:
|
|
|
+ result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
|
|
|
+ return [(" ".join(result), tuple([v for val in values for v in val]))]
|
|
|
+ else:
|
|
|
+ return [
|
|
|
+ (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
|
|
|
+ for p, vals in izip(placeholders, params)
|
|
|
+ ]
|
|
|
|
|
|
def execute_sql(self, return_id=False):
|
|
|
+ assert not (return_id and len(self.query.objs) != 1)
|
|
|
self.return_id = return_id
|
|
|
- cursor = super(SQLInsertCompiler, self).execute_sql(None)
|
|
|
+ cursor = self.connection.cursor()
|
|
|
+ for sql, params in self.as_sql():
|
|
|
+ cursor.execute(sql, params)
|
|
|
if not (return_id and cursor):
|
|
|
return
|
|
|
if self.connection.features.can_return_id_from_insert:
|