Quellcode durchsuchen

Extract references from rich text

Karl Hobley vor 2 Jahren
Ursprung
Commit
a9db3e966b

+ 5 - 0
wagtail/blocks/field_block.py

@@ -16,6 +16,7 @@ from wagtail.coreutils import camelcase_to_underscore, resolve_model_string
 from wagtail.rich_text import (
     RichText,
     RichTextMaxLengthValidator,
+    extract_references_from_rich_text,
     get_text_for_indexing,
 )
 from wagtail.telepath import Adapter, register
@@ -701,6 +702,10 @@ class RichTextBlock(FieldBlock):
         source = force_str(value.source)
         return [get_text_for_indexing(source)]
 
+    def extract_references(self, value):
+        # Extracts any references to images/pages/embeds
+        yield from extract_references_from_rich_text(force_str(value.source))
+
     class Meta:
         icon = "doc-full"
 

+ 8 - 1
wagtail/fields.py

@@ -8,7 +8,11 @@ from django.db.models.fields.json import KeyTransform
 from django.utils.encoding import force_str
 
 from wagtail.blocks import Block, BlockField, StreamBlock, StreamValue
-from wagtail.rich_text import RichTextMaxLengthValidator, get_text_for_indexing
+from wagtail.rich_text import (
+    RichTextMaxLengthValidator,
+    extract_references_from_rich_text,
+    get_text_for_indexing,
+)
 from wagtail.utils.deprecation import RemovedInWagtail50Warning
 
 
@@ -52,6 +56,9 @@ class RichTextField(models.TextField):
         source = force_str(value)
         return [get_text_for_indexing(source)]
 
+    def extract_references(self, value):
+        yield from extract_references_from_rich_text(force_str(value))
+
 
 # https://github.com/django/django/blob/64200c14e0072ba0ffef86da46b2ea82fd1e019a/django/db/models/fields/subclassing.py#L31-L44
 class Creator:

+ 4 - 0
wagtail/images/rich_text/__init__.py

@@ -27,3 +27,7 @@ class ImageEmbedHandler(EmbedHandler):
 
         image_format = get_image_format(attrs["format"])
         return image_format.image_to_html(image, attrs.get("alt", ""))
+
+    @classmethod
+    def extract_references(cls, attrs):
+        yield cls.get_model(), attrs["id"], "", ""

+ 13 - 0
wagtail/images/tests/test_rich_text.py

@@ -1,6 +1,7 @@
 from bs4 import BeautifulSoup
 from django.test import TestCase
 
+from wagtail.fields import RichTextField
 from wagtail.images.rich_text import ImageEmbedHandler as FrontendImageEmbedHandler
 from wagtail.images.rich_text.editor_html import (
     ImageEmbedHandler as EditorHtmlImageEmbedHandler,
@@ -129,3 +130,15 @@ class TestFrontendImageEmbedHandler(TestCase, WagtailTestUtils):
         self.assertTagInHTML(
             '<img class="richtext-image left" alt="" />', result, allow_extra_attrs=True
         )
+
+
+class TestExtractReferencesWithImage(TestCase, WagtailTestUtils):
+    def test_extract_references(self):
+        self.assertEqual(
+            list(
+                RichTextField().extract_references(
+                    '<embed alt="Olivia Ava" embedtype="image" format="left" id="52"/>'
+                )
+            ),
+            [(Image, "52", "", "")],
+        )

+ 2 - 2
wagtail/models/__init__.py

@@ -75,7 +75,7 @@ from wagtail.coreutils import (
     get_supported_content_language_variant,
     resolve_model_string,
 )
-from wagtail.fields import StreamField
+from wagtail.fields import RichTextField, StreamField
 from wagtail.forms import TaskStateCommentForm
 from wagtail.locks import BasicLock, ScheduledForPublishLock, WorkflowLock
 from wagtail.log_actions import log
@@ -4757,7 +4757,7 @@ class ReferenceIndex(models.Model):
                         value
                     ), field.name, field.name
 
