Browse Source

Automatically update reference index when objects are saved/deleted

Karl Hobley 2 years ago
parent
commit
212b2e19a6

+ 1 - 11
wagtail/management/commands/rebuild_references_index.py

@@ -1,7 +1,6 @@
 from django.apps import apps
 from django.core.management.base import BaseCommand
 from django.db import transaction
-from modelcluster.fields import ParentalKey
 
 from wagtail.models import ReferenceIndex
 
@@ -29,16 +28,7 @@ class Command(BaseCommand):
             ReferenceIndex.objects.all().delete()
 
             for model in apps.get_models():
-                if not ReferenceIndex._model_could_have_outbound_references(model):
-                    continue
-
-                # Don't check any models that have a parental key, references from these will be collected from the parent
-                if any(
-                    [
-                        isinstance(field, ParentalKey)
-                        for field in model._meta.get_fields()
-                    ]
-                ):
+                if not ReferenceIndex.model_is_indexible(model):
                     continue
 
                 self.stdout.write(str(model))

+ 27 - 3
wagtail/models/__init__.py

@@ -4767,10 +4767,19 @@ class ReferenceIndex(models.Model):
             )
 
     @classmethod
-    def _model_could_have_outbound_references(cls, model):
+    def model_is_indexible(cls, model, allow_child_models=False):
+        """
+        Returns True if the given model may have outbound references that we would be interested in recording in the index.
+        """
         if getattr(model, "wagtail_reference_index_ignore", False):
             return False
 
+        # Don't check any models that have a parental key, references from these will be collected from the parent
+        if not allow_child_models and any(
+            [isinstance(field, ParentalKey) for field in model._meta.get_fields()]
+        ):
+            return False
+
         for field in model._meta.get_fields():
             if field.is_relation and field.many_to_one:
                 if getattr(field, "wagtail_reference_index_ignore", False):
@@ -4791,8 +4800,9 @@ class ReferenceIndex(models.Model):
 
         if issubclass(model, ClusterableModel):
             for child_relation in get_all_child_relations(model):
-                if cls._model_could_have_outbound_references(
-                    child_relation.related_model
+                if cls.model_is_indexible(
+                    child_relation.related_model,
+                    allow_child_models=True,
                 ):
                     return True
 
@@ -4913,6 +4923,20 @@ class ReferenceIndex(models.Model):
             id__in=[existing_references[reference] for reference in deleted_references]
         ).delete()
 
