test_operations.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import unittest
  2. from unittest import mock
  3. from migrations.test_base import OperationTestBase
  4. from django.db import NotSupportedError, connection
  5. from django.db.migrations.state import ProjectState
  6. from django.db.models import Index
  7. from django.db.utils import ProgrammingError
  8. from django.test import modify_settings, override_settings, skipUnlessDBFeature
  9. from django.test.utils import CaptureQueriesContext
  10. from . import PostgreSQLTestCase
  11. try:
  12. from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
  13. from django.contrib.postgres.operations import (
  14. AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension,
  15. RemoveCollation, RemoveIndexConcurrently,
  16. )
  17. except ImportError:
  18. pass
  19. @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
  20. @modify_settings(INSTALLED_APPS={'append': 'migrations'})
  21. class AddIndexConcurrentlyTests(OperationTestBase):
  22. app_label = 'test_add_concurrently'
  23. def test_requires_atomic_false(self):
  24. project_state = self.set_up_test_model(self.app_label)
  25. new_state = project_state.clone()
  26. operation = AddIndexConcurrently(
  27. 'Pony',
  28. Index(fields=['pink'], name='pony_pink_idx'),
  29. )
  30. msg = (
  31. 'The AddIndexConcurrently operation cannot be executed inside '
  32. 'a transaction (set atomic = False on the migration).'
  33. )
  34. with self.assertRaisesMessage(NotSupportedError, msg):
  35. with connection.schema_editor(atomic=True) as editor:
  36. operation.database_forwards(self.app_label, editor, project_state, new_state)
  37. def test_add(self):
  38. project_state = self.set_up_test_model(self.app_label, index=False)
  39. table_name = '%s_pony' % self.app_label
  40. index = Index(fields=['pink'], name='pony_pink_idx')
  41. new_state = project_state.clone()
  42. operation = AddIndexConcurrently('Pony', index)
  43. self.assertEqual(
  44. operation.describe(),
  45. 'Concurrently create index pony_pink_idx on field(s) pink of '
  46. 'model Pony'
  47. )
  48. operation.state_forwards(self.app_label, new_state)
  49. self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 1)
  50. self.assertIndexNotExists(table_name, ['pink'])
  51. # Add index.
  52. with connection.schema_editor(atomic=False) as editor:
  53. operation.database_forwards(self.app_label, editor, project_state, new_state)
  54. self.assertIndexExists(table_name, ['pink'])
  55. # Reversal.
  56. with connection.schema_editor(atomic=False) as editor:
  57. operation.database_backwards(self.app_label, editor, new_state, project_state)
  58. self.assertIndexNotExists(table_name, ['pink'])
  59. # Deconstruction.
  60. name, args, kwargs = operation.deconstruct()
  61. self.assertEqual(name, 'AddIndexConcurrently')
  62. self.assertEqual(args, [])
  63. self.assertEqual(kwargs, {'model_name': 'Pony', 'index': index})
  64. def test_add_other_index_type(self):
  65. project_state = self.set_up_test_model(self.app_label, index=False)
  66. table_name = '%s_pony' % self.app_label
  67. new_state = project_state.clone()
  68. operation = AddIndexConcurrently(
  69. 'Pony',
  70. BrinIndex(fields=['pink'], name='pony_pink_brin_idx'),
  71. )
  72. self.assertIndexNotExists(table_name, ['pink'])
  73. # Add index.
  74. with connection.schema_editor(atomic=False) as editor:
  75. operation.database_forwards(self.app_label, editor, project_state, new_state)
  76. self.assertIndexExists(table_name, ['pink'], index_type='brin')
  77. # Reversal.
  78. with connection.schema_editor(atomic=False) as editor:
  79. operation.database_backwards(self.app_label, editor, new_state, project_state)
  80. self.assertIndexNotExists(table_name, ['pink'])
  81. def test_add_with_options(self):
  82. project_state = self.set_up_test_model(self.app_label, index=False)
  83. table_name = '%s_pony' % self.app_label
  84. new_state = project_state.clone()
  85. index = BTreeIndex(fields=['pink'], name='pony_pink_btree_idx', fillfactor=70)
  86. operation = AddIndexConcurrently('Pony', index)
  87. self.assertIndexNotExists(table_name, ['pink'])
  88. # Add index.
  89. with connection.schema_editor(atomic=False) as editor:
  90. operation.database_forwards(self.app_label, editor, project_state, new_state)
  91. self.assertIndexExists(table_name, ['pink'], index_type='btree')
  92. # Reversal.
  93. with connection.schema_editor(atomic=False) as editor:
  94. operation.database_backwards(self.app_label, editor, new_state, project_state)
  95. self.assertIndexNotExists(table_name, ['pink'])
  96. @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
  97. @modify_settings(INSTALLED_APPS={'append': 'migrations'})
  98. class RemoveIndexConcurrentlyTests(OperationTestBase):
  99. app_label = 'test_rm_concurrently'
  100. def test_requires_atomic_false(self):
  101. project_state = self.set_up_test_model(self.app_label, index=True)
  102. new_state = project_state.clone()
  103. operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
  104. msg = (
  105. 'The RemoveIndexConcurrently operation cannot be executed inside '
  106. 'a transaction (set atomic = False on the migration).'
  107. )
  108. with self.assertRaisesMessage(NotSupportedError, msg):
  109. with connection.schema_editor(atomic=True) as editor:
  110. operation.database_forwards(self.app_label, editor, project_state, new_state)
  111. def test_remove(self):
  112. project_state = self.set_up_test_model(self.app_label, index=True)
  113. table_name = '%s_pony' % self.app_label
  114. self.assertTableExists(table_name)
  115. new_state = project_state.clone()
  116. operation = RemoveIndexConcurrently('Pony', 'pony_pink_idx')
  117. self.assertEqual(
  118. operation.describe(),
  119. 'Concurrently remove index pony_pink_idx from Pony',
  120. )
  121. operation.state_forwards(self.app_label, new_state)
  122. self.assertEqual(len(new_state.models[self.app_label, 'pony'].options['indexes']), 0)
  123. self.assertIndexExists(table_name, ['pink'])
  124. # Remove index.
  125. with connection.schema_editor(atomic=False) as editor:
  126. operation.database_forwards(self.app_label, editor, project_state, new_state)
  127. self.assertIndexNotExists(table_name, ['pink'])
  128. # Reversal.
  129. with connection.schema_editor(atomic=False) as editor:
  130. operation.database_backwards(self.app_label, editor, new_state, project_state)
  131. self.assertIndexExists(table_name, ['pink'])
  132. # Deconstruction.
  133. name, args, kwargs = operation.deconstruct()
  134. self.assertEqual(name, 'RemoveIndexConcurrently')
  135. self.assertEqual(args, [])
  136. self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
  137. class NoMigrationRouter():
  138. def allow_migrate(self, db, app_label, **hints):
  139. return False
  140. @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
  141. class CreateExtensionTests(PostgreSQLTestCase):
  142. app_label = 'test_allow_create_extention'
  143. @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
  144. def test_no_allow_migrate(self):
  145. operation = CreateExtension('tablefunc')
  146. project_state = ProjectState()
  147. new_state = project_state.clone()
  148. # Don't create an extension.
  149. with CaptureQueriesContext(connection) as captured_queries:
  150. with connection.schema_editor(atomic=False) as editor:
  151. operation.database_forwards(self.app_label, editor, project_state, new_state)
  152. self.assertEqual(len(captured_queries), 0)
  153. # Reversal.
  154. with CaptureQueriesContext(connection) as captured_queries:
  155. with connection.schema_editor(atomic=False) as editor:
  156. operation.database_backwards(self.app_label, editor, new_state, project_state)
  157. self.assertEqual(len(captured_queries), 0)
  158. def test_allow_migrate(self):
  159. operation = CreateExtension('tablefunc')
  160. self.assertEqual(operation.migration_name_fragment, 'create_extension_tablefunc')
  161. project_state = ProjectState()
  162. new_state = project_state.clone()
  163. # Create an extension.
  164. with CaptureQueriesContext(connection) as captured_queries:
  165. with connection.schema_editor(atomic=False) as editor:
  166. operation.database_forwards(self.app_label, editor, project_state, new_state)
  167. self.assertEqual(len(captured_queries), 4)
  168. self.assertIn('CREATE EXTENSION IF NOT EXISTS', captured_queries[1]['sql'])
  169. # Reversal.
  170. with CaptureQueriesContext(connection) as captured_queries:
  171. with connection.schema_editor(atomic=False) as editor:
  172. operation.database_backwards(self.app_label, editor, new_state, project_state)
  173. self.assertEqual(len(captured_queries), 2)
  174. self.assertIn('DROP EXTENSION IF EXISTS', captured_queries[1]['sql'])
  175. def test_create_existing_extension(self):
  176. operation = BloomExtension()
  177. self.assertEqual(operation.migration_name_fragment, 'create_extension_bloom')
  178. project_state = ProjectState()
  179. new_state = project_state.clone()
  180. # Don't create an existing extension.
  181. with CaptureQueriesContext(connection) as captured_queries:
  182. with connection.schema_editor(atomic=False) as editor:
  183. operation.database_forwards(self.app_label, editor, project_state, new_state)
  184. self.assertEqual(len(captured_queries), 3)
  185. self.assertIn('SELECT', captured_queries[0]['sql'])
  186. def test_drop_nonexistent_extension(self):
  187. operation = CreateExtension('tablefunc')
  188. project_state = ProjectState()
  189. new_state = project_state.clone()
  190. # Don't drop a nonexistent extension.
  191. with CaptureQueriesContext(connection) as captured_queries:
  192. with connection.schema_editor(atomic=False) as editor:
  193. operation.database_backwards(self.app_label, editor, project_state, new_state)
  194. self.assertEqual(len(captured_queries), 1)
  195. self.assertIn('SELECT', captured_queries[0]['sql'])
  196. @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
  197. class CreateCollationTests(PostgreSQLTestCase):
  198. app_label = 'test_allow_create_collation'
  199. @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
  200. def test_no_allow_migrate(self):
  201. operation = CreateCollation('C_test', locale='C')
  202. project_state = ProjectState()
  203. new_state = project_state.clone()
  204. # Don't create a collation.
  205. with CaptureQueriesContext(connection) as captured_queries:
  206. with connection.schema_editor(atomic=False) as editor:
  207. operation.database_forwards(self.app_label, editor, project_state, new_state)
  208. self.assertEqual(len(captured_queries), 0)
  209. # Reversal.
  210. with CaptureQueriesContext(connection) as captured_queries:
  211. with connection.schema_editor(atomic=False) as editor:
  212. operation.database_backwards(self.app_label, editor, new_state, project_state)
  213. self.assertEqual(len(captured_queries), 0)
  214. def test_create(self):
  215. operation = CreateCollation('C_test', locale='C')
  216. self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test')
  217. self.assertEqual(operation.describe(), 'Create collation C_test')
  218. project_state = ProjectState()
  219. new_state = project_state.clone()
  220. # Create a collation.
  221. with CaptureQueriesContext(connection) as captured_queries:
  222. with connection.schema_editor(atomic=False) as editor:
  223. operation.database_forwards(self.app_label, editor, project_state, new_state)
  224. self.assertEqual(len(captured_queries), 1)
  225. self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
  226. # Creating the same collation raises an exception.
  227. with self.assertRaisesMessage(ProgrammingError, 'already exists'):
  228. with connection.schema_editor(atomic=True) as editor:
  229. operation.database_forwards(self.app_label, editor, project_state, new_state)
  230. # Reversal.
  231. with CaptureQueriesContext(connection) as captured_queries:
  232. with connection.schema_editor(atomic=False) as editor:
  233. operation.database_backwards(self.app_label, editor, new_state, project_state)
  234. self.assertEqual(len(captured_queries), 1)
  235. self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
  236. # Deconstruction.
  237. name, args, kwargs = operation.deconstruct()
  238. self.assertEqual(name, 'CreateCollation')
  239. self.assertEqual(args, [])
  240. self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
  241. @skipUnlessDBFeature('supports_non_deterministic_collations')
  242. def test_create_non_deterministic_collation(self):
  243. operation = CreateCollation(
  244. 'case_insensitive_test',
  245. 'und-u-ks-level2',
  246. provider='icu',
  247. deterministic=False,
  248. )
  249. project_state = ProjectState()
  250. new_state = project_state.clone()
  251. # Create a collation.
  252. with CaptureQueriesContext(connection) as captured_queries:
  253. with connection.schema_editor(atomic=False) as editor:
  254. operation.database_forwards(self.app_label, editor, project_state, new_state)
  255. self.assertEqual(len(captured_queries), 1)
  256. self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
  257. # Reversal.
  258. with CaptureQueriesContext(connection) as captured_queries:
  259. with connection.schema_editor(atomic=False) as editor:
  260. operation.database_backwards(self.app_label, editor, new_state, project_state)
  261. self.assertEqual(len(captured_queries), 1)
  262. self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
  263. # Deconstruction.
  264. name, args, kwargs = operation.deconstruct()
  265. self.assertEqual(name, 'CreateCollation')
  266. self.assertEqual(args, [])
  267. self.assertEqual(kwargs, {
  268. 'name': 'case_insensitive_test',
  269. 'locale': 'und-u-ks-level2',
  270. 'provider': 'icu',
  271. 'deterministic': False,
  272. })
  273. @skipUnlessDBFeature('supports_alternate_collation_providers')
  274. def test_create_collation_alternate_provider(self):
  275. operation = CreateCollation(
  276. 'german_phonebook_test',
  277. provider='icu',
  278. locale='de-u-co-phonebk',
  279. )
  280. project_state = ProjectState()
  281. new_state = project_state.clone()
  282. # Create an collation.
  283. with CaptureQueriesContext(connection) as captured_queries:
  284. with connection.schema_editor(atomic=False) as editor:
  285. operation.database_forwards(self.app_label, editor, project_state, new_state)
  286. self.assertEqual(len(captured_queries), 1)
  287. self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
  288. # Reversal.
  289. with CaptureQueriesContext(connection) as captured_queries:
  290. with connection.schema_editor(atomic=False) as editor:
  291. operation.database_backwards(self.app_label, editor, new_state, project_state)
  292. self.assertEqual(len(captured_queries), 1)
  293. self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
  294. def test_nondeterministic_collation_not_supported(self):
  295. operation = CreateCollation(
  296. 'case_insensitive_test',
  297. provider='icu',
  298. locale='und-u-ks-level2',
  299. deterministic=False,
  300. )
  301. project_state = ProjectState()
  302. new_state = project_state.clone()
  303. msg = 'Non-deterministic collations require PostgreSQL 12+.'
  304. with connection.schema_editor(atomic=False) as editor:
  305. with mock.patch(
  306. 'django.db.backends.postgresql.features.DatabaseFeatures.'
  307. 'supports_non_deterministic_collations',
  308. False,
  309. ):
  310. with self.assertRaisesMessage(NotSupportedError, msg):
  311. operation.database_forwards(self.app_label, editor, project_state, new_state)
  312. def test_collation_with_icu_provider_raises_error(self):
  313. operation = CreateCollation(
  314. 'german_phonebook',
  315. provider='icu',
  316. locale='de-u-co-phonebk',
  317. )
  318. project_state = ProjectState()
  319. new_state = project_state.clone()
  320. msg = 'Non-libc providers require PostgreSQL 10+.'
  321. with connection.schema_editor(atomic=False) as editor:
  322. with mock.patch(
  323. 'django.db.backends.postgresql.features.DatabaseFeatures.'
  324. 'supports_alternate_collation_providers',
  325. False,
  326. ):
  327. with self.assertRaisesMessage(NotSupportedError, msg):
  328. operation.database_forwards(self.app_label, editor, project_state, new_state)
  329. @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
  330. class RemoveCollationTests(PostgreSQLTestCase):
  331. app_label = 'test_allow_remove_collation'
  332. @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
  333. def test_no_allow_migrate(self):
  334. operation = RemoveCollation('C_test', locale='C')
  335. project_state = ProjectState()
  336. new_state = project_state.clone()
  337. # Don't create a collation.
  338. with CaptureQueriesContext(connection) as captured_queries:
  339. with connection.schema_editor(atomic=False) as editor:
  340. operation.database_forwards(self.app_label, editor, project_state, new_state)
  341. self.assertEqual(len(captured_queries), 0)
  342. # Reversal.
  343. with CaptureQueriesContext(connection) as captured_queries:
  344. with connection.schema_editor(atomic=False) as editor:
  345. operation.database_backwards(self.app_label, editor, new_state, project_state)
  346. self.assertEqual(len(captured_queries), 0)
  347. def test_remove(self):
  348. operation = CreateCollation('C_test', locale='C')
  349. project_state = ProjectState()
  350. new_state = project_state.clone()
  351. with connection.schema_editor(atomic=False) as editor:
  352. operation.database_forwards(self.app_label, editor, project_state, new_state)
  353. operation = RemoveCollation('C_test', locale='C')
  354. self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test')
  355. self.assertEqual(operation.describe(), 'Remove collation C_test')
  356. project_state = ProjectState()
  357. new_state = project_state.clone()
  358. # Remove a collation.
  359. with CaptureQueriesContext(connection) as captured_queries:
  360. with connection.schema_editor(atomic=False) as editor:
  361. operation.database_forwards(self.app_label, editor, project_state, new_state)
  362. self.assertEqual(len(captured_queries), 1)
  363. self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
  364. # Removing a nonexistent collation raises an exception.
  365. with self.assertRaisesMessage(ProgrammingError, 'does not exist'):
  366. with connection.schema_editor(atomic=True) as editor:
  367. operation.database_forwards(self.app_label, editor, project_state, new_state)
  368. # Reversal.
  369. with CaptureQueriesContext(connection) as captured_queries:
  370. with connection.schema_editor(atomic=False) as editor:
  371. operation.database_backwards(self.app_label, editor, new_state, project_state)
  372. self.assertEqual(len(captured_queries), 1)
  373. self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
  374. # Deconstruction.
  375. name, args, kwargs = operation.deconstruct()
  376. self.assertEqual(name, 'RemoveCollation')
  377. self.assertEqual(args, [])
  378. self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})