Browse Source

Fixed #31246 -- Fixed locking models in QuerySet.select_for_update(of=()) for related fields and parent link fields with multi-table inheritance.

Partly regression in 0107e3d1058f653f66032f7fd3a0bd61e96bf782.
Abhijeet Viswa 5 years ago
parent
commit
1712a76b9d

+ 1 - 0
AUTHORS

@@ -9,6 +9,7 @@ answer newbie questions, and generally made Django that much better:
     Aaron Swartz <http://www.aaronsw.com/>
     Aaron T. Myers <atmyers@gmail.com>
     Abeer Upadhyay <ab.esquarer@gmail.com>
+    Abhijeet Viswa <abhijeetviswa@gmail.com>
     Abhinav Patil <https://github.com/ubadub/>
     Abhishek Gautam <abhishekg1128@yahoo.com>
     Adam Allred <adam.w.allred@gmail.com>

+ 22 - 15
django/db/models/sql/compiler.py

@@ -961,19 +961,34 @@ class SQLCompiler:
         the query.
         """
         def _get_parent_klass_info(klass_info):
-            return (
-                {
+            for parent_model, parent_link in klass_info['model']._meta.parents.items():
+                parent_list = parent_model._meta.get_parent_list()
+                yield {
                     'model': parent_model,
                     'field': parent_link,
                     'reverse': False,
                     'select_fields': [
                         select_index
                         for select_index in klass_info['select_fields']
-                        if self.select[select_index][0].target.model == parent_model
+                        # Selected columns from a model or its parents.
+                        if (
+                            self.select[select_index][0].target.model == parent_model or
+                            self.select[select_index][0].target.model in parent_list
+                        )
                     ],
                 }
-                for parent_model, parent_link in klass_info['model']._meta.parents.items()
-            )
+
+        def _get_first_selected_col_from_model(klass_info):
+            """
+            Find the first selected column from a model. If it doesn't exist,
+            don't lock a model.
+
+            select_fields is filled recursively, so it also contains fields
+            from the parent models.
+            """
+            for select_index in klass_info['select_fields']:
+                if self.select[select_index][0].target.model == klass_info['model']:
+                    return self.select[select_index][0]
 
         def _get_field_choices():
             """Yield all allowed field paths in breadth-first search order."""
@@ -1002,14 +1017,7 @@ class SQLCompiler:
         for name in self.query.select_for_update_of:
             klass_info = self.klass_info
             if name == 'self':
-                # Find the first selected column from a base model. If it
-                # doesn't exist, don't lock a base model.
-                for select_index in klass_info['select_fields']:
-                    if self.select[select_index][0].target.model == klass_info['model']:
-                        col = self.select[select_index][0]
-                        break
-                else:
-                    col = None
+                col = _get_first_selected_col_from_model(klass_info)
             else:
                 for part in name.split(LOOKUP_SEP):
                     klass_infos = (
@@ -1029,8 +1037,7 @@ class SQLCompiler:
                 if klass_info is None:
                     invalid_names.append(name)
                     continue
-                select_index = klass_info['select_fields'][0]
-                col = self.select[select_index][0]
+                col = _get_first_selected_col_from_model(klass_info)
             if col is not None:
                 if self.connection.features.select_for_update_of_column:
                     result.append(self.compile(col)[0])

+ 5 - 1
docs/releases/2.2.11.txt

@@ -9,4 +9,8 @@ Django 2.2.11 fixes a data loss bug in 2.2.10.
 Bugfixes
 ========
 
-* ...
+* Fixed a data loss possibility in the
+  :meth:`~django.db.models.query.QuerySet.select_for_update`. When using
+  related fields or parent link fields with :ref:`multi-table-inheritance` in
+  the ``of`` argument, the corresponding models were not locked
+  (:ticket:`31246`).

+ 6 - 0
docs/releases/3.0.4.txt

@@ -14,3 +14,9 @@ Bugfixes
 
 * Fixed a regression in Django 3.0 that caused a file response using a
   temporary file to be closed incorrectly (:ticket:`31240`).
+
+* Fixed a data loss possibility in the
+  :meth:`~django.db.models.query.QuerySet.select_for_update`. When using
+  related fields or parent link fields with :ref:`multi-table-inheritance` in
+  the ``of`` argument, the corresponding models were not locked
+  (:ticket:`31246`).

+ 5 - 1
tests/select_for_update/models.py

@@ -1,7 +1,11 @@
 from django.db import models
 
 
-class Country(models.Model):
+class Entity(models.Model):
+    pass
+
+
+class Country(Entity):
     name = models.CharField(max_length=30)
 
 

+ 42 - 6
tests/select_for_update/tests.py

@@ -113,7 +113,10 @@ class SelectForUpdateTests(TransactionTestCase):
             ))
         features = connections['default'].features
         if features.select_for_update_of_column:
-            expected = ['select_for_update_person"."id', 'select_for_update_country"."id']
+            expected = [
+                'select_for_update_person"."id',
+                'select_for_update_country"."entity_ptr_id',
+            ]
         else:
             expected = ['select_for_update_person', 'select_for_update_country']
         expected = [connection.ops.quote_name(value) for value in expected]
@@ -137,13 +140,29 @@ class SelectForUpdateTests(TransactionTestCase):
         if connection.features.select_for_update_of_column:
             expected = [
                 'select_for_update_eucountry"."country_ptr_id',
-                'select_for_update_country"."id',
+                'select_for_update_country"."entity_ptr_id',
             ]
         else:
             expected = ['select_for_update_eucountry', 'select_for_update_country']
         expected = [connection.ops.quote_name(value) for value in expected]
         self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
 
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_related_model_inheritance_generated_of(self):
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(EUCity.objects.select_related('country').select_for_update(
+                of=('self', 'country'),
+            ))
+        if connection.features.select_for_update_of_column:
+            expected = [
+                'select_for_update_eucity"."id',
+                'select_for_update_eucountry"."country_ptr_id',
+            ]
+        else:
+            expected = ['select_for_update_eucity', 'select_for_update_eucountry']
+        expected = [connection.ops.quote_name(value) for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
     @skipUnlessDBFeature('has_select_for_update_of')
     def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
         with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
@@ -153,13 +172,29 @@ class SelectForUpdateTests(TransactionTestCase):
         if connection.features.select_for_update_of_column:
             expected = [
                 'select_for_update_eucity"."id',
-                'select_for_update_country"."id',
+                'select_for_update_country"."entity_ptr_id',
             ]
         else:
             expected = ['select_for_update_eucity', 'select_for_update_country']
         expected = [connection.ops.quote_name(value) for value in expected]
         self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
 
+    @skipUnlessDBFeature('has_select_for_update_of')
+    def test_for_update_sql_multilevel_model_inheritance_ptr_generated_of(self):
+        with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+            list(EUCountry.objects.select_for_update(
+                of=('country_ptr', 'country_ptr__entity_ptr'),
+            ))
+        if connection.features.select_for_update_of_column:
+            expected = [
+                'select_for_update_country"."entity_ptr_id',
+                'select_for_update_entity"."id',
+            ]
+        else:
+            expected = ['select_for_update_country', 'select_for_update_entity']
+        expected = [connection.ops.quote_name(value) for value in expected]
+        self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
     @skipUnlessDBFeature('has_select_for_update_of')
     def test_for_update_of_followed_by_values(self):
         with transaction.atomic():
@@ -264,7 +299,8 @@ class SelectForUpdateTests(TransactionTestCase):
         msg = (
             'Invalid field name(s) given in select_for_update(of=(...)): %s. '
             'Only relational fields followed in the query are allowed. '
-            'Choices are: self, born, born__country.'
+            'Choices are: self, born, born__country, '
+            'born__country__entity_ptr.'
         )
         invalid_of = [
             ('nonexistent',),
@@ -307,13 +343,13 @@ class SelectForUpdateTests(TransactionTestCase):
         )
         with self.assertRaisesMessage(
             FieldError,
-            msg % 'country, country__country_ptr',
+            msg % 'country, country__country_ptr, country__country_ptr__entity_ptr',
         ):
             with transaction.atomic():
                 EUCity.objects.select_related(
                     'country',
                 ).select_for_update(of=('name',)).get()
-        with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
+        with self.assertRaisesMessage(FieldError, msg % 'country_ptr, country_ptr__entity_ptr'):
             with transaction.atomic():
                 EUCountry.objects.select_for_update(of=('name',)).get()