Browse Source

Ensure cursors are closed when no longer needed.

This commit touchs various parts of the code base and test framework. Any
found usage of opening a cursor for the sake of initializing a connection
has been replaced with 'ensure_connection()'.
Michael Manfre 11 năm trước cách đây
mục cha
commit
3ffeb93186

+ 7 - 7
django/contrib/gis/db/backends/postgis/creation.py

@@ -11,10 +11,10 @@ class PostGISCreation(DatabaseCreation):
     @cached_property
     def template_postgis(self):
         template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis')
-        cursor = self.connection.cursor()
-        cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
-        if cursor.fetchone():
-            return template_postgis
+        with self.connection.cursor() as cursor:
+            cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,))
+            if cursor.fetchone():
+                return template_postgis
         return None
 
     def sql_indexes_for_field(self, model, f, style):
@@ -88,8 +88,8 @@ class PostGISCreation(DatabaseCreation):
             # Connect to the test database in order to create the postgis extension
             self.connection.close()
             self.connection.settings_dict["NAME"] = test_database_name
-            cursor = self.connection.cursor()
-            cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
-            cursor.connection.commit()
+            with self.connection.cursor() as cursor:
+                cursor.execute("CREATE EXTENSION IF NOT EXISTS postgis")
+                cursor.connection.commit()
 
         return test_database_name

+ 2 - 3
django/contrib/gis/db/backends/spatialite/creation.py

@@ -55,9 +55,8 @@ class SpatiaLiteCreation(DatabaseCreation):
 
         call_command('createcachetable', database=self.connection.alias)
 
-        # Get a cursor (even though we don't need one yet). This has
-        # the side effect of initializing the test database.
-        self.connection.cursor()
+        # Ensure a connection for the side effect of initializing the test database.
+        self.connection.ensure_connection()
 
         return test_database_name
 

+ 3 - 3
django/contrib/sites/management.py

@@ -33,9 +33,9 @@ def create_default_site(app_config, verbosity=2, interactive=True, db=DEFAULT_DB
         if sequence_sql:
             if verbosity >= 2:
                 print("Resetting sequence")
-            cursor = connections[db].cursor()
-            for command in sequence_sql:
-                cursor.execute(command)
+            with connections[db].cursor() as cursor:
+                for command in sequence_sql:
+                    cursor.execute(command)
 
         Site.objects.clear_cache()
 

+ 65 - 64
django/core/cache/backends/db.py

@@ -59,11 +59,11 @@ class DatabaseCache(BaseDatabaseCache):
         self.validate_key(key)
         db = router.db_for_read(self.cache_model_class)
         table = connections[db].ops.quote_name(self._table)
-        cursor = connections[db].cursor()
 
-        cursor.execute("SELECT cache_key, value, expires FROM %s "
-                       "WHERE cache_key = %%s" % table, [key])
-        row = cursor.fetchone()
+        with connections[db].cursor() as cursor:
+            cursor.execute("SELECT cache_key, value, expires FROM %s "
+                           "WHERE cache_key = %%s" % table, [key])
+            row = cursor.fetchone()
         if row is None:
             return default
         now = timezone.now()
@@ -75,9 +75,9 @@ class DatabaseCache(BaseDatabaseCache):
             expires = typecast_timestamp(str(expires))
         if expires < now:
             db = router.db_for_write(self.cache_model_class)
-            cursor = connections[db].cursor()
-            cursor.execute("DELETE FROM %s "
-                           "WHERE cache_key = %%s" % table, [key])
+            with connections[db].cursor() as cursor:
+                cursor.execute("DELETE FROM %s "
+                               "WHERE cache_key = %%s" % table, [key])
             return default
         value = connections[db].ops.process_clob(row[1])
         return pickle.loads(base64.b64decode(force_bytes(value)))
@@ -96,55 +96,55 @@ class DatabaseCache(BaseDatabaseCache):
         timeout = self.get_backend_timeout(timeout)
         db = router.db_for_write(self.cache_model_class)
         table = connections[db].ops.quote_name(self._table)
-        cursor = connections[db].cursor()
 
-        cursor.execute("SELECT COUNT(*) FROM %s" % table)
-        num = cursor.fetchone()[0]
-        now = timezone.now()
-        now = now.replace(microsecond=0)
-        if timeout is None:
-            exp = datetime.max
-        elif settings.USE_TZ:
-            exp = datetime.utcfromtimestamp(timeout)
-        else:
-            exp = datetime.fromtimestamp(timeout)
-        exp = exp.replace(microsecond=0)
-        if num > self._max_entries:
-            self._cull(db, cursor, now)
-        pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
-        b64encoded = base64.b64encode(pickled)
-        # The DB column is expecting a string, so make sure the value is a
-        # string, not bytes. Refs #19274.
-        if six.PY3:
-            b64encoded = b64encoded.decode('latin1')
-        try:
-            # Note: typecasting for datetimes is needed by some 3rd party
-            # database backends. All core backends work without typecasting,
-            # so be careful about changes here - test suite will NOT pick
-            # regressions.
-            with transaction.atomic(using=db):
-                cursor.execute("SELECT cache_key, expires FROM %s "
-                               "WHERE cache_key = %%s" % table, [key])
-                result = cursor.fetchone()
-                if result:
-                    current_expires = result[1]
-                    if (connections[db].features.needs_datetime_string_cast and not
-                            isinstance(current_expires, datetime)):
-                        current_expires = typecast_timestamp(str(current_expires))
-                exp = connections[db].ops.value_to_db_datetime(exp)
-                if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
-                    cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
-                                   "WHERE cache_key = %%s" % table,
-                                   [b64encoded, exp, key])
-                else:
-                    cursor.execute("INSERT INTO %s (cache_key, value, expires) "
-                                   "VALUES (%%s, %%s, %%s)" % table,
-                                   [key, b64encoded, exp])
-        except DatabaseError:
-            # To be threadsafe, updates/inserts are allowed to fail silently
-            return False
-        else:
-            return True
+        with connections[db].cursor() as cursor:
+            cursor.execute("SELECT COUNT(*) FROM %s" % table)
+            num = cursor.fetchone()[0]
+            now = timezone.now()
+            now = now.replace(microsecond=0)
+            if timeout is None:
+                exp = datetime.max
+            elif settings.USE_TZ:
+                exp = datetime.utcfromtimestamp(timeout)
+            else:
+                exp = datetime.fromtimestamp(timeout)
+            exp = exp.replace(microsecond=0)
+            if num > self._max_entries:
+                self._cull(db, cursor, now)
+            pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
+            b64encoded = base64.b64encode(pickled)
+            # The DB column is expecting a string, so make sure the value is a
+            # string, not bytes. Refs #19274.
+            if six.PY3:
+                b64encoded = b64encoded.decode('latin1')
+            try:
+                # Note: typecasting for datetimes is needed by some 3rd party
+                # database backends. All core backends work without typecasting,
+                # so be careful about changes here - test suite will NOT pick
+                # regressions.
+                with transaction.atomic(using=db):
+                    cursor.execute("SELECT cache_key, expires FROM %s "
+                                   "WHERE cache_key = %%s" % table, [key])
+                    result = cursor.fetchone()
+                    if result:
+                        current_expires = result[1]
+                        if (connections[db].features.needs_datetime_string_cast and not
+                                isinstance(current_expires, datetime)):
+                            current_expires = typecast_timestamp(str(current_expires))
+                    exp = connections[db].ops.value_to_db_datetime(exp)
+                    if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
+                        cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
+                                       "WHERE cache_key = %%s" % table,
+                                       [b64encoded, exp, key])
+                    else:
+                        cursor.execute("INSERT INTO %s (cache_key, value, expires) "
+                                       "VALUES (%%s, %%s, %%s)" % table,
+                                       [key, b64encoded, exp])
+            except DatabaseError:
+                # To be threadsafe, updates/inserts are allowed to fail silently
+                return False
+            else:
+                return True
 
     def delete(self, key, version=None):
         key = self.make_key(key, version=version)
@@ -152,9 +152,9 @@ class DatabaseCache(BaseDatabaseCache):
 
         db = router.db_for_write(self.cache_model_class)
         table = connections[db].ops.quote_name(self._table)
-        cursor = connections[db].cursor()
 
-        cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
+        with connections[db].cursor() as cursor:
+            cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])
 
     def has_key(self, key, version=None):
         key = self.make_key(key, version=version)
@@ -162,17 +162,18 @@ class DatabaseCache(BaseDatabaseCache):
 
         db = router.db_for_read(self.cache_model_class)
         table = connections[db].ops.quote_name(self._table)
-        cursor = connections[db].cursor()
 
         if settings.USE_TZ:
             now = datetime.utcnow()
         else:
             now = datetime.now()
         now = now.replace(microsecond=0)
-        cursor.execute("SELECT cache_key FROM %s "
-                       "WHERE cache_key = %%s and expires > %%s" % table,
-                       [key, connections[db].ops.value_to_db_datetime(now)])
-        return cursor.fetchone() is not None
+
+        with connections[db].cursor() as cursor:
+            cursor.execute("SELECT cache_key FROM %s "
+                           "WHERE cache_key = %%s and expires > %%s" % table,
+                           [key, connections[db].ops.value_to_db_datetime(now)])
+            return cursor.fetchone() is not None
 
     def _cull(self, db, cursor, now):
         if self._cull_frequency == 0:
