ソースを参照

Use correct connection when searching (#12508)

- Use the indexes preferred DB, not the indexing model
- Use write connections for writes, and reads for reads
Jake Howard 5 ヶ月 前
コミット
6b44838841

+ 1 - 0
CHANGELOG.txt

@@ -20,6 +20,7 @@ Changelog
  * Fix: Ensure form builder correctly checks for duplicate field names when using a custom related name (John-Scott Atlakson, LB (Ben) Johnston)
  * Fix: Normalize `StreamField.get_default()` to prevent creation forms from breaking (Matt Westcott)
  * Fix: Prevent out-of-order migrations from skipping creation of image/document choose permissions (Matt Westcott)
+ * Fix: Use correct connections on multi-database setups in database search backends (Jake Howard)
  * Docs: Move the model reference page from reference/pages to the references section as it covers all Wagtail core models (Srishti Jaiswal)
  * Docs: Move the panels reference page from references/pages to the references section as panels are available for any model editing, merge panels API into this page (Srishti Jaiswal)
  * Docs: Move the tags documentation to standalone advanced topic, instead of being inside the reference/pages section (Srishti Jaiswal)

+ 1 - 0
docs/releases/6.4.md

@@ -33,6 +33,7 @@ depth: 1
  * Ensure form builder correctly checks for duplicate field names when using a custom related name (John-Scott Atlakson, LB (Ben) Johnston)
  * Normalize `StreamField.get_default()` to prevent creation forms from breaking (Matt Westcott)
  * Prevent out-of-order migrations from skipping creation of image/document choose permissions (Matt Westcott)
+ * Use correct connections on multi-database setups in database search backends (Jake Howard)
 
 ### Documentation
 

+ 25 - 17
wagtail/search/backends/database/mysql/mysql.py

@@ -1,7 +1,12 @@
 import warnings
 from collections import OrderedDict
 
-from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
+from django.db import (
+    NotSupportedError,
+    connections,
+    router,
+    transaction,
+)
 from django.db.models import Case, When
 from django.db.models.aggregates import Avg, Count
 from django.db.models.constants import LOOKUP_SEP
@@ -151,17 +156,22 @@ class ObjectIndexer:
 
 
 class Index:
-    def __init__(self, backend, db_alias=None):
+    def __init__(self, backend):
         self.backend = backend
         self.name = self.backend.index_name
-        self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
-        self.connection = connections[self.db_alias]
-        if self.connection.vendor != "mysql":
+
+        self.read_connection = connections[router.db_for_read(IndexEntry)]
+        self.write_connection = connections[router.db_for_write(IndexEntry)]
+
+        if (
+            self.read_connection.vendor != "mysql"
+            or self.write_connection.vendor != "mysql"
+        ):
             raise NotSupportedError(
-                "You must select a MySQL database " "to use MySQL search."
+                "You must select a MySQL database to use MySQL search."
             )
 
-        self.entries = IndexEntry._default_manager.using(self.db_alias)
+        self.entries = IndexEntry._default_manager.all()
 
     def add_model(self, model):
         pass
@@ -201,11 +211,9 @@ class Index:
         ).update(title_norm=lavg / F("title_length"))
 
     def delete_stale_model_entries(self, model):
