Răsfoiți Sursa

Apply field length normalisation to title in PostgreSQL search

Karl Hobley 4 ani în urmă
părinte
comite
ef8429cc1e

+ 90 - 33
wagtail/contrib/postgres_search/backend.py

@@ -4,9 +4,9 @@ from functools import reduce
 
 from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
 from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
-from django.db.models import Count, F, Manager, Q, TextField, Value
+from django.db.models import Avg, Count, F, Manager, Q, TextField, Value
 from django.db.models.constants import LOOKUP_SEP
-from django.db.models.functions import Cast
+from django.db.models.functions import Cast, Length
 from django.db.models.sql.subqueries import InsertQuery
 from django.utils.encoding import force_str
 from django.utils.functional import cached_property
@@ -97,15 +97,28 @@ class ObjectIndexer:
         """
         return force_str(self.obj.pk)
 
+    @cached_property
+    def title(self):
+        """
+        Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
+        """
+        texts = []
+        for field in self.search_fields:
+            for current_field, boost, value in self.prepare_field(self.obj, field):
+                if isinstance(current_field, SearchField) and current_field.field_name == 'title':
+                    texts.append((value, boost))
+
+        return self.as_vector(texts)
+
     @cached_property
     def body(self):
         """
-        Returns all values to index as "body". This is the value of all SearchFields
+        Returns all values to index as "body". This is the value of all SearchFields excluding the title
         """
         texts = []
         for field in self.search_fields:
             for current_field, boost, value in self.prepare_field(self.obj, field):
-                if isinstance(current_field, SearchField):
+                if isinstance(current_field, SearchField) and not current_field.field_name == 'title':
                     texts.append((value, boost))
 
         return self.as_vector(texts)
@@ -146,6 +159,18 @@ class Index:
     def refresh(self):
         pass
 
+    def _refresh_title_norms(self):
+        """
+        Refreshes the value of the title_norm field.
+
+        This needs to be set to 'lavg/ld' where:
+         - lavg is the average length of titles in all documents (also in terms)
+         - ld is the length of the title field in this document (in terms)
+        """
+
+        lavg = self.entries.annotate(title_length=Length('title')).aggregate(Avg('title_length'))['title_length__avg']
+        self.entries.annotate(title_length=Length('title')).filter(title_length__gt=0).update(title_norm=lavg / F('title_length'))
+
     def delete_stale_model_entries(self, model):
         existing_pks = (
             model._default_manager.using(self.db_alias)
@@ -171,6 +196,7 @@ class Index:
 
     def add_items_upsert(self, content_type_pk, indexers):
         compiler = InsertQuery(IndexEntry).get_compiler(connection=self.connection)
+        title_sql = []
         autocomplete_sql = []
         body_sql = []
         data_params = []
@@ -178,6 +204,12 @@ class Index:
         for indexer in indexers:
             data_params.extend((content_type_pk, indexer.id))
 
+            # Compile title value
+            value = compiler.prepare_value(IndexEntry._meta.get_field('title'), indexer.title)
+            sql, params = value.as_sql(compiler, self.connection)
+            title_sql.append(sql)
+            data_params.extend(params)
+
             # Compile autocomplete value
             value = compiler.prepare_value(IndexEntry._meta.get_field('autocomplete'), indexer.autocomplete)
             sql, params = value.as_sql(compiler, self.connection)
@@ -191,45 +223,51 @@ class Index:
             data_params.extend(params)
 
         data_sql = ', '.join([
-            '(%%s, %%s, %s, %s)' % (a, b)
-            for a, b in zip(autocomplete_sql, body_sql)
+            '(%%s, %%s, %s, %s, %s, 1.0)' % (a, b, c)
+            for a, b, c in zip(title_sql, autocomplete_sql, body_sql)
         ])
 
         with self.connection.cursor() as cursor:
             cursor.execute("""
-                INSERT INTO %s (content_type_id, object_id, autocomplete, body)
+                INSERT INTO %s (content_type_id, object_id, title, autocomplete, body, title_norm)
                 (VALUES %s)
                 ON CONFLICT (content_type_id, object_id)
