Browse Source

Fixed #34280 -- Allowed specifying different field values for create operation in QuerySet.update_or_create().

tschilling 2 years ago
parent
commit
c5808470aa

+ 15 - 7
django/db/models/query.py

@@ -926,25 +926,32 @@ class QuerySet(AltersData):
             **kwargs,
         )
 
-    def update_or_create(self, defaults=None, **kwargs):
+    def update_or_create(self, defaults=None, create_defaults=None, **kwargs):
         """
         Look up an object with the given kwargs, updating one with defaults
-        if it exists, otherwise create a new one.
+        if it exists, otherwise create a new one. Optionally, an object can
+        be created with different values than defaults by using
+        create_defaults.
         Return a tuple (object, created), where created is a boolean
         specifying whether an object was created.
         """
-        defaults = defaults or {}
+        if create_defaults is None:
+            update_defaults = create_defaults = defaults or {}
+        else:
+            update_defaults = defaults or {}
         self._for_write = True
         with transaction.atomic(using=self.db):
             # Lock the row so that a concurrent update is blocked until
             # update_or_create() has performed its save.
-            obj, created = self.select_for_update().get_or_create(defaults, **kwargs)
+            obj, created = self.select_for_update().get_or_create(
+                create_defaults, **kwargs
+            )
             if created:
                 return obj, created
-            for k, v in resolve_callables(defaults):
+            for k, v in resolve_callables(update_defaults):
                 setattr(obj, k, v)
 
-            update_fields = set(defaults)
+            update_fields = set(update_defaults)
             concrete_field_names = self.model._meta._non_pk_concrete_field_names
             # update_fields does not support non-concrete fields.
             if concrete_field_names.issuperset(update_fields):
@@ -964,9 +971,10 @@ class QuerySet(AltersData):
                 obj.save(using=self.db)
         return obj, False
 
-    async def aupdate_or_create(self, defaults=None, **kwargs):
+    async def aupdate_or_create(self, defaults=None, create_defaults=None, **kwargs):
         return await sync_to_async(self.update_or_create)(
             defaults=defaults,
+            create_defaults=create_defaults,
             **kwargs,
         )
 

+ 15 - 5
docs/ref/models/querysets.txt

@@ -2263,14 +2263,18 @@ whenever a request to a page has a side effect on your data. For more, see
 ``update_or_create()``
 ~~~~~~~~~~~~~~~~~~~~~~
 
-.. method:: update_or_create(defaults=None, **kwargs)
-.. method:: aupdate_or_create(defaults=None, **kwargs)
+.. method:: update_or_create(defaults=None, create_defaults=None, **kwargs)
+.. method:: aupdate_or_create(defaults=None, create_defaults=None, **kwargs)
 
 *Asynchronous version*: ``aupdate_or_create()``
 
 A convenience method for updating an object with the given ``kwargs``, creating
-a new one if necessary. The ``defaults`` is a dictionary of (field, value)
-pairs used to update the object. The values in ``defaults`` can be callables.
+a new one if necessary. Both ``create_defaults`` and ``defaults`` are
+dictionaries of (field, value) pairs. The values in both ``create_defaults``
+and ``defaults`` can be callables. ``defaults`` is used to update the object
+while ``create_defaults`` are used for the create operation. If
+``create_defaults`` is not supplied, ``defaults`` will be used for the create
+operation.
 
 Returns a tuple of ``(object, created)``, where ``object`` is the created or
 updated object and ``created`` is a boolean specifying whether a new object was
@@ -2283,6 +2287,7 @@ the given ``kwargs``. If a match is found, it updates the fields passed in the
 This is meant as a shortcut to boilerplatish code. For example::
 
     defaults = {'first_name': 'Bob'}
+    create_defaults = {'first_name': 'Bob', 'birthday': date(1940, 10, 9)}
     try:
         obj = Person.objects.get(first_name='John', last_name='Lennon')
         for key, value in defaults.items():
@@ -2290,7 +2295,7 @@ This is meant as a shortcut to boilerplatish code. For example::
         obj.save()
     except Person.DoesNotExist:
         new_values = {'first_name': 'John', 'last_name': 'Lennon'}
-        new_values.update(defaults)
+        new_values.update(create_defaults)
         obj = Person(**new_values)
         obj.save()
 
@@ -2300,6 +2305,7 @@ The above example can be rewritten using ``update_or_create()`` like so::
     obj, created = Person.objects.update_or_create(
         first_name='John', last_name='Lennon',
         defaults={'first_name': 'Bob'},
+        create_defaults={'first_name': 'Bob', 'birthday': date(1940, 10, 9)},
     )
 
 For a detailed description of how names passed in ``kwargs`` are resolved, see