@@ -197,8 +198,8 @@ class DatabaseCache(BaseDatabaseCache):
     def clear(self):
         db = router.db_for_write(self.cache_model_class)
         table = connections[db].ops.quote_name(self._table)
-        cursor = connections[db].cursor()
-        cursor.execute('DELETE FROM %s' % table)
+        with connections[db].cursor() as cursor:
+            cursor.execute('DELETE FROM %s' % table)
 
 
 # For backwards compatibility

+ 9 - 9
django/core/management/commands/createcachetable.py

@@ -72,14 +72,14 @@ class Command(BaseCommand):
             full_statement.append('    %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
         full_statement.append(');')
         with transaction.commit_on_success_unless_managed():
-            curs = connection.cursor()
-            try:
-                curs.execute("\n".join(full_statement))
-            except DatabaseError as e:
-                raise CommandError(
-                    "Cache table '%s' could not be created.\nThe error was: %s." %
-                    (tablename, force_text(e)))
-            for statement in index_output:
-                curs.execute(statement)
+            with connection.cursor() as curs:
+                try:
+                    curs.execute("\n".join(full_statement))
+                except DatabaseError as e:
+                    raise CommandError(
+                        "Cache table '%s' could not be created.\nThe error was: %s." %
+                        (tablename, force_text(e)))
+                for statement in index_output:
+                    curs.execute(statement)
         if self.verbosity > 1:
             self.stdout.write("Cache table '%s' created." % tablename)

+ 3 - 3
django/core/management/commands/flush.py

@@ -64,9 +64,9 @@ Are you sure you want to do this?
         if confirm == 'yes':
             try:
                 with transaction.commit_on_success_unless_managed():
-                    cursor = connection.cursor()
-                    for sql in sql_list:
-                        cursor.execute(sql)
+                    with connection.cursor() as cursor:
+                        for sql in sql_list:
+                            cursor.execute(sql)
             except Exception as e:
                 new_msg = (
                     "Database %s couldn't be flushed. Possible reasons:\n"

+ 99 - 99
django/core/management/commands/inspectdb.py

@@ -37,108 +37,108 @@ class Command(NoArgsCommand):
         table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
         strip_prefix = lambda s: s[1:] if s.startswith("u'") else s
 
-        cursor = connection.cursor()
-        yield "# This is an auto-generated Django model module."
-        yield "# You'll have to do the following manually to clean this up:"
-        yield "#   * Rearrange models' order"
-        yield "#   * Make sure each model has one field with primary_key=True"
-        yield "#   * Remove `managed = False` lines for those models you wish to give write DB access"
-        yield "# Feel free to rename the models, but don't rename db_table values or field names."
-        yield "#"
-        yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
-        yield "# into your database."
-        yield "from __future__ import unicode_literals"
-        yield ''
-        yield 'from %s import models' % self.db_module
-        known_models = []
-        for table_name in connection.introspection.table_names(cursor):
-            if table_name_filter is not None and callable(table_name_filter):
-                if not table_name_filter(table_name):
-                    continue
+        with connection.cursor() as cursor:
+            yield "# This is an auto-generated Django model module."
+            yield "# You'll have to do the following manually to clean this up:"
+            yield "#   * Rearrange models' order"
+            yield "#   * Make sure each model has one field with primary_key=True"
+            yield "#   * Remove `managed = False` lines for those models you wish to give write DB access"
+            yield "# Feel free to rename the models, but don't rename db_table values or field names."
+            yield "#"
+            yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [app_label]'"
+            yield "# into your database."
+            yield "from __future__ import unicode_literals"
             yield ''
-            yield ''
-            yield 'class %s(models.Model):' % table2model(table_name)
-            known_models.append(table2model(table_name))
-            try:
-                relations = connection.introspection.get_relations(cursor, table_name)
-            except NotImplementedError:
-                relations = {}
-            try:
-                indexes = connection.introspection.get_indexes(cursor, table_name)
-            except NotImplementedError:
-                indexes = {}
-            used_column_names = []  # Holds column names used in the table so far
-            for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
-                comment_notes = []  # Holds Field notes, to be displayed in a Python comment.
-                extra_params = OrderedDict()  # Holds Field parameters such as 'db_column'.
-                column_name = row[0]
-                is_relation = i in relations
-
-                att_name, params, notes = self.normalize_col_name(
-                    column_name, used_column_names, is_relation)
-                extra_params.update(params)
-                comment_notes.extend(notes)
-
-                used_column_names.append(att_name)
-
-                # Add primary_key and unique, if necessary.
-                if column_name in indexes:
-                    if indexes[column_name]['primary_key']:
-                        extra_params['primary_key'] = True
-                    elif indexes[column_name]['unique']:
-                        extra_params['unique'] = True
-
-                if is_relation:
-                    rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
-                    if rel_to in known_models:
-                        field_type = 'ForeignKey(%s' % rel_to
-                    else:
-                        field_type = "ForeignKey('%s'" % rel_to
-                else:
-                    # Calling `get_field_type` to get the field type string and any
-                    # additional paramters and notes.
-                    field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
-                    extra_params.update(field_params)
-                    comment_notes.extend(field_notes)
-
-                    field_type += '('
-
-                # Don't output 'id = meta.AutoField(primary_key=True)', because
-                # that's assumed if it doesn't exist.
-                if att_name == 'id' and extra_params == {'primary_key': True}:
-                    if field_type == 'AutoField(':
+            yield 'from %s import models' % self.db_module
+            known_models = []
+            for table_name in connection.introspection.table_names(cursor):
+                if table_name_filter is not None and callable(table_name_filter):
+                    if not table_name_filter(table_name):
                         continue
-                    elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
-                        comment_notes.append('AutoField?')
-
-                # Add 'null' and 'blank', if the 'null_ok' flag was present in the
-                # table description.
-                if row[6]:  # If it's NULL...
-                    if field_type == 'BooleanField(':
-                        field_type = 'NullBooleanField('
+                yield ''
+                yield ''
+                yield 'class %s(models.Model):' % table2model(table_name)
+                known_models.append(table2model(table_name))
+                try:
+                    relations = connection.introspection.get_relations(cursor, table_name)
+                except NotImplementedError:
+                    relations = {}
+                try:
+                    indexes = connection.introspection.get_indexes(cursor, table_name)
+                except NotImplementedError:
+                    indexes = {}
+                used_column_names = []  # Holds column names used in the table so far
+                for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
+                    comment_notes = []  # Holds Field notes, to be displayed in a Python comment.
+                    extra_params = OrderedDict()  # Holds Field parameters such as 'db_column'.
+                    column_name = row[0]
+                    is_relation = i in relations
+
+                    att_name, params, notes = self.normalize_col_name(
+                        column_name, used_column_names, is_relation)
+                    extra_params.update(params)
+                    comment_notes.extend(notes)
+
+                    used_column_names.append(att_name)
+
+                    # Add primary_key and unique, if necessary.
+                    if column_name in indexes:
+                        if indexes[column_name]['primary_key']:
+                            extra_params['primary_key'] = True
+                        elif indexes[column_name]['unique']:
+                            extra_params['unique'] = True
+
+                    if is_relation:
+                        rel_to = "self" if relations[i][1] == table_name else table2model(relations[i][1])
+                        if rel_to in known_models:
+                            field_type = 'ForeignKey(%s' % rel_to
+                        else:
+                            field_type = "ForeignKey('%s'" % rel_to
                     else:
-                        extra_params['blank'] = True
-                        if not field_type in ('TextField(', 'CharField('):
-                            extra_params['null'] = True
-
-                field_desc = '%s = %s%s' % (
-                    att_name,
-                    # Custom fields will have a dotted path
-                    '' if '.' in field_type else 'models.',
-                    field_type,
-                )
-                if extra_params:
-                    if not field_desc.endswith('('):
-                        field_desc += ', '
-                    field_desc += ', '.join([
-                        '%s=%s' % (k, strip_prefix(repr(v)))
-                        for k, v in extra_params.items()])
-                field_desc += ')'
-                if comment_notes:
-                    field_desc += '  # ' + ' '.join(comment_notes)
-                yield '    %s' % field_desc
-            for meta_line in self.get_meta(table_name):
-                yield meta_line
+                        # Calling `get_field_type` to get the field type string and any
+                        # additional paramters and notes.
+                        field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
+                        extra_params.update(field_params)
+                        comment_notes.extend(field_notes)
+
+                        field_type += '('
+
+                    # Don't output 'id = meta.AutoField(primary_key=True)', because
+                    # that's assumed if it doesn't exist.
+                    if att_name == 'id' and extra_params == {'primary_key': True}:
+                        if field_type == 'AutoField(':
+                            continue
+                        elif field_type == 'IntegerField(' and not connection.features.can_introspect_autofield:
+                            comment_notes.append('AutoField?')
+
+                    # Add 'null' and 'blank', if the 'null_ok' flag was present in the
+                    # table description.
+                    if row[6]:  # If it's NULL...
+                        if field_type == 'BooleanField(':
+                            field_type = 'NullBooleanField('
+                        else:
+                            extra_params['blank'] = True
+                            if not field_type in ('TextField(', 'CharField('):
+                                extra_params['null'] = True
+
+                    field_desc = '%s = %s%s' % (
+                        att_name,
+                        # Custom fields will have a dotted path
+                        '' if '.' in field_type else 'models.',
+                        field_type,
+                    )
+                    if extra_params:
+                        if not field_desc.endswith('('):
+                            field_desc += ', '
+                        field_desc += ', '.join([
+                            '%s=%s' % (k, strip_prefix(repr(v)))
+                            for k, v in extra_params.items()])
+                    field_desc += ')'
+                    if comment_notes:
+                        field_desc += '  # ' + ' '.join(comment_notes)
+                    yield '    %s' % field_desc
+                for meta_line in self.get_meta(table_name):
+                    yield meta_line
 
     def normalize_col_name(self, col_name, used_column_names, is_relation):
         """

+ 3 - 4
django/core/management/commands/loaddata.py

@@ -100,10 +100,9 @@ class Command(BaseCommand):
             if sequence_sql:
                 if self.verbosity >= 2:
                     self.stdout.write("Resetting sequences\n")
-                cursor = connection.cursor()
-                for line in sequence_sql:
-                    cursor.execute(line)
-                cursor.close()
+                with connection.cursor() as cursor:
+                    for line in sequence_sql:
+                        cursor.execute(line)
 
         if self.verbosity >= 1:
             if self.fixture_object_count == self.loaded_object_count:

+ 99 - 94
django/core/management/commands/migrate.py

@@ -171,105 +171,110 @@ class Command(BaseCommand):
         "Runs the old syncdb-style operation on a list of app_labels."
         cursor = connection.cursor()
 
-        # Get a list of already installed *models* so that references work right.
-        tables = connection.introspection.table_names()
-        seen_models = connection.introspection.installed_models(tables)
-        created_models = set()
-        pending_references = {}
-
-        # Build the manifest of apps and models that are to be synchronized
-        all_models = [
-            (app_config.label,
-                router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
-            for app_config in apps.get_app_configs()
-            if app_config.models_module is not None and app_config.label in app_labels
-        ]
-
-        def model_installed(model):
-            opts = model._meta
-            converter = connection.introspection.table_name_converter
-            # Note that if a model is unmanaged we short-circuit and never try to install it
-            return not ((converter(opts.db_table) in tables) or
-                (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
-
-        manifest = OrderedDict(
-            (app_name, list(filter(model_installed, model_list)))
-            for app_name, model_list in all_models
-        )
-
-        create_models = set(itertools.chain(*manifest.values()))
-        emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
-
-        # Create the tables for each model
-        if self.verbosity >= 1:
-            self.stdout.write("  Creating tables...\n")
-        with transaction.atomic(using=connection.alias, savepoint=False):
-            for app_name, model_list in manifest.items():
-                for model in model_list:
-                    # Create the model's database table, if it doesn't already exist.
-                    if self.verbosity >= 3:
-                        self.stdout.write("    Processing %s.%s model\n" % (app_name, model._meta.object_name))
-                    sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
-                    seen_models.add(model)
-                    created_models.add(model)
-                    for refto, refs in references.items():
-                        pending_references.setdefault(refto, []).extend(refs)
-                        if refto in seen_models:
-                            sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
-                    sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
-                    if self.verbosity >= 1 and sql:
-                        self.stdout.write("    Creating table %s\n" % model._meta.db_table)
-                    for statement in sql:
-                        cursor.execute(statement)
-                    tables.append(connection.introspection.table_name_converter(model._meta.db_table))
-
-        # We force a commit here, as that was the previous behaviour.
-        # If you can prove we don't need this, remove it.
-        transaction.set_dirty(using=connection.alias)
+        try:
+            # Get a list of already installed *models* so that references work right.
+            tables = connection.introspection.table_names(cursor)
+            seen_models = connection.introspection.installed_models(tables)
+            created_models = set()
+            pending_references = {}
+
+            # Build the manifest of apps and models that are to be synchronized
+            all_models = [
+                (app_config.label,
+                    router.get_migratable_models(app_config, connection.alias, include_auto_created=True))
+                for app_config in apps.get_app_configs()
+                if app_config.models_module is not None and app_config.label in app_labels
+            ]
+
+            def model_installed(model):
+                opts = model._meta
+                converter = connection.introspection.table_name_converter
+                # Note that if a model is unmanaged we short-circuit and never try to install it
+                return not ((converter(opts.db_table) in tables) or
+                    (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables))
+
+            manifest = OrderedDict(
+                (app_name, list(filter(model_installed, model_list)))
+                for app_name, model_list in all_models
+            )
+
+            create_models = set(itertools.chain(*manifest.values()))
+            emit_pre_migrate_signal(create_models, self.verbosity, self.interactive, connection.alias)
+
+            # Create the tables for each model
+            if self.verbosity >= 1:
+                self.stdout.write("  Creating tables...\n")
+            with transaction.atomic(using=connection.alias, savepoint=False):
+                for app_name, model_list in manifest.items():
+                    for model in model_list:
+                        # Create the model's database table, if it doesn't already exist.
+                        if self.verbosity >= 3:
+                            self.stdout.write("    Processing %s.%s model\n" % (app_name, model._meta.object_name))
+                        sql, references = connection.creation.sql_create_model(model, no_style(), seen_models)
+                        seen_models.add(model)
+                        created_models.add(model)
+                        for refto, refs in references.items():
+                            pending_references.setdefault(refto, []).extend(refs)
+                            if refto in seen_models:
+                                sql.extend(connection.creation.sql_for_pending_references(refto, no_style(), pending_references))
+                        sql.extend(connection.creation.sql_for_pending_references(model, no_style(), pending_references))
+                        if self.verbosity >= 1 and sql:
+                            self.stdout.write("    Creating table %s\n" % model._meta.db_table)
+                        for statement in sql:
+                            cursor.execute(statement)
+                        tables.append(connection.introspection.table_name_converter(model._meta.db_table))
+
+            # We force a commit here, as that was the previous behaviour.
+            # If you can prove we don't need this, remove it.
+            transaction.set_dirty(using=connection.alias)
+        finally:
+            cursor.close()
 
         # The connection may have been closed by a syncdb handler.
         cursor = connection.cursor()
+        try:
+            # Install custom SQL for the app (but only if this
+            # is a model we've just created)
+            if self.verbosity >= 1:
+                self.stdout.write("  Installing custom SQL...\n")
+            for app_name, model_list in manifest.items():
+                for model in model_list:
+                    if model in created_models:
+                        custom_sql = custom_sql_for_model(model, no_style(), connection)
+                        if custom_sql:
+                            if self.verbosity >= 2:
+                                self.stdout.write("    Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
+                            try:
+                                with transaction.commit_on_success_unless_managed(using=connection.alias):
+                                    for sql in custom_sql:
+                                        cursor.execute(sql)
+                            except Exception as e:
+                                self.stderr.write("    Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+                                if self.show_traceback:
+                                    traceback.print_exc()
+                        else:
+                            if self.verbosity >= 3:
+                                self.stdout.write("    No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
 
-        # Install custom SQL for the app (but only if this
-        # is a model we've just created)
-        if self.verbosity >= 1:
-            self.stdout.write("  Installing custom SQL...\n")
-        for app_name, model_list in manifest.items():
-            for model in model_list:
-                if model in created_models:
-                    custom_sql = custom_sql_for_model(model, no_style(), connection)
-                    if custom_sql:
-                        if self.verbosity >= 2:
-                            self.stdout.write("    Installing custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
-                        try:
-                            with transaction.commit_on_success_unless_managed(using=connection.alias):
-                                for sql in custom_sql:
-                                    cursor.execute(sql)
-                        except Exception as e:
-                            self.stderr.write("    Failed to install custom SQL for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
-                            if self.show_traceback:
-                                traceback.print_exc()
-                    else:
-                        if self.verbosity >= 3:
-                            self.stdout.write("    No custom SQL for %s.%s model\n" % (app_name, model._meta.object_name))
+            if self.verbosity >= 1:
+                self.stdout.write("  Installing indexes...\n")
 
-        if self.verbosity >= 1:
-            self.stdout.write("  Installing indexes...\n")
-
-        # Install SQL indices for all newly created models
-        for app_name, model_list in manifest.items():
-            for model in model_list:
-                if model in created_models:
-                    index_sql = connection.creation.sql_indexes_for_model(model, no_style())
-                    if index_sql:
-                        if self.verbosity >= 2:
-                            self.stdout.write("    Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
-                        try:
-                            with transaction.commit_on_success_unless_managed(using=connection.alias):
-                                for sql in index_sql:
-                                    cursor.execute(sql)
-                        except Exception as e:
-                            self.stderr.write("    Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+            # Install SQL indices for all newly created models
+            for app_name, model_list in manifest.items():
+                for model in model_list:
+                    if model in created_models:
+                        index_sql = connection.creation.sql_indexes_for_model(model, no_style())
+                        if index_sql:
+                            if self.verbosity >= 2:
+                                self.stdout.write("    Installing index for %s.%s model\n" % (app_name, model._meta.object_name))
+                            try:
+                                with transaction.commit_on_success_unless_managed(using=connection.alias):
+                                    for sql in index_sql:
+                                        cursor.execute(sql)
+                            except Exception as e:
+                                self.stderr.write("    Failed to install index for %s.%s model: %s\n" % (app_name, model._meta.object_name, e))
+        finally:
+            cursor.close()
 
         # Load initial_data fixtures (unless that has been disabled)
         if self.load_initial_data:

+ 33 - 32
django/core/management/sql.py

@@ -67,38 +67,39 @@ def sql_delete(app_config, style, connection):
     except Exception:
         cursor = None
 
-    # Figure out which tables already exist
-    if cursor:
-        table_names = connection.introspection.table_names(cursor)
-    else:
-        table_names = []
-
-    output = []
-
-    # Output DROP TABLE statements for standard application tables.
-    to_delete = set()
-
-    references_to_delete = {}
-    app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
-    for model in app_models:
-        if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
-            # The table exists, so it needs to be dropped
-            opts = model._meta
-            for f in opts.local_fields:
-                if f.rel and f.rel.to not in to_delete:
-                    references_to_delete.setdefault(f.rel.to, []).append((model, f))
-
-            to_delete.add(model)
-
-    for model in app_models:
-        if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
-            output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
-
-    # Close database connection explicitly, in case this output is being piped
-    # directly into a database client, to avoid locking issues.
-    if cursor:
-        cursor.close()
-        connection.close()
+    try:
+        # Figure out which tables already exist
+        if cursor:
+            table_names = connection.introspection.table_names(cursor)
+        else:
+            table_names = []
+
+        output = []
+
+        # Output DROP TABLE statements for standard application tables.
+        to_delete = set()
+
+        references_to_delete = {}
+        app_models = router.get_migratable_models(app_config, connection.alias, include_auto_created=True)
+        for model in app_models:
+            if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
+                # The table exists, so it needs to be dropped
+                opts = model._meta
+                for f in opts.local_fields:
+                    if f.rel and f.rel.to not in to_delete:
+                        references_to_delete.setdefault(f.rel.to, []).append((model, f))
+
+                to_delete.add(model)
+
+        for model in app_models:
+            if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
+                output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
+    finally:
+        # Close database connection explicitly, in case this output is being piped
+        # directly into a database client, to avoid locking issues.
+        if cursor:
+            cursor.close()
+            connection.close()
 
     return output[::-1]  # Reverse it, to deal with table dependencies.
 

+ 17 - 13
django/db/backends/__init__.py

@@ -194,13 +194,16 @@ class BaseDatabaseWrapper(object):
     ##### Backend-specific savepoint management methods #####
 
     def _savepoint(self, sid):
-        self.cursor().execute(self.ops.savepoint_create_sql(sid))
+        with self.cursor() as cursor:
+            cursor.execute(self.ops.savepoint_create_sql(sid))
 
     def _savepoint_rollback(self, sid):
-        self.cursor().execute(self.ops.savepoint_rollback_sql(sid))
+        with self.cursor() as cursor:
+            cursor.execute(self.ops.savepoint_rollback_sql(sid))
 
     def _savepoint_commit(self, sid):
-        self.cursor().execute(self.ops.savepoint_commit_sql(sid))
+        with self.cursor() as cursor:
+            cursor.execute(self.ops.savepoint_commit_sql(sid))
 
     def _savepoint_allowed(self):
         # Savepoints cannot be created outside a transaction
@@ -688,15 +691,15 @@ class BaseDatabaseFeatures(object):
             # otherwise autocommit will cause the confimation to
             # fail.
             self.connection.enter_transaction_management()
-            cursor = self.connection.cursor()
-            cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
-            self.connection.commit()
-            cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
-            self.connection.rollback()
-            cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
-            count, = cursor.fetchone()
-            cursor.execute('DROP TABLE ROLLBACK_TEST')
-            self.connection.commit()
+            with self.connection.cursor() as cursor:
+                cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
+                self.connection.commit()
+                cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
+                self.connection.rollback()
+                cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
+                count, = cursor.fetchone()
+                cursor.execute('DROP TABLE ROLLBACK_TEST')
+                self.connection.commit()
         finally:
             self.connection.leave_transaction_management()
         return count == 0
@@ -1253,7 +1256,8 @@ class BaseDatabaseIntrospection(object):
         in sorting order between databases.
         """
         if cursor is None:
-            cursor = self.connection.cursor()
+            with self.connection.cursor() as cursor:
+                return sorted(self.get_table_list(cursor))
         return sorted(self.get_table_list(cursor))
 
     def get_table_list(self, cursor):

+ 35 - 36
django/db/backends/creation.py

@@ -378,9 +378,8 @@ class BaseDatabaseCreation(object):
 
         call_command('createcachetable', database=self.connection.alias)
 
-        # Get a cursor (even though we don't need one yet). This has
-        # the side effect of initializing the test database.
-        self.connection.cursor()
+        # Ensure a connection for the side effect of initializing the test database.
+        self.connection.ensure_connection()
 
         return test_database_name
 
@@ -406,34 +405,34 @@ class BaseDatabaseCreation(object):
         qn = self.connection.ops.quote_name
 
         # Create the test database and connect to it.
-        cursor = self._nodb_connection.cursor()
-        try:
-            cursor.execute(
-                "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
-        except Exception as e:
-            sys.stderr.write(
-                "Got an error creating the test database: %s\n" % e)
-            if not autoclobber:
-                confirm = input(
-                    "Type 'yes' if you would like to try deleting the test "
-                    "database '%s', or 'no' to cancel: " % test_database_name)
-            if autoclobber or confirm == 'yes':
-                try:
-                    if verbosity >= 1:
-                        print("Destroying old test database '%s'..."
-                              % self.connection.alias)
-                    cursor.execute(
-                        "DROP DATABASE %s" % qn(test_database_name))
-                    cursor.execute(
-                        "CREATE DATABASE %s %s" % (qn(test_database_name),
-                                                   suffix))
-                except Exception as e:
-                    sys.stderr.write(
-                        "Got an error recreating the test database: %s\n" % e)
-                    sys.exit(2)
-            else:
-                print("Tests cancelled.")
-                sys.exit(1)
+        with self._nodb_connection.cursor() as cursor:
+            try:
+                cursor.execute(
+                    "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
+            except Exception as e:
+                sys.stderr.write(
+                    "Got an error creating the test database: %s\n" % e)
+                if not autoclobber:
+                    confirm = input(
+                        "Type 'yes' if you would like to try deleting the test "
+                        "database '%s', or 'no' to cancel: " % test_database_name)
+                if autoclobber or confirm == 'yes':
+                    try:
+                        if verbosity >= 1:
+                            print("Destroying old test database '%s'..."
+                                  % self.connection.alias)
+                        cursor.execute(
+                            "DROP DATABASE %s" % qn(test_database_name))
+                        cursor.execute(
+                            "CREATE DATABASE %s %s" % (qn(test_database_name),
+                                                       suffix))
+                    except Exception as e:
+                        sys.stderr.write(
+                            "Got an error recreating the test database: %s\n" % e)
+                        sys.exit(2)
+                else:
+                    print("Tests cancelled.")
+                    sys.exit(1)
 
         return test_database_name
 
@@ -461,11 +460,11 @@ class BaseDatabaseCreation(object):
         # ourselves. Connect to the previous database (not the test database)
         # to do so, because it's not allowed to delete a database while being
         # connected to it.
-        cursor = self._nodb_connection.cursor()
-        # Wait to avoid "database is being accessed by other users" errors.
-        time.sleep(1)
-        cursor.execute("DROP DATABASE %s"
-                       % self.connection.ops.quote_name(test_database_name))
+        with self._nodb_connection.cursor() as cursor:
+            # Wait to avoid "database is being accessed by other users" errors.
+            time.sleep(1)
+            cursor.execute("DROP DATABASE %s"
+                           % self.connection.ops.quote_name(test_database_name))
 
     def set_autocommit(self):
         """

+ 18 - 19
django/db/backends/mysql/base.py

@@ -180,15 +180,15 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     @cached_property
     def _mysql_storage_engine(self):
         "Internal method used in Django tests. Don't rely on this from your code"
-        cursor = self.connection.cursor()
-        cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
-        # This command is MySQL specific; the second column
-        # will tell you the default table type of the created
-        # table. Since all Django's test tables will have the same
-        # table type, that's enough to evaluate the feature.
-        cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
-        result = cursor.fetchone()
-        cursor.execute('DROP TABLE INTROSPECT_TEST')
+        with self.connection.cursor() as cursor:
+            cursor.execute('CREATE TABLE INTROSPECT_TEST (X INT)')
+            # This command is MySQL specific; the second column
+            # will tell you the default table type of the created
+            # table. Since all Django's test tables will have the same
+            # table type, that's enough to evaluate the feature.
+            cursor.execute("SHOW TABLE STATUS WHERE Name='INTROSPECT_TEST'")
+            result = cursor.fetchone()
+            cursor.execute('DROP TABLE INTROSPECT_TEST')
         return result[1]
 
     @cached_property
@@ -207,9 +207,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
             return False
 
         # Test if the time zone definitions are installed.
-        cursor = self.connection.cursor()
-        cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
-        return cursor.fetchone() is not None
+        with self.connection.cursor() as cursor:
+            cursor.execute("SELECT 1 FROM mysql.time_zone LIMIT 1")
+            return cursor.fetchone() is not None
 
 
 class DatabaseOperations(BaseDatabaseOperations):
@@ -461,13 +461,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         return conn
 
     def init_connection_state(self):
-        cursor = self.connection.cursor()
-        # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
-        # on a recently-inserted row will return when the field is tested for
-        # NULL.  Disabling this value brings this aspect of MySQL in line with
-        # SQL standards.
-        cursor.execute('SET SQL_AUTO_IS_NULL = 0')
-        cursor.close()
+        with self.connection.cursor() as cursor:
+            # SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
+            # on a recently-inserted row will return when the field is tested for
+            # NULL.  Disabling this value brings this aspect of MySQL in line with
+            # SQL standards.
+            cursor.execute('SET SQL_AUTO_IS_NULL = 0')
 
     def create_cursor(self):
         cursor = self.connection.cursor()

+ 2 - 2
django/db/backends/oracle/base.py

@@ -353,8 +353,8 @@ WHEN (new.%(col_name)s IS NULL)
     def regex_lookup(self, lookup_type):
         # If regex_lookup is called before it's been initialized, then create
         # a cursor to initialize it and recur.
-        self.connection.cursor()
-        return self.connection.ops.regex_lookup(lookup_type)
+        with self.connection.cursor():
+            return self.connection.ops.regex_lookup(lookup_type)
 
     def return_insert_id(self):
         return "RETURNING %s INTO %%s", (InsertIdVar(),)

+ 4 - 2
django/db/backends/postgresql_psycopg2/base.py

@@ -149,8 +149,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
 
             if conn_tz != tz:
                 cursor = self.connection.cursor()
-                cursor.execute(self.ops.set_time_zone_sql(), [tz])
-                cursor.close()
+                try:
+                    cursor.execute(self.ops.set_time_zone_sql(), [tz])
+                finally:
+                    cursor.close()
                 # Commit after setting the time zone (see #17062)
                 if not self.get_autocommit():
                     self.connection.commit()

+ 3 - 3
django/db/backends/postgresql_psycopg2/version.py

@@ -39,6 +39,6 @@ def get_version(connection):
     if hasattr(connection, 'server_version'):
         return connection.server_version
     else:
-        cursor = connection.cursor()
-        cursor.execute("SELECT version()")
-        return _parse_version(cursor.fetchone()[0])
+        with connection.cursor() as cursor:
+            cursor.execute("SELECT version()")
+            return _parse_version(cursor.fetchone()[0])

+ 4 - 4
django/db/backends/schema.py

@@ -86,14 +86,13 @@ class BaseDatabaseSchemaEditor(object):
         """
         Executes the given SQL statement, with optional parameters.
         """
-        # Get the cursor
-        cursor = self.connection.cursor()
         # Log the command we're running, then run it
         logger.debug("%s; (params %r)" % (sql, params))
         if self.collect_sql:
             self.collected_sql.append((sql % tuple(map(self.connection.ops.quote_parameter, params))) + ";")
         else:
-            cursor.execute(sql, params)
+            with self.connection.cursor() as cursor:
+                cursor.execute(sql, params)
 
     def quote_name(self, name):
         return self.connection.ops.quote_name(name)
@@ -791,7 +790,8 @@ class BaseDatabaseSchemaEditor(object):
         Returns all constraint names matching the columns and conditions
         """
         column_names = list(column_names) if column_names else None
-        constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
+        with self.connection.cursor() as cursor:
+            constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table)
         result = []
         for name, infodict in constraints.items():
             if column_names is None or column_names == infodict['columns']:

+ 8 - 8
django/db/backends/sqlite3/base.py

@@ -122,14 +122,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
         rule out support for STDDEV. We need to manually check
         whether the call works.
         """
-        cursor = self.connection.cursor()
-        cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
-        try:
-            cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
-            has_support = True
-        except utils.DatabaseError:
-            has_support = False
-        cursor.execute('DROP TABLE STDDEV_TEST')
+        with self.connection.cursor() as cursor:
+            cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
+            try:
+                cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
+                has_support = True
+            except utils.DatabaseError:
+                has_support = False
+            cursor.execute('DROP TABLE STDDEV_TEST')
         return has_support
 
     @cached_property

+ 49 - 44
django/db/models/query.py

@@ -1522,54 +1522,59 @@ class RawQuerySet(object):
 
         query = iter(self.query)
 
-        # Find out which columns are model's fields, and which ones should be
-        # annotated to the model.
-        for pos, column in enumerate(self.columns):
-            if column in self.model_fields:
-                model_init_field_names[self.model_fields[column].attname] = pos
-            else:
-                annotation_fields.append((column, pos))
+        try:
+            # Find out which columns are model's fields, and which ones should be
+            # annotated to the model.
+            for pos, column in enumerate(self.columns):
+                if column in self.model_fields:
+                    model_init_field_names[self.model_fields[column].attname] = pos
+                else:
+                    annotation_fields.append((column, pos))
 
-        # Find out which model's fields are not present in the query.
-        skip = set()
-        for field in self.model._meta.fields:
-            if field.attname not in model_init_field_names:
-                skip.add(field.attname)
-        if skip:
-            if self.model._meta.pk.attname in skip:
-                raise InvalidQuery('Raw query must include the primary key')
-            model_cls = deferred_class_factory(self.model, skip)
-        else:
-            model_cls = self.model
-            # All model's fields are present in the query. So, it is possible
-            # to use *args based model instantation. For each field of the model,
-            # record the query column position matching that field.
-            model_init_field_pos = []
+            # Find out which model's fields are not present in the query.
+            skip = set()
             for field in self.model._meta.fields:
-                model_init_field_pos.append(model_init_field_names[field.attname])
-        if need_resolv_columns:
-            fields = [self.model_fields.get(c, None) for c in self.columns]
-        # Begin looping through the query values.
-        for values in query:
-            if need_resolv_columns:
-                values = compiler.resolve_columns(values, fields)
-            # Associate fields to values
+                if field.attname not in model_init_field_names:
+                    skip.add(field.attname)
             if skip:
-                model_init_kwargs = {}
-                for attname, pos in six.iteritems(model_init_field_names):
-                    model_init_kwargs[attname] = values[pos]
-                instance = model_cls(**model_init_kwargs)
+                if self.model._meta.pk.attname in skip:
+                    raise InvalidQuery('Raw query must include the primary key')
+                model_cls = deferred_class_factory(self.model, skip)
             else:
-                model_init_args = [values[pos] for pos in model_init_field_pos]
-                instance = model_cls(*model_init_args)
-            if annotation_fields:
-                for column, pos in annotation_fields:
-                    setattr(instance, column, values[pos])
-
-            instance._state.db = db
-            instance._state.adding = False
-
-            yield instance
+                model_cls = self.model
+                # All model's fields are present in the query. So, it is possible
+                # to use *args based model instantation. For each field of the model,
+                # record the query column position matching that field.
+                model_init_field_pos = []
+                for field in self.model._meta.fields:
+                    model_init_field_pos.append(model_init_field_names[field.attname])
+            if need_resolv_columns:
+                fields = [self.model_fields.get(c, None) for c in self.columns]
+            # Begin looping through the query values.
+            for values in query:
+                if need_resolv_columns:
+                    values = compiler.resolve_columns(values, fields)
+                # Associate fields to values
+                if skip:
+                    model_init_kwargs = {}
+                    for attname, pos in six.iteritems(model_init_field_names):
+                        model_init_kwargs[attname] = values[pos]
+                    instance = model_cls(**model_init_kwargs)
+                else:
+                    model_init_args = [values[pos] for pos in model_init_field_pos]
+                    instance = model_cls(*model_init_args)
+                if annotation_fields:
+                    for column, pos in annotation_fields:
+                        setattr(instance, column, values[pos])
+
+                instance._state.db = db
+                instance._state.adding = False
+
+                yield instance
+        finally:
+            # Done iterating the Query. If it has its own cursor, close it.
+            if hasattr(self.query, 'cursor') and self.query.cursor:
+                self.query.cursor.close()
 
     def __repr__(self):
         text = self.raw_query

+ 11 - 10
django/db/models/sql/compiler.py

@@ -1,4 +1,5 @@
 import datetime
+import sys
 
 from django.conf import settings
 from django.core.exceptions import FieldError
@@ -777,7 +778,7 @@ class SQLCompiler(object):
         cursor = self.connection.cursor()
         try:
             cursor.execute(sql, params)
-        except:
+        except Exception:
             cursor.close()
             raise
 
@@ -908,15 +909,15 @@ class SQLInsertCompiler(SQLCompiler):
     def execute_sql(self, return_id=False):
         assert not (return_id and len(self.query.objs) != 1)
         self.return_id = return_id
-        cursor = self.connection.cursor()
-        for sql, params in self.as_sql():
-            cursor.execute(sql, params)
-        if not (return_id and cursor):
-            return
-        if self.connection.features.can_return_id_from_insert:
-            return self.connection.ops.fetch_returned_insert_id(cursor)
-        return self.connection.ops.last_insert_id(cursor,
-                self.query.get_meta().db_table, self.query.get_meta().pk.column)
+        with self.connection.cursor() as cursor:
+            for sql, params in self.as_sql():
+                cursor.execute(sql, params)
+            if not (return_id and cursor):
+                return
+            if self.connection.features.can_return_id_from_insert:
+                return self.connection.ops.fetch_returned_insert_id(cursor)
+            return self.connection.ops.last_insert_id(cursor,
+                    self.query.get_meta().db_table, self.query.get_meta().pk.column)
 
 
 class SQLDeleteCompiler(SQLCompiler):

+ 28 - 22
tests/backends/tests.py

@@ -59,9 +59,9 @@ class OracleChecks(unittest.TestCase):
         # stored procedure through our cursor wrapper.
         from django.db.backends.oracle.base import convert_unicode
 
-        cursor = connection.cursor()
-        cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
-                        [convert_unicode('_django_testing!')])
+        with connection.cursor() as cursor:
+            cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
+                            [convert_unicode('_django_testing!')])
 
     @unittest.skipUnless(connection.vendor == 'oracle',
                          "No need to check Oracle cursor semantics")
@@ -70,31 +70,31 @@ class OracleChecks(unittest.TestCase):
         # as query parameters.
         from django.db.backends.oracle.base import Database
 
-        cursor = connection.cursor()
-        var = cursor.var(Database.STRING)
-        cursor.execute("BEGIN %s := 'X'; END; ", [var])
-        self.assertEqual(var.getvalue(), 'X')
+        with connection.cursor() as cursor:
+            var = cursor.var(Database.STRING)
+            cursor.execute("BEGIN %s := 'X'; END; ", [var])
+            self.assertEqual(var.getvalue(), 'X')
 
     @unittest.skipUnless(connection.vendor == 'oracle',
                          "No need to check Oracle cursor semantics")
     def test_long_string(self):
         # If the backend is Oracle, test that we can save a text longer
         # than 4000 chars and read it properly
-        c = connection.cursor()
-        c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
-        long_str = ''.join(six.text_type(x) for x in xrange(4000))
-        c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
-        c.execute('SELECT text FROM ltext')
-        row = c.fetchone()
-        self.assertEqual(long_str, row[0].read())
-        c.execute('DROP TABLE ltext')
+        with connection.cursor() as cursor:
+            cursor.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
+            long_str = ''.join(six.text_type(x) for x in xrange(4000))
+            cursor.execute('INSERT INTO ltext VALUES (%s)', [long_str])
+            cursor.execute('SELECT text FROM ltext')
+            row = cursor.fetchone()
+            self.assertEqual(long_str, row[0].read())
+            cursor.execute('DROP TABLE ltext')
 
     @unittest.skipUnless(connection.vendor == 'oracle',
                          "No need to check Oracle connection semantics")
     def test_client_encoding(self):
         # If the backend is Oracle, test that the client encoding is set
         # correctly.  This was broken under Cygwin prior to r14781.
-        connection.cursor()  # Ensure the connection is initialized.
+        self.connection.ensure_connection()
         self.assertEqual(connection.connection.encoding, "UTF-8")
         self.assertEqual(connection.connection.nencoding, "UTF-8")
 
@@ -103,12 +103,12 @@ class OracleChecks(unittest.TestCase):
     def test_order_of_nls_parameters(self):
         # an 'almost right' datetime should work with configured
         # NLS parameters as per #18465.
-        c = connection.cursor()
-        query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
-        # Test that the query succeeds without errors - pre #18465 this
-        # wasn't the case.
-        c.execute(query)
-        self.assertEqual(c.fetchone()[0], 1)
+        with connection.cursor() as cursor:
+            query = "select 1 from dual where '1936-12-29 00:00' < sysdate"
+            # Test that the query succeeds without errors - pre #18465 this
+            # wasn't the case.
+            cursor.execute(query)
+            self.assertEqual(cursor.fetchone()[0], 1)
 
 
 class SQLiteTests(TestCase):
@@ -328,6 +328,12 @@ class PostgresVersionTest(TestCase):
             def fetchone(self):
                 return ["PostgreSQL 8.3"]
 
+            def __enter__(self):
+                return self
+
+            def __exit__(self, type, value, traceback):
+                pass
+
         class OlderConnectionMock(object):
             "Mock of psycopg2 (< 2.0.12) connection"
             def cursor(self):

+ 3 - 4
tests/cache/tests.py

@@ -896,10 +896,9 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase):
         management.call_command('createcachetable', verbosity=0, interactive=False)
 
     def drop_table(self):
-        cursor = connection.cursor()
-        table_name = connection.ops.quote_name('test cache table')
-        cursor.execute('DROP TABLE %s' % table_name)
-        cursor.close()
+        with connection.cursor() as cursor:
+            table_name = connection.ops.quote_name('test cache table')
+            cursor.execute('DROP TABLE %s' % table_name)
 
     def test_zero_cull(self):
         self._perform_cull_test(caches['zero_cull'], 50, 18)

+ 8 - 8
tests/custom_methods/models.py

@@ -30,11 +30,11 @@ class Article(models.Model):
         database query for the sake of demonstration.
         """
         from django.db import connection
-        cursor = connection.cursor()
-        cursor.execute("""
-            SELECT id, headline, pub_date
-            FROM custom_methods_article
-            WHERE pub_date = %s
-                AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
-                                  self.id])
-        return [self.__class__(*row) for row in cursor.fetchall()]
+        with connection.cursor() as cursor:
+            cursor.execute("""
+                SELECT id, headline, pub_date
+                FROM custom_methods_article
+                WHERE pub_date = %s
+                    AND id != %s""", [connection.ops.value_to_db_date(self.pub_date),
+                                      self.id])
+            return [self.__class__(*row) for row in cursor.fetchall()]

+ 3 - 3
tests/initial_sql_regress/tests.py

@@ -28,9 +28,9 @@ class InitialSQLTests(TestCase):
         connection = connections[DEFAULT_DB_ALIAS]
         custom_sql = custom_sql_for_model(Simple, no_style(), connection)
         self.assertEqual(len(custom_sql), 9)
-        cursor = connection.cursor()
-        for sql in custom_sql:
-            cursor.execute(sql)
+        with connection.cursor() as cursor:
+            for sql in custom_sql:
+                cursor.execute(sql)
         self.assertEqual(Simple.objects.count(), 9)
         self.assertEqual(
             Simple.objects.get(name__contains='placeholders').name,

+ 30 - 30
tests/introspection/tests.py

@@ -23,17 +23,17 @@ class IntrospectionTests(TestCase):
                      "'%s' isn't in table_list()." % Article._meta.db_table)
 
     def test_django_table_names(self):
-        cursor = connection.cursor()
-        cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
-        tl = connection.introspection.django_table_names()
-        cursor.execute("DROP TABLE django_ixn_test_table;")
-        self.assertTrue('django_ixn_testcase_table' not in tl,
-                     "django_table_names() returned a non-Django table")
+        with connection.cursor() as cursor:
+            cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
+            tl = connection.introspection.django_table_names()
+            cursor.execute("DROP TABLE django_ixn_test_table;")
+            self.assertTrue('django_ixn_testcase_table' not in tl,
+                         "django_table_names() returned a non-Django table")
 
     def test_django_table_names_retval_type(self):
         # Ticket #15216
-        cursor = connection.cursor()
-        cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
+        with connection.cursor() as cursor:
+            cursor.execute('CREATE TABLE django_ixn_test_table (id INTEGER);')
 
         tl = connection.introspection.django_table_names(only_existing=True)
         self.assertIs(type(tl), list)
@@ -53,14 +53,14 @@ class IntrospectionTests(TestCase):
                      'Reporter sequence not found in sequence_list()')
 
     def test_get_table_description_names(self):
-        cursor = connection.cursor()
-        desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+        with connection.cursor() as cursor:
+            desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
         self.assertEqual([r[0] for r in desc],
                          [f.column for f in Reporter._meta.fields])
 
     def test_get_table_description_types(self):
-        cursor = connection.cursor()
-        desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+        with connection.cursor() as cursor:
+            desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
         # The MySQL exception is due to the cursor.description returning the same constant for
         # text and blob columns. TODO: use information_schema database to retrieve the proper
         # field type on MySQL
@@ -75,8 +75,8 @@ class IntrospectionTests(TestCase):
     # inspect the length of character columns).
     @expectedFailureOnOracle
     def test_get_table_description_col_lengths(self):
-        cursor = connection.cursor()
-        desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+        with connection.cursor() as cursor:
+            desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
         self.assertEqual(
             [r[3] for r in desc if datatype(r[1], r) == 'CharField'],
             [30, 30, 75]
@@ -87,8 +87,8 @@ class IntrospectionTests(TestCase):
     # so its idea about null_ok in cursor.description is different from ours.
     @skipIfDBFeature('interprets_empty_strings_as_nulls')
     def test_get_table_description_nullable(self):
-        cursor = connection.cursor()
-        desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
+        with connection.cursor() as cursor:
+            desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
         self.assertEqual(
             [r[6] for r in desc],
             [False, False, False, False, True, True]
@@ -97,15 +97,15 @@ class IntrospectionTests(TestCase):
     # Regression test for #9991 - 'real' types in postgres
     @skipUnlessDBFeature('has_real_datatype')
     def test_postgresql_real_type(self):
-        cursor = connection.cursor()
-        cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
-        desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
-        cursor.execute('DROP TABLE django_ixn_real_test_table;')
+        with connection.cursor() as cursor:
+            cursor.execute("CREATE TABLE django_ixn_real_test_table (number REAL);")
+            desc = connection.introspection.get_table_description(cursor, 'django_ixn_real_test_table')
+            cursor.execute('DROP TABLE django_ixn_real_test_table;')
         self.assertEqual(datatype(desc[0][1], desc[0]), 'FloatField')
 
     def test_get_relations(self):
-        cursor = connection.cursor()
-        relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
+        with connection.cursor() as cursor:
+            relations = connection.introspection.get_relations(cursor, Article._meta.db_table)
 
         # Older versions of MySQL don't have the chops to report on this stuff,
         # so just skip it if no relations come back. If they do, though, we
@@ -117,21 +117,21 @@ class IntrospectionTests(TestCase):
 
     @skipUnlessDBFeature('can_introspect_foreign_keys')
     def test_get_key_columns(self):
-        cursor = connection.cursor()
-        key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
+        with connection.cursor() as cursor:
+            key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
         self.assertEqual(
             set(key_columns),
             set([('reporter_id', Reporter._meta.db_table, 'id'),
                  ('response_to_id', Article._meta.db_table, 'id')]))
 
     def test_get_primary_key_column(self):
-        cursor = connection.cursor()
-        primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
+        with connection.cursor() as cursor:
+            primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
         self.assertEqual(primary_key_column, 'id')
 
     def test_get_indexes(self):
-        cursor = connection.cursor()
-        indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
+        with connection.cursor() as cursor:
+            indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)
         self.assertEqual(indexes['reporter_id'], {'unique': False, 'primary_key': False})
 
     def test_get_indexes_multicol(self):
@@ -139,8 +139,8 @@ class IntrospectionTests(TestCase):
         Test that multicolumn indexes are not included in the introspection
         results.
         """
-        cursor = connection.cursor()
-        indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
+        with connection.cursor() as cursor:
+            indexes = connection.introspection.get_indexes(cursor, Reporter._meta.db_table)
         self.assertNotIn('first_name', indexes)
         self.assertIn('id', indexes)
 

+ 21 - 14
tests/migrations/test_base.py

@@ -9,33 +9,40 @@ class MigrationTestBase(TransactionTestCase):
 
     available_apps = ["migrations"]
 
+    def get_table_description(self, table):
+        with connection.cursor() as cursor:
+            return connection.introspection.get_table_description(cursor, table)
+
     def assertTableExists(self, table):
-        self.assertIn(table, connection.introspection.get_table_list(connection.cursor()))
+        with connection.cursor() as cursor:
+            self.assertIn(table, connection.introspection.get_table_list(cursor))
 
     def assertTableNotExists(self, table):
-        self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
+        with connection.cursor() as cursor:
+            self.assertNotIn(table, connection.introspection.get_table_list(cursor))
 
     def assertColumnExists(self, table, column):
-        self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+        self.assertIn(column, [c.name for c in self.get_table_description(table)])
 
     def assertColumnNotExists(self, table, column):
-        self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+        self.assertNotIn(column, [c.name for c in self.get_table_description(table)])
 
     def assertColumnNull(self, table, column):
-        self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], True)
+        self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], True)
 
     def assertColumnNotNull(self, table, column):
