2
0
Эх сурвалжийг харах

Fixed #28010 -- Added FOR UPDATE OF support to QuerySet.select_for_update().

Ran Benita 7 жил өмнө
parent
commit
b9f7dce84b

+ 4 - 0
django/db/backends/base/features.py

@@ -36,6 +36,10 @@ class BaseDatabaseFeatures:
     has_select_for_update = False
     has_select_for_update_nowait = False
     has_select_for_update_skip_locked = False
+    has_select_for_update_of = False
+    # Does the database's SELECT FOR UPDATE OF syntax require a column rather
+    # than a table?
+    select_for_update_of_column = False
 
     supports_select_related = True
 

+ 6 - 7
django/db/backends/base/operations.py

@@ -177,16 +177,15 @@ class BaseDatabaseOperations:
         """
         return []
 
-    def for_update_sql(self, nowait=False, skip_locked=False):
+    def for_update_sql(self, nowait=False, skip_locked=False, of=()):
         """
         Return the FOR UPDATE SQL clause to lock rows for an update operation.
         """
-        if nowait:
-            return 'FOR UPDATE NOWAIT'
-        elif skip_locked:
-            return 'FOR UPDATE SKIP LOCKED'
-        else:
-            return 'FOR UPDATE'
+        return 'FOR UPDATE%s%s%s' % (
+            ' OF %s' % ', '.join(of) if of else '',
+            ' NOWAIT' if nowait else '',
+            ' SKIP LOCKED' if skip_locked else '',
+        )
 
     def last_executed_query(self, cursor, sql, params):
         """

+ 2 - 0
django/db/backends/oracle/features.py

@@ -9,6 +9,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     has_select_for_update = True
     has_select_for_update_nowait = True
     has_select_for_update_skip_locked = True
+    has_select_for_update_of = True
+    select_for_update_of_column = True
     can_return_id_from_insert = True
     allow_sliced_subqueries = False
     can_introspect_autofield = True

+ 1 - 0
django/db/backends/postgresql/features.py

@@ -13,6 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     can_defer_constraint_checks = True
     has_select_for_update = True
     has_select_for_update_nowait = True
+    has_select_for_update_of = True
     has_bulk_insert = True
     uses_savepoints = True
     can_release_savepoints = True

+ 2 - 1
django/db/models/query.py

@@ -839,7 +839,7 @@ class QuerySet:
             return self
         return self._combinator_query('difference', *other_qs)
 
