Browse Source

Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.

Thanks Abhijeet Viswa for the report and initial patch.
Mariusz Felisiak 5 years ago
parent
commit
0107e3d105

+ 51 - 18
django/db/models/sql/compiler.py

@@ -953,6 +953,21 @@ class SQLCompiler:
         Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
         the query.
         """
+        def _get_parent_klass_info(klass_info):
+            return (
+                {
+                    'model': parent_model,
+                    'field': parent_link,
+                    'reverse': False,
+                    'select_fields': [
+                        select_index
+                        for select_index in klass_info['select_fields']
+                        if self.select[select_index][0].target.model == parent_model
+                    ],
+                }
+                for parent_model, parent_link in klass_info['model']._meta.parents.items()
+            )
+
         def _get_field_choices():
             """Yield all allowed field paths in breadth-first search order."""
             queue = collections.deque([(None, self.klass_info)])
@@ -967,6 +982,10 @@ class SQLCompiler:
                         field = field.remote_field
                     path = parent_path + [field.name]
                     yield LOOKUP_SEP.join(path)
+                queue.extend(
+                    (path, klass_info)
+                    for klass_info in _get_parent_klass_info(klass_info)
+                )
                 queue.extend(
                     (path, klass_info)
                     for klass_info in klass_info.get('related_klass_infos', [])
@@ -974,28 +993,42 @@ class SQLCompiler:
         result = []
         invalid_names = []
         for name in self.query.select_for_update_of:
-            parts = [] if name == 'self' else name.split(LOOKUP_SEP)
             klass_info = self.klass_info
-            for part in parts:
-                for related_klass_info in klass_info.get('related_klass_infos', []):
-                    field = related_klass_info['field']
-                    if related_klass_info['reverse']:
-                        field = field.remote_field
-                    if field.name == part:
-                        klass_info = related_klass_info
+            if name == 'self':
+                # Find the first selected column from a base model. If it
+                # doesn't exist, don't lock a base model.
+                for select_index in klass_info['select_fields']:
+                    if self.select[select_index][0].target.model == klass_info['model']:
+                        col = self.select[select_index][0]
                         break
                 else:
-                    klass_info = None
-                    break
-            if klass_info is None:
-                invalid_names.append(name)
-                continue
-            select_index = klass_info['select_fields'][0]
-            col = self.select[select_index][0]
-            if self.connection.features.select_for_update_of_column:
-                result.append(self.compile(col)[0])
+                    col = None
             else:
-                result.append(self.quote_name_unless_alias(col.alias))
+                for part in name.split(LOOKUP_SEP):
+                    klass_infos = (
+                        *klass_info.get('related_klass_infos', []),
+                        *_get_parent_klass_info(klass_info),
+                    )
+                    for related_klass_info in klass_infos:
+                        field = related_klass_info['field']
+                        if related_klass_info['reverse']:
+                            field = field.remote_field
+                        if field.name == part:
+                            klass_info = related_klass_info
+                            break
+                    else:
+                        klass_info = None
+                        break
+                if klass_info is None:
+                    invalid_names.append(name)
+                    continue
+                select_index = klass_info['select_fields'][0]
+                col = self.select[select_index][0]
+            if col is not None:
+                if self.connection.features.select_for_update_of_column:
+                    result.append(self.compile(col)[0])
+                else:
+                    result.append(self.quote_name_unless_alias(col.alias))
         if invalid_names:
             raise FieldError(
                 'Invalid field name(s) given in select_for_update(of=(...)): %s. '

+ 8 - 0
docs/ref/models/querysets.txt

@@ -1692,6 +1692,14 @@ specify the related objects you want to lock in ``select_for_update(of=(...))``
 using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
 to refer to the queryset's model.
 
+.. admonition:: Lock parents models in ``select_for_update(of=(...))``
+
+    If you want to lock parents models when using :ref:`multi-table inheritance
+    <multi-table-inheritance>`, you must specify parent link fields (by default
+    ``<parent_model_name>_ptr``) in the ``of`` argument. For example::
+
+        Restaurant.objects.select_for_update(of=('self', 'place_ptr'))
+
 You can't use ``select_for_update()`` on nullable relations::
 
     >>> Person.objects.select_related('hometown').select_for_update()

+ 6 - 0
docs/releases/2.2.8.txt

@@ -17,3 +17,9 @@ Bugfixes
 * Fixed a regression in Django 2.2.1 that caused a crash when migrating
   permissions for proxy models with a multiple database setup if the
   ``default`` entry was empty (:ticket:`31021`).
+
+* Fixed a data loss possibility in the
+  :meth:`~django.db.models.query.QuerySet.select_for_update()`. When using
+  ``'self'`` in the ``of`` argument with :ref:`multi-table inheritance
+  <multi-table-inheritance>`, a parent model was locked instead of the
+  queryset's model (:ticket:`30953`).

+ 9 - 0
tests/select_for_update/models.py

@@ -5,11 +5,20 @@ class Country(models.Model):
     name = models.CharField(max_length=30)
 
 
+class EUCountry(Country):
+    join_date = models.DateField()
+
+
 class City(models.Model):
     name = models.CharField(max_length=30)
     country = models.ForeignKey(Country, models.CASCADE)
 
 
+class EUCity(models.Model):
+    name = models.CharField(max_length=30)
+    country = models.ForeignKey(EUCountry, models.CASCADE)
+
+
 class Person(models.Model):
     name = models.CharField(max_length=30)
     born = models.ForeignKey(City, models.CASCADE, related_name='+')

+ 61 - 1
tests/select_for_update/tests.py

@@ -15,7 +15,7 @@ from django.test import (
 )
 from django.test.utils import CaptureQueriesContext
 
-from .models import City, Country, Person, PersonProfile
+from .models import City, Country, EUCity, EUCountry, Person, PersonProfile
 
 
 class SelectForUpdateTests(TransactionTestCase):
@@ -119,6 +119,47 @@ class SelectForUpdateTests(TransactionTestCase):
         expected = [connection.ops.quote_name(value) for value in expected]
         self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
 
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_model_inheritance_generated_of(self):
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(EUCountry.objects.select_for_update(of=('self',)))
+        if connection.features.select_for_update_of_column:
+            expected = ['select_for_update_eucountry"."country_ptr_id']
+        else:
+            expected = ['select_for_update_eucountry']
+        expected = [connection.ops.quote_name(value) for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_model_inheritance_ptr_generated_of(self):
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',)))
+        if connection.features.select_for_update_of_column:
+            expected = [
+                'select_for_update_eucountry"."country_ptr_id',
+                'select_for_update_country"."id',
+            ]
+        else:
+            expected = ['select_for_update_eucountry', 'select_for_update_country']
+        expected = [connection.ops.quote_name(value) for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(EUCity.objects.select_related('country').select_for_update(
+                of=('self', 'country__country_ptr',),
+            ))
+        if connection.features.select_for_update_of_column:
+            expected = [
+                'select_for_update_eucity"."id',
+                'select_for_update_country"."id',
+            ]
+        else:
+            expected = ['select_for_update_eucity', 'select_for_update_country']
+        expected = [connection.ops.quote_name(value) for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
     @skipUnlessDBFeature('has_select_for_update_of')
     def test_for_update_of_followed_by_values(self):
         with transaction.atomic():
@@ -257,6 +298,25 @@ class SelectForUpdateTests(TransactionTestCase):
                             'born', 'profile',
                         ).exclude(profile=None).select_for_update(of=(name,)).get()
 
+    @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
+    def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
+        msg = (
+            'Invalid field name(s) given in select_for_update(of=(...)): '
+            'name. Only relational fields followed in the query are allowed. '
+            'Choices are: self, %s.'
+        )
+        with self.assertRaisesMessage(
+            FieldError,
+            msg % 'country, country__country_ptr',
+        ):
+            with transaction.atomic():
+                EUCity.objects.select_related(
+                    'country',
+                ).select_for_update(of=('name',)).get()
+        with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
+            with transaction.atomic():
+                EUCountry.objects.select_for_update(of=('name',)).get()
+
     @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
     def test_reverse_one_to_one_of_arguments(self):
         """