-        self.assertEqual([c.null_ok for c in connection.introspection.get_table_description(connection.cursor(), table) if c.name == column][0], False)
+        self.assertEqual([c.null_ok for c in self.get_table_description(table) if c.name == column][0], False)
 
     def assertIndexExists(self, table, columns, value=True):
-        self.assertEqual(
-            value,
-            any(
-                c["index"]
-                for c in connection.introspection.get_constraints(connection.cursor(), table).values()
-                if c['columns'] == list(columns)
-            ),
-        )
+        with connection.cursor() as cursor:
+            self.assertEqual(
+                value,
+                any(
+                    c["index"]
+                    for c in connection.introspection.get_constraints(cursor, table).values()
+                    if c['columns'] == list(columns)
+                ),
+            )
 
     def assertIndexNotExists(self, table, columns):
         return self.assertIndexExists(table, columns, False)

+ 36 - 36
tests/migrations/test_operations.py

@@ -19,15 +19,15 @@ class OperationTests(MigrationTestBase):
         Creates a test model state and database table.
         """
         # Delete the tables if they already exist
-        cursor = connection.cursor()
-        try:
-            cursor.execute("DROP TABLE %s_pony" % app_label)
-        except:
-            pass
-        try:
-            cursor.execute("DROP TABLE %s_stable" % app_label)
-        except:
-            pass
+        with connection.cursor() as cursor:
+            try:
+                cursor.execute("DROP TABLE %s_pony" % app_label)
+            except:
+                pass
+            try:
+                cursor.execute("DROP TABLE %s_stable" % app_label)
+            except:
+                pass
         # Make the "current" state
         operations = [migrations.CreateModel(
             "Pony",
@@ -348,21 +348,21 @@ class OperationTests(MigrationTestBase):
         operation.state_forwards("test_alflpkfk", new_state)
         self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
         self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
+
+        def assertIdTypeEqualsFkType(self):
+            with connection.cursor() as cursor:
+                id_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_pony") if c.name == "id"][0]
+                fk_type = [c.type_code for c in connection.introspection.get_table_description(cursor, "test_alflpkfk_rider") if c.name == "pony_id"][0]
+            self.assertEqual(id_type, fk_type)
+        assertIdTypeEqualsFkType()
         # Test the database alteration
-        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
-        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
-        self.assertEqual(id_type, fk_type)
         with connection.schema_editor() as editor:
             operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
-        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
-        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
-        self.assertEqual(id_type, fk_type)
+        assertIdTypeEqualsFkType()
         # And test reversal
         with connection.schema_editor() as editor:
             operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
-        id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
-        fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
-        self.assertEqual(id_type, fk_type)
+        assertIdTypeEqualsFkType()
 
     def test_rename_field(self):
         """
