Browse Source

Fixed #28853 -- Updated connection.cursor() uses to use a context manager.

Jon Dufresne 7 years ago
parent
commit
7a6fbf36b1

+ 1 - 5
django/contrib/gis/db/backends/mysql/introspection.py

@@ -11,8 +11,7 @@ class MySQLIntrospection(DatabaseIntrospection):
     data_types_reverse[FIELD_TYPE.GEOMETRY] = 'GeometryField'
 
     def get_geometry_type(self, table_name, geo_col):
-        cursor = self.connection.cursor()
-        try:
+        with self.connection.cursor() as cursor:
             # In order to get the specific geometry type of the field,
             # we introspect on the table definition using `DESCRIBE`.
             cursor.execute('DESCRIBE %s' %
@@ -27,9 +26,6 @@ class MySQLIntrospection(DatabaseIntrospection):
                     field_type = OGRGeomType(typ).django
                     field_params = {}
                     break
-        finally:
-            cursor.close()
-
         return field_type, field_params
 
     def supports_spatial_index(self, cursor, table_name):

+ 1 - 5
django/contrib/gis/db/backends/oracle/introspection.py

@@ -11,8 +11,7 @@ class OracleIntrospection(DatabaseIntrospection):
     data_types_reverse[cx_Oracle.OBJECT] = 'GeometryField'
 
     def get_geometry_type(self, table_name, geo_col):
-        cursor = self.connection.cursor()
-        try:
+        with self.connection.cursor() as cursor:
             # Querying USER_SDO_GEOM_METADATA to get the SRID and dimension information.
             try:
                 cursor.execute(
@@ -40,7 +39,4 @@ class OracleIntrospection(DatabaseIntrospection):
             dim = dim.size()
             if dim != 2:
                 field_params['dim'] = dim
-        finally:
-            cursor.close()
-
         return field_type, field_params

+ 2 - 10
django/contrib/gis/db/backends/postgis/introspection.py

@@ -59,15 +59,11 @@ class PostGISIntrospection(DatabaseIntrospection):
         # to query the PostgreSQL pg_type table corresponding to the
         # PostGIS custom data types.
         oid_sql = 'SELECT "oid" FROM "pg_type" WHERE "typname" = %s'
-        cursor = self.connection.cursor()
-        try:
+        with self.connection.cursor() as cursor:
             for field_type in field_types:
                 cursor.execute(oid_sql, (field_type[0],))
                 for result in cursor.fetchall():
                     postgis_types[result[0]] = field_type[1]
-        finally:
-            cursor.close()
-
         return postgis_types
 
     def get_field_type(self, data_type, description):
@@ -88,8 +84,7 @@ class PostGISIntrospection(DatabaseIntrospection):
         PointField or a PolygonField).  Thus, this routine queries the PostGIS
         metadata tables to determine the geometry type.
         """
-        cursor = self.connection.cursor()
-        try:
+        with self.connection.cursor() as cursor:
             try:
                 # First seeing if this geometry column is in the `geometry_columns`
                 cursor.execute('SELECT "coord_dimension", "srid", "type" '
@@ -122,7 +117,4 @@ class PostGISIntrospection(DatabaseIntrospection):
                 field_params['srid'] = srid
             if dim != 2:
                 field_params['dim'] = dim
-        finally:
-            cursor.close()
-
         return field_type, field_params

+ 1 - 5
django/contrib/gis/db/backends/spatialite/introspection.py

@@ -25,8 +25,7 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
     data_types_reverse = GeoFlexibleFieldLookupDict()
 
     def get_geometry_type(self, table_name, geo_col):
-        cursor = self.connection.cursor()
-        try:
+        with self.connection.cursor() as cursor:
             # Querying the `geometry_columns` table to get additional metadata.
             cursor.execute('SELECT coord_dimension, srid, geometry_type '
                            'FROM geometry_columns '
@@ -55,9 +54,6 @@ class SpatiaLiteIntrospection(DatabaseIntrospection):
                 field_params['srid'] = srid
             if (isinstance(dim, str) and 'Z' in dim) or dim == 3:
                 field_params['dim'] = 3
-        finally:
-            cursor.close()
-
         return field_type, field_params
 
     def get_constraints(self, cursor, table_name):

+ 2 - 3
django/db/backends/base/base.py

@@ -573,11 +573,10 @@ class BaseDatabaseWrapper:
         Provide a cursor: with self.temporary_connection() as cursor: ...
         """
         must_close = self.connection is None
-        cursor = self.cursor()
         try:
-            yield cursor
+            with self.cursor() as cursor:
+                yield cursor
         finally:
-            cursor.close()
             if must_close:
                 self.close()
 

+ 14 - 15
django/db/backends/base/introspection.py

@@ -116,21 +116,20 @@ class BaseDatabaseIntrospection:
         from django.db import router
 
         sequence_list = []
-        cursor = self.connection.cursor()
-
-        for app_config in apps.get_app_configs():
-            for model in router.get_migratable_models(app_config, self.connection.alias):
-                if not model._meta.managed:
-                    continue
-                if model._meta.swapped:
-                    continue
-                sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
-                for f in model._meta.local_many_to_many:
-                    # If this is an m2m using an intermediate table,
-                    # we don't need to reset the sequence.
-                    if f.remote_field.through is None:
-                        sequence = self.get_sequences(cursor, f.m2m_db_table())
-                        sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
+        with self.connection.cursor() as cursor:
+            for app_config in apps.get_app_configs():
+                for model in router.get_migratable_models(app_config, self.connection.alias):
+                    if not model._meta.managed:
+                        continue
+                    if model._meta.swapped:
+                        continue
+                    sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
+                    for f in model._meta.local_many_to_many:
+                        # If this is an m2m using an intermediate table,
+                        # we don't need to reset the sequence.
+                        if f.remote_field.through is None:
+                            sequence = self.get_sequences(cursor, f.m2m_db_table())
+                            sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
         return sequence_list
 
     def get_sequences(self, cursor, table_name, table_fields=()):

+ 29 - 28
django/db/backends/mysql/base.py

@@ -294,36 +294,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         Backends can override this method if they can more directly apply
         constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
         """
-        cursor = self.cursor()
-        if table_names is None:
-            table_names = self.introspection.table_names(cursor)
-        for table_name in table_names:
-            primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
-            if not primary_key_column_name:
-                continue
-            key_columns = self.introspection.get_key_columns(cursor, table_name)
-            for column_name, referenced_table_name, referenced_column_name in key_columns:
-                cursor.execute(
-                    """
-                    SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
-                    LEFT JOIN `%s` as REFERRED
-                    ON (REFERRING.`%s` = REFERRED.`%s`)
-                    WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
-                    """ % (
-                        primary_key_column_name, column_name, table_name,
-                        referenced_table_name, column_name, referenced_column_name,
-                        column_name, referenced_column_name,
-                    )
-                )
-                for bad_row in cursor.fetchall():
-                    raise utils.IntegrityError(
-                        "The row in table '%s' with primary key '%s' has an invalid "
-                        "foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
-                        % (
-                            table_name, bad_row[0], table_name, column_name,
-                            bad_row[1], referenced_table_name, referenced_column_name,
+        with self.cursor() as cursor:
+            if table_names is None:
+                table_names = self.introspection.table_names(cursor)
+            for table_name in table_names:
+                primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+                if not primary_key_column_name:
+                    continue
+                key_columns = self.introspection.get_key_columns(cursor, table_name)
+                for column_name, referenced_table_name, referenced_column_name in key_columns:
+                    cursor.execute(
+                        """
+                        SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
+                        LEFT JOIN `%s` as REFERRED
+                        ON (REFERRING.`%s` = REFERRED.`%s`)
+                        WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
+                        """ % (
+                            primary_key_column_name, column_name, table_name,
+                            referenced_table_name, column_name, referenced_column_name,
+                            column_name, referenced_column_name,
                         )
                     )
+                    for bad_row in cursor.fetchall():
+                        raise utils.IntegrityError(
+                            "The row in table '%s' with primary key '%s' has an invalid "
+                            "foreign key: %s.%s contains a value '%s' that does not "
+                            "have a corresponding value in %s.%s."
+                            % (
+                                table_name, bad_row[0], table_name, column_name,
+                                bad_row[1], referenced_table_name, referenced_column_name,
+                            )
+                        )
 
     def is_usable(self):
         try:

+ 71 - 76
django/db/backends/oracle/creation.py

@@ -30,75 +30,72 @@ class DatabaseCreation(BaseDatabaseCreation):
 
     def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
         parameters = self._get_test_db_params()
-        cursor = self._maindb_connection.cursor()
-        if self._test_database_create():
-            try:
-                self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
-            except Exception as e:
-                if 'ORA-01543' not in str(e):
-                    # All errors except "tablespace already exists" cancel tests
-                    sys.stderr.write("Got an error creating the test database: %s\n" % e)
-                    sys.exit(2)
-                if not autoclobber:
-                    confirm = input(
-                        "It appears the test database, %s, already exists. "
-                        "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
-                if autoclobber or confirm == 'yes':
-                    if verbosity >= 1:
-                        print("Destroying old test database for alias '%s'..." % self.connection.alias)
-                    try:
-                        self._execute_test_db_destruction(cursor, parameters, verbosity)
-                    except DatabaseError as e:
-                        if 'ORA-29857' in str(e):
-                            self._handle_objects_preventing_db_destruction(cursor, parameters,
-                                                                           verbosity, autoclobber)
-                        else:
-                            # Ran into a database error that isn't about leftover objects in the tablespace
+        with self._maindb_connection.cursor() as cursor:
+            if self._test_database_create():
+                try:
+                    self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+                except Exception as e:
+                    if 'ORA-01543' not in str(e):
+                        # All errors except "tablespace already exists" cancel tests
+                        sys.stderr.write("Got an error creating the test database: %s\n" % e)
+                        sys.exit(2)
+                    if not autoclobber:
+                        confirm = input(
+                            "It appears the test database, %s, already exists. "
+                            "Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
+                    if autoclobber or confirm == 'yes':
+                        if verbosity >= 1:
+                            print("Destroying old test database for alias '%s'..." % self.connection.alias)
+                        try:
+                            self._execute_test_db_destruction(cursor, parameters, verbosity)
+                        except DatabaseError as e:
+                            if 'ORA-29857' in str(e):
+                                self._handle_objects_preventing_db_destruction(cursor, parameters,
+                                                                               verbosity, autoclobber)
+                            else:
+                                # Ran into a database error that isn't about leftover objects in the tablespace
+                                sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
+                                sys.exit(2)
+                        except Exception as e:
                             sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
                             sys.exit(2)
-                    except Exception as e:
-                        sys.stderr.write("Got an error destroying the old test database: %s\n" % e)
-                        sys.exit(2)
-                    try:
-                        self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
-                    except Exception as e:
-                        sys.stderr.write("Got an error recreating the test database: %s\n" % e)
-                        sys.exit(2)
-                else:
-                    print("Tests cancelled.")
-                    sys.exit(1)
+                        try:
+                            self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
+                        except Exception as e:
+                            sys.stderr.write("Got an error recreating the test database: %s\n" % e)
+                            sys.exit(2)
+                    else:
+                        print("Tests cancelled.")
+                        sys.exit(1)
 
-        if self._test_user_create():
-            if verbosity >= 1:
-                print("Creating test user...")
-            try:
-                self._create_test_user(cursor, parameters, verbosity, keepdb)
-            except Exception as e:
-                if 'ORA-01920' not in str(e):
-                    # All errors except "user already exists" cancel tests
-                    sys.stderr.write("Got an error creating the test user: %s\n" % e)
-                    sys.exit(2)
-                if not autoclobber:
-                    confirm = input(
-                        "It appears the test user, %s, already exists. Type "
-                        "'yes' to delete it, or 'no' to cancel: " % parameters['user'])
-                if autoclobber or confirm == 'yes':
-                    try:
-                        if verbosity >= 1:
-                            print("Destroying old test user...")
-                        self._destroy_test_user(cursor, parameters, verbosity)
-                        if verbosity >= 1:
-                            print("Creating test user...")
-                        self._create_test_user(cursor, parameters, verbosity, keepdb)
-                    except Exception as e:
-                        sys.stderr.write("Got an error recreating the test user: %s\n" % e)
+            if self._test_user_create():
+                if verbosity >= 1:
+                    print("Creating test user...")
+                try:
+                    self._create_test_user(cursor, parameters, verbosity, keepdb)
+                except Exception as e:
+                    if 'ORA-01920' not in str(e):
+                        # All errors except "user already exists" cancel tests
+                        sys.stderr.write("Got an error creating the test user: %s\n" % e)
                         sys.exit(2)
-                else:
-                    print("Tests cancelled.")
-                    sys.exit(1)
-
-        # Cursor must be closed before closing connection.
-        cursor.close()
+                    if not autoclobber:
+                        confirm = input(
+                            "It appears the test user, %s, already exists. Type "
+                            "'yes' to delete it, or 'no' to cancel: " % parameters['user'])
+                    if autoclobber or confirm == 'yes':
+                        try:
+                            if verbosity >= 1:
+                                print("Destroying old test user...")
+                            self._destroy_test_user(cursor, parameters, verbosity)
+                            if verbosity >= 1:
+                                print("Creating test user...")
+                            self._create_test_user(cursor, parameters, verbosity, keepdb)
+                        except Exception as e:
+                            sys.stderr.write("Got an error recreating the test user: %s\n" % e)
+                            sys.exit(2)
+                    else:
+                        print("Tests cancelled.")
+                        sys.exit(1)
         self._maindb_connection.close()  # done with main user -- test user and tablespaces created
         self._switch_to_test_user(parameters)
         return self.connection.settings_dict['NAME']
@@ -175,17 +172,15 @@ class DatabaseCreation(BaseDatabaseCreation):
         self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
         self.connection.close()
         parameters = self._get_test_db_params()
-        cursor = self._maindb_connection.cursor()
-        if self._test_user_create():
-            if verbosity >= 1:
-                print('Destroying test user...')
-            self._destroy_test_user(cursor, parameters, verbosity)
-        if self._test_database_create():
-            if verbosity >= 1:
-                print('Destroying test database tables...')
-            self._execute_test_db_destruction(cursor, parameters, verbosity)
-        # Cursor must be closed before closing connection.
-        cursor.close()
+        with self._maindb_connection.cursor() as cursor:
+            if self._test_user_create():
+                if verbosity >= 1:
+                    print('Destroying test user...')
+                self._destroy_test_user(cursor, parameters, verbosity)
+            if self._test_database_create():
+                if verbosity >= 1:
+                    print('Destroying test database tables...')
+                self._execute_test_db_destruction(cursor, parameters, verbosity)
         self._maindb_connection.close()
 
     def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):

+ 29 - 29
django/db/backends/sqlite3/base.py

@@ -237,37 +237,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         Backends can override this method if they can more directly apply
         constraint checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE")
         """
-        cursor = self.cursor()
-        if table_names is None:
-            table_names = self.introspection.table_names(cursor)
-        for table_name in table_names:
-            primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
-            if not primary_key_column_name:
-                continue
-            key_columns = self.introspection.get_key_columns(cursor, table_name)
-            for column_name, referenced_table_name, referenced_column_name in key_columns:
-                cursor.execute(
-                    """
-                    SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
-                    LEFT JOIN `%s` as REFERRED
-                    ON (REFERRING.`%s` = REFERRED.`%s`)
-                    WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
-                    """
-                    % (
-                        primary_key_column_name, column_name, table_name,
-                        referenced_table_name, column_name, referenced_column_name,
-                        column_name, referenced_column_name,
-                    )
-                )
-                for bad_row in cursor.fetchall():
-                    raise utils.IntegrityError(
-                        "The row in table '%s' with primary key '%s' has an "
-                        "invalid foreign key: %s.%s contains a value '%s' that "
-                        "does not have a corresponding value in %s.%s." % (
-                            table_name, bad_row[0], table_name, column_name,
-                            bad_row[1], referenced_table_name, referenced_column_name,
+        with self.cursor() as cursor:
+            if table_names is None:
+                table_names = self.introspection.table_names(cursor)
+            for table_name in table_names:
+                primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
+                if not primary_key_column_name:
+                    continue
+                key_columns = self.introspection.get_key_columns(cursor, table_name)
+                for column_name, referenced_table_name, referenced_column_name in key_columns:
+                    cursor.execute(
+                        """
+                        SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
+                        LEFT JOIN `%s` as REFERRED
+                        ON (REFERRING.`%s` = REFERRED.`%s`)
+                        WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
+                        """
+                        % (
+                            primary_key_column_name, column_name, table_name,
+                            referenced_table_name, column_name, referenced_column_name,
+                            column_name, referenced_column_name,
                         )
                     )
+                    for bad_row in cursor.fetchall():
+                        raise utils.IntegrityError(
+                            "The row in table '%s' with primary key '%s' has an "
+                            "invalid foreign key: %s.%s contains a value '%s' that "
+                            "does not have a corresponding value in %s.%s." % (
+                                table_name, bad_row[0], table_name, column_name,
+                                bad_row[1], referenced_table_name, referenced_column_name,
+                            )
+                        )
 
     def is_usable(self):
         return True

+ 2 - 1
django/db/migrations/executor.py

@@ -322,7 +322,8 @@ class MigrationExecutor:
         apps = after_state.apps
         found_create_model_migration = False
         found_add_field_migration = False
-        existing_table_names = self.connection.introspection.table_names(self.connection.cursor())
+        with self.connection.cursor() as cursor:
+            existing_table_names = self.connection.introspection.table_names(cursor)
         # Make sure all create model and add field operations are done
         for operation in migration.operations:
             if isinstance(operation, migrations.CreateModel):

+ 3 - 3
django/test/testcases.py

@@ -852,9 +852,9 @@ class TransactionTestCase(SimpleTestCase):
                 no_style(), conn.introspection.sequence_list())
             if sql_list:
                 with transaction.atomic(using=db_name):
-                    cursor = conn.cursor()
-                    for sql in sql_list:
-                        cursor.execute(sql)
+                    with conn.cursor() as cursor:
+                        for sql in sql_list:
+                            cursor.execute(sql)
 
     def _fixture_setup(self):
         for db_name in self._databases_names(include_mirrors=False):

+ 2 - 1
docs/topics/db/multi-db.txt

@@ -664,7 +664,8 @@ object that allows you to retrieve a specific connection using its
 alias::
 
     from django.db import connections
-    cursor = connections['my_db_alias'].cursor()
+    with connections['my_db_alias'].cursor() as cursor:
+        ...
 
 Limitations of multiple databases
 =================================

+ 2 - 2
docs/topics/db/sql.txt

@@ -279,8 +279,8 @@ object that allows you to retrieve a specific connection using its
 alias::
 
     from django.db import connections
-    cursor = connections['my_db_alias'].cursor()
-    # Your code here...
+    with connections['my_db_alias'].cursor() as cursor:
+        # Your code here...
 
 By default, the Python DB API will return results without their field names,
 which means you end up with a ``list`` of values, rather than a ``dict``. At a

+ 12 - 12
tests/backends/postgresql/test_introspection.py

@@ -9,15 +9,15 @@ from ..models import Person
 @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
 class DatabaseSequenceTests(TestCase):
     def test_get_sequences(self):
-        cursor = connection.cursor()
-        seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
-        self.assertEqual(
-            seqs,
-            [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
-        )
-        cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
-        seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
-        self.assertEqual(
-            seqs,
-            [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
-        )
+        with connection.cursor() as cursor:
+            seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+            self.assertEqual(
+                seqs,
+                [{'table': Person._meta.db_table, 'column': 'id', 'name': 'backends_person_id_seq'}]
+            )
+            cursor.execute('ALTER SEQUENCE backends_person_id_seq RENAME TO pers_seq')
+            seqs = connection.introspection.get_sequences(cursor, Person._meta.db_table)
+            self.assertEqual(
+                seqs,
+                [{'table': Person._meta.db_table, 'column': 'id', 'name': 'pers_seq'}]
+            )

+ 7 - 7
tests/backends/postgresql/tests.py

@@ -44,10 +44,10 @@ class Tests(TestCase):
             # Ensure the database default time zone is different than
             # the time zone in new_connection.settings_dict. We can
             # get the default time zone by reset & show.
-            cursor = new_connection.cursor()
-            cursor.execute("RESET TIMEZONE")
-            cursor.execute("SHOW TIMEZONE")
-            db_default_tz = cursor.fetchone()[0]
+            with new_connection.cursor() as cursor:
+                cursor.execute("RESET TIMEZONE")
+                cursor.execute("SHOW TIMEZONE")
+                db_default_tz = cursor.fetchone()[0]
             new_tz = 'Europe/Paris' if db_default_tz == 'UTC' else 'UTC'
             new_connection.close()
 
@@ -59,12 +59,12 @@ class Tests(TestCase):
             # time zone, run a query and rollback.
             with self.settings(TIME_ZONE=new_tz):
                 new_connection.set_autocommit(False)
-                cursor = new_connection.cursor()
                 new_connection.rollback()
 
                 # Now let's see if the rollback rolled back the SET TIME ZONE.
-                cursor.execute("SHOW TIMEZONE")
-                tz = cursor.fetchone()[0]
+                with new_connection.cursor() as cursor:
+                    cursor.execute("SHOW TIMEZONE")
+                    tz = cursor.fetchone()[0]
                 self.assertEqual(new_tz, tz)
 
         finally:

+ 8 - 8
tests/backends/sqlite/tests.py

@@ -82,11 +82,11 @@ class LastExecutedQueryTest(TestCase):
         # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
         # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
         # can hit the SQLITE_MAX_COLUMN limit (#26063).
-        cursor = connection.cursor()
-        sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
-        params = list(range(2001))
-        # This should not raise an exception.
-        cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
+        with connection.cursor() as cursor:
+            sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
+            params = list(range(2001))
+            # This should not raise an exception.
+            cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
 
 
 @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
@@ -97,9 +97,9 @@ class EscapingChecks(TestCase):
     """
     def test_parameter_escaping(self):
         # '%s' escaping support for sqlite3 (#13648).
-        cursor = connection.cursor()
-        cursor.execute("select strftime('%s', date('now'))")
-        response = cursor.fetchall()[0][0]
+        with connection.cursor() as cursor:
+            cursor.execute("select strftime('%s', date('now'))")
+            response = cursor.fetchall()[0][0]
         # response should be an non-zero integer
         self.assertTrue(int(response))
 

+ 44 - 43
tests/backends/tests.py

@@ -56,8 +56,8 @@ class LastExecutedQueryTest(TestCase):
         last_executed_query should not raise an exception even if no previous
         query has been run.
         """
-        cursor = connection.cursor()
-        connection.ops.last_executed_query(cursor, '', ())
+        with connection.cursor() as cursor:
+            connection.ops.last_executed_query(cursor, '', ())
 
     def test_debug_sql(self):
         list(Reporter.objects.filter(first_name="test"))
@@ -78,16 +78,16 @@ class ParameterHandlingTest(TestCase):
 
     def test_bad_parameter_count(self):
         "An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
-        cursor = connection.cursor()
-        query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
-            connection.introspection.table_name_converter('backends_square'),
-            connection.ops.quote_name('root'),
-            connection.ops.quote_name('square')
-        ))
-        with self.assertRaises(Exception):
-            cursor.executemany(query, [(1, 2, 3)])
-        with self.assertRaises(Exception):
-            cursor.executemany(query, [(1,)])
+        with connection.cursor() as cursor:
+            query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' % (
+                connection.introspection.table_name_converter('backends_square'),
+                connection.ops.quote_name('root'),
+                connection.ops.quote_name('square')
+            ))
+            with self.assertRaises(Exception):
+                cursor.executemany(query, [(1, 2, 3)])
+            with self.assertRaises(Exception):
+                cursor.executemany(query, [(1,)])
 
 
 class LongNameTest(TransactionTestCase):