@@ -2318,6 +2324,10 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
     In older versions, ``update_or_create()`` didn't specify ``update_fields``
     when calling :meth:`Model.save() <django.db.models.Model.save>`.
 
+.. versionchanged:: 5.0
+
+    The ``create_defaults`` argument was added.
+
 ``bulk_create()``
 ~~~~~~~~~~~~~~~~~
 

+ 11 - 1
docs/releases/5.0.txt

@@ -178,7 +178,9 @@ Migrations
 Models
 ~~~~~~
 
-* ...
+* The new ``create_defaults`` argument of :meth:`.QuerySet.update_or_create`
+  and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different
+  field values for the create operation.
 
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
@@ -238,6 +240,14 @@ backends.
 
 * ...
 
+Using ``create_defaults__exact`` may now be required with ``QuerySet.update_or_create()``
+-----------------------------------------------------------------------------------------
+
+:meth:`.QuerySet.update_or_create` now supports the parameter
+``create_defaults``. As a consequence, any models that have a field named
+``create_defaults`` that are used with an ``update_or_create()`` should specify
+the field in the lookup with ``create_defaults__exact``.
+
 Miscellaneous
 -------------
 

+ 7 - 0
tests/async/test_async_queryset.py

@@ -99,10 +99,17 @@ class AsyncQuerySetTest(TestCase):
             id=self.s1.id, defaults={"field": 2}
         )
         self.assertEqual(instance, self.s1)
+        self.assertEqual(instance.field, 2)
         self.assertIs(created, False)
         instance, created = await SimpleModel.objects.aupdate_or_create(field=4)
         self.assertEqual(await SimpleModel.objects.acount(), 4)
         self.assertIs(created, True)
+        instance, created = await SimpleModel.objects.aupdate_or_create(
+            field=5, defaults={"field": 7}, create_defaults={"field": 6}
+        )
+        self.assertEqual(await SimpleModel.objects.acount(), 5)
+        self.assertIs(created, True)
+        self.assertEqual(instance.field, 6)
 
     @skipUnlessDBFeature("has_bulk_insert")
     @async_to_sync

+ 9 - 3
tests/async/test_async_related_managers.py

@@ -44,12 +44,18 @@ class AsyncRelatedManagersOperationTest(TestCase):
         self.assertIs(created, True)
         self.assertEqual(await self.mtm1.simples.acount(), 1)
         self.assertEqual(new_simple.field, 2)