@@ -400,24 +400,24 @@ class OperationTests(MigrationTestBase):
         self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
         self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
         # Make sure we can insert duplicate rows
-        cursor = connection.cursor()
-        cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
-        cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
-        cursor.execute("DELETE FROM test_alunto_pony")
-        # Test the database alteration
-        with connection.schema_editor() as editor:
-            operation.database_forwards("test_alunto", editor, project_state, new_state)
-        cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
-        with self.assertRaises(IntegrityError):
-            with atomic():
-                cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
-        cursor.execute("DELETE FROM test_alunto_pony")
-        # And test reversal
-        with connection.schema_editor() as editor:
-            operation.database_backwards("test_alunto", editor, new_state, project_state)
-        cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
-        cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
-        cursor.execute("DELETE FROM test_alunto_pony")
+        with connection.cursor() as cursor:
+            cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+            cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+            cursor.execute("DELETE FROM test_alunto_pony")
+            # Test the database alteration
+            with connection.schema_editor() as editor:
+                operation.database_forwards("test_alunto", editor, project_state, new_state)
+            cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+            with self.assertRaises(IntegrityError):
+                with atomic():
+                    cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+            cursor.execute("DELETE FROM test_alunto_pony")
+            # And test reversal
+            with connection.schema_editor() as editor:
+                operation.database_backwards("test_alunto", editor, new_state, project_state)
+            cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
+            cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
+            cursor.execute("DELETE FROM test_alunto_pony")
         # Test flat unique_together
         operation = migrations.AlterUniqueTogether("Pony", ("pink", "weight"))
         operation.state_forwards("test_alunto", new_state)

