test_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. from unittest.mock import MagicMock, patch
  2. from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
  3. from django.db.backends.base.base import BaseDatabaseWrapper
  4. from django.test import (
  5. SimpleTestCase,
  6. TestCase,
  7. TransactionTestCase,
  8. skipUnlessDBFeature,
  9. )
  10. from django.test.utils import CaptureQueriesContext, override_settings
  11. from ..models import Person, Square
  12. class DatabaseWrapperTests(SimpleTestCase):
  13. def test_repr(self):
  14. conn = connections[DEFAULT_DB_ALIAS]
  15. self.assertEqual(
  16. repr(conn),
  17. f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
  18. )
  19. def test_initialization_class_attributes(self):
  20. """
  21. The "initialization" class attributes like client_class and
  22. creation_class should be set on the class and reflected in the
  23. corresponding instance attributes of the instantiated backend.
  24. """
  25. conn = connections[DEFAULT_DB_ALIAS]
  26. conn_class = type(conn)
  27. attr_names = [
  28. ("client_class", "client"),
  29. ("creation_class", "creation"),
  30. ("features_class", "features"),
  31. ("introspection_class", "introspection"),
  32. ("ops_class", "ops"),
  33. ("validation_class", "validation"),
  34. ]
  35. for class_attr_name, instance_attr_name in attr_names:
  36. class_attr_value = getattr(conn_class, class_attr_name)
  37. self.assertIsNotNone(class_attr_value)
  38. instance_attr_value = getattr(conn, instance_attr_name)
  39. self.assertIsInstance(instance_attr_value, class_attr_value)
  40. def test_initialization_display_name(self):
  41. self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
  42. self.assertNotEqual(connection.display_name, "unknown")
  43. def test_get_database_version(self):
  44. with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
  45. msg = (
  46. "subclasses of BaseDatabaseWrapper may require a "
  47. "get_database_version() method."
  48. )
  49. with self.assertRaisesMessage(NotImplementedError, msg):
  50. BaseDatabaseWrapper().get_database_version()
  51. def test_check_database_version_supported_with_none_as_database_version(self):
  52. with patch.object(connection.features, "minimum_database_version", None):
  53. connection.check_database_version_supported()
  54. class DatabaseWrapperLoggingTests(TransactionTestCase):
  55. available_apps = ["backends"]
  56. @override_settings(DEBUG=True)
  57. def test_commit_debug_log(self):
  58. conn = connections[DEFAULT_DB_ALIAS]
  59. with CaptureQueriesContext(conn):
  60. with self.assertLogs("django.db.backends", "DEBUG") as cm:
  61. with transaction.atomic():
  62. Person.objects.create(first_name="first", last_name="last")
  63. self.assertGreaterEqual(len(conn.queries_log), 3)
  64. self.assertEqual(conn.queries_log[-3]["sql"], "BEGIN")
  65. self.assertRegex(
  66. cm.output[0],
  67. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  68. rf"BEGIN; args=None; alias={DEFAULT_DB_ALIAS}",
  69. )
  70. self.assertEqual(conn.queries_log[-1]["sql"], "COMMIT")
  71. self.assertRegex(
  72. cm.output[-1],
  73. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  74. rf"COMMIT; args=None; alias={DEFAULT_DB_ALIAS}",
  75. )
  76. @override_settings(DEBUG=True)
  77. def test_rollback_debug_log(self):
  78. conn = connections[DEFAULT_DB_ALIAS]
  79. with CaptureQueriesContext(conn):
  80. with self.assertLogs("django.db.backends", "DEBUG") as cm:
  81. with self.assertRaises(Exception), transaction.atomic():
  82. Person.objects.create(first_name="first", last_name="last")
  83. raise Exception("Force rollback")
  84. self.assertEqual(conn.queries_log[-1]["sql"], "ROLLBACK")
  85. self.assertRegex(
  86. cm.output[-1],
  87. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  88. rf"ROLLBACK; args=None; alias={DEFAULT_DB_ALIAS}",
  89. )
  90. def test_no_logs_without_debug(self):
  91. with self.assertNoLogs("django.db.backends", "DEBUG"):
  92. with self.assertRaises(Exception), transaction.atomic():
  93. Person.objects.create(first_name="first", last_name="last")
  94. raise Exception("Force rollback")
  95. conn = connections[DEFAULT_DB_ALIAS]
  96. self.assertEqual(len(conn.queries_log), 0)
  97. class ExecuteWrapperTests(TestCase):
  98. @staticmethod
  99. def call_execute(connection, params=None):
  100. ret_val = "1" if params is None else "%s"
  101. sql = "SELECT " + ret_val + connection.features.bare_select_suffix
  102. with connection.cursor() as cursor:
  103. cursor.execute(sql, params)
  104. def call_executemany(self, connection, params=None):
  105. # executemany() must use an update query. Make sure it does nothing
  106. # by putting a false condition in the WHERE clause.
  107. sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
  108. if params is None:
  109. params = [(i,) for i in range(3)]
  110. with connection.cursor() as cursor:
  111. cursor.executemany(sql, params)
  112. @staticmethod
  113. def mock_wrapper():
  114. return MagicMock(side_effect=lambda execute, *args: execute(*args))
  115. def test_wrapper_invoked(self):
  116. wrapper = self.mock_wrapper()
  117. with connection.execute_wrapper(wrapper):
  118. self.call_execute(connection)
  119. self.assertTrue(wrapper.called)
  120. (_, sql, params, many, context), _ = wrapper.call_args
  121. self.assertIn("SELECT", sql)
  122. self.assertIsNone(params)
  123. self.assertIs(many, False)
  124. self.assertEqual(context["connection"], connection)
  125. def test_wrapper_invoked_many(self):
  126. wrapper = self.mock_wrapper()
  127. with connection.execute_wrapper(wrapper):
  128. self.call_executemany(connection)
  129. self.assertTrue(wrapper.called)
  130. (_, sql, param_list, many, context), _ = wrapper.call_args
  131. self.assertIn("DELETE", sql)
  132. self.assertIsInstance(param_list, (list, tuple))
  133. self.assertIs(many, True)
  134. self.assertEqual(context["connection"], connection)
  135. def test_database_queried(self):
  136. wrapper = self.mock_wrapper()
  137. with connection.execute_wrapper(wrapper):
  138. with connection.cursor() as cursor:
  139. sql = "SELECT 17" + connection.features.bare_select_suffix
  140. cursor.execute(sql)
  141. seventeen = cursor.fetchall()
  142. self.assertEqual(list(seventeen), [(17,)])
  143. self.call_executemany(connection)
  144. def test_nested_wrapper_invoked(self):
  145. outer_wrapper = self.mock_wrapper()
  146. inner_wrapper = self.mock_wrapper()
  147. with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(
  148. inner_wrapper
  149. ):
  150. self.call_execute(connection)
  151. self.assertEqual(inner_wrapper.call_count, 1)
  152. self.call_executemany(connection)
  153. self.assertEqual(inner_wrapper.call_count, 2)
  154. def test_outer_wrapper_blocks(self):
  155. def blocker(*args):
  156. pass
  157. wrapper = self.mock_wrapper()
  158. c = connection # This alias shortens the next line.
  159. with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(
  160. wrapper
  161. ):
  162. with c.cursor() as cursor:
  163. cursor.execute("The database never sees this")
  164. self.assertEqual(wrapper.call_count, 1)
  165. cursor.executemany("The database never sees this %s", [("either",)])
  166. self.assertEqual(wrapper.call_count, 2)
  167. def test_wrapper_gets_sql(self):
  168. wrapper = self.mock_wrapper()
  169. sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
  170. with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
  171. cursor.execute(sql)
  172. (_, reported_sql, _, _, _), _ = wrapper.call_args
  173. self.assertEqual(reported_sql, sql)
  174. def test_wrapper_connection_specific(self):
  175. wrapper = self.mock_wrapper()
  176. with connections["other"].execute_wrapper(wrapper):
  177. self.assertEqual(connections["other"].execute_wrappers, [wrapper])
  178. self.call_execute(connection)
  179. self.assertFalse(wrapper.called)
  180. self.assertEqual(connection.execute_wrappers, [])
  181. self.assertEqual(connections["other"].execute_wrappers, [])
  182. class ConnectionHealthChecksTests(SimpleTestCase):
  183. databases = {"default"}
  184. def setUp(self):
  185. # All test cases here need newly configured and created connections.
  186. # Use the default db connection for convenience.
  187. connection.close()
  188. self.addCleanup(connection.close)
  189. def patch_settings_dict(self, conn_health_checks):
  190. self.settings_dict_patcher = patch.dict(
  191. connection.settings_dict,
  192. {
  193. **connection.settings_dict,
  194. "CONN_MAX_AGE": None,
  195. "CONN_HEALTH_CHECKS": conn_health_checks,
  196. },
  197. )
  198. self.settings_dict_patcher.start()
  199. self.addCleanup(self.settings_dict_patcher.stop)
  200. def run_query(self):
  201. with connection.cursor() as cursor:
  202. cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
  203. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  204. def test_health_checks_enabled(self):
  205. self.patch_settings_dict(conn_health_checks=True)
  206. self.assertIsNone(connection.connection)
  207. # Newly created connections are considered healthy without performing
  208. # the health check.
  209. with patch.object(connection, "is_usable", side_effect=AssertionError):
  210. self.run_query()
  211. old_connection = connection.connection
  212. # Simulate request_finished.
  213. connection.close_if_unusable_or_obsolete()
  214. self.assertIs(old_connection, connection.connection)
  215. # Simulate connection health check failing.
  216. with patch.object(
  217. connection, "is_usable", return_value=False
  218. ) as mocked_is_usable:
  219. self.run_query()
  220. new_connection = connection.connection
  221. # A new connection is established.
  222. self.assertIsNot(new_connection, old_connection)
  223. # Only one health check per "request" is performed, so the next
  224. # query will carry on even if the health check fails. Next query
  225. # succeeds because the real connection is healthy and only the
  226. # health check failure is mocked.
  227. self.run_query()
  228. self.assertIs(new_connection, connection.connection)
  229. self.assertEqual(mocked_is_usable.call_count, 1)
  230. # Simulate request_finished.
  231. connection.close_if_unusable_or_obsolete()
  232. # The underlying connection is being reused further with health checks
  233. # succeeding.
  234. self.run_query()
  235. self.run_query()
  236. self.assertIs(new_connection, connection.connection)
  237. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  238. def test_health_checks_enabled_errors_occurred(self):
  239. self.patch_settings_dict(conn_health_checks=True)
  240. self.assertIsNone(connection.connection)
  241. # Newly created connections are considered healthy without performing
  242. # the health check.
  243. with patch.object(connection, "is_usable", side_effect=AssertionError):
  244. self.run_query()
  245. old_connection = connection.connection
  246. # Simulate errors_occurred.
  247. connection.errors_occurred = True
  248. # Simulate request_started (the connection is healthy).
  249. connection.close_if_unusable_or_obsolete()
  250. # Persistent connections are enabled.
  251. self.assertIs(old_connection, connection.connection)
  252. # No additional health checks after the one in
  253. # close_if_unusable_or_obsolete() are executed during this "request"
  254. # when running queries.
  255. with patch.object(connection, "is_usable", side_effect=AssertionError):
  256. self.run_query()
  257. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  258. def test_health_checks_disabled(self):
  259. self.patch_settings_dict(conn_health_checks=False)
  260. self.assertIsNone(connection.connection)
  261. # Newly created connections are considered healthy without performing
  262. # the health check.
  263. with patch.object(connection, "is_usable", side_effect=AssertionError):
  264. self.run_query()
  265. old_connection = connection.connection
  266. # Simulate request_finished.
  267. connection.close_if_unusable_or_obsolete()
  268. # Persistent connections are enabled (connection is not).
  269. self.assertIs(old_connection, connection.connection)
  270. # Health checks are not performed.
  271. with patch.object(connection, "is_usable", side_effect=AssertionError):
  272. self.run_query()
  273. # Health check wasn't performed and the connection is unchanged.
  274. self.assertIs(old_connection, connection.connection)
  275. self.run_query()
  276. # The connection is unchanged after the next query either during
  277. # the current "request".
  278. self.assertIs(old_connection, connection.connection)
  279. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  280. def test_set_autocommit_health_checks_enabled(self):
  281. self.patch_settings_dict(conn_health_checks=True)
  282. self.assertIsNone(connection.connection)
  283. # Newly created connections are considered healthy without performing
  284. # the health check.
  285. with patch.object(connection, "is_usable", side_effect=AssertionError):
  286. # Simulate outermost atomic block: changing autocommit for
  287. # a connection.
  288. connection.set_autocommit(False)
  289. self.run_query()
  290. connection.commit()
  291. connection.set_autocommit(True)
  292. old_connection = connection.connection
  293. # Simulate request_finished.
  294. connection.close_if_unusable_or_obsolete()
  295. # Persistent connections are enabled.
  296. self.assertIs(old_connection, connection.connection)
  297. # Simulate connection health check failing.
  298. with patch.object(
  299. connection, "is_usable", return_value=False
  300. ) as mocked_is_usable:
  301. # Simulate outermost atomic block: changing autocommit for
  302. # a connection.
  303. connection.set_autocommit(False)
  304. new_connection = connection.connection
  305. self.assertIsNot(new_connection, old_connection)
  306. # Only one health check per "request" is performed, so a query will
  307. # carry on even if the health check fails. This query succeeds
  308. # because the real connection is healthy and only the health check
  309. # failure is mocked.
  310. self.run_query()
  311. connection.commit()
  312. connection.set_autocommit(True)
  313. # The connection is unchanged.
  314. self.assertIs(new_connection, connection.connection)
  315. self.assertEqual(mocked_is_usable.call_count, 1)
  316. # Simulate request_finished.
  317. connection.close_if_unusable_or_obsolete()
  318. # The underlying connection is being reused further with health checks
  319. # succeeding.
  320. connection.set_autocommit(False)
  321. self.run_query()
  322. connection.commit()
  323. connection.set_autocommit(True)
  324. self.assertIs(new_connection, connection.connection)
  325. class MultiDatabaseTests(TestCase):
  326. databases = {"default", "other"}
  327. def test_multi_database_init_connection_state_called_once(self):
  328. for db in self.databases:
  329. with self.subTest(database=db):
  330. with patch.object(connections[db], "commit", return_value=None):
  331. with patch.object(
  332. connections[db],
  333. "check_database_version_supported",
  334. ) as mocked_check_database_version_supported:
  335. connections[db].init_connection_state()
  336. after_first_calls = len(
  337. mocked_check_database_version_supported.mock_calls
  338. )
  339. connections[db].init_connection_state()
  340. self.assertEqual(
  341. len(mocked_check_database_version_supported.mock_calls),
  342. after_first_calls,
  343. )