Browse Source

Simplify Page.copy() (#6277)

* Use Django modelcluster's copy_all_child_relations method

* page.specific.__class__ => page.specific_class

* Use child_object_map as returned by modelcluster for revision rewriting

* Use modelcluster to commit child relations

* Use a callback instead of a method for _save_copy_instance

* Make CopyMixin work on non-MTI models

* Make gathering exclude_fields the job of the callee

._copy() no longer depends on any custom attributes in the base class!

* Converted CopyMixin into some utility methods (and renamed some stuff)

* Don't commit the new page in _copy

* Refactor _copy_m2m_relations to be more standalone

* Merge _make_copy into _copy

Not really useful outside _copy

* Give unused variable a name

* Version-bump django-modelcluster to 5.1

* Address review feedback

Co-authored-by: Matt Westcott <matt@west.co.tt>
Karl Hobley 4 years ago
parent
commit
519c0c332d
3 changed files with 83 additions and 128 deletions
  1. 1 1
      setup.py
  2. 0 2
      wagtail/core/migrations/0047_add_workflow_models.py
  3. 82 125
      wagtail/core/models.py

+ 1 - 1
setup.py

@@ -22,7 +22,7 @@ except ImportError:
 
 install_requires = [
     "Django>=2.2,<3.2",
-    "django-modelcluster>=5.0.2,<6.0",
+    "django-modelcluster>=5.1,<6.0",
     "django-taggit>=1.0,<2.0",
     "django-treebeard>=4.2.0,<5.0",
     "djangorestframework>=3.11.1,<4.0",

+ 0 - 2
wagtail/core/migrations/0047_add_workflow_models.py

@@ -4,7 +4,6 @@ from django.conf import settings
 from django.db import migrations, models
 import django.db.models.deletion
 import modelcluster.fields
-import wagtail.core.models
 
 
 class Migration(migrations.Migration):
@@ -45,7 +44,6 @@ class Migration(migrations.Migration):
                 'verbose_name': 'Task state',
                 'verbose_name_plural': 'Task states',
             },
-            bases=(wagtail.core.models.MultiTableCopyMixin, models.Model),
         ),
         migrations.CreateModel(
             name='Workflow',

+ 82 - 125
wagtail/core/models.py

@@ -1,6 +1,5 @@
 import json
 import logging
-from collections import defaultdict
 from io import StringIO
 from urllib.parse import urlparse
 
@@ -51,131 +50,82 @@ logger = logging.getLogger('wagtail.core')
 PAGE_TEMPLATE_VAR = 'page'
 
 
-class MultiTableCopyMixin:
-    default_exclude_fields_in_copy = ['id']
-
-    def _get_field_dictionaries(self, exclude_fields=None):
-        """Get dictionaries representing the model: one with all non m2m fields, and one containing the m2m fields"""
-        specific_self = self.specific
-        exclude_fields = exclude_fields or []
-        specific_dict = {}
-        specific_m2m_dict = {}
-
-        for field in specific_self._meta.get_fields():
-            # Ignore explicitly excluded fields
-            if field.name in exclude_fields:
-                continue
+def _extract_field_data(source, exclude_fields=None):
+    """
+    Get dictionaries representing the model's field data.
 
-            # Ignore reverse relations
-            if field.auto_created:
-                continue
+    This excludes many to many fields (which are handled by _copy_m2m_relations)'
+    """
+    exclude_fields = exclude_fields or []
+    data_dict = {}
 
-            # Copy parental m2m relations
-            # Otherwise add them to the m2m dict to be set after saving
-            if field.many_to_many:
-                if isinstance(field, ParentalManyToManyField):
-                    parental_field = getattr(specific_self, field.name)
-                    if hasattr(parental_field, 'all'):
-                        values = parental_field.all()
-                        if values:
-                            specific_dict[field.name] = values
-                else:
-                    try:
-                        # Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects
-                        through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == specific_self.__class__ or field.related_model in specific_self._meta.parents)]
-                        if through_model_parental_links:
-                            continue
-                    except AttributeError:
-                        pass
-                    specific_m2m_dict[field.name] = getattr(specific_self, field.name).all()
-                continue
+    for field in source._meta.get_fields():
+        # Ignore explicitly excluded fields
+        if field.name in exclude_fields:
+            continue
 
-            # Ignore parent links (page_ptr)
-            if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
-                continue
+        # Ignore reverse relations
+        if field.auto_created:
+            continue
 