+ 1 - 1
tests/requests/tests.py

@@ -725,7 +725,7 @@ class DatabaseConnectionHandlingTests(TransactionTestCase):
         # request_finished signal.
         response = self.client.get('/')
         # Make sure there is an open connection
-        connection.cursor()
+        self.connection.ensure_connection()
         connection.enter_transaction_management()
         signals.request_finished.send(sender=response._handler_class)
         self.assertEqual(len(connection.transaction_state), 0)

+ 55 - 41
tests/schema/tests.py

@@ -37,38 +37,38 @@ class SchemaTests(TransactionTestCase):
 
     def delete_tables(self):
         "Deletes all model tables for our models for a clean test environment"
-        cursor = connection.cursor()
-        connection.disable_constraint_checking()
-        table_names = connection.introspection.table_names(cursor)
-        for model in self.models:
-            # Remove any M2M tables first
-            for field in model._meta.local_many_to_many:
+        with connection.cursor() as cursor:
+            connection.disable_constraint_checking()
+            table_names = connection.introspection.table_names(cursor)
+            for model in self.models:
+                # Remove any M2M tables first
+                for field in model._meta.local_many_to_many:
+                    with atomic():
+                        tbl = field.rel.through._meta.db_table
+                        if tbl in table_names:
+                            cursor.execute(connection.schema_editor().sql_delete_table % {
+                                "table": connection.ops.quote_name(tbl),
+                            })
+                            table_names.remove(tbl)
+                # Then remove the main tables
                 with atomic():