+    @classmethod
+    def remove_for_object(cls, object):
+        base_content_type = cls._get_base_content_type(object)
+        cls.objects.filter(
+            base_content_type=base_content_type, object_id=object.pk
+        ).delete()
+
+    @classmethod
+    def get_references_for_object(cls, object):
+        return cls.objects.filter(
+            base_content_type_id=cls._get_base_content_type(object),
+            object_id=object.pk,
+        )
+
     @classmethod
     def get_references_to(cls, object):
         return cls.objects.filter(

+ 84 - 2
wagtail/signal_handlers.py

@@ -1,10 +1,20 @@
 import logging
+from contextlib import contextmanager
 
+from asgiref.local import Local
 from django.core.cache import cache
-from django.db.models.signals import post_delete, post_save, pre_delete
+from django.db import transaction
+from django.db.models.signals import (
+    post_delete,
+    post_migrate,
+    post_save,
+    pre_delete,
+    pre_migrate,
+)
+from modelcluster.fields import ParentalKey
 
 from wagtail.coreutils import get_locales_display_names
-from wagtail.models import Locale, Page, Site
+from wagtail.models import Locale, Page, ReferenceIndex, Site
 
 logger = logging.getLogger("wagtail")
 
@@ -33,6 +43,70 @@ def reset_locales_display_names_cache(sender, instance, **kwargs):
     get_locales_display_names.cache_clear()
 
 
+reference_index_auto_update_disabled = Local()
+
+
+@contextmanager
+def disable_reference_index_auto_update():
+    """
+    A context manager that can be used to temporarily disable the reference index auto-update signal handlers.
+
+    For example:
+
+    with disable_reference_index_auto_update():
+        my_instance.save()  # Reference index will not be updated by this save
+    """
+    try:
+        reference_index_auto_update_disabled.value = True
+        yield
+    finally:
+        del reference_index_auto_update_disabled.value
+
+
+def update_reference_index_on_save(instance, **kwargs):
+    # Don't populate reference index while loading fixtures as referenced objects may not be populated yet
+    if kwargs.get("raw", False):
+        return
+
+    if getattr(reference_index_auto_update_disabled, "value", False):
+        return
+
+    # If the model is a child model, find the parent instance and index that instead
+    while True:
+        parental_keys = list(
+            filter(
+                lambda field: isinstance(field, ParentalKey),
+                instance._meta.get_fields(),
+            )
+        )
+        if not parental_keys:
+            break
+
+        instance = getattr(instance, parental_keys[0].name)
+
+    if ReferenceIndex.model_is_indexible(type(instance)):
+        with transaction.atomic():
+            ReferenceIndex.create_or_update_for_object(instance)
+
+
+def remove_reference_index_on_delete(instance, **kwargs):
+    if getattr(reference_index_auto_update_disabled, "value", False):
+        return
+
+    with transaction.atomic():
+        ReferenceIndex.remove_for_object(instance)
+
+
+def connect_reference_index_signal_handlers(**kwargs):
+    post_save.connect(update_reference_index_on_save)
+    post_delete.connect(remove_reference_index_on_delete)
+
+
+def disconnect_reference_index_signal_handlers(**kwargs):
+    post_save.disconnect(update_reference_index_on_save)
+    post_delete.disconnect(remove_reference_index_on_delete)
+
+
 def register_signal_handlers():
     post_save.connect(post_save_site_signal_handler, sender=Site)
     post_delete.connect(post_delete_site_signal_handler, sender=Site)
@@ -42,3 +116,11 @@ def register_signal_handlers():
 
     post_save.connect(reset_locales_display_names_cache, sender=Locale)
     post_delete.connect(reset_locales_display_names_cache, sender=Locale)
+
+    # Reference index signal handlers
+    connect_reference_index_signal_handlers()
+
+    # Disconnect reference index signals while migrations are running
+    # (we don't want to log references in migrations as the ReferenceIndex model might not exist)
+    pre_migrate.connect(disconnect_reference_index_signal_handlers)
+    post_migrate.connect(connect_reference_index_signal_handlers)

+ 7 - 14
wagtail/tests/test_reference_index.py

@@ -50,11 +50,9 @@ class TestCreateOrUpdateForObject(TestCase):
         self.root_page.add_child(instance=self.event_page)
 
     def test(self):
-        ReferenceIndex.create_or_update_for_object(self.event_page)
-
         self.assertSetEqual(
             set(
-                ReferenceIndex.objects.values_list(
+                ReferenceIndex.get_references_for_object(self.event_page).values_list(
                     "to_content_type", "to_object_id", "model_path", "content_path"
                 )
             ),
@@ -87,25 +85,20 @@ class TestCreateOrUpdateForObject(TestCase):
         )
 
     def test_update(self):
-        reference_to_keep = ReferenceIndex.objects.create(
+        reference_to_keep = ReferenceIndex.objects.get(
             base_content_type=ReferenceIndex._get_base_content_type(self.event_page),
             content_type=ContentType.objects.get_for_model(self.event_page),
-            object_id=self.event_page.pk,
-            to_content_type=self.image_content_type,
-            to_object_id=self.test_feed_image.pk,
-            model_path="feed_image",
             content_path="feed_image",
-            content_path_hash=ReferenceIndex._get_content_path_hash("feed_image"),
         )
         reference_to_remove = ReferenceIndex.objects.create(
             base_content_type=ReferenceIndex._get_base_content_type(self.event_page),
             content_type=ContentType.objects.get_for_model(self.event_page),
             object_id=self.event_page.pk,
             to_content_type=self.image_content_type,
-            to_object_id=self.test_image_1.pk,  # Image ID is not used in this field
-            model_path="feed_image",
-            content_path="feed_image",
-            content_path_hash=ReferenceIndex._get_content_path_hash("feed_image"),
+            to_object_id=self.test_image_1.pk,
+            model_path="hero_image",  # Field doesn't exist
+            content_path="hero_image",
+            content_path_hash=ReferenceIndex._get_content_path_hash("hero_image"),
         )
 
         ReferenceIndex.create_or_update_for_object(self.event_page)
@@ -121,7 +114,7 @@ class TestCreateOrUpdateForObject(TestCase):
         # Check that the current stored references are correct
         self.assertSetEqual(
             set(
-                ReferenceIndex.objects.values_list(
+                ReferenceIndex.get_references_for_object(self.event_page).values_list(
                     "to_content_type", "to_object_id", "model_path", "content_path"
                 )
             ),

+ 4 - 2
wagtail/tests/test_streamfield.py

@@ -14,6 +14,7 @@ from wagtail.fields import StreamField
 from wagtail.images.models import Image
 from wagtail.images.tests.utils import get_test_image_file
 from wagtail.rich_text import RichText
+from wagtail.signal_handlers import disable_reference_index_auto_update
 from wagtail.test.testapp.models import (
     BlockCountsStreamModel,
     JSONBlockCountsStreamModel,
@@ -177,8 +178,9 @@ class TestLazyStreamField(TestCase):
 
         # Expect a single UPDATE to update the model, without any additional
         # SELECT related to the image block that has not been accessed.
-        with self.assertNumQueries(1):
-            instance.save()
+        with disable_reference_index_auto_update():
+            with self.assertNumQueries(1):
+                instance.save()
 
 
 class TestJSONLazyStreamField(TestLazyStreamField):