test_creation.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import copy
  2. import datetime
  3. import os
  4. from unittest import mock
  5. from django.db import DEFAULT_DB_ALIAS, connection, connections
  6. from django.db.backends.base.creation import TEST_DATABASE_PREFIX, BaseDatabaseCreation
  7. from django.test import SimpleTestCase, TransactionTestCase
  8. from django.test.utils import override_settings
  9. from django.utils.deprecation import RemovedInDjango70Warning
  10. from ..models import (
  11. CircularA,
  12. CircularB,
  13. Object,
  14. ObjectReference,
  15. ObjectSelfReference,
  16. SchoolBus,
  17. SchoolClass,
  18. )
  19. def get_connection_copy():
  20. # Get a copy of the default connection. (Can't use django.db.connection
  21. # because it'll modify the default connection itself.)
  22. test_connection = copy.copy(connections[DEFAULT_DB_ALIAS])
  23. test_connection.settings_dict = copy.deepcopy(
  24. connections[DEFAULT_DB_ALIAS].settings_dict
  25. )
  26. return test_connection
  27. class TestDbSignatureTests(SimpleTestCase):
  28. def test_default_name(self):
  29. # A test db name isn't set.
  30. prod_name = "hodor"
  31. test_connection = get_connection_copy()
  32. test_connection.settings_dict["NAME"] = prod_name
  33. test_connection.settings_dict["TEST"] = {"NAME": None}
  34. signature = BaseDatabaseCreation(test_connection).test_db_signature()
  35. self.assertEqual(signature[3], TEST_DATABASE_PREFIX + prod_name)
  36. def test_custom_test_name(self):
  37. # A regular test db name is set.
  38. test_name = "hodor"
  39. test_connection = get_connection_copy()
  40. test_connection.settings_dict["TEST"] = {"NAME": test_name}
  41. signature = BaseDatabaseCreation(test_connection).test_db_signature()
  42. self.assertEqual(signature[3], test_name)
  43. def test_custom_test_name_with_test_prefix(self):
  44. # A test db name prefixed with TEST_DATABASE_PREFIX is set.
  45. test_name = TEST_DATABASE_PREFIX + "hodor"
  46. test_connection = get_connection_copy()
  47. test_connection.settings_dict["TEST"] = {"NAME": test_name}
  48. signature = BaseDatabaseCreation(test_connection).test_db_signature()
  49. self.assertEqual(signature[3], test_name)
  50. @override_settings(INSTALLED_APPS=["backends.base.app_unmigrated"])
  51. @mock.patch.object(connection, "ensure_connection")
  52. @mock.patch.object(connection, "prepare_database")
  53. @mock.patch(
  54. "django.db.migrations.recorder.MigrationRecorder.has_table", return_value=False
  55. )
  56. @mock.patch("django.core.management.commands.migrate.Command.sync_apps")
  57. class TestDbCreationTests(SimpleTestCase):
  58. available_apps = ["backends.base.app_unmigrated"]
  59. @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
  60. def test_migrate_test_setting_false(
  61. self, mocked_migrate, mocked_sync_apps, *mocked_objects
  62. ):
  63. test_connection = get_connection_copy()
  64. test_connection.settings_dict["TEST"]["MIGRATE"] = False
  65. creation = test_connection.creation_class(test_connection)
  66. if connection.vendor == "oracle":
  67. # Don't close connection on Oracle.
  68. creation.connection.close = mock.Mock()
  69. old_database_name = test_connection.settings_dict["NAME"]
  70. try:
  71. with mock.patch.object(creation, "_create_test_db"):
  72. creation.create_test_db(verbosity=0, autoclobber=True)
  73. # Migrations don't run.
  74. mocked_migrate.assert_called()
  75. args, kwargs = mocked_migrate.call_args
  76. self.assertEqual(args, ([],))
  77. self.assertEqual(kwargs["plan"], [])
  78. # App is synced.
  79. mocked_sync_apps.assert_called()
  80. mocked_args, _ = mocked_sync_apps.call_args
  81. self.assertEqual(mocked_args[1], {"app_unmigrated"})
  82. finally:
  83. with mock.patch.object(creation, "_destroy_test_db"):
  84. creation.destroy_test_db(old_database_name, verbosity=0)
  85. @mock.patch("django.db.migrations.executor.MigrationRecorder.ensure_schema")
  86. def test_migrate_test_setting_false_ensure_schema(
  87. self,
  88. mocked_ensure_schema,
  89. mocked_sync_apps,
  90. *mocked_objects,
  91. ):
  92. test_connection = get_connection_copy()
  93. test_connection.settings_dict["TEST"]["MIGRATE"] = False
  94. creation = test_connection.creation_class(test_connection)
  95. if connection.vendor == "oracle":
  96. # Don't close connection on Oracle.
  97. creation.connection.close = mock.Mock()
  98. old_database_name = test_connection.settings_dict["NAME"]
  99. try:
  100. with mock.patch.object(creation, "_create_test_db"):
  101. creation.create_test_db(verbosity=0, autoclobber=True)
  102. # The django_migrations table is not created.
  103. mocked_ensure_schema.assert_not_called()
  104. # App is synced.
  105. mocked_sync_apps.assert_called()
  106. mocked_args, _ = mocked_sync_apps.call_args
  107. self.assertEqual(mocked_args[1], {"app_unmigrated"})
  108. finally:
  109. with mock.patch.object(creation, "_destroy_test_db"):
  110. creation.destroy_test_db(old_database_name, verbosity=0)
  111. @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
  112. def test_migrate_test_setting_true(
  113. self, mocked_migrate, mocked_sync_apps, *mocked_objects
  114. ):
  115. test_connection = get_connection_copy()
  116. test_connection.settings_dict["TEST"]["MIGRATE"] = True
  117. creation = test_connection.creation_class(test_connection)
  118. if connection.vendor == "oracle":
  119. # Don't close connection on Oracle.
  120. creation.connection.close = mock.Mock()
  121. old_database_name = test_connection.settings_dict["NAME"]
  122. try:
  123. with mock.patch.object(creation, "_create_test_db"):
  124. creation.create_test_db(verbosity=0, autoclobber=True)
  125. # Migrations run.
  126. mocked_migrate.assert_called()
  127. args, kwargs = mocked_migrate.call_args
  128. self.assertEqual(args, ([("app_unmigrated", "0001_initial")],))
  129. self.assertEqual(len(kwargs["plan"]), 1)
  130. # App is not synced.
  131. mocked_sync_apps.assert_not_called()
  132. finally:
  133. with mock.patch.object(creation, "_destroy_test_db"):
  134. creation.destroy_test_db(old_database_name, verbosity=0)
  135. @mock.patch.dict(os.environ, {"RUNNING_DJANGOS_TEST_SUITE": ""})
  136. @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
  137. @mock.patch.object(BaseDatabaseCreation, "mark_expected_failures_and_skips")
  138. def test_mark_expected_failures_and_skips_call(
  139. self, mark_expected_failures_and_skips, *mocked_objects
  140. ):
  141. """
  142. mark_expected_failures_and_skips() isn't called unless
  143. RUNNING_DJANGOS_TEST_SUITE is 'true'.
  144. """
  145. test_connection = get_connection_copy()
  146. creation = test_connection.creation_class(test_connection)
  147. if connection.vendor == "oracle":
  148. # Don't close connection on Oracle.
  149. creation.connection.close = mock.Mock()
  150. old_database_name = test_connection.settings_dict["NAME"]
  151. try:
  152. with mock.patch.object(creation, "_create_test_db"):
  153. creation.create_test_db(verbosity=0, autoclobber=True)
  154. self.assertIs(mark_expected_failures_and_skips.called, False)
  155. finally:
  156. with mock.patch.object(creation, "_destroy_test_db"):
  157. creation.destroy_test_db(old_database_name, verbosity=0)
  158. @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
  159. @mock.patch.object(BaseDatabaseCreation, "serialize_db_to_string")
  160. def test_serialize_deprecation(self, serialize_db_to_string, *mocked_objects):
  161. test_connection = get_connection_copy()
  162. creation = test_connection.creation_class(test_connection)
  163. if connection.vendor == "oracle":
  164. # Don't close connection on Oracle.
  165. creation.connection.close = mock.Mock()
  166. old_database_name = test_connection.settings_dict["NAME"]
  167. msg = (
  168. "DatabaseCreation.create_test_db(serialize) is deprecated. Call "
  169. "DatabaseCreation.serialize_test_db() once all test databases are set up "
  170. "instead if you need fixtures persistence between tests."
  171. )
  172. try:
  173. with (
  174. self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx,
  175. mock.patch.object(creation, "_create_test_db"),
  176. ):
  177. creation.create_test_db(verbosity=0, serialize=True)
  178. self.assertEqual(ctx.filename, __file__)
  179. serialize_db_to_string.assert_called_once_with()
  180. finally:
  181. with mock.patch.object(creation, "_destroy_test_db"):
  182. creation.destroy_test_db(old_database_name, verbosity=0)
  183. # Now with `serialize` False.
  184. serialize_db_to_string.reset_mock()
  185. try:
  186. with (
  187. self.assertWarnsMessage(RemovedInDjango70Warning, msg) as ctx,
  188. mock.patch.object(creation, "_create_test_db"),
  189. ):
  190. creation.create_test_db(verbosity=0, serialize=False)
  191. self.assertEqual(ctx.filename, __file__)
  192. serialize_db_to_string.assert_not_called()
  193. finally:
  194. with mock.patch.object(creation, "_destroy_test_db"):
  195. creation.destroy_test_db(old_database_name, verbosity=0)
  196. class TestDeserializeDbFromString(TransactionTestCase):
  197. available_apps = ["backends"]
  198. def test_circular_reference(self):
  199. # deserialize_db_from_string() handles circular references.
  200. data = """
  201. [
  202. {
  203. "model": "backends.object",
  204. "pk": 1,
  205. "fields": {"obj_ref": 1, "related_objects": []}
  206. },
  207. {
  208. "model": "backends.objectreference",
  209. "pk": 1,
  210. "fields": {"obj": 1}
  211. }
  212. ]
  213. """
  214. connection.creation.deserialize_db_from_string(data)
  215. obj = Object.objects.get()
  216. obj_ref = ObjectReference.objects.get()
  217. self.assertEqual(obj.obj_ref, obj_ref)
  218. self.assertEqual(obj_ref.obj, obj)
  219. def test_self_reference(self):
  220. # serialize_db_to_string() and deserialize_db_from_string() handles
  221. # self references.
  222. obj_1 = ObjectSelfReference.objects.create(key="X")
  223. obj_2 = ObjectSelfReference.objects.create(key="Y", obj=obj_1)
  224. obj_1.obj = obj_2
  225. obj_1.save()
  226. # Serialize objects.
  227. with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
  228. # serialize_db_to_string() serializes only migrated apps, so mark
  229. # the backends app as migrated.
  230. loader_instance = loader.return_value
  231. loader_instance.migrated_apps = {"backends"}
  232. data = connection.creation.serialize_db_to_string()
  233. ObjectSelfReference.objects.all().delete()
  234. # Deserialize objects.
  235. connection.creation.deserialize_db_from_string(data)
  236. obj_1 = ObjectSelfReference.objects.get(key="X")
  237. obj_2 = ObjectSelfReference.objects.get(key="Y")
  238. self.assertEqual(obj_1.obj, obj_2)
  239. self.assertEqual(obj_2.obj, obj_1)
  240. def test_circular_reference_with_natural_key(self):
  241. # serialize_db_to_string() and deserialize_db_from_string() handles
  242. # circular references for models with natural keys.
  243. obj_a = CircularA.objects.create(key="A")
  244. obj_b = CircularB.objects.create(key="B", obj=obj_a)
  245. obj_a.obj = obj_b
  246. obj_a.save()
  247. # Serialize objects.
  248. with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
  249. # serialize_db_to_string() serializes only migrated apps, so mark
  250. # the backends app as migrated.
  251. loader_instance = loader.return_value
  252. loader_instance.migrated_apps = {"backends"}
  253. data = connection.creation.serialize_db_to_string()
  254. CircularA.objects.all().delete()
  255. CircularB.objects.all().delete()
  256. # Deserialize objects.
  257. connection.creation.deserialize_db_from_string(data)
  258. obj_a = CircularA.objects.get()
  259. obj_b = CircularB.objects.get()
  260. self.assertEqual(obj_a.obj, obj_b)
  261. self.assertEqual(obj_b.obj, obj_a)
  262. def test_serialize_db_to_string_base_manager(self):
  263. SchoolClass.objects.create(year=1000, last_updated=datetime.datetime.now())
  264. with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
  265. # serialize_db_to_string() serializes only migrated apps, so mark
  266. # the backends app as migrated.
  267. loader_instance = loader.return_value
  268. loader_instance.migrated_apps = {"backends"}
  269. data = connection.creation.serialize_db_to_string()
  270. self.assertIn('"model": "backends.schoolclass"', data)
  271. self.assertIn('"year": 1000', data)
  272. def test_serialize_db_to_string_base_manager_with_prefetch_related(self):
  273. sclass = SchoolClass.objects.create(
  274. year=2000, last_updated=datetime.datetime.now()
  275. )
  276. bus = SchoolBus.objects.create(number=1)
  277. bus.schoolclasses.add(sclass)
  278. with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
  279. # serialize_db_to_string() serializes only migrated apps, so mark
  280. # the backends app as migrated.
  281. loader_instance = loader.return_value
  282. loader_instance.migrated_apps = {"backends"}
  283. data = connection.creation.serialize_db_to_string()
  284. self.assertIn('"model": "backends.schoolbus"', data)
  285. self.assertIn('"model": "backends.schoolclass"', data)
  286. self.assertIn(f'"schoolclasses": [{sclass.pk}]', data)
  287. class SkipTestClass:
  288. def skip_function(self):
  289. pass
  290. def skip_test_function():
  291. pass
  292. def expected_failure_test_function():
  293. pass
  294. class TestMarkTests(SimpleTestCase):
  295. def test_mark_expected_failures_and_skips(self):
  296. test_connection = get_connection_copy()
  297. creation = BaseDatabaseCreation(test_connection)
  298. creation.connection.features.django_test_expected_failures = {
  299. "backends.base.test_creation.expected_failure_test_function",
  300. }
  301. creation.connection.features.django_test_skips = {
  302. "skip test class": {
  303. "backends.base.test_creation.SkipTestClass",
  304. },
  305. "skip test function": {
  306. "backends.base.test_creation.skip_test_function",
  307. },
  308. }
  309. creation.mark_expected_failures_and_skips()
  310. self.assertIs(
  311. expected_failure_test_function.__unittest_expecting_failure__,
  312. True,
  313. )
  314. self.assertIs(SkipTestClass.__unittest_skip__, True)
  315. self.assertEqual(
  316. SkipTestClass.__unittest_skip_why__,
  317. "skip test class",
  318. )
  319. self.assertIs(skip_test_function.__unittest_skip__, True)
  320. self.assertEqual(
  321. skip_test_function.__unittest_skip_why__,
  322. "skip test function",
  323. )