test_base.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import gc
  2. from unittest.mock import MagicMock, patch
  3. from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
  4. from django.db.backends.base.base import BaseDatabaseWrapper
  5. from django.test import (
  6. SimpleTestCase,
  7. TestCase,
  8. TransactionTestCase,
  9. skipUnlessDBFeature,
  10. )
  11. from django.test.utils import CaptureQueriesContext, override_settings
  12. from ..models import Person, Square
  13. class DatabaseWrapperTests(SimpleTestCase):
  14. def test_repr(self):
  15. conn = connections[DEFAULT_DB_ALIAS]
  16. self.assertEqual(
  17. repr(conn),
  18. f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
  19. )
  20. def test_initialization_class_attributes(self):
  21. """
  22. The "initialization" class attributes like client_class and
  23. creation_class should be set on the class and reflected in the
  24. corresponding instance attributes of the instantiated backend.
  25. """
  26. conn = connections[DEFAULT_DB_ALIAS]
  27. conn_class = type(conn)
  28. attr_names = [
  29. ("client_class", "client"),
  30. ("creation_class", "creation"),
  31. ("features_class", "features"),
  32. ("introspection_class", "introspection"),
  33. ("ops_class", "ops"),
  34. ("validation_class", "validation"),
  35. ]
  36. for class_attr_name, instance_attr_name in attr_names:
  37. class_attr_value = getattr(conn_class, class_attr_name)
  38. self.assertIsNotNone(class_attr_value)
  39. instance_attr_value = getattr(conn, instance_attr_name)
  40. self.assertIsInstance(instance_attr_value, class_attr_value)
  41. def test_initialization_display_name(self):
  42. self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
  43. self.assertNotEqual(connection.display_name, "unknown")
  44. def test_get_database_version(self):
  45. with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
  46. msg = (
  47. "subclasses of BaseDatabaseWrapper may require a "
  48. "get_database_version() method."
  49. )
  50. with self.assertRaisesMessage(NotImplementedError, msg):
  51. BaseDatabaseWrapper().get_database_version()
  52. def test_check_database_version_supported_with_none_as_database_version(self):
  53. with patch.object(connection.features, "minimum_database_version", None):
  54. connection.check_database_version_supported()
  55. def test_release_memory_without_garbage_collection(self):
  56. # Schedule the restore of the garbage collection settings.
  57. self.addCleanup(gc.set_debug, 0)
  58. self.addCleanup(gc.enable)
  59. # Disable automatic garbage collection to control when it's triggered,
  60. # then run a full collection cycle to ensure `gc.garbage` is empty.
  61. gc.disable()
  62. gc.collect()
  63. # The garbage list isn't automatically populated to avoid CPU overhead,
  64. # so debugging needs to be enabled to track all unreachable items and
  65. # have them stored in `gc.garbage`.
  66. gc.set_debug(gc.DEBUG_SAVEALL)
  67. # Create a new connection that will be closed during the test, and also
  68. # ensure that a `DatabaseErrorWrapper` is created for this connection.
  69. test_connection = connection.copy()
  70. with test_connection.wrap_database_errors:
  71. self.assertEqual(test_connection.queries, [])
  72. # Close the connection and remove references to it. This will mark all
  73. # objects related to the connection as garbage to be collected.
  74. test_connection.close()
  75. test_connection = None
  76. # Enforce garbage collection to populate `gc.garbage` for inspection.
  77. gc.collect()
  78. self.assertEqual(gc.garbage, [])
  79. class DatabaseWrapperLoggingTests(TransactionTestCase):
  80. available_apps = ["backends"]
  81. @override_settings(DEBUG=True)
  82. def test_commit_debug_log(self):
  83. conn = connections[DEFAULT_DB_ALIAS]
  84. with CaptureQueriesContext(conn):
  85. with self.assertLogs("django.db.backends", "DEBUG") as cm:
  86. with transaction.atomic():
  87. Person.objects.create(first_name="first", last_name="last")
  88. self.assertGreaterEqual(len(conn.queries_log), 3)
  89. self.assertEqual(conn.queries_log[-3]["sql"], "BEGIN")
  90. self.assertRegex(
  91. cm.output[0],
  92. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  93. rf"BEGIN; args=None; alias={DEFAULT_DB_ALIAS}",
  94. )
  95. self.assertEqual(conn.queries_log[-1]["sql"], "COMMIT")
  96. self.assertRegex(
  97. cm.output[-1],
  98. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  99. rf"COMMIT; args=None; alias={DEFAULT_DB_ALIAS}",
  100. )
  101. @override_settings(DEBUG=True)
  102. def test_rollback_debug_log(self):
  103. conn = connections[DEFAULT_DB_ALIAS]
  104. with CaptureQueriesContext(conn):
  105. with self.assertLogs("django.db.backends", "DEBUG") as cm:
  106. with self.assertRaises(Exception), transaction.atomic():
  107. Person.objects.create(first_name="first", last_name="last")
  108. raise Exception("Force rollback")
  109. self.assertEqual(conn.queries_log[-1]["sql"], "ROLLBACK")
  110. self.assertRegex(
  111. cm.output[-1],
  112. r"DEBUG:django.db.backends:\(\d+.\d{3}\) "
  113. rf"ROLLBACK; args=None; alias={DEFAULT_DB_ALIAS}",
  114. )
  115. def test_no_logs_without_debug(self):
  116. with self.assertNoLogs("django.db.backends", "DEBUG"):
  117. with self.assertRaises(Exception), transaction.atomic():
  118. Person.objects.create(first_name="first", last_name="last")
  119. raise Exception("Force rollback")
  120. conn = connections[DEFAULT_DB_ALIAS]
  121. self.assertEqual(len(conn.queries_log), 0)
  122. class ExecuteWrapperTests(TestCase):
  123. @staticmethod
  124. def call_execute(connection, params=None):
  125. ret_val = "1" if params is None else "%s"
  126. sql = "SELECT " + ret_val + connection.features.bare_select_suffix
  127. with connection.cursor() as cursor:
  128. cursor.execute(sql, params)
  129. def call_executemany(self, connection, params=None):
  130. # executemany() must use an update query. Make sure it does nothing
  131. # by putting a false condition in the WHERE clause.
  132. sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
  133. if params is None:
  134. params = [(i,) for i in range(3)]
  135. with connection.cursor() as cursor:
  136. cursor.executemany(sql, params)
  137. @staticmethod
  138. def mock_wrapper():
  139. return MagicMock(side_effect=lambda execute, *args: execute(*args))
  140. def test_wrapper_invoked(self):
  141. wrapper = self.mock_wrapper()
  142. with connection.execute_wrapper(wrapper):
  143. self.call_execute(connection)
  144. self.assertTrue(wrapper.called)
  145. (_, sql, params, many, context), _ = wrapper.call_args
  146. self.assertIn("SELECT", sql)
  147. self.assertIsNone(params)
  148. self.assertIs(many, False)
  149. self.assertEqual(context["connection"], connection)
  150. def test_wrapper_invoked_many(self):
  151. wrapper = self.mock_wrapper()
  152. with connection.execute_wrapper(wrapper):
  153. self.call_executemany(connection)
  154. self.assertTrue(wrapper.called)
  155. (_, sql, param_list, many, context), _ = wrapper.call_args
  156. self.assertIn("DELETE", sql)
  157. self.assertIsInstance(param_list, (list, tuple))
  158. self.assertIs(many, True)
  159. self.assertEqual(context["connection"], connection)
  160. def test_database_queried(self):
  161. wrapper = self.mock_wrapper()
  162. with connection.execute_wrapper(wrapper):
  163. with connection.cursor() as cursor:
  164. sql = "SELECT 17" + connection.features.bare_select_suffix
  165. cursor.execute(sql)
  166. seventeen = cursor.fetchall()
  167. self.assertEqual(list(seventeen), [(17,)])
  168. self.call_executemany(connection)
  169. def test_nested_wrapper_invoked(self):
  170. outer_wrapper = self.mock_wrapper()
  171. inner_wrapper = self.mock_wrapper()
  172. with (
  173. connection.execute_wrapper(outer_wrapper),
  174. connection.execute_wrapper(inner_wrapper),
  175. ):
  176. self.call_execute(connection)
  177. self.assertEqual(inner_wrapper.call_count, 1)
  178. self.call_executemany(connection)
  179. self.assertEqual(inner_wrapper.call_count, 2)
  180. def test_outer_wrapper_blocks(self):
  181. def blocker(*args):
  182. pass
  183. wrapper = self.mock_wrapper()
  184. c = connection # This alias shortens the next line.
  185. with (
  186. c.execute_wrapper(wrapper),
  187. c.execute_wrapper(blocker),
  188. c.execute_wrapper(wrapper),
  189. ):
  190. with c.cursor() as cursor:
  191. cursor.execute("The database never sees this")
  192. self.assertEqual(wrapper.call_count, 1)
  193. cursor.executemany("The database never sees this %s", [("either",)])
  194. self.assertEqual(wrapper.call_count, 2)
  195. def test_wrapper_gets_sql(self):
  196. wrapper = self.mock_wrapper()
  197. sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
  198. with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
  199. cursor.execute(sql)
  200. (_, reported_sql, _, _, _), _ = wrapper.call_args
  201. self.assertEqual(reported_sql, sql)
  202. def test_wrapper_connection_specific(self):
  203. wrapper = self.mock_wrapper()
  204. with connections["other"].execute_wrapper(wrapper):
  205. self.assertEqual(connections["other"].execute_wrappers, [wrapper])
  206. self.call_execute(connection)
  207. self.assertFalse(wrapper.called)
  208. self.assertEqual(connection.execute_wrappers, [])
  209. self.assertEqual(connections["other"].execute_wrappers, [])
  210. def test_wrapper_debug(self):
  211. def wrap_with_comment(execute, sql, params, many, context):
  212. return execute(f"/* My comment */ {sql}", params, many, context)
  213. with CaptureQueriesContext(connection) as ctx:
  214. with connection.execute_wrapper(wrap_with_comment):
  215. list(Person.objects.all())
  216. last_query = ctx.captured_queries[-1]["sql"]
  217. self.assertTrue(last_query.startswith("/* My comment */"))
  218. class ConnectionHealthChecksTests(SimpleTestCase):
  219. databases = {"default"}
  220. def setUp(self):
  221. # All test cases here need newly configured and created connections.
  222. # Use the default db connection for convenience.
  223. connection.close()
  224. self.addCleanup(connection.close)
  225. def patch_settings_dict(self, conn_health_checks):
  226. self.settings_dict_patcher = patch.dict(
  227. connection.settings_dict,
  228. {
  229. **connection.settings_dict,
  230. "CONN_MAX_AGE": None,
  231. "CONN_HEALTH_CHECKS": conn_health_checks,
  232. },
  233. )
  234. self.settings_dict_patcher.start()
  235. self.addCleanup(self.settings_dict_patcher.stop)
  236. def run_query(self):
  237. with connection.cursor() as cursor:
  238. cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
  239. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  240. def test_health_checks_enabled(self):
  241. self.patch_settings_dict(conn_health_checks=True)
  242. self.assertIsNone(connection.connection)
  243. # Newly created connections are considered healthy without performing
  244. # the health check.
  245. with patch.object(connection, "is_usable", side_effect=AssertionError):
  246. self.run_query()
  247. old_connection = connection.connection
  248. # Simulate request_finished.
  249. connection.close_if_unusable_or_obsolete()
  250. self.assertIs(old_connection, connection.connection)
  251. # Simulate connection health check failing.
  252. with patch.object(
  253. connection, "is_usable", return_value=False
  254. ) as mocked_is_usable:
  255. self.run_query()
  256. new_connection = connection.connection
  257. # A new connection is established.
  258. self.assertIsNot(new_connection, old_connection)
  259. # Only one health check per "request" is performed, so the next
  260. # query will carry on even if the health check fails. Next query
  261. # succeeds because the real connection is healthy and only the
  262. # health check failure is mocked.
  263. self.run_query()
  264. self.assertIs(new_connection, connection.connection)
  265. self.assertEqual(mocked_is_usable.call_count, 1)
  266. # Simulate request_finished.
  267. connection.close_if_unusable_or_obsolete()
  268. # The underlying connection is being reused further with health checks
  269. # succeeding.
  270. self.run_query()
  271. self.run_query()
  272. self.assertIs(new_connection, connection.connection)
  273. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  274. def test_health_checks_enabled_errors_occurred(self):
  275. self.patch_settings_dict(conn_health_checks=True)
  276. self.assertIsNone(connection.connection)
  277. # Newly created connections are considered healthy without performing
  278. # the health check.
  279. with patch.object(connection, "is_usable", side_effect=AssertionError):
  280. self.run_query()
  281. old_connection = connection.connection
  282. # Simulate errors_occurred.
  283. connection.errors_occurred = True
  284. # Simulate request_started (the connection is healthy).
  285. connection.close_if_unusable_or_obsolete()
  286. # Persistent connections are enabled.
  287. self.assertIs(old_connection, connection.connection)
  288. # No additional health checks after the one in
  289. # close_if_unusable_or_obsolete() are executed during this "request"
  290. # when running queries.
  291. with patch.object(connection, "is_usable", side_effect=AssertionError):
  292. self.run_query()
  293. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  294. def test_health_checks_disabled(self):
  295. self.patch_settings_dict(conn_health_checks=False)
  296. self.assertIsNone(connection.connection)
  297. # Newly created connections are considered healthy without performing
  298. # the health check.
  299. with patch.object(connection, "is_usable", side_effect=AssertionError):
  300. self.run_query()
  301. old_connection = connection.connection
  302. # Simulate request_finished.
  303. connection.close_if_unusable_or_obsolete()
  304. # Persistent connections are enabled (connection is not).
  305. self.assertIs(old_connection, connection.connection)
  306. # Health checks are not performed.
  307. with patch.object(connection, "is_usable", side_effect=AssertionError):
  308. self.run_query()
  309. # Health check wasn't performed and the connection is unchanged.
  310. self.assertIs(old_connection, connection.connection)
  311. self.run_query()
  312. # The connection is unchanged after the next query either during
  313. # the current "request".
  314. self.assertIs(old_connection, connection.connection)
  315. @skipUnlessDBFeature("test_db_allows_multiple_connections")
  316. def test_set_autocommit_health_checks_enabled(self):
  317. self.patch_settings_dict(conn_health_checks=True)
  318. self.assertIsNone(connection.connection)
  319. # Newly created connections are considered healthy without performing
  320. # the health check.
  321. with patch.object(connection, "is_usable", side_effect=AssertionError):
  322. # Simulate outermost atomic block: changing autocommit for
  323. # a connection.
  324. connection.set_autocommit(False)
  325. self.run_query()
  326. connection.commit()
  327. connection.set_autocommit(True)
  328. old_connection = connection.connection
  329. # Simulate request_finished.
  330. connection.close_if_unusable_or_obsolete()
  331. # Persistent connections are enabled.
  332. self.assertIs(old_connection, connection.connection)
  333. # Simulate connection health check failing.
  334. with patch.object(
  335. connection, "is_usable", return_value=False
  336. ) as mocked_is_usable:
  337. # Simulate outermost atomic block: changing autocommit for
  338. # a connection.
  339. connection.set_autocommit(False)
  340. new_connection = connection.connection
  341. self.assertIsNot(new_connection, old_connection)
  342. # Only one health check per "request" is performed, so a query will
  343. # carry on even if the health check fails. This query succeeds
  344. # because the real connection is healthy and only the health check
  345. # failure is mocked.
  346. self.run_query()
  347. connection.commit()
  348. connection.set_autocommit(True)
  349. # The connection is unchanged.
  350. self.assertIs(new_connection, connection.connection)
  351. self.assertEqual(mocked_is_usable.call_count, 1)
  352. # Simulate request_finished.
  353. connection.close_if_unusable_or_obsolete()
  354. # The underlying connection is being reused further with health checks
  355. # succeeding.
  356. connection.set_autocommit(False)
  357. self.run_query()
  358. connection.commit()
  359. connection.set_autocommit(True)
  360. self.assertIs(new_connection, connection.connection)
  361. class MultiDatabaseTests(TestCase):
  362. databases = {"default", "other"}
  363. def test_multi_database_init_connection_state_called_once(self):
  364. for db in self.databases:
  365. with self.subTest(database=db):
  366. with patch.object(connections[db], "commit", return_value=None):
  367. with patch.object(
  368. connections[db],
  369. "check_database_version_supported",
  370. ) as mocked_check_database_version_supported:
  371. connections[db].init_connection_state()
  372. after_first_calls = len(
  373. mocked_check_database_version_supported.mock_calls
  374. )
  375. connections[db].init_connection_state()
  376. self.assertEqual(
  377. len(mocked_check_database_version_supported.mock_calls),
  378. after_first_calls,
  379. )