Browse Source

Fixed #4102 -- Allow update of specific fields in model.save()

Added the ability to update only part of the model's fields in
model.save() by introducing a new kwarg "update_fields". Thanks
to all the numerous reviewers and commenters in the ticket
Andrei Antoukh 13 năm trước cách đây
mục cha
commit
365853da01

+ 41 - 14
django/db/models/base.py

@@ -11,7 +11,7 @@ from django.core import validators
 from django.db.models.fields import AutoField, FieldDoesNotExist
 from django.db.models.fields.related import (ManyToOneRel,
     OneToOneField, add_lazy_relation)
-from django.db import (connections, router, transaction, DatabaseError,
+from django.db import (router, transaction, DatabaseError,
     DEFAULT_DB_ALIAS)
 from django.db.models.query import Q
 from django.db.models.query_utils import DeferredAttribute
@@ -449,7 +449,8 @@ class Model(object):
             return getattr(self, field_name)
         return getattr(self, field.attname)
 
-    def save(self, force_insert=False, force_update=False, using=None):
+    def save(self, force_insert=False, force_update=False, using=None,
+             update_fields=None):
         """
         Saves the current instance. Override this in a subclass if you want to
         control the saving process.
@@ -458,14 +459,32 @@ class Model(object):
         that the "save" must be an SQL insert or update (or equivalent for
         non-SQL backends), respectively. Normally, they should not be set.
         """
-        if force_insert and force_update:
+        if force_insert and (force_update or update_fields):
             raise ValueError("Cannot force both insert and updating in model saving.")
-        self.save_base(using=using, force_insert=force_insert, force_update=force_update)
 
+        if update_fields is not None:
+            # If update_fields is empty, skip the save. We do also check for
+            # no-op saves later on for inheritance cases. This bailout is
+            # still needed for skipping signal sending.
+            if len(update_fields) == 0:
+                return
+
+            update_fields = frozenset(update_fields)
+            field_names = set([field.name for field in self._meta.fields
+                               if not field.primary_key])
+            non_model_fields = update_fields.difference(field_names)
+
+            if non_model_fields:
+                raise ValueError("The following fields do not exist in this "
+                                 "model or are m2m fields: %s"
+                                 % ', '.join(non_model_fields))
+
+        self.save_base(using=using, force_insert=force_insert,
+                       force_update=force_update, update_fields=update_fields)
     save.alters_data = True
 
     def save_base(self, raw=False, cls=None, origin=None, force_insert=False,
-            force_update=False, using=None):
+                  force_update=False, using=None, update_fields=None):
         """
         Does the heavy-lifting involved in saving. Subclasses shouldn't need to
         override this method. It's separate from save() in order to hide the
@@ -473,7 +492,8 @@ class Model(object):
         ('raw', 'cls', and 'origin').
         """
         using = using or router.db_for_write(self.__class__, instance=self)
-        assert not (force_insert and force_update)
+        assert not (force_insert and (force_update or update_fields))
+        assert update_fields is None or len(update_fields) > 0
         if cls is None:
             cls = self.__class__
             meta = cls._meta
@@ -483,7 +503,8 @@ class Model(object):
             meta = cls._meta
 
         if origin and not meta.auto_created:
-            signals.pre_save.send(sender=origin, instance=self, raw=raw, using=using)
+            signals.pre_save.send(sender=origin, instance=self, raw=raw, using=using,
+                                  update_fields=update_fields)
 
         # If we are in a raw save, save the object exactly as presented.
         # That means that we don't try to be smart about saving attributes
@@ -503,7 +524,8 @@ class Model(object):
                 if field and getattr(self, parent._meta.pk.attname) is None and getattr(self, field.attname) is not None:
                     setattr(self, parent._meta.pk.attname, getattr(self, field.attname))
 
-                self.save_base(cls=parent, origin=org, using=using)
+                self.save_base(cls=parent, origin=org, using=using,
+                               update_fields=update_fields)
 
                 if field:
                     setattr(self, field.attname, self._get_pk_val(parent._meta))
@@ -513,22 +535,27 @@ class Model(object):
         if not meta.proxy:
             non_pks = [f for f in meta.local_fields if not f.primary_key]
 
