123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666 |
- import threading
- import time
- from unittest import mock
- from multiple_database.routers import TestRouter
- from django.core.exceptions import FieldError
- from django.db import (
- DatabaseError,
- NotSupportedError,
- connection,
- connections,
- router,
- transaction,
- )
- from django.test import (
- TransactionTestCase,
- override_settings,
- skipIfDBFeature,
- skipUnlessDBFeature,
- )
- from django.test.utils import CaptureQueriesContext
- from .models import (
- City,
- CityCountryProxy,
- Country,
- EUCity,
- EUCountry,
- Person,
- PersonProfile,
- )
- class SelectForUpdateTests(TransactionTestCase):
- available_apps = ["select_for_update"]
- def setUp(self):
- # This is executed in autocommit mode so that code in
- # run_select_for_update can see this data.
- self.country1 = Country.objects.create(name="Belgium")
- self.country2 = Country.objects.create(name="France")
- self.city1 = City.objects.create(name="Liberchies", country=self.country1)
- self.city2 = City.objects.create(name="Samois-sur-Seine", country=self.country2)
- self.person = Person.objects.create(
- name="Reinhardt", born=self.city1, died=self.city2
- )
- self.person_profile = PersonProfile.objects.create(person=self.person)
- # We need another database connection in transaction to test that one
- # connection issuing a SELECT ... FOR UPDATE will block.
- self.new_connection = connection.copy()
- def tearDown(self):
- try:
- self.end_blocking_transaction()
- except (DatabaseError, AttributeError):
- pass
- self.new_connection.close()
- def start_blocking_transaction(self):
- self.new_connection.set_autocommit(False)
- # Start a blocking transaction. At some point,
- # end_blocking_transaction() should be called.
- self.cursor = self.new_connection.cursor()
- sql = "SELECT * FROM %(db_table)s %(for_update)s;" % {
- "db_table": Person._meta.db_table,
- "for_update": self.new_connection.ops.for_update_sql(),
- }
- self.cursor.execute(sql, ())
- self.cursor.fetchone()
- def end_blocking_transaction(self):
- # Roll back the blocking transaction.
- self.cursor.close()
- self.new_connection.rollback()
- self.new_connection.set_autocommit(True)
- def has_for_update_sql(self, queries, **kwargs):
- # Examine the SQL that was executed to determine whether it
- # contains the 'SELECT..FOR UPDATE' stanza.
- for_update_sql = connection.ops.for_update_sql(**kwargs)
- return any(for_update_sql in query["sql"] for query in queries)
- @skipUnlessDBFeature("has_select_for_update")
- def test_for_update_sql_generated(self):
- """
- The backend's FOR UPDATE variant appears in
- generated SQL when select_for_update is invoked.
- """
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(Person.objects.select_for_update())
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries))
- @skipUnlessDBFeature("has_select_for_update_nowait")
- def test_for_update_sql_generated_nowait(self):
- """
- The backend's FOR UPDATE NOWAIT variant appears in
- generated SQL when select_for_update is invoked.
- """
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(Person.objects.select_for_update(nowait=True))
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, nowait=True))
- @skipUnlessDBFeature("has_select_for_update_skip_locked")
- def test_for_update_sql_generated_skip_locked(self):
- """
- The backend's FOR UPDATE SKIP LOCKED variant appears in
- generated SQL when select_for_update is invoked.
- """
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(Person.objects.select_for_update(skip_locked=True))
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
- @skipUnlessDBFeature("has_select_for_no_key_update")
- def test_update_sql_generated_no_key(self):
- """
- The backend's FOR NO KEY UPDATE variant appears in generated SQL when
- select_for_update() is invoked.
- """
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(Person.objects.select_for_update(no_key=True))
- self.assertIs(self.has_for_update_sql(ctx.captured_queries, no_key=True), True)
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_generated_of(self):
- """
- The backend's FOR UPDATE OF variant appears in the generated SQL when
- select_for_update() is invoked.
- """
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- Person.objects.select_related(
- "born__country",
- )
- .select_for_update(
- of=("born__country",),
- )
- .select_for_update(of=("self", "born__country"))
- )
- features = connections["default"].features
- if features.select_for_update_of_column:
- expected = [
- 'select_for_update_person"."id',
- 'select_for_update_country"."entity_ptr_id',
- ]
- else:
- expected = ["select_for_update_person", "select_for_update_country"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_model_inheritance_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(EUCountry.objects.select_for_update(of=("self",)))
- if connection.features.select_for_update_of_column:
- expected = ['select_for_update_eucountry"."country_ptr_id']
- else:
- expected = ["select_for_update_eucountry"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_model_inheritance_ptr_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- EUCountry.objects.select_for_update(
- of=(
- "self",
- "country_ptr",
- )
- )
- )
- if connection.features.select_for_update_of_column:
- expected = [
- 'select_for_update_eucountry"."country_ptr_id',
- 'select_for_update_country"."entity_ptr_id',
- ]
- else:
- expected = ["select_for_update_eucountry", "select_for_update_country"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_related_model_inheritance_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- EUCity.objects.select_related("country").select_for_update(
- of=("self", "country"),
- )
- )
- if connection.features.select_for_update_of_column:
- expected = [
- 'select_for_update_eucity"."id',
- 'select_for_update_eucountry"."country_ptr_id',
- ]
- else:
- expected = ["select_for_update_eucity", "select_for_update_eucountry"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- EUCity.objects.select_related("country").select_for_update(
- of=(
- "self",
- "country__country_ptr",
- ),
- )
- )
- if connection.features.select_for_update_of_column:
- expected = [
- 'select_for_update_eucity"."id',
- 'select_for_update_country"."entity_ptr_id',
- ]
- else:
- expected = ["select_for_update_eucity", "select_for_update_country"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_multilevel_model_inheritance_ptr_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- EUCountry.objects.select_for_update(
- of=("country_ptr", "country_ptr__entity_ptr"),
- )
- )
- if connection.features.select_for_update_of_column:
- expected = [
- 'select_for_update_country"."entity_ptr_id',
- 'select_for_update_entity"."id',
- ]
- else:
- expected = ["select_for_update_country", "select_for_update_entity"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_sql_model_proxy_generated_of(self):
- with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
- list(
- CityCountryProxy.objects.select_related("country",).select_for_update(
- of=("country",),
- )
- )
- if connection.features.select_for_update_of_column:
- expected = ['select_for_update_country"."entity_ptr_id']
- else:
- expected = ["select_for_update_country"]
- expected = [connection.ops.quote_name(value) for value in expected]
- self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_of_followed_by_values(self):
- with transaction.atomic():
- values = list(Person.objects.select_for_update(of=("self",)).values("pk"))
- self.assertEqual(values, [{"pk": self.person.pk}])
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_of_followed_by_values_list(self):
- with transaction.atomic():
- values = list(
- Person.objects.select_for_update(of=("self",)).values_list("pk")
- )
- self.assertEqual(values, [(self.person.pk,)])
- @skipUnlessDBFeature("has_select_for_update_of")
- def test_for_update_of_self_when_self_is_not_selected(self):
- """
- select_for_update(of=['self']) when the only columns selected are from
- related tables.
- """
- with transaction.atomic():
- values = list(
- Person.objects.select_related("born")
- .select_for_update(of=("self",))
- .values("born__name")
- )
- self.assertEqual(values, [{"born__name": self.city1.name}])
- @skipUnlessDBFeature(
- "has_select_for_update_of",
- "supports_select_for_update_with_limit",
- )
- def test_for_update_of_with_exists(self):
- with transaction.atomic():
- qs = Person.objects.select_for_update(of=("self", "born"))
- self.assertIs(qs.exists(), True)
- @skipUnlessDBFeature("has_select_for_update_nowait")
- def test_nowait_raises_error_on_block(self):
- """
- If nowait is specified, we expect an error to be raised rather
- than blocking.
- """
- self.start_blocking_transaction()
- status = []
- thread = threading.Thread(
- target=self.run_select_for_update,
- args=(status,),
- kwargs={"nowait": True},
- )
- thread.start()
- time.sleep(1)
- thread.join()
- self.end_blocking_transaction()
- self.assertIsInstance(status[-1], DatabaseError)
- @skipUnlessDBFeature("has_select_for_update_skip_locked")
- def test_skip_locked_skips_locked_rows(self):
- """
- If skip_locked is specified, the locked row is skipped resulting in
- Person.DoesNotExist.
- """
- self.start_blocking_transaction()
- status = []
- thread = threading.Thread(
- target=self.run_select_for_update,
- args=(status,),
- kwargs={"skip_locked": True},
- )
- thread.start()
- time.sleep(1)
- thread.join()
- self.end_blocking_transaction()
- self.assertIsInstance(status[-1], Person.DoesNotExist)
- @skipIfDBFeature("has_select_for_update_nowait")
- @skipUnlessDBFeature("has_select_for_update")
- def test_unsupported_nowait_raises_error(self):
- """
- NotSupportedError is raised if a SELECT...FOR UPDATE NOWAIT is run on
- a database backend that supports FOR UPDATE but not NOWAIT.
- """
- with self.assertRaisesMessage(
- NotSupportedError, "NOWAIT is not supported on this database backend."
- ):
- with transaction.atomic():
- Person.objects.select_for_update(nowait=True).get()
- @skipIfDBFeature("has_select_for_update_skip_locked")
- @skipUnlessDBFeature("has_select_for_update")
- def test_unsupported_skip_locked_raises_error(self):
- """
- NotSupportedError is raised if a SELECT...FOR UPDATE SKIP LOCKED is run
- on a database backend that supports FOR UPDATE but not SKIP LOCKED.
- """
- with self.assertRaisesMessage(
- NotSupportedError, "SKIP LOCKED is not supported on this database backend."
- ):
- with transaction.atomic():
- Person.objects.select_for_update(skip_locked=True).get()
- @skipIfDBFeature("has_select_for_update_of")
- @skipUnlessDBFeature("has_select_for_update")
- def test_unsupported_of_raises_error(self):
- """
- NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
- a database backend that supports FOR UPDATE but not OF.
- """
- msg = "FOR UPDATE OF is not supported on this database backend."
- with self.assertRaisesMessage(NotSupportedError, msg):
- with transaction.atomic():
- Person.objects.select_for_update(of=("self",)).get()
- @skipIfDBFeature("has_select_for_no_key_update")
- @skipUnlessDBFeature("has_select_for_update")
- def test_unsuported_no_key_raises_error(self):
- """
- NotSupportedError is raised if a SELECT...FOR NO KEY UPDATE... is run
- on a database backend that supports FOR UPDATE but not NO KEY.
- """
- msg = "FOR NO KEY UPDATE is not supported on this database backend."
- with self.assertRaisesMessage(NotSupportedError, msg):
- with transaction.atomic():
- Person.objects.select_for_update(no_key=True).get()
- @skipUnlessDBFeature("has_select_for_update", "has_select_for_update_of")
- def test_unrelated_of_argument_raises_error(self):
- """
- FieldError is raised if a non-relation field is specified in of=(...).
- """
- msg = (
- "Invalid field name(s) given in select_for_update(of=(...)): %s. "
- "Only relational fields followed in the query are allowed. "
- "Choices are: self, born, born__country, "
- "born__country__entity_ptr."
- )
- invalid_of = [
- ("nonexistent",),
- ("name",),
- ("born__nonexistent",),
- ("born__name",),
- ("born__nonexistent", "born__name"),
- ]
- for of in invalid_of:
- with self.subTest(of=of):
- with self.assertRaisesMessage(FieldError, msg % ", ".join(of)):
- with transaction.atomic():
- Person.objects.select_related(
- "born__country"
- ).select_for_update(of=of).get()
- @skipUnlessDBFeature("has_select_for_update", "has_select_for_update_of")
- def test_related_but_unselected_of_argument_raises_error(self):
- """
- FieldError is raised if a relation field that is not followed in the
- query is specified in of=(...).
- """
- msg = (
- "Invalid field name(s) given in select_for_update(of=(...)): %s. "
- "Only relational fields followed in the query are allowed. "
- "Choices are: self, born, profile."
- )
- for name in ["born__country", "died", "died__country"]:
- with self.subTest(name=name):
- with self.assertRaisesMessage(FieldError, msg % name):
- with transaction.atomic():
- Person.objects.select_related("born", "profile",).exclude(
- profile=None
- ).select_for_update(of=(name,)).get()
- @skipUnlessDBFeature("has_select_for_update", "has_select_for_update_of")
- def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
- msg = (
- "Invalid field name(s) given in select_for_update(of=(...)): "
- "name. Only relational fields followed in the query are allowed. "
- "Choices are: self, %s."
- )
- with self.assertRaisesMessage(
- FieldError,
- msg % "country, country__country_ptr, country__country_ptr__entity_ptr",
- ):
- with transaction.atomic():
- EUCity.objects.select_related(
- "country",
- ).select_for_update(of=("name",)).get()
- with self.assertRaisesMessage(
- FieldError, msg % "country_ptr, country_ptr__entity_ptr"
- ):
- with transaction.atomic():
- EUCountry.objects.select_for_update(of=("name",)).get()
- @skipUnlessDBFeature("has_select_for_update", "has_select_for_update_of")
- def test_model_proxy_of_argument_raises_error_proxy_field_in_choices(self):
- msg = (
- "Invalid field name(s) given in select_for_update(of=(...)): "
- "name. Only relational fields followed in the query are allowed. "
- "Choices are: self, country, country__entity_ptr."
- )
- with self.assertRaisesMessage(FieldError, msg):
- with transaction.atomic():
- CityCountryProxy.objects.select_related(
- "country",
- ).select_for_update(of=("name",)).get()
- @skipUnlessDBFeature("has_select_for_update", "has_select_for_update_of")
- def test_reverse_one_to_one_of_arguments(self):
- """
- Reverse OneToOneFields may be included in of=(...) as long as NULLs
- are excluded because LEFT JOIN isn't allowed in SELECT FOR UPDATE.
- """
- with transaction.atomic():
- person = (
- Person.objects.select_related(
- "profile",
- )
- .exclude(profile=None)
- .select_for_update(of=("profile",))
- .get()
- )
- self.assertEqual(person.profile, self.person_profile)
- @skipUnlessDBFeature("has_select_for_update")
- def test_for_update_after_from(self):
- features_class = connections["default"].features.__class__
- attribute_to_patch = "%s.%s.for_update_after_from" % (
- features_class.__module__,
- features_class.__name__,
- )
- with mock.patch(attribute_to_patch, return_value=True):
- with transaction.atomic():
- self.assertIn(
- "FOR UPDATE WHERE",
- str(Person.objects.filter(name="foo").select_for_update().query),
- )
- @skipUnlessDBFeature("has_select_for_update")
- def test_for_update_requires_transaction(self):
- """
- A TransactionManagementError is raised
- when a select_for_update query is executed outside of a transaction.
- """
- msg = "select_for_update cannot be used outside of a transaction."
- with self.assertRaisesMessage(transaction.TransactionManagementError, msg):
- list(Person.objects.select_for_update())
- @skipUnlessDBFeature("has_select_for_update")
- def test_for_update_requires_transaction_only_in_execution(self):
- """
- No TransactionManagementError is raised
- when select_for_update is invoked outside of a transaction -
- only when the query is executed.
- """
- people = Person.objects.select_for_update()
- msg = "select_for_update cannot be used outside of a transaction."
- with self.assertRaisesMessage(transaction.TransactionManagementError, msg):
- list(people)
- @skipUnlessDBFeature("supports_select_for_update_with_limit")
- def test_select_for_update_with_limit(self):
- other = Person.objects.create(name="Grappeli", born=self.city1, died=self.city2)
- with transaction.atomic():
- qs = list(Person.objects.order_by("pk").select_for_update()[1:2])
- self.assertEqual(qs[0], other)
- @skipIfDBFeature("supports_select_for_update_with_limit")
- def test_unsupported_select_for_update_with_limit(self):
- msg = (
- "LIMIT/OFFSET is not supported with select_for_update on this database "
- "backend."
- )
- with self.assertRaisesMessage(NotSupportedError, msg):
- with transaction.atomic():
- list(Person.objects.order_by("pk").select_for_update()[1:2])
- def run_select_for_update(self, status, **kwargs):
- """
- Utility method that runs a SELECT FOR UPDATE against all
- Person instances. After the select_for_update, it attempts
- to update the name of the only record, save, and commit.
- This function expects to run in a separate thread.
- """
- status.append("started")
- try:
- # We need to enter transaction management again, as this is done on
- # per-thread basis
- with transaction.atomic():
- person = Person.objects.select_for_update(**kwargs).get()
- person.name = "Fred"
- person.save()
- except (DatabaseError, Person.DoesNotExist) as e:
- status.append(e)
- finally:
- # This method is run in a separate thread. It uses its own
- # database connection. Close it without waiting for the GC.
- connection.close()
- @skipUnlessDBFeature("has_select_for_update")
- @skipUnlessDBFeature("supports_transactions")
- def test_block(self):
- """
- A thread running a select_for_update that accesses rows being touched
- by a similar operation on another connection blocks correctly.
- """
- # First, let's start the transaction in our thread.
- self.start_blocking_transaction()
- # Now, try it again using the ORM's select_for_update
- # facility. Do this in a separate thread.
- status = []
- thread = threading.Thread(target=self.run_select_for_update, args=(status,))
- # The thread should immediately block, but we'll sleep
- # for a bit to make sure.
- thread.start()
- sanity_count = 0
- while len(status) != 1 and sanity_count < 10:
- sanity_count += 1
- time.sleep(1)
- if sanity_count >= 10:
- raise ValueError("Thread did not run and block")
- # Check the person hasn't been updated. Since this isn't
- # using FOR UPDATE, it won't block.
- p = Person.objects.get(pk=self.person.pk)
- self.assertEqual("Reinhardt", p.name)
- # When we end our blocking transaction, our thread should
- # be able to continue.
- self.end_blocking_transaction()
- thread.join(5.0)
- # Check the thread has finished. Assuming it has, we should
- # find that it has updated the person's name.
- self.assertFalse(thread.is_alive())
- # We must commit the transaction to ensure that MySQL gets a fresh read,
- # since by default it runs in REPEATABLE READ mode
- transaction.commit()
- p = Person.objects.get(pk=self.person.pk)
- self.assertEqual("Fred", p.name)
- @skipUnlessDBFeature("has_select_for_update")
- def test_raw_lock_not_available(self):
- """
- Running a raw query which can't obtain a FOR UPDATE lock raises
- the correct exception
- """
- self.start_blocking_transaction()
- def raw(status):
- try:
- list(
- Person.objects.raw(
- "SELECT * FROM %s %s"
- % (
- Person._meta.db_table,
- connection.ops.for_update_sql(nowait=True),
- )
- )
- )
- except DatabaseError as e:
- status.append(e)
- finally:
- # This method is run in a separate thread. It uses its own
- # database connection. Close it without waiting for the GC.
- # Connection cannot be closed on Oracle because cursor is still
- # open.
- if connection.vendor != "oracle":
- connection.close()
- status = []
- thread = threading.Thread(target=raw, kwargs={"status": status})
- thread.start()
- time.sleep(1)
- thread.join()
- self.end_blocking_transaction()
- self.assertIsInstance(status[-1], DatabaseError)
- @skipUnlessDBFeature("has_select_for_update")
- @override_settings(DATABASE_ROUTERS=[TestRouter()])
- def test_select_for_update_on_multidb(self):
- query = Person.objects.select_for_update()
- self.assertEqual(router.db_for_write(Person), query.db)
- @skipUnlessDBFeature("has_select_for_update")
- def test_select_for_update_with_get(self):
- with transaction.atomic():
- person = Person.objects.select_for_update().get(name="Reinhardt")
- self.assertEqual(person.name, "Reinhardt")
- def test_nowait_and_skip_locked(self):
- with self.assertRaisesMessage(
- ValueError, "The nowait option cannot be used with skip_locked."
- ):
- Person.objects.select_for_update(nowait=True, skip_locked=True)
- def test_ordered_select_for_update(self):
- """
- Subqueries should respect ordering as an ORDER BY clause may be useful
- to specify a row locking order to prevent deadlocks (#27193).
- """
- with transaction.atomic():
- qs = Person.objects.filter(
- id__in=Person.objects.order_by("-id").select_for_update()
- )
- self.assertIn("ORDER BY", str(qs.query))
|