-                    tbl = field.rel.through._meta.db_table
+                    tbl = model._meta.db_table
                     if tbl in table_names:
                         cursor.execute(connection.schema_editor().sql_delete_table % {
                             "table": connection.ops.quote_name(tbl),
                         })
                         table_names.remove(tbl)
-            # Then remove the main tables
-            with atomic():
-                tbl = model._meta.db_table
-                if tbl in table_names:
-                    cursor.execute(connection.schema_editor().sql_delete_table % {
-                        "table": connection.ops.quote_name(tbl),
-                    })
-                    table_names.remove(tbl)
         connection.enable_constraint_checking()
 
     def column_classes(self, model):
-        cursor = connection.cursor()
-        columns = dict(
-            (d[0], (connection.introspection.get_field_type(d[1], d), d))
-            for d in connection.introspection.get_table_description(
-                cursor,
-                model._meta.db_table,
+        with connection.cursor() as cursor:
+            columns = dict(
+                (d[0], (connection.introspection.get_field_type(d[1], d), d))
+                for d in connection.introspection.get_table_description(
+                    cursor,
+                    model._meta.db_table,
+                )
             )
-        )
         # SQLite has a different format for field_type
         for name, (type, desc) in columns.items():
             if isinstance(type, tuple):
@@ -78,6 +78,20 @@ class SchemaTests(TransactionTestCase):
             raise DatabaseError("Table does not exist (empty pragma)")
         return columns
 
+    def get_indexes(self, table):
+        """
+        Get the indexes on the table using a new cursor.
+        """
+        with connection.cursor() as cursor:
+            return connection.introspection.get_indexes(cursor, table)
+
+    def get_constraints(self, table):
+        """
+        Get the constraints on a table using a new cursor.
+        """
+        with connection.cursor() as cursor:
+            return connection.introspection.get_constraints(cursor, table)
+
     # Tests
 
     def test_creation_deletion(self):
@@ -127,7 +141,7 @@ class SchemaTests(TransactionTestCase):
                 strict=True,
             )
         # Make sure the new FK constraint is present