-            if isinstance(field, StreamField):
+            if isinstance(field, (StreamField, RichTextField)):
                 value = field.value_from_object(object)
                 if value is not None:
                     yield from (

+ 40 - 23
wagtail/rich_text/__init__.py

@@ -1,4 +1,5 @@
 import re
+from functools import lru_cache
 from html import unescape
 
 from django.core.validators import MaxLengthValidator
@@ -17,36 +18,48 @@ features = FeatureRegistry()
 # from wagtail.rich_text.rewriters along with the embed handlers / link handlers registered
 # with the feature registry
 
-FRONTEND_REWRITER = None
+
+@lru_cache(maxsize=1)
+def get_rewriter():
+    embed_rules = features.get_embed_types()
+    link_rules = features.get_link_types()
+    return MultiRuleRewriter(
+        [
+            LinkRewriter(
+                {
+                    linktype: handler.expand_db_attributes
+                    for linktype, handler in link_rules.items()
+                },
+                {
+                    linktype: handler.extract_references
+                    for linktype, handler in link_rules.items()
+                },
+            ),
+            EmbedRewriter(
+                {
+                    embedtype: handler.expand_db_attributes
+                    for embedtype, handler in embed_rules.items()
+                },
+                {
+                    linktype: handler.extract_references
+                    for linktype, handler in embed_rules.items()
+                },
+            ),
+        ]
+    )
 
 
 def expand_db_html(html):
     """
     Expand database-representation HTML into proper HTML usable on front-end templates
     """
-    global FRONTEND_REWRITER
-
-    if FRONTEND_REWRITER is None:
-        embed_rules = features.get_embed_types()
-        link_rules = features.get_link_types()
-        FRONTEND_REWRITER = MultiRuleRewriter(
-            [
-                LinkRewriter(
-                    {
-                        linktype: handler.expand_db_attributes
-                        for linktype, handler in link_rules.items()
-                    }
-                ),
-                EmbedRewriter(
-                    {
-                        embedtype: handler.expand_db_attributes
-                        for embedtype, handler in embed_rules.items()
-                    }
-                ),
-            ]
-        )
+    rewriter = get_rewriter()
+    return rewriter(html)
 
-    return FRONTEND_REWRITER(html)
+
+def extract_references_from_rich_text(html):
+    rewriter = get_rewriter()
+    yield from rewriter.extract_references(html)
 
 
 def get_text_for_indexing(richtext):
@@ -120,6 +133,10 @@ class EntityHandler:
         """
         raise NotImplementedError
 
+    @classmethod
+    def extract_references(cls, attrs):
+        return []
+
 
 class LinkHandler(EntityHandler):
     pass

+ 4 - 0
wagtail/rich_text/pages.py

@@ -22,3 +22,7 @@ class PageLinkHandler(LinkHandler):
             return '<a href="%s">' % escape(page.localized.specific.url)
         except Page.DoesNotExist:
             return "<a>"
+
+    @classmethod
+    def extract_references(self, attrs):
+        yield Page, attrs["id"], "", ""

+ 32 - 2
wagtail/rich_text/rewriters.py

@@ -32,8 +32,9 @@ class EmbedRewriter:
     returns the HTML fragment.
     """
 
-    def __init__(self, embed_rules):
+    def __init__(self, embed_rules, reference_extractors=None):
         self.embed_rules = embed_rules
+        self.reference_extractors = reference_extractors or {}
 
     def replace_tag(self, match):
         attrs = extract_attrs(match.group(1))
@@ -47,6 +48,17 @@ class EmbedRewriter:
     def __call__(self, html):
         return FIND_EMBED_TAG.sub(self.replace_tag, html)
 
+    def extract_references(self, html):
+        for match in FIND_EMBED_TAG.findall(html):
+            attrs = extract_attrs(match)
+            if (
+                "embedtype" not in attrs
+                or attrs["embedtype"] not in self.reference_extractors
+            ):
+                continue
+
+            yield from self.reference_extractors[attrs["embedtype"]](attrs)
+
 
 class LinkRewriter:
     """
@@ -55,8 +67,9 @@ class LinkRewriter:
     returns the HTML fragment for the opening tag (only).
     """
 
-    def __init__(self, link_rules):
+    def __init__(self, link_rules, reference_extractors=None):
         self.link_rules = link_rules
+        self.reference_extractors = reference_extractors or {}
 
     def replace_tag(self, match):
         attrs = extract_attrs(match.group(1))
@@ -95,6 +108,19 @@ class LinkRewriter:
     def __call__(self, html):
         return FIND_A_TAG.sub(self.replace_tag, html)
 
+    def extract_references(self, html):
+        for match in FIND_A_TAG.findall(html):
+            attrs = extract_attrs(match)
+            if (
+                "linktype" not in attrs
+                or attrs["linktype"] not in self.reference_extractors
+            ):
+                continue
+
+            yield from self.reference_extractors[attrs["linktype"]](attrs)
+
+        return []
+
 
 class MultiRuleRewriter:
     """Rewrites HTML by applying a sequence of rewriter functions"""
@@ -106,3 +132,7 @@ class MultiRuleRewriter:
         for rewrite in self.rewriters:
             html = rewrite(html)
         return html
+
+    def extract_references(self, html):
+        for rewriter in self.rewriters:
+            yield from rewriter.extract_references(html)

+ 6 - 0
wagtail/tests/test_blocks.py

@@ -664,6 +664,12 @@ class TestRichTextBlock(TestCase):
         result = block.get_searchable_content(value)
         self.assertEqual(result, ["mashed potatoes"])
 
+    def test_extract_references(self):
+        block = blocks.RichTextBlock()
+        value = RichText('<a linktype="page" id="1">Link to an internal page</a>')
+
+        self.assertEqual(list(block.extract_references(value)), [(Page, "1", "", "")])
+
 
 class TestChoiceBlock(WagtailTestUtils, SimpleTestCase):
     def setUp(self):

+ 11 - 0
wagtail/tests/test_rich_text.py

@@ -4,6 +4,7 @@ from django.forms.models import modelform_factory
 from django.test import TestCase, override_settings
 from django.utils import translation
 
+from wagtail.fields import RichTextField
 from wagtail.models import Locale, Page
 from wagtail.rich_text import RichText, RichTextMaxLengthValidator, expand_db_html
 from wagtail.rich_text.feature_registry import FeatureRegistry
@@ -276,6 +277,16 @@ class TestRichTextField(TestCase):
         )
         self.assertTrue(form.is_valid())
 
+    def test_extract_references(self):
+        self.assertEqual(
+            list(
+                RichTextField().extract_references(
+                    '<a linktype="page" id="1">Link to an internal page</a>'
+                )
+            ),
+            [(Page, "1", "", "")],
+        )
+
 
 class TestRichTextMaxLengthValidator(TestCase):
     def test_count_characters(self):