-            specific_dict[field.name] = getattr(specific_self, field.name)
+        # Copy parental m2m relations
+        if field.many_to_many:
+            if isinstance(field, ParentalManyToManyField):
+                parental_field = getattr(source, field.name)
+                if hasattr(parental_field, 'all'):
+                    values = parental_field.all()
+                    if values:
+                        data_dict[field.name] = values
+            continue
 
-        return specific_dict, specific_m2m_dict
+        # Ignore parent links (page_ptr)
+        if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
+            continue
 
-    def _get_copy_instance(self, specific_dict, specific_m2m_dict, update_attrs=None):
-        """Create a copy instance (without saving) from dictionaries of the model's fields, and update any attributes in update_attrs"""
+        data_dict[field.name] = getattr(source, field.name)
 
-        if not update_attrs:
-            update_attrs = {}
+    return data_dict
 
-        specific_class = self.specific.__class__
 
-        copy_instance = specific_class(**specific_dict)
+def _copy_m2m_relations(source, target, exclude_fields=None, update_attrs=None):
+    """
+    Copies non-ParentalManyToMany m2m relations
+    """
+    update_attrs = update_attrs or {}
+    exclude_fields = exclude_fields or []
 
-        if update_attrs:
-            for field, value in update_attrs.items():
-                if field in specific_m2m_dict:
+    for field in source._meta.get_fields():
+        # Copy m2m relations. Ignore explicitly excluded fields, reverse relations, and Parental m2m fields.
+        if field.many_to_many and field.name not in exclude_fields and not field.auto_created and not isinstance(field, ParentalManyToManyField):
+            try:
+                # Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects
+                through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == source.__class__ or field.related_model in source._meta.parents)]
+                if through_model_parental_links:
                     continue
-                setattr(copy_instance, field, value)
-
-        return copy_instance
+            except AttributeError:
+                pass
 
-    def _save_copy_instance(self, instance, **kwargs):
-        raise NotImplementedError
+            if field.name in update_attrs:
+                value = update_attrs[field.name]
 
-    def _set_m2m_relations(self, instance, specific_m2m_dict, update_attrs=None):
-        """Set non-ParentalManyToMany m2m relations"""
-        if not update_attrs:
-            update_attrs = {}
-        for field_name, value in specific_m2m_dict.items():
-            value = update_attrs.get(field_name, value)
-            getattr(instance, field_name).set(value)
+            else:
+                value = getattr(source, field.name).all()
 
-        return instance
+            getattr(target, field.name).set(value)
 
-    def _copy_child_objects_to_instance(self, instance, exclude_fields=None, process_child_object=None):
-        """Copy objects linked to the model by a ParentalKey, and set this to the new revision"""
 
-        # A dict that maps child objects to their new ids
-        # Used to remap child object ids in revisions
-        child_object_id_map = defaultdict(dict)
-        exclude_fields = exclude_fields or []
-        specific_self = self.specific
-        for child_relation in get_all_child_relations(specific_self):
-            accessor_name = child_relation.get_accessor_name()
+def _copy(source, exclude_fields=None, update_attrs=None):
+    data_dict = _extract_field_data(source, exclude_fields=exclude_fields)
+    target = source.__class__(**data_dict)
 
-            if accessor_name in exclude_fields:
+    if update_attrs:
+        for field, value in update_attrs.items():
+            if field not in data_dict:
                 continue
+            setattr(target, field, value)
 
-            parental_key_name = child_relation.field.attname
-            child_objects = getattr(specific_self, accessor_name, None)
-
-            if child_objects:
-                for child_object in child_objects.all():
-                    old_pk = child_object.pk
-                    child_object.pk = None
-                    setattr(child_object, parental_key_name, instance.id)
-
-                    if process_child_object is not None:
-                        process_child_object(specific_self, instance, child_relation, child_object)
-
-                    child_object.save()
-
-                    # Add mapping to new primary key (so we can apply this change to revisions)
-                    child_object_id_map[accessor_name][old_pk] = child_object.pk
-
-        return child_object_id_map
-
-    def _copy(self, exclude_fields=None, update_attrs=None, process_child_object=None, **kwargs):
-        exclude_fields = self.default_exclude_fields_in_copy + self.specific.exclude_fields_in_copy + (exclude_fields or [])
-
-        specific_dict, specific_m2m_dict = self._get_field_dictionaries(exclude_fields=exclude_fields)
-
-        copy_instance = self._get_copy_instance(specific_dict, specific_m2m_dict, update_attrs=update_attrs)
-
-        copy_instance = self._save_copy_instance(copy_instance, **kwargs)
-
-        copy_instance = self._set_m2m_relations(copy_instance, specific_m2m_dict, update_attrs)
-
-        child_object_id_map = self._copy_child_objects_to_instance(copy_instance, exclude_fields=exclude_fields, process_child_object=process_child_object)
-
-        return copy_instance, child_object_id_map
+    child_object_map = source.copy_all_child_relations(target, exclude=exclude_fields)
+    return target, child_object_map
 
 
 class SiteManager(models.Manager):