-    def select_for_update(self, nowait=False, skip_locked=False):
+    def select_for_update(self, nowait=False, skip_locked=False, of=()):
         """
         Return a new QuerySet instance that will select objects with a
         FOR UPDATE lock.
@@ -851,6 +851,7 @@ class QuerySet:
         obj.query.select_for_update = True
         obj.query.select_for_update_nowait = nowait
         obj.query.select_for_update_skip_locked = skip_locked
+        obj.query.select_for_update_of = of
         return obj
 
     def select_related(self, *fields):

+ 64 - 3
django/db/models/sql/compiler.py

@@ -1,3 +1,4 @@
+import collections
 import re
 from itertools import chain
 
@@ -472,14 +473,21 @@ class SQLCompiler:
                         )
                     nowait = self.query.select_for_update_nowait
                     skip_locked = self.query.select_for_update_skip_locked
-                    # If it's a NOWAIT/SKIP LOCKED query but the backend
-                    # doesn't support it, raise a DatabaseError to prevent a
+                    of = self.query.select_for_update_of
+                    # If it's a NOWAIT/SKIP LOCKED/OF query but the backend
+                    # doesn't support it, raise NotSupportedError to prevent a
                     # possible deadlock.
                     if nowait and not self.connection.features.has_select_for_update_nowait:
                         raise NotSupportedError('NOWAIT is not supported on this database backend.')
                     elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
                         raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
-                    for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
+                    elif of and not self.connection.features.has_select_for_update_of:
+                        raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
+                    for_update_part = self.connection.ops.for_update_sql(
+                        nowait=nowait,
+                        skip_locked=skip_locked,
+                        of=self.get_select_for_update_of_arguments(),
+                    )
 
                 if for_update_part and self.connection.features.for_update_after_from:
                     result.append(for_update_part)
@@ -832,6 +840,59 @@ class SQLCompiler:
                 )
         return related_klass_infos
 
+    def get_select_for_update_of_arguments(self):
+        """
+        Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
+        the query.
+        """
+        def _get_field_choices():
+            """Yield all allowed field paths in breadth-first search order."""
+            queue = collections.deque([(None, self.klass_info)])
+            while queue:
+                parent_path, klass_info = queue.popleft()
+                if parent_path is None:
+                    path = []
+                    yield 'self'
+                else:
+                    path = parent_path + [klass_info['field'].name]
+                    yield LOOKUP_SEP.join(path)
+                queue.extend(
+                    (path, klass_info)
+                    for klass_info in klass_info.get('related_klass_infos', [])
+                )
+        result = []
+        invalid_names = []
+        for name in self.query.select_for_update_of:
+            parts = [] if name == 'self' else name.split(LOOKUP_SEP)
+            klass_info = self.klass_info
+            for part in parts:
+                for related_klass_info in klass_info.get('related_klass_infos', []):
+                    if related_klass_info['field'].name == part:
+                        klass_info = related_klass_info
+                        break
+                else:
+                    klass_info = None
+                    break
+            if klass_info is None:
+                invalid_names.append(name)
+                continue
+            select_index = klass_info['select_fields'][0]
+            col = self.select[select_index][0]
+            if self.connection.features.select_for_update_of_column:
+                result.append(self.compile(col)[0])
+            else:
+                result.append(self.quote_name_unless_alias(col.alias))
+        if invalid_names:
+            raise FieldError(
+                'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+                'Only relational fields followed in the query are allowed. '
+                'Choices are: %s.' % (
+                    ', '.join(invalid_names),
+                    ', '.join(_get_field_choices()),
+                )
+            )
+        return result
+
     def deferred_to_columns(self):
         """
         Convert the self.deferred_loading data structure to mapping of table

+ 2 - 0
django/db/models/sql/query.py

@@ -161,6 +161,7 @@ class Query:
         self.select_for_update = False
         self.select_for_update_nowait = False
         self.select_for_update_skip_locked = False
+        self.select_for_update_of = ()
 
         self.select_related = False
         # Arbitrary limit for select_related to prevents infinite recursion.
@@ -288,6 +289,7 @@ class Query:
         obj.select_for_update = self.select_for_update
         obj.select_for_update_nowait = self.select_for_update_nowait
         obj.select_for_update_skip_locked = self.select_for_update_skip_locked
+        obj.select_for_update_of = self.select_for_update_of
         obj.select_related = self.select_related
         obj.values_select = self.values_select
         obj._annotations = self._annotations.copy() if self._annotations is not None else None

+ 3 - 3
docs/ref/databases.txt

@@ -629,9 +629,9 @@ both MySQL and Django will attempt to convert the values from UTC to local time.
 Row locking with ``QuerySet.select_for_update()``
 -------------------------------------------------
 
-MySQL does not support the ``NOWAIT`` and ``SKIP LOCKED`` options to the
-``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used with
-``nowait=True`` or ``skip_locked=True``, then a
+MySQL does not support the ``NOWAIT``, ``SKIP LOCKED``, and ``OF`` options to
+the ``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used
+with ``nowait=True``, ``skip_locked=True``, or ``of`` then a
 :exc:`~django.db.NotSupportedError` is raised.
 
 Automatic typecasting can cause unexpected results

+ 17 - 6
docs/ref/models/querysets.txt

@@ -1611,7 +1611,7 @@ For example::
 ``select_for_update()``
 ~~~~~~~~~~~~~~~~~~~~~~~
 
-.. method:: select_for_update(nowait=False, skip_locked=False)
+.. method:: select_for_update(nowait=False, skip_locked=False, of=())
 
 Returns a queryset that will lock rows until the end of the transaction,
 generating a ``SELECT ... FOR UPDATE`` SQL statement on supported databases.
