Browse Source

Fixed #7596. Added Model.objects.bulk_create, and make use of it in several places. This provides a performance benefit when inserting multiple objects. THanks to Russ for the review, and Simon Meers for the MySQl implementation.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16739 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Alex Gaynor 13 năm trước cách đây
mục cha
commit
7deb25b8dd

+ 9 - 11
django/contrib/auth/management/__init__.py

@@ -46,17 +46,15 @@ def create_permissions(app, created_models, verbosity, **kwargs):
         "content_type", "codename"
     ))
 
-    for ctype, (codename, name) in searched_perms:
-        # If the permissions exists, move on.
-        if (ctype.pk, codename) in all_perms:
-            continue
-        p = auth_app.Permission.objects.create(
-            codename=codename,
-            name=name,
-            content_type=ctype
-        )
-        if verbosity >= 2:
-            print "Adding permission '%s'" % p
+    objs = [
+        auth_app.Permission(codename=codename, name=name, content_type=ctype)
+        for ctype, (codename, name) in searched_perms
+        if (ctype.pk, codename) not in all_perms
+    ]
+    auth_app.Permission.objects.bulk_create(objs)
+    if verbosity >= 2:
+        for obj in objs:
+            print "Adding permission '%s'" % obj
 
 
 def create_superuser(app, created_models, verbosity, **kwargs):

+ 33 - 17
django/contrib/contenttypes/management.py

@@ -8,25 +8,41 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs):
     entries that no longer have a matching model class.
     """
     ContentType.objects.clear_cache()
-    content_types = list(ContentType.objects.filter(app_label=app.__name__.split('.')[-2]))
     app_models = get_models(app)
     if not app_models:
         return
-    for klass in app_models:
-        opts = klass._meta
-        try:
-            ct = ContentType.objects.get(app_label=opts.app_label,
-                                         model=opts.object_name.lower())
-            content_types.remove(ct)
-        except ContentType.DoesNotExist:
-            ct = ContentType(name=smart_unicode(opts.verbose_name_raw),
-                app_label=opts.app_label, model=opts.object_name.lower())
-            ct.save()
-            if verbosity >= 2:
-                print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
-    # The presence of any remaining content types means the supplied app has an
-    # undefined model. Confirm that the content type is stale before deletion.
-    if content_types:
+    # They all have the same app_label, get the first one.
+    app_label = app_models[0]._meta.app_label
+    app_models = dict(
+        (model._meta.object_name.lower(), model)
+        for model in app_models
+    )
+    # Get all the content types
+    content_types = dict(
+        (ct.model, ct)
+        for ct in ContentType.objects.filter(app_label=app_label)
+    )
+    to_remove = [
+        ct
+        for (model_name, ct) in content_types.iteritems()
+        if model_name not in app_models
+    ]
+
+    cts = ContentType.objects.bulk_create([
+        ContentType(
+            name=smart_unicode(model._meta.verbose_name_raw),
+            app_label=app_label,
+            model=model_name,
+        )
+        for (model_name, model) in app_models.iteritems()
+        if model_name not in content_types
+    ])
+    if verbosity >= 2:
+        for ct in cts:
+            print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
+
+    # Confirm that the content type is stale before deletion.
+    if to_remove:
         if kwargs.get('interactive', False):
             content_type_display = '\n'.join(['    %s | %s' % (ct.app_label, ct.model) for ct in content_types])
             ok_to_delete = raw_input("""The following content types are stale and need to be deleted:
@@ -42,7 +58,7 @@ If you're unsure, answer 'no'.
             ok_to_delete = False
 
         if ok_to_delete == 'yes':
-            for ct in content_types:
+            for ct in to_remove:
                 if verbosity >= 2:
                     print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model)
                 ct.delete()

+ 2 - 0
django/db/backends/__init__.py

@@ -301,8 +301,10 @@ class BaseDatabaseFeatures(object):
 
     can_use_chunked_reads = True
     can_return_id_from_insert = False
+    has_bulk_insert = False
     uses_autocommit = False
     uses_savepoints = False
+    can_combine_inserts_with_and_without_auto_increment_pk = False
 
     # If True, don't use integer foreign keys referring to, e.g., positive
     # integer primary keys.

+ 5 - 0
django/db/backends/mysql/base.py