-        constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
+        constraints = self.get_constraints(Book._meta.db_table)
         for name, details in constraints.items():
             if details['columns'] == ["author_id"] and details['foreign_key']:
                 self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
@@ -342,7 +356,7 @@ class SchemaTests(TransactionTestCase):
             editor.create_model(TagM2MTest)
             editor.create_model(UniqueTest)
         # Ensure the M2M exists and points to TagM2MTest
-        constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
+        constraints = self.get_constraints(BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
         if connection.features.supports_foreign_keys:
             for name, details in constraints.items():
                 if details['columns'] == ["tagm2mtest_id"] and details['foreign_key']:
@@ -363,7 +377,7 @@ class SchemaTests(TransactionTestCase):
             # Ensure old M2M is gone
             self.assertRaises(DatabaseError, self.column_classes, BookWithM2M._meta.get_field_by_name("tags")[0].rel.through)
             # Ensure the new M2M exists and points to UniqueTest
-            constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
+            constraints = self.get_constraints(new_field.rel.through._meta.db_table)
             if connection.features.supports_foreign_keys:
                 for name, details in constraints.items():
                     if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
@@ -388,7 +402,7 @@ class SchemaTests(TransactionTestCase):
         with connection.schema_editor() as editor:
             editor.create_model(Author)
         # Ensure the constraint exists
-        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+        constraints = self.get_constraints(Author._meta.db_table)
         for name, details in constraints.items():
             if details['columns'] == ["height"] and details['check']:
                 break
@@ -404,7 +418,7 @@ class SchemaTests(TransactionTestCase):
                 new_field,
                 strict=True,
             )