@@ -1635,14 +1635,21 @@ queryset is evaluated. You can also ignore locked rows by using
 ``select_for_update()`` with both options enabled will result in a
 :exc:`ValueError`.
 
+By default, ``select_for_update()`` locks all rows that are selected by the
+query. For example, rows of related objects specified in :meth:`select_related`
+are locked in addition to rows of the queryset's model. If this isn't desired,
+specify the related objects you want to lock in ``select_for_update(of=(...))``
+using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
+to refer to the queryset's model.
+
 Currently, the ``postgresql``, ``oracle``, and ``mysql`` database
 backends support ``select_for_update()``. However, MySQL doesn't support the
-``nowait`` and ``skip_locked`` arguments.
+``nowait``, ``skip_locked``, and ``of`` arguments.
 
-Passing ``nowait=True`` or ``skip_locked=True`` to ``select_for_update()``
-using database backends that do not support these options, such as MySQL,
-raises a :exc:`~django.db.NotSupportedError`. This prevents code from
-unexpectedly blocking.
+Passing ``nowait=True``, ``skip_locked=True``, or ``of`` to
+``select_for_update()`` using database backends that do not support these
+options, such as MySQL, raises a :exc:`~django.db.NotSupportedError`. This
+prevents code from unexpectedly blocking.
 
 Evaluating a queryset with ``select_for_update()`` in autocommit mode on
 backends which support ``SELECT ... FOR UPDATE`` is a
@@ -1670,6 +1677,10 @@ raised if ``select_for_update()`` is used in autocommit mode.
 
     The ``skip_locked`` argument was added.
 
+.. versionchanged:: 2.0
+
+    The ``of`` argument was added.
+
 ``raw()``
 ~~~~~~~~~
 

+ 11 - 0
docs/releases/2.0.txt

@@ -252,6 +252,12 @@ Models
   :class:`~django.db.models.functions.datetime.Extract` now works with
   :class:`~django.db.models.DurationField`.
 
+* Added the ``of`` argument to :meth:`.QuerySet.select_for_update()`, supported
+  on PostgreSQL and Oracle, to lock only rows from specific tables rather than
+  all selected tables. It may be helpful particularly when
+  :meth:`~.QuerySet.select_for_update()` is used in conjunction with
+  :meth:`~.QuerySet.select_related()`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
@@ -331,6 +337,11 @@ backends.
 * The first argument of ``SchemaEditor._create_index_name()`` is now
   ``table_name`` rather than ``model``.
 
+* To enable ``FOR UPDATE OF`` support, set
+  ``DatabaseFeatures.has_select_for_update_of = True``. If the database
+  requires that the arguments to ``OF`` be columns rather than tables, set
+  ``DatabaseFeatures.select_for_update_of_column = True``.
+
 Dropped support for Oracle 11.2
 -------------------------------
 

+ 11 - 0
tests/select_for_update/models.py

@@ -1,5 +1,16 @@
 from django.db import models
 
 
+class Country(models.Model):
+    name = models.CharField(max_length=30)
+
+
+class City(models.Model):
+    name = models.CharField(max_length=30)
+    country = models.ForeignKey(Country, models.CASCADE)
+
+
 class Person(models.Model):
     name = models.CharField(max_length=30)
+    born = models.ForeignKey(City, models.CASCADE, related_name='+')
+    died = models.ForeignKey(City, models.CASCADE, related_name='+')

+ 83 - 3
tests/select_for_update/tests.py

@@ -4,6 +4,7 @@ from unittest import mock
 
 from multiple_database.routers import TestRouter
 