@@ -124,6 +124,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     allows_group_by_pk = True
     related_fields_match_type = True
     allow_sliced_subqueries = False
+    has_bulk_insert = True
     has_select_for_update = True
     has_select_for_update_nowait = False
     supports_forward_references = False
@@ -263,6 +264,10 @@ class DatabaseOperations(BaseDatabaseOperations):
     def max_name_length(self):
         return 64
 
+    def bulk_insert_sql(self, fields, num_values):
+        items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
+        return "VALUES " + ", ".join([items_sql] * num_values)
+
 class DatabaseWrapper(BaseDatabaseWrapper):
     vendor = 'mysql'
     operators = {

+ 1 - 0
django/db/backends/postgresql_psycopg2/base.py

@@ -74,6 +74,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     can_defer_constraint_checks = True
     has_select_for_update = True
     has_select_for_update_nowait = True
+    has_bulk_insert = True
 
 
 class DatabaseWrapper(BaseDatabaseWrapper):

+ 4 - 0
django/db/backends/postgresql_psycopg2/operations.py

@@ -180,3 +180,7 @@ class DatabaseOperations(BaseDatabaseOperations):
 
     def return_insert_id(self):
         return "RETURNING %s", ()
+
+    def bulk_insert_sql(self, fields, num_values):
+        items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
+        return "VALUES " + ", ".join([items_sql] * num_values)

+ 11 - 1
django/db/backends/sqlite3/base.py

@@ -58,6 +58,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_unspecified_pk = True
     supports_1000_query_parameters = False
     supports_mixed_date_datetime_comparisons = False
+    has_bulk_insert = True
+    can_combine_inserts_with_and_without_auto_increment_pk = True
 
     def _supports_stddev(self):
         """Confirm support for STDDEV and related stats functions
@@ -106,7 +108,7 @@ class DatabaseOperations(BaseDatabaseOperations):
         return ""
 
     def pk_default_value(self):
-        return 'NULL'
+        return "NULL"
 
     def quote_name(self, name):
         if name.startswith('"') and name.endswith('"'):
@@ -154,6 +156,14 @@ class DatabaseOperations(BaseDatabaseOperations):
         # No field, or the field isn't known to be a decimal or integer
         return value
 
+    def bulk_insert_sql(self, fields, num_values):
+        res = []
+        res.append("SELECT %s" % ", ".join(
+            "%%s AS %s" % self.quote_name(f.column) for f in fields
+        ))
+        res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
+        return " ".join(res)
+
 class DatabaseWrapper(BaseDatabaseWrapper):
     vendor = 'sqlite'
     # SQLite requires LIKE statements to include an ESCAPE clause if the value

+ 3 - 11
django/db/models/base.py

@@ -540,24 +540,16 @@ class Model(object):
                     order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
                     self._order = order_value
 
+                fields = meta.local_fields
                 if not pk_set:
                     if force_update:
                         raise ValueError("Cannot force an update in save() with no primary key.")
-                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
-                        for f in meta.local_fields if not isinstance(f, AutoField)]
-                else:
-                    values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
-                        for f in meta.local_fields]
+                    fields = [f for f in fields if not isinstance(f, AutoField)]
 
                 record_exists = False
 
                 update_pk = bool(meta.has_auto_field and not pk_set)
-                if values:
-                    # Create a new record.
-                    result = manager._insert(values, return_id=update_pk, using=using)
-                else:
-                    # Create a new record with defaults for everything.
-                    result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
+                result = manager._insert([self], fields=fields, return_id=update_pk, using=using, raw=raw)
 
                 if update_pk:
                     setattr(self, meta.pk.attname, result)

+ 8 - 6
django/db/models/fields/related.py

@@ -430,7 +430,7 @@ class ForeignRelatedObjectsDescriptor(object):
             add.alters_data = True
 
             def create(self, **kwargs):
-                kwargs.update({rel_field.name: instance})
+                kwargs[rel_field.name] = instance
                 db = router.db_for_write(rel_model, instance=instance)
                 return super(RelatedManager, self.db_manager(db)).create(**kwargs)
             create.alters_data = True
@@ -438,7 +438,7 @@ class ForeignRelatedObjectsDescriptor(object):
             def get_or_create(self, **kwargs):
                 # Update kwargs with the related object that this
                 # ForeignRelatedObjectsDescriptor knows about.
-                kwargs.update({rel_field.name: instance})
+                kwargs[rel_field.name] = instance
                 db = router.db_for_write(rel_model, instance=instance)
                 return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
             get_or_create.alters_data = True
@@ -578,11 +578,13 @@ def create_many_related_manager(superclass, rel=False):
                         instance=self.instance, reverse=self.reverse,
                         model=self.model, pk_set=new_ids, using=db)
                 # Add the ones that aren't there already
-                for obj_id in new_ids:
-                    self.through._default_manager.using(db).create(**{
+                self.through._default_manager.using(db).bulk_create([
+                    self.through(**{
                         '%s_id' % source_field_name: self._pk_val,
                         '%s_id' % target_field_name: obj_id,
                     })
+                    for obj_id in new_ids
+                ])
                 if self.reverse or source_field_name == self.source_field_name:
                     # Don't send the signal when we are inserting the
                     # duplicate data row for symmetrical reverse entries.
@@ -701,12 +703,12 @@ class ReverseManyRelatedObjectsDescriptor(object):
     def __init__(self, m2m_field):
         self.field = m2m_field
 
-    def _through(self):
+    @property
+    def through(self):
         # through is provided so that you have easy access to the through
         # model (Book.authors.through) for inlines, etc. This is done as
         # a property to ensure that the fully resolved value is returned.
         return self.field.rel.through
-    through = property(_through)
 
     def __get__(self, instance, instance_type=None):
         if instance is None:

+ 5 - 2
django/db/models/manager.py

@@ -136,6 +136,9 @@ class Manager(object):
     def create(self, **kwargs):
         return self.get_query_set().create(**kwargs)
 
+    def bulk_create(self, *args, **kwargs):
+        return self.get_query_set().bulk_create(*args, **kwargs)
+
     def filter(self, *args, **kwargs):
         return self.get_query_set().filter(*args, **kwargs)
 
@@ -193,8 +196,8 @@ class Manager(object):
     def exists(self, *args, **kwargs):
         return self.get_query_set().exists(*args, **kwargs)
 
-    def _insert(self, values, **kwargs):
-        return insert_query(self.model, values, **kwargs)
+    def _insert(self, objs, fields, **kwargs):
+        return insert_query(self.model, objs, fields, **kwargs)
 
     def _update(self, values, **kwargs):
         return self.get_query_set()._update(values, **kwargs)

+ 39 - 2
django/db/models/query.py

@@ -5,10 +5,12 @@ The main QuerySet implementation. This provides the public API for the ORM.
 import copy
 
 from django.db import connections, router, transaction, IntegrityError
+from django.db.models.fields import AutoField
 from django.db.models.query_utils import (Q, select_related_descend,
     deferred_class_factory, InvalidQuery)
 from django.db.models.deletion import Collector
 from django.db.models import signals, sql
+from django.utils.functional import partition
 
 # Used to control how many objects are worked with at once in some cases (e.g.
 # when deleting objects).
@@ -352,6 +354,41 @@ class QuerySet(object):
         obj.save(force_insert=True, using=self.db)
         return obj
 
+    def bulk_create(self, objs):
+        """
+        Inserts each of the instances into the database. This does *not* call
+        save() on each of the instances, does not send any pre/post save
+        signals, and does not set the primary key attribute if it is an
+        autoincrement field.
+        """
+        # So this case is fun. When you bulk insert you don't get the primary
+        # keys back (if it's an autoincrement), so you can't insert into the
+        # child tables which references this. There are two workarounds, 1)
+        # this could be implemented if you didn't have an autoincrement pk,
+        # and 2) you could do it by doing O(n) normal inserts into the parent
+        # tables to get the primary keys back, and then doing a single bulk
+        # insert into the childmost table. We're punting on these for now
+        # because they are relatively rare cases.
+        if self.model._meta.parents:
+            raise ValueError("Can't bulk create an inherited model")
+        if not objs:
+            return
+        self._for_write = True
+        connection = connections[self.db]
+        fields = self.model._meta.local_fields
+        if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk
+            and self.model._meta.has_auto_field):
+            self.model._base_manager._insert(objs, fields=fields, using=self.db)
+        else:
+            objs_with_pk, objs_without_pk = partition(
+                lambda o: o.pk is None,
+                objs
+            )
+            if objs_with_pk:
+                self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db)
+            if objs_without_pk:
+                self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db)
+
     def get_or_create(self, **kwargs):
         """
         Looks up an object with the given kwargs, creating one if necessary.
@@ -1437,12 +1474,12 @@ class RawQuerySet(object):
                 self._model_fields[converter(column)] = field
         return self._model_fields
 
-def insert_query(model, values, return_id=False, raw_values=False, using=None):
+def insert_query(model, objs, fields, return_id=False, raw=False, using=None):
     """
     Inserts a new record for the given model. This provides an interface to
     the InsertQuery class and is how Model.save() is implemented. It is not
     part of the public API.
     """
     query = sql.InsertQuery(model)
-    query.insert_values(values, raw_values)
+    query.insert_values(fields, objs, raw=raw)
     return query.get_compiler(using=using).execute_sql(return_id)

+ 45 - 7
django/db/models/sql/compiler.py

@@ -1,3 +1,5 @@
+from itertools import izip
+
 from django.core.exceptions import FieldError
 from django.db import connections
 from django.db import transaction
@@ -9,6 +11,7 @@ from django.db.models.sql.query import (get_proxied_model, get_order_dir,
      select_related_descend, Query)
 from django.db.utils import DatabaseError
 
+
 class SQLCompiler(object):
     def __init__(self, query, connection, using):
         self.query = query
@@ -794,20 +797,55 @@ class SQLInsertCompiler(SQLCompiler):
         qn = self.connection.ops.quote_name
         opts = self.query.model._meta
         result = ['INSERT INTO %s' % qn(opts.db_table)]
-        result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
-        values = [self.placeholder(*v) for v in self.query.values]
-        result.append('VALUES (%s)' % ', '.join(values))
-        params = self.query.params
+
+        has_fields = bool(self.query.fields)
+        fields = self.query.fields if has_fields else [opts.pk]
+        result.append('(%s)' % ', '.join([qn(f.column) for f in fields]))
+
+        if has_fields:
+            params = values = [
+                [
+                    f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection)
+                    for f in fields
+                ]
+                for obj in self.query.objs
+            ]
+        else:
+            values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
+            params = [[]]
+            fields = [None]
+        can_bulk = not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id
+
+        if can_bulk:
+            placeholders = [["%s"] * len(fields)]
+        else:
+            placeholders = [
+                [self.placeholder(field, v) for field, v in izip(fields, val)]
+                for val in values
+            ]
         if self.return_id and self.connection.features.can_return_id_from_insert:
+            params = values[0]
             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
+            result.append("VALUES (%s)" % ", ".join(placeholders[0]))
             r_fmt, r_params = self.connection.ops.return_insert_id()
             result.append(r_fmt % col)
-            params = params + r_params
-        return ' '.join(result), params
+            params += r_params
+            return [(" ".join(result), tuple(params))]
+        if can_bulk and self.connection.features.has_bulk_insert:
+            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
+            return [(" ".join(result), tuple([v for val in values for v in val]))]
+        else:
+            return [
+                (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
+                for p, vals in izip(placeholders, params)
+            ]
 
     def execute_sql(self, return_id=False):
+        assert not (return_id and len(self.query.objs) != 1)
         self.return_id = return_id
-        cursor = super(SQLInsertCompiler, self).execute_sql(None)
+        cursor = self.connection.cursor()
+        for sql, params in self.as_sql():
+            cursor.execute(sql, params)
         if not (return_id and cursor):
             return
         if self.connection.features.can_return_id_from_insert:

+ 2 - 1
django/db/models/sql/query.py

@@ -8,9 +8,10 @@ all about the internals of models in order to get the information it needs.
 """
 
 import copy
-from django.utils.tree import Node
+
 from django.utils.datastructures import SortedDict
 from django.utils.encoding import force_unicode
+from django.utils.tree import Node
 from django.db import connections, DEFAULT_DB_ALIAS
 from django.db.models import signals
 from django.db.models.fields import FieldDoesNotExist

+ 9 - 17
django/db/models/sql/subqueries.py

@@ -136,20 +136,19 @@ class InsertQuery(Query):
 
     def __init__(self, *args, **kwargs):
         super(InsertQuery, self).__init__(*args, **kwargs)
-        self.columns = []
-        self.values = []
-        self.params = ()
+        self.fields = []
+        self.objs = []
 
     def clone(self, klass=None, **kwargs):
         extras = {
-            'columns': self.columns[:],
-            'values': self.values[:],
-            'params': self.params
+            'fields': self.fields[:],
+            'objs': self.objs[:],
+            'raw': self.raw,
         }
         extras.update(kwargs)
         return super(InsertQuery, self).clone(klass, **extras)
 
-    def insert_values(self, insert_values, raw_values=False):
+    def insert_values(self, fields, objs, raw=False):
         """
         Set up the insert query from the 'insert_values' dictionary. The
         dictionary gives the model field names and their target values.
@@ -159,16 +158,9 @@ class InsertQuery(Query):
         parameters. This provides a way to insert NULL and DEFAULT keywords
         into the query, for example.
         """
-        placeholders, values = [], []
-        for field, val in insert_values:
-            placeholders.append((field, val))
-            self.columns.append(field.column)
-            values.append(val)
-        if raw_values:
-            self.values.extend([(None, v) for v in values])
-        else:
-            self.params += tuple(values)
-            self.values.extend(placeholders)
+        self.fields = fields
+        self.objs = objs
+        self.raw = raw
 
 class DateQuery(Query):
     """

+ 14 - 1
django/utils/functional.py

@@ -275,4 +275,17 @@ class lazy_property(property):
             @wraps(fdel)
             def fdel(instance, name=fdel.__name__):
                 return getattr(instance, name)()
-        return property(fget, fset, fdel, doc)
+        return property(fget, fset, fdel, doc)
+
+def partition(predicate, values):
+    """
+    Splits the values into two sets, based on the return value of the function
+    (True/False). e.g.:
+
+        >>> partition(lambda: x > 3, range(5))
+        [1, 2, 3], [4]
+    """
+    results = ([], [])
+    for item in values:
+        results[predicate(item)].append(item)
+    return results

+ 23 - 0
docs/ref/models/querysets.txt

@@ -1158,6 +1158,29 @@ has a side effect on your data. For more, see `Safe methods`_ in the HTTP spec.
 
 .. _Safe methods: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#sec9.1.1
 
+bulk_create
+~~~~~~~~~~~
+
+.. method:: bulk_create(objs)
+
+This method inserts the provided list of objects into the database in an
+efficient manner (generally only 1 query, no matter how many objects there
+are)::
+
+    >>> Entry.objects.bulk_create([
+    ...     Entry(headline="Django 1.0 Released"),
+    ...     Entry(headline="Django 1.1 Announced"),
+    ...     Entry(headline="Breaking: Django is awesome")
+    ... ])
+
+This has a number of caveats though:
+
+  * The model's ``save()`` method will not be called, and the ``pre_save`` and
+    ``post_save`` signals will not be sent.
+  * It does not work with child models in a multi-table inheritance scenario.
+  * If the model's primary key is an :class:`~django.db.models.AutoField` it
+    does not retrieve and set the primary key attribute, as ``save()`` does.
+
 count
 ~~~~~
 

+ 11 - 0
docs/releases/1.4.txt

@@ -252,6 +252,17 @@ filename. For example, the file ``css/styles.css`` would also be saved as
 See the :class:`~django.contrib.staticfiles.storage.CachedStaticFilesStorage`
 docs for more information.
 
+``Model.objects.bulk_create`` in the ORM
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This method allows for more efficient creation of multiple objects in the ORM.
+It can provide significant performance increases if you have many objects,
+Django makes use of this internally, meaning some operations (such as database
+setup for test suites) has seen a performance benefit as a result.
+
+See the :meth:`~django.db.models.query.QuerySet.bulk_create` docs for more
+information.
+
 Minor features
 ~~~~~~~~~~~~~~
 

+ 30 - 0
docs/topics/db/optimization.txt

@@ -268,3 +268,33 @@ instead of::
 
    entry.blog.id
 
+Insert in bulk
+==============
+
+When creating objects, where possible, use the
+:meth:`~django.db.models.query.QuerySet.bulk_create()` method to reduce the
+number of SQL queries. For example::
+
+    Entry.objects.bulk_create([
+        Entry(headline="Python 3.0 Released"),
+        Entry(headline="Python 3.1 Planned")
+    ])
+
+Is preferable to::
+
+    Entry.objects.create(headline="Python 3.0 Released")
+    Entry.objects.create(headline="Python 3.1 Planned")
+
+Note that there are a number of :meth:`caveats to this method
+<django.db.models.query.QuerySet.bulk_create>`, make sure it is appropriate for
+your use case. This also applies to :class:`ManyToManyFields
+<django.db.models.ManyToManyField>`, doing::
+
+    my_band.members.add(me, my_friend)
+
+Is preferable to::
+
+    my_band.members.add(me)
+    my_band.members.add(my_friend)
+
+Where ``Bands`` and ``Artists`` have a many-to-many relationship.

+ 0 - 0
tests/regressiontests/bulk_create/__init__.py


+ 21 - 0
tests/regressiontests/bulk_create/models.py

@@ -0,0 +1,21 @@
+from django.db import models
+
+
+class Country(models.Model):
+    name = models.CharField(max_length=255)
+    iso_two_letter = models.CharField(max_length=2)
+
+class Place(models.Model):
+    name = models.CharField(max_length=100)
+
+    class Meta:
+        abstract = True
+
+class Restaurant(Place):
+    pass
+
+class Pizzeria(Restaurant):
+    pass
+
+class State(models.Model):
+    two_letter_code = models.CharField(max_length=2, primary_key=True)

+ 54 - 0
tests/regressiontests/bulk_create/tests.py

@@ -0,0 +1,54 @@
+from __future__ import with_statement
+
+from operator import attrgetter
+
+from django.test import TestCase, skipUnlessDBFeature
+
+from models import Country, Restaurant, Pizzeria, State
+
+
+class BulkCreateTests(TestCase):
+    def setUp(self):
+        self.data = [
+            Country(name="United States of America", iso_two_letter="US"),
+            Country(name="The Netherlands", iso_two_letter="NL"),
+            Country(name="Germany", iso_two_letter="DE"),
+            Country(name="Czech Republic", iso_two_letter="CZ")
+        ]
+
+    def test_simple(self):
+        Country.objects.bulk_create(self.data)
+        self.assertQuerysetEqual(Country.objects.order_by("-name"), [
+            "United States of America", "The Netherlands", "Germany", "Czech Republic"
+        ], attrgetter("name"))
+
+    @skipUnlessDBFeature("has_bulk_insert")
+    def test_efficiency(self):
+        with self.assertNumQueries(1):
+            Country.objects.bulk_create(self.data)
+
+    def test_inheritance(self):
+        Restaurant.objects.bulk_create([
+            Restaurant(name="Nicholas's")
+        ])
+        self.assertQuerysetEqual(Restaurant.objects.all(), [
+            "Nicholas's",
+        ], attrgetter("name"))
+        with self.assertRaises(ValueError):
+            Pizzeria.objects.bulk_create([
+                Pizzeria(name="The Art of Pizza")
+            ])
+        self.assertQuerysetEqual(Pizzeria.objects.all(), [])
+        self.assertQuerysetEqual(Restaurant.objects.all(), [
+            "Nicholas's",
+        ], attrgetter("name"))
+
+    def test_non_auto_increment_pk(self):
+        with self.assertNumQueries(1):
+            State.objects.bulk_create([
+                State(two_letter_code=s)
+                for s in ["IL", "NY", "CA", "ME"]
+            ])
+        self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [
+            "CA", "IL", "ME", "NY",
+        ], attrgetter("two_letter_code"))

+ 2 - 2
tests/regressiontests/db_typecasts/tests.py

@@ -53,10 +53,10 @@ TEST_CASES = {
 
 class DBTypeCasts(unittest.TestCase):
     def test_typeCasts(self):
-        for k, v in TEST_CASES.items():
+        for k, v in TEST_CASES.iteritems():
             for inpt, expected in v:
                 got = getattr(typecasts, k)(inpt)
-                assert got == expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got)
+                self.assertEqual(got, expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got))
 
 if __name__ == '__main__':
     unittest.main()