-                DO UPDATE SET autocomplete = EXCLUDED.autocomplete,
+                DO UPDATE SET title = EXCLUDED.title,
+                              autocomplete = EXCLUDED.autocomplete,
                               body = EXCLUDED.body
                 """ % (IndexEntry._meta.db_table, data_sql), data_params)
 
+        self._refresh_title_norms()
+
     def add_items_update_then_create(self, content_type_pk, indexers):
         ids_and_data = {}
         for indexer in indexers:
-            ids_and_data[indexer.id] = (indexer.autocomplete, indexer.body)
+            ids_and_data[indexer.id] = (indexer.title, indexer.autocomplete, indexer.body)
 
         index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
         indexed_ids = frozenset(
             index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list('object_id', flat=True)
         )
         for indexed_id in indexed_ids:
-            autocomplete, body = ids_and_data[indexed_id]
-            index_entries_for_ct.filter(object_id=indexed_id).update(autocomplete=autocomplete, body=body)
+            title, autocomplete, body = ids_and_data[indexed_id]
+            index_entries_for_ct.filter(object_id=indexed_id).update(title=title, autocomplete=autocomplete, body=body)
 
         to_be_created = []
         for object_id in ids_and_data.keys():
             if object_id not in indexed_ids:
-                autocomplete, body = ids_and_data[object_id]
+                title, autocomplete, body = ids_and_data[object_id]
                 to_be_created.append(IndexEntry(
                     content_type_id=content_type_pk,
                     object_id=object_id,
+                    title=title,
                     autocomplete=autocomplete,
                     body=body
                 ))
 
         self.entries.bulk_create(to_be_created)
 
+        self._refresh_title_norms()
+
     def add_items(self, model, objs):
         search_fields = model.get_search_fields()
         if not search_fields:
@@ -406,25 +444,39 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
             '`%s` is not supported by the PostgreSQL search backend.'
             % query.__class__.__name__)
 
-    def get_index_vector(self, search_query):
-        return F('index_entries__body')
+    def get_index_vectors(self, search_query):
+        return [
+            (F('index_entries__title'), F('index_entries__title_norm')),
+            (F('index_entries__body'), 1.0),
+        ]
 
-    def get_fields_vector(self, search_query):
-        return ADD(
-            SearchVector(
+    def get_fields_vectors(self, search_query):
+        return [
+            (SearchVector(
                 field_lookup,
                 config=search_query.config,
-                weight=get_weight(search_field.boost)
-            )
+            ), search_field.boost)
             for field_lookup, search_field in self.search_fields.items()
-        )
+        ]
 
-    def get_search_vector(self, search_query):
+    def get_search_vectors(self, search_query):
         if self.fields is None:
-            return self.get_index_vector(search_query)
+            return self.get_index_vectors(search_query)
 
         else:
-            return self.get_fields_vector(search_query)
+            return self.get_fields_vectors(search_query)
+
+    def _build_rank_expression(self, vectors, config):
+        rank_expressions = [
+            self.build_tsrank(vector, self.query, config=config) * boost
+            for vector, boost in vectors
+        ]
+
+        rank_expression = rank_expressions[0]
+        for other_rank_expression in rank_expressions[1:]:
+            rank_expression += other_rank_expression
+
+        return rank_expression
 
     def search(self, config, start, stop, score_field=None):
         # TODO: Handle MatchAll nested inside other search query classes.
@@ -435,9 +487,14 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
             return self.queryset.none()
 
         search_query = self.build_tsquery(self.query, config=config)
-        vector = self.get_search_vector(search_query)
-        rank_expression = self.build_tsrank(vector, self.query, config=config)
-        queryset = self.queryset.annotate(_vector_=vector).filter(_vector_=search_query)
+        vectors = self.get_search_vectors(search_query)
+        rank_expression = self._build_rank_expression(vectors, config)
+
+        combined_vector = vectors[0][0]
+        for vector, boost in vectors[1:]:
+            combined_vector = combined_vector._combine(vector, '||', False)
+
+        queryset = self.queryset.annotate(_vector_=combined_vector).filter(_vector_=search_query)
 
         if self.order_by_relevance:
             queryset = queryset.order_by(rank_expression.desc(), '-pk')
@@ -482,18 +539,18 @@ class PostgresAutocompleteQueryCompiler(PostgresSearchQueryCompiler):
     def get_search_fields_for_model(self):
         return self.queryset.model.get_autocomplete_search_fields()
 
-    def get_index_vector(self, search_query):
-        return F('index_entries__autocomplete')
+    def get_index_vectors(self, search_query):
+        return [(F('index_entries__autocomplete'), 1.0)]
 
-    def get_fields_vector(self, search_query):
-        return ADD(
-            SearchVector(
+    def get_fields_vectors(self, search_query):
+        return [
+            (SearchVector(
                 field_lookup,
                 config=search_query.config,
                 weight='D',
-            )
+            ), 1.0)
             for field_lookup, search_field in self.search_fields.items()
-        )
+        ]
 
 
 class PostgresSearchResults(BaseSearchResults):

+ 30 - 0
wagtail/contrib/postgres_search/migrations/0004_title.py

@@ -0,0 +1,30 @@
+# Generated by Django 3.0.6 on 2020-04-24 13:00
+
+import django.contrib.postgres.indexes
+import django.contrib.postgres.search
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('postgres_search', '0002_add_autocomplete'),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name='indexentry',
+            name='title',
+            field=django.contrib.postgres.search.SearchVectorField(default=''),
+            preserve_default=False,
+        ),
+        migrations.AddIndex(
+            model_name='indexentry',
+            index=django.contrib.postgres.indexes.GinIndex(fields=['title'], name='postgres_se_title_b56f33_gin'),
+        ),
+        migrations.AddField(
+            model_name='indexentry',
+            name='title_norm',
+            field=models.FloatField(default=1.0),
+        ),
+    ]

+ 11 - 4
wagtail/contrib/postgres_search/models.py

@@ -3,7 +3,7 @@ from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelatio
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.postgres.indexes import GinIndex
 from django.contrib.postgres.search import SearchVectorField
-from django.db.models import CASCADE, ForeignKey, Model, TextField
+from django.db import models
 from django.db.models.functions import Cast
 from django.utils.translation import gettext_lazy as _
 
@@ -40,14 +40,20 @@ class TextIDGenericRelation(GenericRelation):
         return []
 
 
-class IndexEntry(Model):
-    content_type = ForeignKey(ContentType, on_delete=CASCADE)
+class IndexEntry(models.Model):
+    content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
     # We do not use an IntegerField since primary keys are not always integers.
-    object_id = TextField()
+    object_id = models.TextField()
     content_object = GenericForeignKey()
 
     # TODO: Add per-object boosting.
     autocomplete = SearchVectorField()
+    title = SearchVectorField()
+    # This field stores the "Title Normalisation Factor"
+    # This factor is multiplied onto the the rank of the title field.
+    # This allows us to apply a boost to results with shorter titles
+    # elevating more specific matches to the top.
+    title_norm = models.FloatField(default=1.0)
     body = SearchVectorField()
 
     class Meta:
@@ -55,6 +61,7 @@ class IndexEntry(Model):
         verbose_name = _('index entry')
         verbose_name_plural = _('index entries')
         indexes = [GinIndex(fields=['autocomplete']),
+                   GinIndex(fields=['title']),
                    GinIndex(fields=['body'])]
 
     def __str__(self):