+            if update_fields:
+                non_pks = [f for f in non_pks if f.name in update_fields]
+
             # First, try an UPDATE. If that doesn't update anything, do an INSERT.
             pk_val = self._get_pk_val(meta)
             pk_set = pk_val is not None
             record_exists = True
             manager = cls._base_manager
             if pk_set:
-                # Determine whether a record with the primary key already exists.
-                if (force_update or (not force_insert and
+                # Determine if we should do an update (pk already exists, forced update,
+                # no force_insert)
+                if ((force_update or update_fields) or (not force_insert and
                         manager.using(using).filter(pk=pk_val).exists())):
-                    # It does already exist, so do an UPDATE.
                     if force_update or non_pks:
                         values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
                         if values:
                             rows = manager.using(using).filter(pk=pk_val)._update(values)
                             if force_update and not rows:
                                 raise DatabaseError("Forced update did not affect any rows.")
+                            if update_fields and not rows:
+                                raise DatabaseError("Save with update_fields did not affect any rows.")
                 else:
                     record_exists = False
             if not pk_set or not record_exists:
@@ -541,7 +568,7 @@ class Model(object):
 
                 fields = meta.local_fields
                 if not pk_set:
-                    if force_update:
+                    if force_update or update_fields:
                         raise ValueError("Cannot force an update in save() with no primary key.")
                     fields = [f for f in fields if not isinstance(f, AutoField)]
 
@@ -561,8 +588,8 @@ class Model(object):
 
         # Signal that the save is complete
         if origin and not meta.auto_created:
-            signals.post_save.send(sender=origin, instance=self,
-                created=(not record_exists), raw=raw, using=using)
+            signals.post_save.send(sender=origin, instance=self, created=(not record_exists),
+                                   update_fields=update_fields, raw=raw, using=using)
 
 
     save_base.alters_data = True

+ 2 - 2
django/db/models/signals.py

@@ -5,8 +5,8 @@ class_prepared = Signal(providing_args=["class"])
 pre_init = Signal(providing_args=["instance", "args", "kwargs"])
 post_init = Signal(providing_args=["instance"])
 
-pre_save = Signal(providing_args=["instance", "raw", "using"])
-post_save = Signal(providing_args=["instance", "raw", "created", "using"])
+pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"])
+post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"])
 
 pre_delete = Signal(providing_args=["instance", "using"])
 post_delete = Signal(providing_args=["instance", "using"])

+ 23 - 1
docs/ref/models/instances.txt

@@ -135,7 +135,7 @@ Saving objects
 
 To save an object back to the database, call ``save()``:
 
-.. method:: Model.save([force_insert=False, force_update=False, using=DEFAULT_DB_ALIAS])
+.. method:: Model.save([force_insert=False, force_update=False, using=DEFAULT_DB_ALIAS, update_fields=None])
 
 .. versionadded:: 1.2
    The ``using`` argument was added.
@@ -289,6 +289,8 @@ almost always do the right thing and trying to override that will lead to
 errors that are difficult to track down. This feature is for advanced use
 only.
 
+Using ``update_fields`` will force an update similarly to ``force_update``.
+
 Updating attributes based on existing fields
 --------------------------------------------
 
@@ -334,6 +336,26 @@ For more details, see the documentation on :ref:`F() expressions
 <query-expressions>` and their :ref:`use in update queries
 <topics-db-queries-update>`.
 
+Specifying which fields to save
+-------------------------------
+
+.. versionadded:: 1.5
+
+If ``save()`` is passed a list of field names in keyword argument
+``update_fields``, only the fields named in that list will be updated.
+This may be desirable if you want to update just one or a few fields on
+an object. There will be a slight performance benefit from preventing
+all of the model fields from being updated in the database. For example:
+
+    product.name = 'Name changed again'
+    product.save(update_fields=['name'])
+
+The ``update_fields`` argument can be any iterable containing strings. An
+empty ``update_fields`` iterable will skip the save. A value of None will
+perform an update on all fields.
+
+Specifying ``update_fields`` will force an update.
+
 Deleting objects
 ================
 

+ 12 - 0
docs/ref/signals.txt

@@ -123,6 +123,12 @@ Arguments sent with this signal:
 ``using``
     The database alias being used.
 
+.. versionadded:: 1.5
+
+``update_fields``
+    The set of fields to update explicitly specified in the ``save()`` method.
+    ``None`` if this argument was not used in the ``save()`` call.
+
 post_save
 ---------
 
@@ -154,6 +160,12 @@ Arguments sent with this signal:
 ``using``
     The database alias being used.
 
+.. versionadded:: 1.5
+
+``update_fields``
+    The set of fields to update explicitly specified in the ``save()`` method.
+    ``None`` if this argument was not used in the ``save()`` call.
+
 pre_delete
 ----------
 

+ 11 - 0
docs/releases/1.5.txt

@@ -33,6 +33,17 @@ version compatible with Python 2.6.
 What's new in Django 1.5
 ========================
 
+Support for saving a subset of model's fields
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The method :meth:`Model.save() <django.db.models.Model.save()>` has a new
+keyword argument ``update_fields``. By using this argument it is possible to
+save only a select list of model's fields. This can be useful for performance
+reasons or when trying to avoid overwriting concurrent changes.
+
+See the :meth:`Model.save() <django.db.models.Model.save()>` documentation for
+more details.
+
 Minor features
 ~~~~~~~~~~~~~~
 

+ 0 - 0
tests/modeltests/update_only_fields/__init__.py


+ 37 - 0
tests/modeltests/update_only_fields/models.py

@@ -0,0 +1,37 @@
+
+from django.db import models
+
+GENDER_CHOICES = (
+    ('M', 'Male'),
+    ('F', 'Female'),
+)
+
+class Account(models.Model):
+    num = models.IntegerField()
+
+
+class Person(models.Model):
+    name = models.CharField(max_length=20)
+    gender = models.CharField(max_length=1, choices=GENDER_CHOICES)
+
+    def __unicode__(self):
+        return self.name
+
+
+class Employee(Person):
+    employee_num = models.IntegerField(default=0)
+    profile = models.ForeignKey('Profile', related_name='profiles', null=True)
+    accounts = models.ManyToManyField('Account', related_name='employees', blank=True, null=True)
+
+
+class Profile(models.Model):
+    name = models.CharField(max_length=200)
+    salary = models.FloatField(default=1000.0)
+
+    def __unicode__(self):
+        return self.name
+
+
+class ProxyEmployee(Employee):
+    class Meta:
+        proxy = True

+ 146 - 0
tests/modeltests/update_only_fields/tests.py

@@ -0,0 +1,146 @@
+from __future__ import absolute_import
+
+from django.test import TestCase
+from django.db.models.signals import pre_save, post_save
+from .models import Person, Employee, ProxyEmployee, Profile, Account
+
+
+class UpdateOnlyFieldsTests(TestCase):
+    def test_update_fields_basic(self):
+        s = Person.objects.create(name='Sara', gender='F')
+        self.assertEqual(s.gender, 'F')
+
+        s.gender = 'M'
+        s.name = 'Ian'
+        s.save(update_fields=['name'])
+
+        s = Person.objects.get(pk=s.pk)
+        self.assertEqual(s.gender, 'F')
+        self.assertEqual(s.name, 'Ian')
+
+    def test_update_fields_m2m(self):
+        profile_boss = Profile.objects.create(name='Boss', salary=3000)
+        e1 = Employee.objects.create(name='Sara', gender='F',
+            employee_num=1, profile=profile_boss)
+
+        a1 = Account.objects.create(num=1)
+        a2 = Account.objects.create(num=2)
+
+        e1.accounts = [a1,a2]
+
+        with self.assertRaises(ValueError):
+            e1.save(update_fields=['accounts'])
+
+    def test_update_fields_inheritance(self):
+        profile_boss = Profile.objects.create(name='Boss', salary=3000)
+        profile_receptionist = Profile.objects.create(name='Receptionist', salary=1000)
+
+        e1 = Employee.objects.create(name='Sara', gender='F',
+            employee_num=1, profile=profile_boss)
+
+        e1.name = 'Ian'
+        e1.gender = 'M'
+        e1.save(update_fields=['name'])
+
+        e2 = Employee.objects.get(pk=e1.pk)
+        self.assertEqual(e2.name, 'Ian')
+        self.assertEqual(e2.gender, 'F')
+        self.assertEqual(e2.profile, profile_boss)
+
+        e2.profile = profile_receptionist
+        e2.name = 'Sara'
+        e2.save(update_fields=['profile'])
+
+        e3 = Employee.objects.get(pk=e1.pk)
+        self.assertEqual(e3.name, 'Ian')
+        self.assertEqual(e3.profile, profile_receptionist)
+
+    def test_update_fields_inheritance_with_proxy_model(self):
+        profile_boss = Profile.objects.create(name='Boss', salary=3000)
+        profile_receptionist = Profile.objects.create(name='Receptionist', salary=1000)
+
+        e1 = ProxyEmployee.objects.create(name='Sara', gender='F',
+            employee_num=1, profile=profile_boss)
+
+        e1.name = 'Ian'
+        e1.gender = 'M'
+        e1.save(update_fields=['name'])
+
+        e2 = ProxyEmployee.objects.get(pk=e1.pk)
+        self.assertEqual(e2.name, 'Ian')
+        self.assertEqual(e2.gender, 'F')
+        self.assertEqual(e2.profile, profile_boss)
+
+        e2.profile = profile_receptionist
+        e2.name = 'Sara'
+        e2.save(update_fields=['profile'])
+
+        e3 = ProxyEmployee.objects.get(pk=e1.pk)
+        self.assertEqual(e3.name, 'Ian')
+        self.assertEqual(e3.profile, profile_receptionist)
+
+    def test_update_fields_signals(self):
+        p = Person.objects.create(name='Sara', gender='F')
+        pre_save_data = []
+        def pre_save_receiver(**kwargs):
+            pre_save_data.append(kwargs['update_fields'])
+        pre_save.connect(pre_save_receiver)
+        post_save_data = []
+        def post_save_receiver(**kwargs):
+            post_save_data.append(kwargs['update_fields'])
+        post_save.connect(post_save_receiver)
+        p.save(update_fields=['name'])
+        self.assertEqual(len(pre_save_data), 1)
+        self.assertEqual(len(pre_save_data[0]), 1)
+        self.assertTrue('name' in pre_save_data[0])
+        self.assertEqual(len(post_save_data), 1)
+        self.assertEqual(len(post_save_data[0]), 1)
+        self.assertTrue('name' in post_save_data[0])
+
+    def test_update_fields_incorrect_params(self):
+        s = Person.objects.create(name='Sara', gender='F')
+
+        with self.assertRaises(ValueError):
+            s.save(update_fields=['first_name'])
+
+        with self.assertRaises(ValueError):
+            s.save(update_fields="name")
+
+    def test_empty_update_fields(self):
+        s = Person.objects.create(name='Sara', gender='F')
+        pre_save_data = []
+        def pre_save_receiver(**kwargs):
+            pre_save_data.append(kwargs['update_fields'])
+        pre_save.connect(pre_save_receiver)
+        post_save_data = []
+        def post_save_receiver(**kwargs):
+            post_save_data.append(kwargs['update_fields'])
+        post_save.connect(post_save_receiver)
+        # Save is skipped.
+        with self.assertNumQueries(0):
+            s.save(update_fields=[])
+        # Signals were skipped, too...
+        self.assertEqual(len(pre_save_data), 0)
+        self.assertEqual(len(post_save_data), 0)
+
+    def test_num_queries_inheritance(self):
+        s = Employee.objects.create(name='Sara', gender='F')
+        s.employee_num = 1
+        s.name = 'Emily'
+        with self.assertNumQueries(1):
+            s.save(update_fields=['employee_num'])
+        s = Employee.objects.get(pk=s.pk)
+        self.assertEqual(s.employee_num, 1)
+        self.assertEqual(s.name, 'Sara')
+        s.employee_num = 2
+        s.name = 'Emily'
+        with self.assertNumQueries(1):
+            s.save(update_fields=['name'])
+        s = Employee.objects.get(pk=s.pk)
+        self.assertEqual(s.name, 'Emily')
+        self.assertEqual(s.employee_num, 1)
+        # A little sanity check that we actually did updates...
+        self.assertEqual(Employee.objects.count(), 1)
+        self.assertEqual(Person.objects.count(), 1)
+        with self.assertNumQueries(2):
+            s.save(update_fields=['name', 'employee_num'])