-        new_simple, created = await self.mtm1.simples.aupdate_or_create(
+        new_simple1, created = await self.mtm1.simples.aupdate_or_create(
             id=new_simple.id, defaults={"field": 3}
         )
         self.assertIs(created, False)
-        self.assertEqual(await self.mtm1.simples.acount(), 1)
-        self.assertEqual(new_simple.field, 3)
+        self.assertEqual(new_simple1.field, 3)
+
+        new_simple2, created = await self.mtm1.simples.aupdate_or_create(
+            field=4, defaults={"field": 6}, create_defaults={"field": 5}
+        )
+        self.assertIs(created, True)
+        self.assertEqual(new_simple2.field, 5)
+        self.assertEqual(await self.mtm1.simples.acount(), 2)
 
     async def test_aupdate_or_create_reverse(self):
         new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create()

+ 60 - 0
tests/generic_relations/tests.py

@@ -59,6 +59,19 @@ class GenericRelationsTests(TestCase):
         self.assertTrue(created)
         self.assertEqual(count + 1, self.bacon.tags.count())
 
+    def test_generic_update_or_create_when_created_with_create_defaults(self):
+        count = self.bacon.tags.count()
+        tag, created = self.bacon.tags.update_or_create(
+            # Since, the "stinky" tag doesn't exist create
+            # a "juicy" tag.
+            create_defaults={"tag": "juicy"},
+            defaults={"tag": "uncured"},
+            tag="stinky",
+        )
+        self.assertEqual(tag.tag, "juicy")
+        self.assertIs(created, True)
+        self.assertEqual(count + 1, self.bacon.tags.count())
+
     def test_generic_update_or_create_when_updated(self):
         """
         Should be able to use update_or_create from the generic related manager
@@ -74,6 +87,17 @@ class GenericRelationsTests(TestCase):
         self.assertEqual(count + 1, self.bacon.tags.count())
         self.assertEqual(tag.tag, "juicy")
 
+    def test_generic_update_or_create_when_updated_with_defaults(self):
+        count = self.bacon.tags.count()
+        tag = self.bacon.tags.create(tag="stinky")
+        self.assertEqual(count + 1, self.bacon.tags.count())
+        tag, created = self.bacon.tags.update_or_create(
+            create_defaults={"tag": "uncured"}, defaults={"tag": "juicy"}, id=tag.id
+        )
+        self.assertIs(created, False)
+        self.assertEqual(count + 1, self.bacon.tags.count())
+        self.assertEqual(tag.tag, "juicy")
+
     async def test_generic_async_aupdate_or_create(self):
         tag, created = await self.bacon.tags.aupdate_or_create(
             id=self.fatty.id, defaults={"tag": "orange"}
@@ -86,6 +110,22 @@ class GenericRelationsTests(TestCase):
         self.assertEqual(await self.bacon.tags.acount(), 3)
         self.assertEqual(tag.tag, "pink")
 
+    async def test_generic_async_aupdate_or_create_with_create_defaults(self):
+        tag, created = await self.bacon.tags.aupdate_or_create(
+            id=self.fatty.id,
+            create_defaults={"tag": "pink"},
+            defaults={"tag": "orange"},
+        )
+        self.assertIs(created, False)
+        self.assertEqual(tag.tag, "orange")
+        self.assertEqual(await self.bacon.tags.acount(), 2)
+        tag, created = await self.bacon.tags.aupdate_or_create(
+            tag="pink", create_defaults={"tag": "brown"}
+        )
+        self.assertIs(created, True)
+        self.assertEqual(await self.bacon.tags.acount(), 3)
+        self.assertEqual(tag.tag, "brown")
+
     def test_generic_get_or_create_when_created(self):
         """
         Should be able to use get_or_create from the generic related manager
@@ -550,6 +590,26 @@ class GenericRelationsTests(TestCase):
         self.assertFalse(created)
         self.assertEqual(tag.content_object.id, diamond.id)
 
+    def test_update_or_create_defaults_with_create_defaults(self):
+        # update_or_create() should work with virtual fields (content_object).
+        quartz = Mineral.objects.create(name="Quartz", hardness=7)
+        diamond = Mineral.objects.create(name="Diamond", hardness=7)
+        tag, created = TaggedItem.objects.update_or_create(
+            tag="shiny",
+            create_defaults={"content_object": quartz},
+            defaults={"content_object": diamond},
+        )
+        self.assertIs(created, True)
+        self.assertEqual(tag.content_object.id, quartz.id)
+
+        tag, created = TaggedItem.objects.update_or_create(
+            tag="shiny",
+            create_defaults={"content_object": quartz},
+            defaults={"content_object": diamond},
+        )
+        self.assertIs(created, False)
+        self.assertEqual(tag.content_object.id, diamond.id)
+
     def test_query_content_type(self):
         msg = "Field 'content_object' does not generate an automatic reverse relation"
         with self.assertRaisesMessage(FieldError, msg):

+ 1 - 0
tests/get_or_create/models.py

@@ -6,6 +6,7 @@ class Person(models.Model):
     last_name = models.CharField(max_length=100)
     birthday = models.DateField()
     defaults = models.TextField()
+    create_defaults = models.TextField()
 
 
 class DefaultPerson(models.Model):

+ 94 - 11
tests/get_or_create/tests.py

@@ -330,15 +330,24 @@ class UpdateOrCreateTests(TestCase):
         self.assertEqual(p.birthday, date(1940, 10, 10))
 
     def test_create_twice(self):
-        params = {
-            "first_name": "John",
-            "last_name": "Lennon",
-            "birthday": date(1940, 10, 10),
-        }
-        Person.objects.update_or_create(**params)
-        # If we execute the exact same statement, it won't create a Person.
-        p, created = Person.objects.update_or_create(**params)
-        self.assertFalse(created)
+        p, created = Person.objects.update_or_create(
+            first_name="John",
+            last_name="Lennon",
+            create_defaults={"birthday": date(1940, 10, 10)},
+            defaults={"birthday": date(1950, 2, 2)},
+        )
+        self.assertIs(created, True)
+        self.assertEqual(p.birthday, date(1940, 10, 10))
+        # If we execute the exact same statement, it won't create a Person, but
+        # will update the birthday.
+        p, created = Person.objects.update_or_create(
+            first_name="John",
+            last_name="Lennon",
+            create_defaults={"birthday": date(1940, 10, 10)},
+            defaults={"birthday": date(1950, 2, 2)},
+        )
+        self.assertIs(created, False)
+        self.assertEqual(p.birthday, date(1950, 2, 2))
 
     def test_integrity(self):
         """
@@ -391,8 +400,14 @@ class UpdateOrCreateTests(TestCase):
         """
         p = Publisher.objects.create(name="Acme Publishing")
         book, created = p.books.update_or_create(name="The Book of Ed & Fred")
-        self.assertTrue(created)
+        self.assertIs(created, True)
         self.assertEqual(p.books.count(), 1)
+        book, created = p.books.update_or_create(
+            name="Basics of Django", create_defaults={"name": "Advanced Django"}
+        )
+        self.assertIs(created, True)
+        self.assertEqual(book.name, "Advanced Django")
+        self.assertEqual(p.books.count(), 2)
 
     def test_update_with_related_manager(self):
         """
@@ -406,6 +421,14 @@ class UpdateOrCreateTests(TestCase):
         book, created = p.books.update_or_create(defaults={"name": name}, id=book.id)
         self.assertFalse(created)
         self.assertEqual(book.name, name)
+        # create_defaults should be ignored.
+        book, created = p.books.update_or_create(
+            create_defaults={"name": "Basics of Django"},
+            defaults={"name": name},
+            id=book.id,
+        )
+        self.assertIs(created, False)
+        self.assertEqual(book.name, name)
         self.assertEqual(p.books.count(), 1)
 
     def test_create_with_many(self):
@@ -418,8 +441,16 @@ class UpdateOrCreateTests(TestCase):
         book, created = author.books.update_or_create(
             name="The Book of Ed & Fred", publisher=p
         )
-        self.assertTrue(created)
+        self.assertIs(created, True)
         self.assertEqual(author.books.count(), 1)
+        book, created = author.books.update_or_create(
+            name="Basics of Django",
+            publisher=p,
+            create_defaults={"name": "Advanced Django"},
+        )
+        self.assertIs(created, True)
+        self.assertEqual(book.name, "Advanced Django")
+        self.assertEqual(author.books.count(), 2)
 
     def test_update_with_many(self):
         """
@@ -437,6 +468,14 @@ class UpdateOrCreateTests(TestCase):
         )
         self.assertFalse(created)
         self.assertEqual(book.name, name)