-        existing_pks = (
-            model._default_manager.using(self.db_alias)
-            .annotate(object_id=Cast("pk", TextField()))
-            .values("object_id")
-        )
+        existing_pks = model._default_manager.annotate(
+            object_id=Cast("pk", TextField())
+        ).values("object_id")
         content_types_pks = get_descendants_content_types_pks(model)
         stale_entries = self.entries.filter(
             content_type_id__in=content_types_pks
@@ -276,7 +284,7 @@ class Index:
             update_method(content_type_pk, indexers)
 
     def delete_item(self, item):
-        item.index_entries.all()._raw_delete(using=self.db_alias)
+        item.index_entries.all()._raw_delete(using=self.write_connection.alias)
 
     def __str__(self):
         return self.name
@@ -610,7 +618,7 @@ class MySQLSearchRebuilder:
 class MySQLSearchAtomicRebuilder(MySQLSearchRebuilder):
     def __init__(self, index):
         super().__init__(index)
-        self.transaction = transaction.atomic(using=index.db_alias)
+        self.transaction = transaction.atomic(using=index.write_connection.alias)
         self.transaction_opened = False
 
     def start(self):
@@ -650,11 +658,11 @@ class MySQLSearchBackend(BaseSearchBackend):
         if params.get("ATOMIC_REBUILD", False):
             self.rebuilder_class = self.atomic_rebuilder_class
 
-    def get_index_for_model(self, model, db_alias=None):
-        return Index(self, db_alias)
+    def get_index_for_model(self, model):
+        return Index(self)
 
     def get_index_for_object(self, obj):
-        return self.get_index_for_model(obj._meta.model, obj._state.db)
+        return self.get_index_for_model(obj._meta.model)
 
     def reset_index(self):
         for connection in [

+ 32 - 22
wagtail/search/backends/database/postgres/postgres.py

@@ -3,7 +3,12 @@ from collections import OrderedDict
 from functools import reduce
 
 from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
-from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
+from django.db import (
+    NotSupportedError,
+    connections,
+    router,
+    transaction,
+)
 from django.db.models import Avg, Count, F, Manager, Q, TextField, Value
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.functions import Cast, Length
@@ -163,17 +168,22 @@ class ObjectIndexer:
 
 
 class Index:
-    def __init__(self, backend, db_alias=None):
+    def __init__(self, backend):
         self.backend = backend
         self.name = self.backend.index_name
-        self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
-        self.connection = connections[self.db_alias]
-        if self.connection.vendor != "postgresql":
+
+        self.read_connection = connections[router.db_for_read(IndexEntry)]
+        self.write_connection = connections[router.db_for_write(IndexEntry)]
+
+        if (
+            self.read_connection.vendor != "postgresql"
+            or self.write_connection.vendor != "postgresql"
+        ):
             raise NotSupportedError(
-                "You must select a PostgreSQL database " "to use PostgreSQL search."
+                "You must select a PostgreSQL database to use PostgreSQL search."
             )
 
-        self.entries = IndexEntry._default_manager.using(self.db_alias)
+        self.entries = IndexEntry._default_manager.all()
 
     def add_model(self, model):
         pass
@@ -213,11 +223,9 @@ class Index:
         ).update(title_norm=lavg / F("title_length"))
 
     def delete_stale_model_entries(self, model):
-        existing_pks = (
-            model._default_manager.using(self.db_alias)
-            .annotate(object_id=Cast("pk", TextField()))
-            .values("object_id")
-        )
+        existing_pks = model._default_manager.annotate(
+            object_id=Cast("pk", TextField())
+        ).values("object_id")
         content_types_pks = get_descendants_content_types_pks(model)
         stale_entries = self.entries.filter(
             content_type_id__in=content_types_pks
@@ -246,7 +254,9 @@ class Index:
             return
 
         content_type_pk = get_content_type_pk(model)
-        compiler = InsertQuery(IndexEntry).get_compiler(connection=self.connection)
+        compiler = InsertQuery(IndexEntry).get_compiler(
+            connection=self.write_connection
+        )
         title_sql = []
         autocomplete_sql = []
         body_sql = []
@@ -259,7 +269,7 @@ class Index:
             value = compiler.prepare_value(
                 IndexEntry._meta.get_field("title"), indexer.title
             )
-            sql, params = value.as_sql(compiler, self.connection)
+            sql, params = value.as_sql(compiler, self.write_connection)
             title_sql.append(sql)
             data_params.extend(params)
 
@@ -267,7 +277,7 @@ class Index:
             value = compiler.prepare_value(
                 IndexEntry._meta.get_field("autocomplete"), indexer.autocomplete
             )
-            sql, params = value.as_sql(compiler, self.connection)
+            sql, params = value.as_sql(compiler, self.write_connection)
             autocomplete_sql.append(sql)
             data_params.extend(params)
 
@@ -275,7 +285,7 @@ class Index:
             value = compiler.prepare_value(
                 IndexEntry._meta.get_field("body"), indexer.body
             )
-            sql, params = value.as_sql(compiler, self.connection)
+            sql, params = value.as_sql(compiler, self.write_connection)
             body_sql.append(sql)
             data_params.extend(params)
 
@@ -286,7 +296,7 @@ class Index:
             ]
         )
 