@@ -133,9 +133,10 @@ class LongNameTest(TransactionTestCase):
                 'table': VLM._meta.db_table
             },
         ]
-        cursor = connection.cursor()
-        for statement in connection.ops.sql_flush(no_style(), tables, sequences):
-            cursor.execute(statement)
+        sql_list = connection.ops.sql_flush(no_style(), tables, sequences)
+        with connection.cursor() as cursor:
+            for statement in sql_list:
+                cursor.execute(statement)
 
 
 class SequenceResetTest(TestCase):
@@ -146,10 +147,10 @@ class SequenceResetTest(TestCase):
         Post.objects.create(id=10, name='1st post', text='hello world')
 
         # Reset the sequences for the database
-        cursor = connection.cursor()
         commands = connections[DEFAULT_DB_ALIAS].ops.sequence_reset_sql(no_style(), [Post])
-        for sql in commands:
-            cursor.execute(sql)
+        with connection.cursor() as cursor:
+            for sql in commands:
+                cursor.execute(sql)
 
         # If we create a new object now, it should have a PK greater
         # than the PK we specified manually.
@@ -192,14 +193,14 @@ class EscapingChecks(TestCase):
     bare_select_suffix = connection.features.bare_select_suffix
 
     def test_paramless_no_escaping(self):
