tests.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. import threading
  2. import time
  3. from unittest import mock
  4. from multiple_database.routers import TestRouter
  5. from django.core.exceptions import FieldError
  6. from django.db import (
  7. DatabaseError, NotSupportedError, connection, connections, router,
  8. transaction,
  9. )
  10. from django.test import (
  11. TransactionTestCase, override_settings, skipIfDBFeature,
  12. skipUnlessDBFeature,
  13. )
  14. from django.test.utils import CaptureQueriesContext
  15. from .models import City, Country, EUCity, EUCountry, Person, PersonProfile
  16. class SelectForUpdateTests(TransactionTestCase):
  17. available_apps = ['select_for_update']
  18. def setUp(self):
  19. # This is executed in autocommit mode so that code in
  20. # run_select_for_update can see this data.
  21. self.country1 = Country.objects.create(name='Belgium')
  22. self.country2 = Country.objects.create(name='France')
  23. self.city1 = City.objects.create(name='Liberchies', country=self.country1)
  24. self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
  25. self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
  26. self.person_profile = PersonProfile.objects.create(person=self.person)
  27. # We need another database connection in transaction to test that one
  28. # connection issuing a SELECT ... FOR UPDATE will block.
  29. self.new_connection = connection.copy()
  30. def tearDown(self):
  31. try:
  32. self.end_blocking_transaction()
  33. except (DatabaseError, AttributeError):
  34. pass
  35. self.new_connection.close()
  36. def start_blocking_transaction(self):
  37. self.new_connection.set_autocommit(False)
  38. # Start a blocking transaction. At some point,
  39. # end_blocking_transaction() should be called.
  40. self.cursor = self.new_connection.cursor()
  41. sql = 'SELECT * FROM %(db_table)s %(for_update)s;' % {
  42. 'db_table': Person._meta.db_table,
  43. 'for_update': self.new_connection.ops.for_update_sql(),
  44. }
  45. self.cursor.execute(sql, ())
  46. self.cursor.fetchone()
  47. def end_blocking_transaction(self):
  48. # Roll back the blocking transaction.
  49. self.cursor.close()
  50. self.new_connection.rollback()
  51. self.new_connection.set_autocommit(True)
  52. def has_for_update_sql(self, queries, **kwargs):
  53. # Examine the SQL that was executed to determine whether it
  54. # contains the 'SELECT..FOR UPDATE' stanza.
  55. for_update_sql = connection.ops.for_update_sql(**kwargs)
  56. return any(for_update_sql in query['sql'] for query in queries)
  57. @skipUnlessDBFeature('has_select_for_update')
  58. def test_for_update_sql_generated(self):
  59. """
  60. The backend's FOR UPDATE variant appears in
  61. generated SQL when select_for_update is invoked.
  62. """
  63. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  64. list(Person.objects.all().select_for_update())
  65. self.assertTrue(self.has_for_update_sql(ctx.captured_queries))
  66. @skipUnlessDBFeature('has_select_for_update_nowait')
  67. def test_for_update_sql_generated_nowait(self):
  68. """
  69. The backend's FOR UPDATE NOWAIT variant appears in
  70. generated SQL when select_for_update is invoked.
  71. """
  72. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  73. list(Person.objects.all().select_for_update(nowait=True))
  74. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, nowait=True))
  75. @skipUnlessDBFeature('has_select_for_update_skip_locked')
  76. def test_for_update_sql_generated_skip_locked(self):
  77. """
  78. The backend's FOR UPDATE SKIP LOCKED variant appears in
  79. generated SQL when select_for_update is invoked.
  80. """
  81. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  82. list(Person.objects.all().select_for_update(skip_locked=True))
  83. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
  84. @skipUnlessDBFeature('has_select_for_update_of')
  85. def test_for_update_sql_generated_of(self):
  86. """
  87. The backend's FOR UPDATE OF variant appears in the generated SQL when
  88. select_for_update() is invoked.
  89. """
  90. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  91. list(Person.objects.select_related(
  92. 'born__country',
  93. ).select_for_update(
  94. of=('born__country',),
  95. ).select_for_update(
  96. of=('self', 'born__country')
  97. ))
  98. features = connections['default'].features
  99. if features.select_for_update_of_column:
  100. expected = ['select_for_update_person"."id', 'select_for_update_country"."id']
  101. else:
  102. expected = ['select_for_update_person', 'select_for_update_country']
  103. expected = [connection.ops.quote_name(value) for value in expected]
  104. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
  105. @skipUnlessDBFeature('has_select_for_update_of')
  106. def test_for_update_sql_model_inheritance_generated_of(self):
  107. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  108. list(EUCountry.objects.select_for_update(of=('self',)))
  109. if connection.features.select_for_update_of_column:
  110. expected = ['select_for_update_eucountry"."country_ptr_id']
  111. else:
  112. expected = ['select_for_update_eucountry']
  113. expected = [connection.ops.quote_name(value) for value in expected]
  114. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
  115. @skipUnlessDBFeature('has_select_for_update_of')
  116. def test_for_update_sql_model_inheritance_ptr_generated_of(self):
  117. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  118. list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',)))
  119. if connection.features.select_for_update_of_column:
  120. expected = [
  121. 'select_for_update_eucountry"."country_ptr_id',
  122. 'select_for_update_country"."id',
  123. ]
  124. else:
  125. expected = ['select_for_update_eucountry', 'select_for_update_country']
  126. expected = [connection.ops.quote_name(value) for value in expected]
  127. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
  128. @skipUnlessDBFeature('has_select_for_update_of')
  129. def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
  130. with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
  131. list(EUCity.objects.select_related('country').select_for_update(
  132. of=('self', 'country__country_ptr',),
  133. ))
  134. if connection.features.select_for_update_of_column:
  135. expected = [
  136. 'select_for_update_eucity"."id',
  137. 'select_for_update_country"."id',
  138. ]
  139. else:
  140. expected = ['select_for_update_eucity', 'select_for_update_country']
  141. expected = [connection.ops.quote_name(value) for value in expected]
  142. self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
  143. @skipUnlessDBFeature('has_select_for_update_of')
  144. def test_for_update_of_followed_by_values(self):
  145. with transaction.atomic():
  146. values = list(Person.objects.select_for_update(of=('self',)).values('pk'))
  147. self.assertEqual(values, [{'pk': self.person.pk}])
  148. @skipUnlessDBFeature('has_select_for_update_of')
  149. def test_for_update_of_followed_by_values_list(self):
  150. with transaction.atomic():
  151. values = list(Person.objects.select_for_update(of=('self',)).values_list('pk'))
  152. self.assertEqual(values, [(self.person.pk,)])
  153. @skipUnlessDBFeature('has_select_for_update_of')
  154. def test_for_update_of_self_when_self_is_not_selected(self):
  155. """
  156. select_for_update(of=['self']) when the only columns selected are from
  157. related tables.
  158. """
  159. with transaction.atomic():
  160. values = list(Person.objects.select_related('born').select_for_update(of=('self',)).values('born__name'))
  161. self.assertEqual(values, [{'born__name': self.city1.name}])
  162. @skipUnlessDBFeature('has_select_for_update_nowait')
  163. def test_nowait_raises_error_on_block(self):
  164. """
  165. If nowait is specified, we expect an error to be raised rather
  166. than blocking.
  167. """
  168. self.start_blocking_transaction()
  169. status = []
  170. thread = threading.Thread(
  171. target=self.run_select_for_update,
  172. args=(status,),
  173. kwargs={'nowait': True},
  174. )
  175. thread.start()
  176. time.sleep(1)
  177. thread.join()
  178. self.end_blocking_transaction()
  179. self.assertIsInstance(status[-1], DatabaseError)
  180. @skipUnlessDBFeature('has_select_for_update_skip_locked')
  181. def test_skip_locked_skips_locked_rows(self):
  182. """
  183. If skip_locked is specified, the locked row is skipped resulting in
  184. Person.DoesNotExist.
  185. """
  186. self.start_blocking_transaction()
  187. status = []
  188. thread = threading.Thread(
  189. target=self.run_select_for_update,
  190. args=(status,),
  191. kwargs={'skip_locked': True},
  192. )
  193. thread.start()
  194. time.sleep(1)
  195. thread.join()
  196. self.end_blocking_transaction()
  197. self.assertIsInstance(status[-1], Person.DoesNotExist)
  198. @skipIfDBFeature('has_select_for_update_nowait')
  199. @skipUnlessDBFeature('has_select_for_update')
  200. def test_unsupported_nowait_raises_error(self):
  201. """
  202. NotSupportedError is raised if a SELECT...FOR UPDATE NOWAIT is run on
  203. a database backend that supports FOR UPDATE but not NOWAIT.
  204. """
  205. with self.assertRaisesMessage(NotSupportedError, 'NOWAIT is not supported on this database backend.'):
  206. with transaction.atomic():
  207. Person.objects.select_for_update(nowait=True).get()
  208. @skipIfDBFeature('has_select_for_update_skip_locked')
  209. @skipUnlessDBFeature('has_select_for_update')
  210. def test_unsupported_skip_locked_raises_error(self):
  211. """
  212. NotSupportedError is raised if a SELECT...FOR UPDATE SKIP LOCKED is run
  213. on a database backend that supports FOR UPDATE but not SKIP LOCKED.
  214. """
  215. with self.assertRaisesMessage(NotSupportedError, 'SKIP LOCKED is not supported on this database backend.'):
  216. with transaction.atomic():
  217. Person.objects.select_for_update(skip_locked=True).get()
  218. @skipIfDBFeature('has_select_for_update_of')
  219. @skipUnlessDBFeature('has_select_for_update')
  220. def test_unsupported_of_raises_error(self):
  221. """
  222. NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
  223. a database backend that supports FOR UPDATE but not OF.
  224. """
  225. msg = 'FOR UPDATE OF is not supported on this database backend.'
  226. with self.assertRaisesMessage(NotSupportedError, msg):
  227. with transaction.atomic():
  228. Person.objects.select_for_update(of=('self',)).get()
  229. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  230. def test_unrelated_of_argument_raises_error(self):
  231. """
  232. FieldError is raised if a non-relation field is specified in of=(...).
  233. """
  234. msg = (
  235. 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
  236. 'Only relational fields followed in the query are allowed. '
  237. 'Choices are: self, born, born__country.'
  238. )
  239. invalid_of = [
  240. ('nonexistent',),
  241. ('name',),
  242. ('born__nonexistent',),
  243. ('born__name',),
  244. ('born__nonexistent', 'born__name'),
  245. ]
  246. for of in invalid_of:
  247. with self.subTest(of=of):
  248. with self.assertRaisesMessage(FieldError, msg % ', '.join(of)):
  249. with transaction.atomic():
  250. Person.objects.select_related('born__country').select_for_update(of=of).get()
  251. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  252. def test_related_but_unselected_of_argument_raises_error(self):
  253. """
  254. FieldError is raised if a relation field that is not followed in the
  255. query is specified in of=(...).
  256. """
  257. msg = (
  258. 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
  259. 'Only relational fields followed in the query are allowed. '
  260. 'Choices are: self, born, profile.'
  261. )
  262. for name in ['born__country', 'died', 'died__country']:
  263. with self.subTest(name=name):
  264. with self.assertRaisesMessage(FieldError, msg % name):
  265. with transaction.atomic():
  266. Person.objects.select_related(
  267. 'born', 'profile',
  268. ).exclude(profile=None).select_for_update(of=(name,)).get()
  269. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  270. def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
  271. msg = (
  272. 'Invalid field name(s) given in select_for_update(of=(...)): '
  273. 'name. Only relational fields followed in the query are allowed. '
  274. 'Choices are: self, %s.'
  275. )
  276. with self.assertRaisesMessage(
  277. FieldError,
  278. msg % 'country, country__country_ptr',
  279. ):
  280. with transaction.atomic():
  281. EUCity.objects.select_related(
  282. 'country',
  283. ).select_for_update(of=('name',)).get()
  284. with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
  285. with transaction.atomic():
  286. EUCountry.objects.select_for_update(of=('name',)).get()
  287. @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
  288. def test_reverse_one_to_one_of_arguments(self):
  289. """
  290. Reverse OneToOneFields may be included in of=(...) as long as NULLs
  291. are excluded because LEFT JOIN isn't allowed in SELECT FOR UPDATE.
  292. """
  293. with transaction.atomic():
  294. person = Person.objects.select_related(
  295. 'profile',
  296. ).exclude(profile=None).select_for_update(of=('profile',)).get()
  297. self.assertEqual(person.profile, self.person_profile)
  298. @skipUnlessDBFeature('has_select_for_update')
  299. def test_for_update_after_from(self):
  300. features_class = connections['default'].features.__class__
  301. attribute_to_patch = "%s.%s.for_update_after_from" % (features_class.__module__, features_class.__name__)
  302. with mock.patch(attribute_to_patch, return_value=True):
  303. with transaction.atomic():
  304. self.assertIn('FOR UPDATE WHERE', str(Person.objects.filter(name='foo').select_for_update().query))
  305. @skipUnlessDBFeature('has_select_for_update')
  306. def test_for_update_requires_transaction(self):
  307. """
  308. A TransactionManagementError is raised
  309. when a select_for_update query is executed outside of a transaction.
  310. """
  311. msg = 'select_for_update cannot be used outside of a transaction.'
  312. with self.assertRaisesMessage(transaction.TransactionManagementError, msg):
  313. list(Person.objects.all().select_for_update())
  314. @skipUnlessDBFeature('has_select_for_update')
  315. def test_for_update_requires_transaction_only_in_execution(self):
  316. """
  317. No TransactionManagementError is raised
  318. when select_for_update is invoked outside of a transaction -
  319. only when the query is executed.
  320. """
  321. people = Person.objects.all().select_for_update()
  322. msg = 'select_for_update cannot be used outside of a transaction.'
  323. with self.assertRaisesMessage(transaction.TransactionManagementError, msg):
  324. list(people)
  325. @skipUnlessDBFeature('supports_select_for_update_with_limit')
  326. def test_select_for_update_with_limit(self):
  327. other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2)
  328. with transaction.atomic():
  329. qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2])
  330. self.assertEqual(qs[0], other)
  331. @skipIfDBFeature('supports_select_for_update_with_limit')
  332. def test_unsupported_select_for_update_with_limit(self):
  333. msg = 'LIMIT/OFFSET is not supported with select_for_update on this database backend.'
  334. with self.assertRaisesMessage(NotSupportedError, msg):
  335. with transaction.atomic():
  336. list(Person.objects.all().order_by('pk').select_for_update()[1:2])
  337. def run_select_for_update(self, status, **kwargs):
  338. """
  339. Utility method that runs a SELECT FOR UPDATE against all
  340. Person instances. After the select_for_update, it attempts
  341. to update the name of the only record, save, and commit.
  342. This function expects to run in a separate thread.
  343. """
  344. status.append('started')
  345. try:
  346. # We need to enter transaction management again, as this is done on
  347. # per-thread basis
  348. with transaction.atomic():
  349. person = Person.objects.select_for_update(**kwargs).get()
  350. person.name = 'Fred'
  351. person.save()
  352. except (DatabaseError, Person.DoesNotExist) as e:
  353. status.append(e)
  354. finally:
  355. # This method is run in a separate thread. It uses its own
  356. # database connection. Close it without waiting for the GC.
  357. connection.close()
  358. @skipUnlessDBFeature('has_select_for_update')
  359. @skipUnlessDBFeature('supports_transactions')
  360. def test_block(self):
  361. """
  362. A thread running a select_for_update that accesses rows being touched
  363. by a similar operation on another connection blocks correctly.
  364. """
  365. # First, let's start the transaction in our thread.
  366. self.start_blocking_transaction()
  367. # Now, try it again using the ORM's select_for_update
  368. # facility. Do this in a separate thread.
  369. status = []
  370. thread = threading.Thread(
  371. target=self.run_select_for_update, args=(status,)
  372. )
  373. # The thread should immediately block, but we'll sleep
  374. # for a bit to make sure.
  375. thread.start()
  376. sanity_count = 0
  377. while len(status) != 1 and sanity_count < 10:
  378. sanity_count += 1
  379. time.sleep(1)
  380. if sanity_count >= 10:
  381. raise ValueError('Thread did not run and block')
  382. # Check the person hasn't been updated. Since this isn't
  383. # using FOR UPDATE, it won't block.
  384. p = Person.objects.get(pk=self.person.pk)
  385. self.assertEqual('Reinhardt', p.name)
  386. # When we end our blocking transaction, our thread should
  387. # be able to continue.
  388. self.end_blocking_transaction()
  389. thread.join(5.0)
  390. # Check the thread has finished. Assuming it has, we should
  391. # find that it has updated the person's name.
  392. self.assertFalse(thread.is_alive())
  393. # We must commit the transaction to ensure that MySQL gets a fresh read,
  394. # since by default it runs in REPEATABLE READ mode
  395. transaction.commit()
  396. p = Person.objects.get(pk=self.person.pk)
  397. self.assertEqual('Fred', p.name)
  398. @skipUnlessDBFeature('has_select_for_update')
  399. def test_raw_lock_not_available(self):
  400. """
  401. Running a raw query which can't obtain a FOR UPDATE lock raises
  402. the correct exception
  403. """
  404. self.start_blocking_transaction()
  405. def raw(status):
  406. try:
  407. list(
  408. Person.objects.raw(
  409. 'SELECT * FROM %s %s' % (
  410. Person._meta.db_table,
  411. connection.ops.for_update_sql(nowait=True)
  412. )
  413. )
  414. )
  415. except DatabaseError as e:
  416. status.append(e)
  417. finally:
  418. # This method is run in a separate thread. It uses its own
  419. # database connection. Close it without waiting for the GC.
  420. # Connection cannot be closed on Oracle because cursor is still
  421. # open.
  422. if connection.vendor != 'oracle':
  423. connection.close()
  424. status = []
  425. thread = threading.Thread(target=raw, kwargs={'status': status})
  426. thread.start()
  427. time.sleep(1)
  428. thread.join()
  429. self.end_blocking_transaction()
  430. self.assertIsInstance(status[-1], DatabaseError)
  431. @skipUnlessDBFeature('has_select_for_update')
  432. @override_settings(DATABASE_ROUTERS=[TestRouter()])
  433. def test_select_for_update_on_multidb(self):
  434. query = Person.objects.select_for_update()
  435. self.assertEqual(router.db_for_write(Person), query.db)
  436. @skipUnlessDBFeature('has_select_for_update')
  437. def test_select_for_update_with_get(self):
  438. with transaction.atomic():
  439. person = Person.objects.select_for_update().get(name='Reinhardt')
  440. self.assertEqual(person.name, 'Reinhardt')
  441. def test_nowait_and_skip_locked(self):
  442. with self.assertRaisesMessage(ValueError, 'The nowait option cannot be used with skip_locked.'):
  443. Person.objects.select_for_update(nowait=True, skip_locked=True)
  444. def test_ordered_select_for_update(self):
  445. """
  446. Subqueries should respect ordering as an ORDER BY clause may be useful
  447. to specify a row locking order to prevent deadlocks (#27193).
  448. """
  449. with transaction.atomic():
  450. qs = Person.objects.filter(id__in=Person.objects.order_by('-id').select_for_update())
  451. self.assertIn('ORDER BY', str(qs.query))