123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- import gc
- from unittest.mock import MagicMock, patch
- from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
- from django.db.backends.base.base import BaseDatabaseWrapper
- from django.test import (
- SimpleTestCase,
- TestCase,
- TransactionTestCase,
- skipUnlessDBFeature,
- )
- from django.test.utils import CaptureQueriesContext, override_settings
- from ..models import Person, Square
- class DatabaseWrapperTests(SimpleTestCase):
- def test_repr(self):
- conn = connections[DEFAULT_DB_ALIAS]
- self.assertEqual(
- repr(conn),
- f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
- )
- def test_initialization_class_attributes(self):
- """
- The "initialization" class attributes like client_class and
- creation_class should be set on the class and reflected in the
- corresponding instance attributes of the instantiated backend.
- """
- conn = connections[DEFAULT_DB_ALIAS]
- conn_class = type(conn)
- attr_names = [
- ("client_class", "client"),
- ("creation_class", "creation"),
- ("features_class", "features"),
- ("introspection_class", "introspection"),
- ("ops_class", "ops"),
- ("validation_class", "validation"),
- ]
- for class_attr_name, instance_attr_name in attr_names:
- class_attr_value = getattr(conn_class, class_attr_name)
- self.assertIsNotNone(class_attr_value)
- instance_attr_value = getattr(conn, instance_attr_name)
- self.assertIsInstance(instance_attr_value, class_attr_value)
- def test_initialization_display_name(self):
- self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
- self.assertNotEqual(connection.display_name, "unknown")
- def test_get_database_version(self):
- with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
- msg = (
- "subclasses of BaseDatabaseWrapper may require a "
- "get_database_version() method."
- )
- with self.assertRaisesMessage(NotImplementedError, msg):
- BaseDatabaseWrapper().get_database_version()
- def test_check_database_version_supported_with_none_as_database_version(self):
- with patch.object(connection.features, "minimum_database_version", None):
- connection.check_database_version_supported()
- def test_release_memory_without_garbage_collection(self):
- # Schedule the restore of the garbage collection settings.
- self.addCleanup(gc.set_debug, 0)
- self.addCleanup(gc.enable)
- # Disable automatic garbage collection to control when it's triggered,
- # then run a full collection cycle to ensure `gc.garbage` is empty.
- gc.disable()
- gc.collect()
- # The garbage list isn't automatically populated to avoid CPU overhead,
- # so debugging needs to be enabled to track all unreachable items and
- # have them stored in `gc.garbage`.
- gc.set_debug(gc.DEBUG_SAVEALL)
- # Create a new connection that will be closed during the test, and also
- # ensure that a `DatabaseErrorWrapper` is created for this connection.
- test_connection = connection.copy()
- with test_connection.wrap_database_errors:
- self.assertEqual(test_connection.queries, [])
- # Close the connection and remove references to it. This will mark all
- # objects related to the connection as garbage to be collected.
- test_connection.close()
- test_connection = None
- # Enforce garbage collection to populate `gc.garbage` for inspection.
- gc.collect()
- self.assertEqual(gc.garbage, [])
- class DatabaseWrapperLoggingTests(TransactionTestCase):
- available_apps = ["backends"]
- @override_settings(DEBUG=True)
- def test_commit_debug_log(self):
- conn = connections[DEFAULT_DB_ALIAS]
- with CaptureQueriesContext(conn):
- with self.assertLogs("django.db.backends", "DEBUG") as cm:
- with transaction.atomic():
- Person.objects.create(first_name="first", last_name="last")
- self.assertGreaterEqual(len(conn.queries_log), 3)
- self.assertEqual(conn.queries_log[-3]["sql"], "BEGIN")
- self.assertRegex(
- cm.output[0],
- r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
- rf"BEGIN; args=None; alias={DEFAULT_DB_ALIAS}",
- )
- self.assertEqual(conn.queries_log[-1]["sql"], "COMMIT")
- self.assertRegex(
- cm.output[-1],
- r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
- rf"COMMIT; args=None; alias={DEFAULT_DB_ALIAS}",
- )
- @override_settings(DEBUG=True)
- def test_rollback_debug_log(self):
- conn = connections[DEFAULT_DB_ALIAS]
- with CaptureQueriesContext(conn):
- with self.assertLogs("django.db.backends", "DEBUG") as cm:
- with self.assertRaises(Exception), transaction.atomic():
- Person.objects.create(first_name="first", last_name="last")
- raise Exception("Force rollback")
- self.assertEqual(conn.queries_log[-1]["sql"], "ROLLBACK")
- self.assertRegex(
- cm.output[-1],
- r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
- rf"ROLLBACK; args=None; alias={DEFAULT_DB_ALIAS}",
- )
- def test_no_logs_without_debug(self):
- with self.assertNoLogs("django.db.backends", "DEBUG"):
- with self.assertRaises(Exception), transaction.atomic():
- Person.objects.create(first_name="first", last_name="last")
- raise Exception("Force rollback")
- conn = connections[DEFAULT_DB_ALIAS]
- self.assertEqual(len(conn.queries_log), 0)
- class ExecuteWrapperTests(TestCase):
- @staticmethod
- def call_execute(connection, params=None):
- ret_val = "1" if params is None else "%s"
- sql = "SELECT " + ret_val + connection.features.bare_select_suffix
- with connection.cursor() as cursor:
- cursor.execute(sql, params)
- def call_executemany(self, connection, params=None):
- # executemany() must use an update query. Make sure it does nothing
- # by putting a false condition in the WHERE clause.
- sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
- if params is None:
- params = [(i,) for i in range(3)]
- with connection.cursor() as cursor:
- cursor.executemany(sql, params)
- @staticmethod
- def mock_wrapper():
- return MagicMock(side_effect=lambda execute, *args: execute(*args))
- def test_wrapper_invoked(self):
- wrapper = self.mock_wrapper()
- with connection.execute_wrapper(wrapper):
- self.call_execute(connection)
- self.assertTrue(wrapper.called)
- (_, sql, params, many, context), _ = wrapper.call_args
- self.assertIn("SELECT", sql)
- self.assertIsNone(params)
- self.assertIs(many, False)
- self.assertEqual(context["connection"], connection)
- def test_wrapper_invoked_many(self):
- wrapper = self.mock_wrapper()
- with connection.execute_wrapper(wrapper):
- self.call_executemany(connection)
- self.assertTrue(wrapper.called)
- (_, sql, param_list, many, context), _ = wrapper.call_args
- self.assertIn("DELETE", sql)
- self.assertIsInstance(param_list, (list, tuple))
- self.assertIs(many, True)
- self.assertEqual(context["connection"], connection)
- def test_database_queried(self):
- wrapper = self.mock_wrapper()
- with connection.execute_wrapper(wrapper):
- with connection.cursor() as cursor:
- sql = "SELECT 17" + connection.features.bare_select_suffix
- cursor.execute(sql)
- seventeen = cursor.fetchall()
- self.assertEqual(list(seventeen), [(17,)])
- self.call_executemany(connection)
- def test_nested_wrapper_invoked(self):
- outer_wrapper = self.mock_wrapper()
- inner_wrapper = self.mock_wrapper()
- with (
- connection.execute_wrapper(outer_wrapper),
- connection.execute_wrapper(inner_wrapper),
- ):
- self.call_execute(connection)
- self.assertEqual(inner_wrapper.call_count, 1)
- self.call_executemany(connection)
- self.assertEqual(inner_wrapper.call_count, 2)
- def test_outer_wrapper_blocks(self):
- def blocker(*args):
- pass
- wrapper = self.mock_wrapper()
- c = connection # This alias shortens the next line.
- with (
- c.execute_wrapper(wrapper),
- c.execute_wrapper(blocker),
- c.execute_wrapper(wrapper),
- ):
- with c.cursor() as cursor:
- cursor.execute("The database never sees this")
- self.assertEqual(wrapper.call_count, 1)
- cursor.executemany("The database never sees this %s", [("either",)])
- self.assertEqual(wrapper.call_count, 2)
- def test_wrapper_gets_sql(self):
- wrapper = self.mock_wrapper()
- sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
- with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
- cursor.execute(sql)
- (_, reported_sql, _, _, _), _ = wrapper.call_args
- self.assertEqual(reported_sql, sql)
- def test_wrapper_connection_specific(self):
- wrapper = self.mock_wrapper()
- with connections["other"].execute_wrapper(wrapper):
- self.assertEqual(connections["other"].execute_wrappers, [wrapper])
- self.call_execute(connection)
- self.assertFalse(wrapper.called)
- self.assertEqual(connection.execute_wrappers, [])
- self.assertEqual(connections["other"].execute_wrappers, [])
- def test_wrapper_debug(self):
- def wrap_with_comment(execute, sql, params, many, context):
- return execute(f"/* My comment */ {sql}", params, many, context)
- with CaptureQueriesContext(connection) as ctx:
- with connection.execute_wrapper(wrap_with_comment):
- list(Person.objects.all())
- last_query = ctx.captured_queries[-1]["sql"]
- self.assertTrue(last_query.startswith("/* My comment */"))
- class ConnectionHealthChecksTests(SimpleTestCase):
- databases = {"default"}
- def setUp(self):
- # All test cases here need newly configured and created connections.
- # Use the default db connection for convenience.
- connection.close()
- self.addCleanup(connection.close)
- def patch_settings_dict(self, conn_health_checks):
- self.settings_dict_patcher = patch.dict(
- connection.settings_dict,
- {
- **connection.settings_dict,
- "CONN_MAX_AGE": None,
- "CONN_HEALTH_CHECKS": conn_health_checks,
- },
- )
- self.settings_dict_patcher.start()
- self.addCleanup(self.settings_dict_patcher.stop)
- def run_query(self):
- with connection.cursor() as cursor:
- cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
- @skipUnlessDBFeature("test_db_allows_multiple_connections")
- def test_health_checks_enabled(self):
- self.patch_settings_dict(conn_health_checks=True)
- self.assertIsNone(connection.connection)
- # Newly created connections are considered healthy without performing
- # the health check.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- self.run_query()
- old_connection = connection.connection
- # Simulate request_finished.
- connection.close_if_unusable_or_obsolete()
- self.assertIs(old_connection, connection.connection)
- # Simulate connection health check failing.
- with patch.object(
- connection, "is_usable", return_value=False
- ) as mocked_is_usable:
- self.run_query()
- new_connection = connection.connection
- # A new connection is established.
- self.assertIsNot(new_connection, old_connection)
- # Only one health check per "request" is performed, so the next
- # query will carry on even if the health check fails. Next query
- # succeeds because the real connection is healthy and only the
- # health check failure is mocked.
- self.run_query()
- self.assertIs(new_connection, connection.connection)
- self.assertEqual(mocked_is_usable.call_count, 1)
- # Simulate request_finished.
- connection.close_if_unusable_or_obsolete()
- # The underlying connection is being reused further with health checks
- # succeeding.
- self.run_query()
- self.run_query()
- self.assertIs(new_connection, connection.connection)
- @skipUnlessDBFeature("test_db_allows_multiple_connections")
- def test_health_checks_enabled_errors_occurred(self):
- self.patch_settings_dict(conn_health_checks=True)
- self.assertIsNone(connection.connection)
- # Newly created connections are considered healthy without performing
- # the health check.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- self.run_query()
- old_connection = connection.connection
- # Simulate errors_occurred.
- connection.errors_occurred = True
- # Simulate request_started (the connection is healthy).
- connection.close_if_unusable_or_obsolete()
- # Persistent connections are enabled.
- self.assertIs(old_connection, connection.connection)
- # No additional health checks after the one in
- # close_if_unusable_or_obsolete() are executed during this "request"
- # when running queries.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- self.run_query()
- @skipUnlessDBFeature("test_db_allows_multiple_connections")
- def test_health_checks_disabled(self):
- self.patch_settings_dict(conn_health_checks=False)
- self.assertIsNone(connection.connection)
- # Newly created connections are considered healthy without performing
- # the health check.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- self.run_query()
- old_connection = connection.connection
- # Simulate request_finished.
- connection.close_if_unusable_or_obsolete()
- # Persistent connections are enabled (connection is not).
- self.assertIs(old_connection, connection.connection)
- # Health checks are not performed.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- self.run_query()
- # Health check wasn't performed and the connection is unchanged.
- self.assertIs(old_connection, connection.connection)
- self.run_query()
- # The connection is unchanged after the next query either during
- # the current "request".
- self.assertIs(old_connection, connection.connection)
- @skipUnlessDBFeature("test_db_allows_multiple_connections")
- def test_set_autocommit_health_checks_enabled(self):
- self.patch_settings_dict(conn_health_checks=True)
- self.assertIsNone(connection.connection)
- # Newly created connections are considered healthy without performing
- # the health check.
- with patch.object(connection, "is_usable", side_effect=AssertionError):
- # Simulate outermost atomic block: changing autocommit for
- # a connection.
- connection.set_autocommit(False)
- self.run_query()
- connection.commit()
- connection.set_autocommit(True)
- old_connection = connection.connection
- # Simulate request_finished.
- connection.close_if_unusable_or_obsolete()
- # Persistent connections are enabled.
- self.assertIs(old_connection, connection.connection)
- # Simulate connection health check failing.
- with patch.object(
- connection, "is_usable", return_value=False
- ) as mocked_is_usable:
- # Simulate outermost atomic block: changing autocommit for
- # a connection.
- connection.set_autocommit(False)
- new_connection = connection.connection
- self.assertIsNot(new_connection, old_connection)
- # Only one health check per "request" is performed, so a query will
- # carry on even if the health check fails. This query succeeds
- # because the real connection is healthy and only the health check
- # failure is mocked.
- self.run_query()
- connection.commit()
- connection.set_autocommit(True)
- # The connection is unchanged.
- self.assertIs(new_connection, connection.connection)
- self.assertEqual(mocked_is_usable.call_count, 1)
- # Simulate request_finished.
- connection.close_if_unusable_or_obsolete()
- # The underlying connection is being reused further with health checks
- # succeeding.
- connection.set_autocommit(False)
- self.run_query()
- connection.commit()
- connection.set_autocommit(True)
- self.assertIs(new_connection, connection.connection)
- class MultiDatabaseTests(TestCase):
- databases = {"default", "other"}
- def test_multi_database_init_connection_state_called_once(self):
- for db in self.databases:
- with self.subTest(database=db):
- with patch.object(connections[db], "commit", return_value=None):
- with patch.object(
- connections[db],
- "check_database_version_supported",
- ) as mocked_check_database_version_supported:
- connections[db].init_connection_state()
- after_first_calls = len(
- mocked_check_database_version_supported.mock_calls
- )
- connections[db].init_connection_state()
- self.assertEqual(
- len(mocked_check_database_version_supported.mock_calls),
- after_first_calls,
- )
|