-        cursor = connection.cursor()
-        cursor.execute("SELECT '%s'" + self.bare_select_suffix)
-        self.assertEqual(cursor.fetchall()[0][0], '%s')
+        with connection.cursor() as cursor:
+            cursor.execute("SELECT '%s'" + self.bare_select_suffix)
+            self.assertEqual(cursor.fetchall()[0][0], '%s')
 
     def test_parameter_escaping(self):
-        cursor = connection.cursor()
-        cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
-        self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
+        with connection.cursor() as cursor:
+            cursor.execute("SELECT '%%', %s" + self.bare_select_suffix, ('%d',))
+            self.assertEqual(cursor.fetchall()[0], ('%', '%d'))
 
 
 @override_settings(DEBUG=True)
@@ -215,7 +216,6 @@ class BackendTestCase(TransactionTestCase):
         self.create_squares(args, 'format', True)
 
     def create_squares(self, args, paramstyle, multiple):
-        cursor = connection.cursor()
         opts = Square._meta
         tbl = connection.introspection.table_name_converter(opts.db_table)
         f1 = connection.ops.quote_name(opts.get_field('root').column)
@@ -226,10 +226,11 @@ class BackendTestCase(TransactionTestCase):
             query = 'INSERT INTO %s (%s, %s) VALUES (%%(root)s, %%(square)s)' % (tbl, f1, f2)
         else:
             raise ValueError("unsupported paramstyle in test")