-        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+        constraints = self.get_constraints(Author._meta.db_table)
         for name, details in constraints.items():
             if details['columns'] == ["height"] and details['check']:
                 self.fail("Check constraint for height found")
@@ -416,7 +430,7 @@ class SchemaTests(TransactionTestCase):
                 Author._meta.get_field_by_name("height")[0],
                 strict=True,
             )
-        constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+        constraints = self.get_constraints(Author._meta.db_table)
         for name, details in constraints.items():
             if details['columns'] == ["height"] and details['check']:
                 break
@@ -527,7 +541,7 @@ class SchemaTests(TransactionTestCase):
             False,
             any(
                 c["index"]
-                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+                for c in self.get_constraints("schema_tag").values()
                 if c['columns'] == ["slug", "title"]
             ),
         )
@@ -543,7 +557,7 @@ class SchemaTests(TransactionTestCase):
             True,
             any(
                 c["index"]
-                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+                for c in self.get_constraints("schema_tag").values()
                 if c['columns'] == ["slug", "title"]
             ),
         )
@@ -561,7 +575,7 @@ class SchemaTests(TransactionTestCase):
             False,
             any(
                 c["index"]
-                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
+                for c in self.get_constraints("schema_tag").values()
                 if c['columns'] == ["slug", "title"]
             ),
         )
@@ -578,7 +592,7 @@ class SchemaTests(TransactionTestCase):
             True,
             any(
                 c["index"]
-                for c in connection.introspection.get_constraints(connection.cursor(), "schema_tagindexed").values()
+                for c in self.get_constraints("schema_tagindexed").values()
                 if c['columns'] == ["slug", "title"]
             ),
         )
@@ -627,7 +641,7 @@ class SchemaTests(TransactionTestCase):
         # Ensure the table is there and has the right index
         self.assertIn(
             "title",
-            connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+            self.get_indexes(Book._meta.db_table),
         )
         # Alter to remove the index
         new_field = CharField(max_length=100, db_index=False)
@@ -642,7 +656,7 @@ class SchemaTests(TransactionTestCase):
         # Ensure the table is there and has no index
         self.assertNotIn(
             "title",
-            connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+            self.get_indexes(Book._meta.db_table),
         )
         # Alter to re-add the index
         with connection.schema_editor() as editor:
@@ -655,7 +669,7 @@ class SchemaTests(TransactionTestCase):
         # Ensure the table is there and has the index again
         self.assertIn(
             "title",
-            connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+            self.get_indexes(Book._meta.db_table),
         )
         # Add a unique column, verify that creates an implicit index
         with connection.schema_editor() as editor:
@@ -665,7 +679,7 @@ class SchemaTests(TransactionTestCase):
             )
         self.assertIn(
             "slug",
-            connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+            self.get_indexes(Book._meta.db_table),
         )
         # Remove the unique, check the index goes with it
         new_field2 = CharField(max_length=20, unique=False)
@@ -679,7 +693,7 @@ class SchemaTests(TransactionTestCase):
             )
         self.assertNotIn(
             "slug",
-            connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table),
+            self.get_indexes(Book._meta.db_table),
         )
 
     def test_primary_key(self):
@@ -691,7 +705,7 @@ class SchemaTests(TransactionTestCase):
             editor.create_model(Tag)
         # Ensure the table is there and has the right PK
         self.assertTrue(
-            connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['id']['primary_key'],
+            self.get_indexes(Tag._meta.db_table)['id']['primary_key'],
         )
         # Alter to change the PK
         new_field = SlugField(primary_key=True)
@@ -707,10 +721,10 @@ class SchemaTests(TransactionTestCase):
         # Ensure the PK changed
         self.assertNotIn(
             'id',
-            connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table),
+            self.get_indexes(Tag._meta.db_table),
         )
         self.assertTrue(
-            connection.introspection.get_indexes(connection.cursor(), Tag._meta.db_table)['slug']['primary_key'],
+            self.get_indexes(Tag._meta.db_table)['slug']['primary_key'],
         )
 
     def test_context_manager_exit(self):
@@ -741,7 +755,7 @@ class SchemaTests(TransactionTestCase):
         # Ensure the table is there and has an index on the column
         self.assertIn(
             column_name,
-            connection.introspection.get_indexes(connection.cursor(), BookWithLongName._meta.db_table),
+            self.get_indexes(BookWithLongName._meta.db_table),
         )
 
     def test_creation_deletion_reserved_names(self):

+ 7 - 6
tests/transactions/tests.py

@@ -202,8 +202,9 @@ class AtomicTests(TransactionTestCase):
             # trigger a database error inside an inner atomic without savepoint
             with self.assertRaises(DatabaseError):
                 with transaction.atomic(savepoint=False):
-                    connection.cursor().execute(
-                        "SELECT no_such_col FROM transactions_reporter")
+                    with connection.cursor() as cursor:
+                        cursor.execute(
+                            "SELECT no_such_col FROM transactions_reporter")
             # prevent atomic from rolling back since we're recovering manually
             self.assertTrue(transaction.get_rollback())
             transaction.set_rollback(False)
@@ -534,8 +535,8 @@ class TransactionRollbackTests(IgnoreDeprecationWarningsMixin, TransactionTestCa
     available_apps = ['transactions']
 
     def execute_bad_sql(self):
-        cursor = connection.cursor()
-        cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
+        with connection.cursor() as cursor:
+            cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
 
     @skipUnlessDBFeature('requires_rollback_on_dirty_transaction')
     def test_bad_sql(self):
@@ -678,6 +679,6 @@ class TransactionContextManagerTests(IgnoreDeprecationWarningsMixin, Transaction
         """
         with self.assertRaises(IntegrityError):
             with transaction.commit_on_success():
-                cursor = connection.cursor()
-                cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
+                with connection.cursor() as cursor:
+                    cursor.execute("INSERT INTO transactions_reporter (first_name, last_name) VALUES ('Douglas', 'Adams');")
         transaction.rollback()

+ 6 - 6
tests/transactions_regress/tests.py

@@ -54,8 +54,8 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
         @commit_on_success
         def raw_sql():
             "Write a record using raw sql under a commit_on_success decorator"
-            cursor = connection.cursor()
-            cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
+            with connection.cursor() as cursor:
+                cursor.execute("INSERT into transactions_regress_mod (fld) values (18)")
 
         raw_sql()
         # Rollback so that if the decorator didn't commit, the record is unwritten
@@ -143,10 +143,10 @@ class TestTransactionClosing(IgnoreDeprecationWarningsMixin, TransactionTestCase
             (reference). All this under commit_on_success, so the second insert should
             be committed.
             """
-            cursor = connection.cursor()
-            cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
-            transaction.rollback()
-            cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
+            with connection.cursor() as cursor:
+                cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
+                transaction.rollback()
+                cursor.execute("INSERT into transactions_regress_mod (fld) values (2)")
 
         reuse_cursor_ref()
         # Rollback so that if the decorator didn't commit, the record is unwritten