@@ -388,7 +338,7 @@ class AbstractPage(TreebeardPathFixMixin, MP_Node):
         abstract = True
 
 
-class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
+class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
     title = models.CharField(
         verbose_name=_('title'),
         max_length=255,
@@ -1428,7 +1378,7 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
         :param log_action flag for logging the action. Pass None to skip logging.
             Can be passed an action string. Defaults to 'wagtail.copy'
         """
-
+        exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or [])
         specific_self = self.specific
         if keep_live:
             base_update_attrs = {}
@@ -1447,7 +1397,22 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
         if update_attrs:
             base_update_attrs.update(update_attrs)
 
-        page_copy, child_object_id_map = self._copy(exclude_fields=exclude_fields, update_attrs=base_update_attrs, to=to, recursive=recursive, process_child_object=process_child_object)
+        page_copy, child_object_map = _copy(specific_self, exclude_fields=exclude_fields, update_attrs=base_update_attrs)
+
+        # Save copied child objects and run process_child_object on them if we need to
+        for (child_relation, old_pk), child_object in child_object_map.items():
+            if process_child_object:
+                process_child_object(specific_self, page_copy, child_relation, child_object)
+
+        # Save the new page
+        if to:
+            if recursive and (to == self or to.is_descendant_of(self)):
+                raise Exception("You cannot copy a tree branch recursively into itself")
+            page_copy = to.add_child(instance=page_copy)
+        else:
+            page_copy = self.add_sibling(instance=page_copy)
+
+        _copy_m2m_relations(specific_self, page_copy, exclude_fields=exclude_fields, update_attrs=base_update_attrs)
 
         # Copy revisions
         if copy_revisions:
@@ -1476,7 +1441,8 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
                         # Remap primary key to copied versions
                         # If the primary key is not recognised (eg, the child object has been deleted from the database)
                         # set the primary key to None
-                        child_object['pk'] = child_object_id_map[accessor_name].get(child_object['pk'], None)
+                        copied_child_object = child_object_map.get((child_relation, child_object['pk']))
+                        child_object['pk'] = copied_child_object.pk if copied_child_object else None
 
                 revision.content_json = json.dumps(revision_content)
 
@@ -1542,15 +1508,6 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
 
         return page_copy
 
-    def _save_copy_instance(self, instance, to=None, recursive=False, **kwargs):
-        if to:
-            if recursive and (to == self or to.is_descendant_of(self)):
-                raise Exception("You cannot copy a tree branch recursively into itself")
-            instance = to.add_child(instance=instance)
-        else:
-            instance = self.add_sibling(instance=instance)
-        return instance
-
     copy.alters_data = True
 
     def permissions_for_user(self, user):
@@ -3489,7 +3446,7 @@ class TaskStateManager(models.Manager):
         return states
 
 
-class TaskState(MultiTableCopyMixin, models.Model):
+class TaskState(models.Model):
     """Tracks the status of a given Task for a particular page revision."""
     STATUS_IN_PROGRESS = 'in_progress'
     STATUS_APPROVED = 'approved'
@@ -3526,6 +3483,7 @@ class TaskState(MultiTableCopyMixin, models.Model):
         on_delete=models.CASCADE
     )
     exclude_fields_in_copy = []
+    default_exclude_fields_in_copy = ['id']
 
     objects = TaskStateManager()
 
@@ -3630,11 +3588,10 @@ class TaskState(MultiTableCopyMixin, models.Model):
     def copy(self, update_attrs=None, exclude_fields=None):
         """Copy this task state, excluding the attributes in the ``exclude_fields`` list and updating any attributes to values
         specified in the ``update_attrs`` dictionary of ``attribute``: ``new value`` pairs"""
-        copy_instance, _ = self._copy(exclude_fields, update_attrs)
-        return copy_instance
-
-    def _save_copy_instance(self, instance, **kwargs):
+        exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or [])
+        instance, child_object_map = _copy(self.specific, exclude_fields, update_attrs)
         instance.save()
+        _copy_m2m_relations(self, instance, exclude_fields=exclude_fields)
         return instance
 
     def get_comment(self):