-        if multiple:
-            cursor.executemany(query, args)
-        else:
-            cursor.execute(query, args)
+        with connection.cursor() as cursor:
+            if multiple:
+                cursor.executemany(query, args)
+            else:
+                cursor.execute(query, args)
 
     def test_cursor_executemany(self):
         # Test cursor.executemany #4896
@@ -297,18 +298,18 @@ class BackendTestCase(TransactionTestCase):
         Person(first_name="Clark", last_name="Kent").save()
         opts2 = Person._meta
         f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
-        cursor = connection.cursor()
-        cursor.execute(
-            'SELECT %s, %s FROM %s ORDER BY %s' % (
-                qn(f3.column),
-                qn(f4.column),
-                connection.introspection.table_name_converter(opts2.db_table),
-                qn(f3.column),
+        with connection.cursor() as cursor:
+            cursor.execute(
+                'SELECT %s, %s FROM %s ORDER BY %s' % (
+                    qn(f3.column),
+                    qn(f4.column),
+                    connection.introspection.table_name_converter(opts2.db_table),
+                    qn(f3.column),
+                )
             )
-        )
-        self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
-        self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
-        self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
+            self.assertEqual(cursor.fetchone(), ('Clark', 'Kent'))
+            self.assertEqual(list(cursor.fetchmany(2)), [('Jane', 'Doe'), ('John', 'Doe')])
+            self.assertEqual(list(cursor.fetchall()), [('Mary', 'Agnelline'), ('Peter', 'Parker')])
 
     def test_unicode_password(self):
         old_password = connection.settings_dict['PASSWORD']
@@ -344,10 +345,10 @@ class BackendTestCase(TransactionTestCase):
 
     def test_duplicate_table_error(self):
         """ Creating an existing table returns a DatabaseError """
-        cursor = connection.cursor()
         query = 'CREATE TABLE %s (id INTEGER);' % Article._meta.db_table
-        with self.assertRaises(DatabaseError):
-            cursor.execute(query)
+        with connection.cursor() as cursor:
+            with self.assertRaises(DatabaseError):
+                cursor.execute(query)
 
     def test_cursor_contextmanager(self):
         """

+ 4 - 4
tests/db_utils/tests.py

@@ -26,10 +26,10 @@ class DatabaseErrorWrapperTests(TestCase):
 
     @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL test')
     def test_reraising_backend_specific_database_exception(self):
-        cursor = connection.cursor()
-        msg = 'table "X" does not exist'
-        with self.assertRaisesMessage(ProgrammingError, msg) as cm:
-            cursor.execute('DROP TABLE "X"')
+        with connection.cursor() as cursor:
+            msg = 'table "X" does not exist'
+            with self.assertRaisesMessage(ProgrammingError, msg) as cm:
+                cursor.execute('DROP TABLE "X"')
         self.assertNotEqual(type(cm.exception), type(cm.exception.__cause__))
         self.assertIsNotNone(cm.exception.__cause__)
         self.assertIsNotNone(cm.exception.__cause__.pgcode)