+from django.core.exceptions import FieldError
 from django.db import (
     DatabaseError, NotSupportedError, connection, connections, router,
     transaction,
@@ -14,7 +15,7 @@ from django.test import (
 )
 from django.test.utils import CaptureQueriesContext
 
-from .models import Person
+from .models import City, Country, Person
 
 
 class SelectForUpdateTests(TransactionTestCase):
@@ -24,7 +25,11 @@ class SelectForUpdateTests(TransactionTestCase):
     def setUp(self):
         # This is executed in autocommit mode so that code in
         # run_select_for_update can see this data.
-        self.person = Person.objects.create(name='Reinhardt')
+        self.country1 = Country.objects.create(name='Belgium')
+        self.country2 = Country.objects.create(name='France')
+        self.city1 = City.objects.create(name='Liberchies', country=self.country1)
+        self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
+        self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
 
         # We need another database connection in transaction to test that one
         # connection issuing a SELECT ... FOR UPDATE will block.
@@ -90,6 +95,29 @@ class SelectForUpdateTests(TransactionTestCase):
             list(Person.objects.all().select_for_update(skip_locked=True))
         self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
 
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_generated_of(self):
+        """
+        The backend's FOR UPDATE OF variant appears in the generated SQL when
+        select_for_update() is invoked.
+        """
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(Person.objects.select_related(
+                'born__country',
+            ).select_for_update(
+                of=('born__country',),
+            ).select_for_update(
+                of=('self', 'born__country')
+            ))
+        features = connections['default'].features
+        if features.select_for_update_of_column:
+            expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"']
+        else:
+            expected = ['"select_for_update_person"', '"select_for_update_country"']
+        if features.uppercases_column_names:
+            expected = [value.upper() for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
     @skipUnlessDBFeature('has_select_for_update_nowait')
     def test_nowait_raises_error_on_block(self):
         """
@@ -152,6 +180,58 @@ class SelectForUpdateTests(TransactionTestCase):
             with transaction.atomic():
                 Person.objects.select_for_update(skip_locked=True).get()
 
+    @skipIfDBFeature('has_select_for_update_of')
+    @skipUnlessDBFeature('has_select_for_update')
+    def test_unsupported_of_raises_error(self):
+        """
+        NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
+        a database backend that supports FOR UPDATE but not OF.
+        """
+        msg = 'FOR UPDATE OF is not supported on this database backend.'
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            with transaction.atomic():
+                Person.objects.select_for_update(of=('self',)).get()
+
+    @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+    def test_unrelated_of_argument_raises_error(self):
+        """
+        FieldError is raised if a non-relation field is specified in of=(...).
+        """
+        msg = (
+            'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+            'Only relational fields followed in the query are allowed. '
+            'Choices are: self, born, born__country.'
+        )
+        invalid_of = [
+            ('nonexistent',),
+            ('name',),
+            ('born__nonexistent',),
+            ('born__name',),
+            ('born__nonexistent', 'born__name'),
+        ]
+        for of in invalid_of:
+            with self.subTest(of=of):
+                with self.assertRaisesMessage(FieldError, msg % ', '.join(of)):
+                    with transaction.atomic():
+                        Person.objects.select_related('born__country').select_for_update(of=of).get()
+
+    @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+    def test_related_but_unselected_of_argument_raises_error(self):
+        """
+        FieldError is raised if a relation field that is not followed in the
+        query is specified in of=(...).
+        """
+        msg = (
+            'Invalid field name(s) given in select_for_update(of=(...)): %s. '
+            'Only relational fields followed in the query are allowed. '
+            'Choices are: self, born.'
+        )
+        for name in ['born__country', 'died', 'died__country']:
+            with self.subTest(name=name):
+                with self.assertRaisesMessage(FieldError, msg % name):
+                    with transaction.atomic():
+                        Person.objects.select_related('born').select_for_update(of=(name,)).get()
+
     @skipUnlessDBFeature('has_select_for_update')
     def test_for_update_after_from(self):
         features_class = connections['default'].features.__class__
@@ -182,7 +262,7 @@ class SelectForUpdateTests(TransactionTestCase):
 
     @skipUnlessDBFeature('supports_select_for_update_with_limit')
     def test_select_for_update_with_limit(self):
-        other = Person.objects.create(name='Grappeli')
+        other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2)
         with transaction.atomic():
             qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2])
             self.assertEqual(qs[0], other)