tests.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import threading
  2. import time
  3. from unittest import mock
  4. from multiple_database.routers import TestRouter
  5. from django.core.exceptions import FieldError
  6. from django.db import (
  7. DatabaseError, NotSupportedError, connection, connections, router,
  8. transaction,
  9. )
  10. from django.test import (
  11. TransactionTestCase, override_settings, skipIfDBFeature,
  12. skipUnlessDBFeature,
  13. )
  14. from django.test.utils import CaptureQueriesContext
  15. from .models import City, Country, Person
  16. class SelectForUpdateTests(TransactionTestCase):
  17. available_apps = ['select_for_update']
  18. def setUp(self):
  19. # This is executed in autocommit mode so that code in
  20. # run_select_for_update can see this data.
  21. self.country1 = Country.objects.create(name='Belgium')
  22. self.country2 = Country.objects.create(name='France')
  23. self.city1 = City.objects.create(name='Liberchies', country=self.country1)
  24. self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
  25. self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
  26. # We need another database connection in transaction to test that one
  27. # connection issuing a SELECT ... FOR UPDATE will block.
  28. self.new_connection = connection.copy()
  29. def tearDown(self):
  30. try:
  31. self.end_blocking_transaction()
  32. except (DatabaseError, AttributeError):
  33. pass
  34. self.new_connection.close()
  35. def start_blocking_transaction(self):
  36. self.new_connection.set_autocommit(False)
  37. # Start a blocking transaction. At some point,
  38. # end_blocking_transaction() should be called.
  39. self.cursor = self.new_connection.cursor()
  40. sql = 'SELECT * FROM %(db_table)s %(for_update)s;' % {
  41. 'db_table': Person._meta.db_table,
  42. 'for_update': self.new_connection.ops.for_update_sql(),
  43. }
  44. self.cursor.execute(sql, ())
  45. self.cursor.fetchone()
  46. def end_blocking_transaction(self):
  47. # Roll back the blocking transaction.
  48. self.new_connection.rollback()
  49. self.new_connection.set_autocommit(True)
  50. def has_for_update_sql(self, queries, **kwargs):
  51. # Examine the SQL that was executed to determine whether it
  52. # contains the 'SELECT..FOR UPDATE' stanza.
  53. for_update_sql = connection.ops.for_update_sql(**kwargs)
  54. return any(for_update_sql in query['sql'] for query in queries)
  55. @skipUnlessDBFeature('has_select_for_update')
  56. def test_for_update_sql_generated(self):
  57. """
  58. The backend's FOR UPDATE variant appears in
  59. generated SQL when select_for_update is invoked.
  60. """
  61. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  62. list(Person.objects.all().select_for_update())
  63. self.assertTrue(self.has_for_update_sql(ctx.captured_queries))
  64. @skipUnlessDBFeature('has_select_for_update_nowait')
  65. def test_for_update_sql_generated_nowait(self):
  66. """
  67. The backend's FOR UPDATE NOWAIT variant appears in
  68. generated SQL when select_for_update is invoked.
  69. """
  70. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  71. list(Person.objects.all().select_for_update(nowait=True))
  72. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, nowait=True))
  73. @skipUnlessDBFeature('has_select_for_update_skip_locked')
  74. def test_for_update_sql_generated_skip_locked(self):
  75. """
  76. The backend's FOR UPDATE SKIP LOCKED variant appears in
  77. generated SQL when select_for_update is invoked.
  78. """
  79. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  80. list(Person.objects.all().select_for_update(skip_locked=True))
  81. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
  82. @skipUnlessDBFeature('has_select_for_update_of')
  83. def test_for_update_sql_generated_of(self):
  84. """
  85. The backend's FOR UPDATE OF variant appears in the generated SQL when
  86. select_for_update() is invoked.
  87. """
  88. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  89. list(Person.objects.select_related(
  90. 'born__country',
  91. ).select_for_update(
  92. of=('born__country',),
  93. ).select_for_update(
  94. of=('self', 'born__country')
  95. ))
  96. features = connections['default'].features
  97. if features.select_for_update_of_column:
  98. expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"']
  99. else:
  100. expected = ['"select_for_update_person"', '"select_for_update_country"']
  101. if features.uppercases_column_names:
  102. expected = [value.upper() for value in expected]
  103. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
  104. @skipUnlessDBFeature('has_select_for_update_nowait')
  105. def test_nowait_raises_error_on_block(self):
  106. """
  107. If nowait is specified, we expect an error to be raised rather
  108. than blocking.
  109. """
  110. self.start_blocking_transaction()
  111. status = []
  112. thread = threading.Thread(
  113. target=self.run_select_for_update,
  114. args=(status,),
  115. kwargs={'nowait': True},
  116. )
  117. thread.start()
  118. time.sleep(1)
  119. thread.join()
  120. self.end_blocking_transaction()
  121. self.assertIsInstance(status[-1], DatabaseError)
  122. @skipUnlessDBFeature('has_select_for_update_skip_locked')
  123. def test_skip_locked_skips_locked_rows(self):
  124. """
  125. If skip_locked is specified, the locked row is skipped resulting in
  126. Person.DoesNotExist.
  127. """
  128. self.start_blocking_transaction()
  129. status = []
  130. thread = threading.Thread(
  131. target=self.run_select_for_update,
  132. args=(status,),
  133. kwargs={'skip_locked': True},
  134. )
  135. thread.start()
  136. time.sleep(1)
  137. thread.join()
  138. self.end_blocking_transaction()
  139. self.assertIsInstance(status[-1], Person.DoesNotExist)
  140. @skipIfDBFeature('has_select_for_update_nowait')
  141. @skipUnlessDBFeature('has_select_for_update')
  142. def test_unsupported_nowait_raises_error(self):
  143. """
  144. NotSupportedError is raised if a SELECT...FOR UPDATE NOWAIT is run on
  145. a database backend that supports FOR UPDATE but not NOWAIT.
  146. """
  147. with self.assertRaisesMessage(NotSupportedError, 'NOWAIT is not supported on this database backend.'):
  148. with transaction.atomic():
  149. Person.objects.select_for_update(nowait=True).get()
  150. @skipIfDBFeature('has_select_for_update_skip_locked')
  151. @skipUnlessDBFeature('has_select_for_update')
  152. def test_unsupported_skip_locked_raises_error(self):
  153. """
  154. NotSupportedError is raised if a SELECT...FOR UPDATE SKIP LOCKED is run
  155. on a database backend that supports FOR UPDATE but not SKIP LOCKED.
  156. """
  157. with self.assertRaisesMessage(NotSupportedError, 'SKIP LOCKED is not supported on this database backend.'):
  158. with transaction.atomic():
  159. Person.objects.select_for_update(skip_locked=True).get()
  160. @skipIfDBFeature('has_select_for_update_of')
  161. @skipUnlessDBFeature('has_select_for_update')
  162. def test_unsupported_of_raises_error(self):
  163. """
  164. NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
  165. a database backend that supports FOR UPDATE but not OF.
  166. """
  167. msg = 'FOR UPDATE OF is not supported on this database backend.'
  168. with self.assertRaisesMessage(NotSupportedError, msg):
  169. with transaction.atomic():
  170. Person.objects.select_for_update(of=('self',)).get()
  171. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  172. def test_unrelated_of_argument_raises_error(self):
  173. """
  174. FieldError is raised if a non-relation field is specified in of=(...).
  175. """
  176. msg = (
  177. 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
  178. 'Only relational fields followed in the query are allowed. '
  179. 'Choices are: self, born, born__country.'
  180. )
  181. invalid_of = [
  182. ('nonexistent',),
  183. ('name',),
  184. ('born__nonexistent',),
  185. ('born__name',),
  186. ('born__nonexistent', 'born__name'),
  187. ]
  188. for of in invalid_of:
  189. with self.subTest(of=of):
  190. with self.assertRaisesMessage(FieldError, msg % ', '.join(of)):
  191. with transaction.atomic():
  192. Person.objects.select_related('born__country').select_for_update(of=of).get()
  193. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  194. def test_related_but_unselected_of_argument_raises_error(self):
  195. """
  196. FieldError is raised if a relation field that is not followed in the
  197. query is specified in of=(...).
  198. """
  199. msg = (
  200. 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
  201. 'Only relational fields followed in the query are allowed. '
  202. 'Choices are: self, born.'
  203. )
  204. for name in ['born__country', 'died', 'died__country']:
  205. with self.subTest(name=name):
  206. with self.assertRaisesMessage(FieldError, msg % name):
  207. with transaction.atomic():
  208. Person.objects.select_related('born').select_for_update(of=(name,)).get()
  209. @skipUnlessDBFeature('has_select_for_update')
  210. def test_for_update_after_from(self):
  211. features_class = connections['default'].features.__class__
  212. attribute_to_patch = "%s.%s.for_update_after_from" % (features_class.__module__, features_class.__name__)
  213. with mock.patch(attribute_to_patch, return_value=True):
  214. with transaction.atomic():
  215. self.assertIn('FOR UPDATE WHERE', str(Person.objects.filter(name='foo').select_for_update().query))
  216. @skipUnlessDBFeature('has_select_for_update')
  217. def test_for_update_requires_transaction(self):
  218. """
  219. A TransactionManagementError is raised
  220. when a select_for_update query is executed outside of a transaction.
  221. """
  222. with self.assertRaises(transaction.TransactionManagementError):
  223. list(Person.objects.all().select_for_update())
  224. @skipUnlessDBFeature('has_select_for_update')
  225. def test_for_update_requires_transaction_only_in_execution(self):
  226. """
  227. No TransactionManagementError is raised
  228. when select_for_update is invoked outside of a transaction -
  229. only when the query is executed.
  230. """
  231. people = Person.objects.all().select_for_update()
  232. with self.assertRaises(transaction.TransactionManagementError):
  233. list(people)
  234. @skipUnlessDBFeature('supports_select_for_update_with_limit')
  235. def test_select_for_update_with_limit(self):
  236. other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2)
  237. with transaction.atomic():
  238. qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2])
  239. self.assertEqual(qs[0], other)
  240. @skipIfDBFeature('supports_select_for_update_with_limit')
  241. def test_unsupported_select_for_update_with_limit(self):
  242. msg = 'LIMIT/OFFSET is not supported with select_for_update on this database backend.'
  243. with self.assertRaisesMessage(NotSupportedError, msg):
  244. with transaction.atomic():
  245. list(Person.objects.all().order_by('pk').select_for_update()[1:2])
  246. def run_select_for_update(self, status, **kwargs):
  247. """
  248. Utility method that runs a SELECT FOR UPDATE against all
  249. Person instances. After the select_for_update, it attempts
  250. to update the name of the only record, save, and commit.
  251. This function expects to run in a separate thread.
  252. """
  253. status.append('started')
  254. try:
  255. # We need to enter transaction management again, as this is done on
  256. # per-thread basis
  257. with transaction.atomic():
  258. person = Person.objects.select_for_update(**kwargs).get()
  259. person.name = 'Fred'
  260. person.save()
  261. except (DatabaseError, Person.DoesNotExist) as e:
  262. status.append(e)
  263. finally:
  264. # This method is run in a separate thread. It uses its own
  265. # database connection. Close it without waiting for the GC.
  266. connection.close()
  267. @skipUnlessDBFeature('has_select_for_update')
  268. @skipUnlessDBFeature('supports_transactions')
  269. def test_block(self):
  270. """
  271. A thread running a select_for_update that accesses rows being touched
  272. by a similar operation on another connection blocks correctly.
  273. """
  274. # First, let's start the transaction in our thread.
  275. self.start_blocking_transaction()
  276. # Now, try it again using the ORM's select_for_update
  277. # facility. Do this in a separate thread.
  278. status = []
  279. thread = threading.Thread(
  280. target=self.run_select_for_update, args=(status,)
  281. )
  282. # The thread should immediately block, but we'll sleep
  283. # for a bit to make sure.
  284. thread.start()
  285. sanity_count = 0
  286. while len(status) != 1 and sanity_count < 10:
  287. sanity_count += 1
  288. time.sleep(1)
  289. if sanity_count >= 10:
  290. raise ValueError('Thread did not run and block')
  291. # Check the person hasn't been updated. Since this isn't
  292. # using FOR UPDATE, it won't block.
  293. p = Person.objects.get(pk=self.person.pk)
  294. self.assertEqual('Reinhardt', p.name)
  295. # When we end our blocking transaction, our thread should
  296. # be able to continue.
  297. self.end_blocking_transaction()
  298. thread.join(5.0)
  299. # Check the thread has finished. Assuming it has, we should
  300. # find that it has updated the person's name.
  301. self.assertFalse(thread.isAlive())
  302. # We must commit the transaction to ensure that MySQL gets a fresh read,
  303. # since by default it runs in REPEATABLE READ mode
  304. transaction.commit()
  305. p = Person.objects.get(pk=self.person.pk)
  306. self.assertEqual('Fred', p.name)
  307. @skipUnlessDBFeature('has_select_for_update')
  308. def test_raw_lock_not_available(self):
  309. """
  310. Running a raw query which can't obtain a FOR UPDATE lock raises
  311. the correct exception
  312. """
  313. self.start_blocking_transaction()
  314. def raw(status):
  315. try:
  316. list(
  317. Person.objects.raw(
  318. 'SELECT * FROM %s %s' % (
  319. Person._meta.db_table,
  320. connection.ops.for_update_sql(nowait=True)
  321. )
  322. )
  323. )
  324. except DatabaseError as e:
  325. status.append(e)
  326. finally:
  327. # This method is run in a separate thread. It uses its own
  328. # database connection. Close it without waiting for the GC.
  329. connection.close()
  330. status = []
  331. thread = threading.Thread(target=raw, kwargs={'status': status})
  332. thread.start()
  333. time.sleep(1)
  334. thread.join()
  335. self.end_blocking_transaction()
  336. self.assertIsInstance(status[-1], DatabaseError)
  337. @skipUnlessDBFeature('has_select_for_update')
  338. @override_settings(DATABASE_ROUTERS=[TestRouter()])
  339. def test_select_for_update_on_multidb(self):
  340. query = Person.objects.select_for_update()
  341. self.assertEqual(router.db_for_write(Person), query.db)
  342. @skipUnlessDBFeature('has_select_for_update')
  343. def test_select_for_update_with_get(self):
  344. with transaction.atomic():
  345. person = Person.objects.select_for_update().get(name='Reinhardt')
  346. self.assertEqual(person.name, 'Reinhardt')
  347. def test_nowait_and_skip_locked(self):
  348. with self.assertRaisesMessage(ValueError, 'The nowait option cannot be used with skip_locked.'):
  349. Person.objects.select_for_update(nowait=True, skip_locked=True)
  350. def test_ordered_select_for_update(self):
  351. """
  352. Subqueries should respect ordering as an ORDER BY clause may be useful
  353. to specify a row locking order to prevent deadlocks (#27193).
  354. """
  355. with transaction.atomic():
  356. qs = Person.objects.filter(id__in=Person.objects.order_by('-id').select_for_update())
  357. self.assertIn('ORDER BY', str(qs.query))