-        with self.connection.cursor() as cursor:
+        with self.write_connection.cursor() as cursor:
             cursor.execute(
                 """
                 INSERT INTO %s (content_type_id, object_id, title, autocomplete, body, title_norm)
@@ -304,7 +314,7 @@ class Index:
         self._refresh_title_norms()
 
     def delete_item(self, item):
-        item.index_entries.all()._raw_delete(using=self.db_alias)
+        item.index_entries.all()._raw_delete(using=self.write_connection.alias)
 
     def __str__(self):
         return self.name
@@ -660,7 +670,7 @@ class PostgresSearchRebuilder:
 class PostgresSearchAtomicRebuilder(PostgresSearchRebuilder):
     def __init__(self, index):
         super().__init__(index)
-        self.transaction = transaction.atomic(using=index.db_alias)
+        self.transaction = transaction.atomic(using=index.write_connection.alias)
         self.transaction_opened = False
 
     def start(self):
@@ -701,11 +711,11 @@ class PostgresSearchBackend(BaseSearchBackend):
         if params.get("ATOMIC_REBUILD", False):
             self.rebuilder_class = self.atomic_rebuilder_class
 
-    def get_index_for_model(self, model, db_alias=None):
-        return Index(self, db_alias)
+    def get_index_for_model(self, model):
+        return Index(self)
 
     def get_index_for_object(self, obj):
-        return self.get_index_for_model(obj._meta.model, obj._state.db)
+        return self.get_index_for_model(obj._meta.model)
 
     def reset_index(self):
         for connection in [

+ 25 - 17
wagtail/search/backends/database/sqlite/sqlite.py

@@ -1,7 +1,12 @@
 from collections import OrderedDict
 from functools import reduce
 
-from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
+from django.db import (
+    NotSupportedError,
+    connections,
+    router,
+    transaction,
+)
 from django.db.models import Avg, Count, F, Manager, Q, TextField
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.functions import Cast, Length
@@ -145,17 +150,22 @@ class ObjectIndexer:
 
 
 class Index:
-    def __init__(self, backend, db_alias=None):
+    def __init__(self, backend):
         self.backend = backend
         self.name = self.backend.index_name
-        self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
-        self.connection = connections[self.db_alias]
-        if self.connection.vendor != "sqlite":
+
+        self.read_connection = connections[router.db_for_read(IndexEntry)]
+        self.write_connection = connections[router.db_for_write(IndexEntry)]
+
+        if (
+            self.read_connection.vendor != "sqlite"
+            or self.write_connection.vendor != "sqlite"
+        ):
             raise NotSupportedError(
-                "You must select a SQLite database " "to use the SQLite search backend."
+                "You must select a SQLite database to use the SQLite search backend."
             )
 
-        self.entries = IndexEntry._default_manager.using(self.db_alias)
+        self.entries = IndexEntry._default_manager.all()
 
     def add_model(self, model):
         pass
@@ -195,11 +205,9 @@ class Index:
         ).update(title_norm=lavg / F("title_length"))
 
     def delete_stale_model_entries(self, model):
-        existing_pks = (
-            model._default_manager.using(self.db_alias)
-            .annotate(object_id=Cast("pk", TextField()))
-            .values("object_id")
-        )
+        existing_pks = model._default_manager.annotate(
+            object_id=Cast("pk", TextField())
+        ).values("object_id")
         content_types_pks = get_descendants_content_types_pks(model)
         stale_entries = self.entries.filter(
             content_type_id__in=content_types_pks
@@ -270,7 +278,7 @@ class Index:
             update_method(content_type_pk, indexers)
 
     def delete_item(self, item):
-        item.index_entries.all()._raw_delete(using=self.db_alias)
+        item.index_entries.all()._raw_delete(using=self.write_connection.alias)
 
     def __str__(self):
         return self.name
@@ -291,7 +299,7 @@ class SQLiteSearchRebuilder:
 class SQLiteSearchAtomicRebuilder(SQLiteSearchRebuilder):
     def __init__(self, index):
         super().__init__(index)
-        self.transaction = transaction.atomic(using=index.db_alias)
+        self.transaction = transaction.atomic(using=index.write_connection.alias)
         self.transaction_opened = False
 
     def start(self):
@@ -673,11 +681,11 @@ class SQLiteSearchBackend(BaseSearchBackend):
         if params.get("ATOMIC_REBUILD", False):
             self.rebuilder_class = self.atomic_rebuilder_class
 
-    def get_index_for_model(self, model, db_alias=None):
-        return Index(self, db_alias)
+    def get_index_for_model(self, model):
+        return Index(self)
 
     def get_index_for_object(self, obj):
-        return self.get_index_for_model(obj._meta.model, obj._state.db)
+        return self.get_index_for_model(obj._meta.model)
 
     def reset_index(self):
         for connection in [