Browse Source

Fixed #26804 -- Fixed a race condition in QuerySet.update_or_create().

Jensen Cochran 8 years ago
parent
commit
d44afd8892
3 changed files with 62 additions and 11 deletions
  1. 1 0
      AUTHORS
  2. 10 9
      django/db/models/query.py
  3. 51 2
      tests/get_or_create/tests.py

+ 1 - 0
AUTHORS

@@ -347,6 +347,7 @@ answer newbie questions, and generally made Django that much better:
     Jeff Triplett <jeff.triplett@gmail.com>
     Jens Diemer <django@htfx.de>
     Jens Page
+    Jensen Cochran <jensen.cochran@gmail.com>
     Jeong-Min Lee <falsetru@gmail.com>
     Jérémie Blaser <blaserje@gmail.com>
     Jeremy Carbaugh <jcarbaugh@gmail.com>

+ 10 - 9
django/db/models/query.py

@@ -482,15 +482,16 @@ class QuerySet(object):
         defaults = defaults or {}
         lookup, params = self._extract_model_params(defaults, **kwargs)
         self._for_write = True
-        try:
-            obj = self.get(**lookup)
-        except self.model.DoesNotExist:
-            obj, created = self._create_object_from_params(lookup, params)
-            if created:
-                return obj, created
-        for k, v in six.iteritems(defaults):
-            setattr(obj, k, v() if callable(v) else v)
-        obj.save(using=self.db)
+        with transaction.atomic(using=self.db):
+            try:
+                obj = self.select_for_update().get(**lookup)
+            except self.model.DoesNotExist:
+                obj, created = self._create_object_from_params(lookup, params)
+                if created:
+                    return obj, created
+            for k, v in six.iteritems(defaults):
+                setattr(obj, k, v() if callable(v) else v)
+            obj.save(using=self.db)
         return obj, False
 
     def _create_object_from_params(self, lookup, params):

+ 51 - 2
tests/get_or_create/tests.py

@@ -1,10 +1,14 @@
 from __future__ import unicode_literals
 
+import time
 import traceback
-from datetime import date
+from datetime import date, datetime, timedelta
+from threading import Thread
 
 from django.db import DatabaseError, IntegrityError
-from django.test import TestCase, TransactionTestCase, ignore_warnings
+from django.test import (
+    TestCase, TransactionTestCase, ignore_warnings, skipUnlessDBFeature,
+)
 from django.utils.encoding import DjangoUnicodeDecodeError
 
 from .models import (
@@ -422,3 +426,48 @@ class UpdateOrCreateTests(TestCase):
         )
         self.assertIs(created, False)
         self.assertEqual(obj.last_name, 'NotHarrison')
+
+
+class UpdateOrCreateTransactionTests(TransactionTestCase):
+    available_apps = ['get_or_create']
+
+    @skipUnlessDBFeature('has_select_for_update')
+    @skipUnlessDBFeature('supports_transactions')
+    def test_updates_in_transaction(self):
+        """
+        Objects are selected and updated in a transaction to avoid race
+        conditions. This test forces update_or_create() to hold the lock
+        in another thread for a relatively long time so that it can update
+        while it holds the lock. The updated field isn't a field in 'defaults',
+        so update_or_create() shouldn't have an effect on it.
+        """
+        def birthday_sleep():
+            time.sleep(0.3)
+            return date(1940, 10, 10)
+
+        def update_birthday_slowly():
+            Person.objects.update_or_create(
+                first_name='John', defaults={'birthday': birthday_sleep}
+            )
+
+        Person.objects.create(first_name='John', last_name='Lennon', birthday=date(1940, 10, 9))
+
+        # update_or_create in a separate thread
+        t = Thread(target=update_birthday_slowly)
+        before_start = datetime.now()
+        t.start()
+
+        # Wait for lock to begin
+        time.sleep(0.05)
+
+        # Update during lock
+        Person.objects.filter(first_name='John').update(last_name='NotLennon')
+        after_update = datetime.now()
+
+        # Wait for thread to finish
+        t.join()
+
+        # The update remains and it blocked.
+        updated_person = Person.objects.get(first_name='John')
+        self.assertGreater(after_update - before_start, timedelta(seconds=0.3))
+        self.assertEqual(updated_person.last_name, 'NotLennon')