+        # create_defaults should be ignored.
+        book, created = author.books.update_or_create(
+            create_defaults={"name": "Basics of Django"},
+            defaults={"name": name},
+            id=book.id,
+        )
+        self.assertIs(created, False)
+        self.assertEqual(book.name, name)
         self.assertEqual(author.books.count(), 1)
 
     def test_defaults_exact(self):
@@ -467,6 +506,34 @@ class UpdateOrCreateTests(TestCase):
         self.assertFalse(created)
         self.assertEqual(obj.defaults, "another testing")
 
+    def test_create_defaults_exact(self):
+        """
+        If you have a field named create_defaults and want to use it as an
+        exact lookup, you need to use 'create_defaults__exact'.
+        """
+        obj, created = Person.objects.update_or_create(
+            first_name="George",
+            last_name="Harrison",
+            create_defaults__exact="testing",
+            create_defaults={
+                "birthday": date(1943, 2, 25),
+                "create_defaults": "testing",
+            },
+        )
+        self.assertIs(created, True)
+        self.assertEqual(obj.create_defaults, "testing")
+        obj, created = Person.objects.update_or_create(
+            first_name="George",
+            last_name="Harrison",
+            create_defaults__exact="testing",
+            create_defaults={
+                "birthday": date(1943, 2, 25),
+                "create_defaults": "another testing",
+            },
+        )
+        self.assertIs(created, False)
+        self.assertEqual(obj.create_defaults, "testing")
+
     def test_create_callable_default(self):
         obj, created = Person.objects.update_or_create(
             first_name="George",
@@ -476,6 +543,16 @@ class UpdateOrCreateTests(TestCase):
         self.assertIs(created, True)
         self.assertEqual(obj.birthday, date(1943, 2, 25))
 
+    def test_create_callable_create_defaults(self):
+        obj, created = Person.objects.update_or_create(
+            first_name="George",
+            last_name="Harrison",
+            defaults={},
+            create_defaults={"birthday": lambda: date(1943, 2, 25)},
+        )
+        self.assertIs(created, True)
+        self.assertEqual(obj.birthday, date(1943, 2, 25))
+
     def test_update_callable_default(self):
         Person.objects.update_or_create(
             first_name="George",
@@ -694,6 +771,12 @@ class InvalidCreateArgumentsTests(TransactionTestCase):
         with self.assertRaisesMessage(FieldError, self.msg):
             Thing.objects.update_or_create(name="a", defaults={"nonexistent": "b"})
 
+    def test_update_or_create_with_invalid_create_defaults(self):
+        with self.assertRaisesMessage(FieldError, self.msg):
+            Thing.objects.update_or_create(
+                name="a", create_defaults={"nonexistent": "b"}
+            )
+
     def test_update_or_create_with_invalid_kwargs(self):
         with self.assertRaisesMessage(FieldError, self.bad_field_msg):
             Thing.objects.update_or_create(name="a", nonexistent="b")