|
@@ -3,7 +3,12 @@ from collections import OrderedDict
|
|
|
from functools import reduce
|
|
|
|
|
|
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
|
|
|
-from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
|
|
+from django.db import (
|
|
|
+ NotSupportedError,
|
|
|
+ connections,
|
|
|
+ router,
|
|
|
+ transaction,
|
|
|
+)
|
|
|
from django.db.models import Avg, Count, F, Manager, Q, TextField, Value
|
|
|
from django.db.models.constants import LOOKUP_SEP
|
|
|
from django.db.models.functions import Cast, Length
|
|
@@ -163,17 +168,22 @@ class ObjectIndexer:
|
|
|
|
|
|
|
|
|
class Index:
|
|
|
- def __init__(self, backend, db_alias=None):
|
|
|
+ def __init__(self, backend):
|
|
|
self.backend = backend
|
|
|
self.name = self.backend.index_name
|
|
|
- self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
|
|
|
- self.connection = connections[self.db_alias]
|
|
|
- if self.connection.vendor != "postgresql":
|
|
|
+
|
|
|
+ self.read_connection = connections[router.db_for_read(IndexEntry)]
|
|
|
+ self.write_connection = connections[router.db_for_write(IndexEntry)]
|
|
|
+
|
|
|
+ if (
|
|
|
+ self.read_connection.vendor != "postgresql"
|
|
|
+ or self.write_connection.vendor != "postgresql"
|
|
|
+ ):
|
|
|
raise NotSupportedError(
|
|
|
- "You must select a PostgreSQL database " "to use PostgreSQL search."
|
|
|
+ "You must select a PostgreSQL database to use PostgreSQL search."
|
|
|
)
|
|
|
|
|
|
- self.entries = IndexEntry._default_manager.using(self.db_alias)
|
|
|
+ self.entries = IndexEntry._default_manager.all()
|
|
|
|
|
|
def add_model(self, model):
|
|
|
pass
|
|
@@ -213,11 +223,9 @@ class Index:
|
|
|
).update(title_norm=lavg / F("title_length"))
|
|
|
|
|
|
def delete_stale_model_entries(self, model):
|
|
|
- existing_pks = (
|
|
|
- model._default_manager.using(self.db_alias)
|
|
|
- .annotate(object_id=Cast("pk", TextField()))
|
|
|
- .values("object_id")
|
|
|
- )
|
|
|
+ existing_pks = model._default_manager.annotate(
|
|
|
+ object_id=Cast("pk", TextField())
|
|
|
+ ).values("object_id")
|
|
|
content_types_pks = get_descendants_content_types_pks(model)
|
|
|
stale_entries = self.entries.filter(
|
|
|
content_type_id__in=content_types_pks
|
|
@@ -246,7 +254,9 @@ class Index:
|
|
|
return
|
|
|
|
|
|
content_type_pk = get_content_type_pk(model)
|
|
|
- compiler = InsertQuery(IndexEntry).get_compiler(connection=self.connection)
|
|
|
+ compiler = InsertQuery(IndexEntry).get_compiler(
|
|
|
+ connection=self.write_connection
|
|
|
+ )
|
|
|
title_sql = []
|
|
|
autocomplete_sql = []
|
|
|
body_sql = []
|
|
@@ -259,7 +269,7 @@ class Index:
|
|
|
value = compiler.prepare_value(
|
|
|
IndexEntry._meta.get_field("title"), indexer.title
|
|
|
)
|
|
|
- sql, params = value.as_sql(compiler, self.connection)
|
|
|
+ sql, params = value.as_sql(compiler, self.write_connection)
|
|
|
title_sql.append(sql)
|
|
|
data_params.extend(params)
|
|
|
|
|
@@ -267,7 +277,7 @@ class Index:
|
|
|
value = compiler.prepare_value(
|
|
|
IndexEntry._meta.get_field("autocomplete"), indexer.autocomplete
|
|
|
)
|
|
|
- sql, params = value.as_sql(compiler, self.connection)
|
|
|
+ sql, params = value.as_sql(compiler, self.write_connection)
|
|
|
autocomplete_sql.append(sql)
|
|
|
data_params.extend(params)
|
|
|
|
|
@@ -275,7 +285,7 @@ class Index:
|
|
|
value = compiler.prepare_value(
|
|
|
IndexEntry._meta.get_field("body"), indexer.body
|
|
|
)
|
|
|
- sql, params = value.as_sql(compiler, self.connection)
|
|
|
+ sql, params = value.as_sql(compiler, self.write_connection)
|
|
|
body_sql.append(sql)
|
|
|
data_params.extend(params)
|
|
|
|
|
@@ -286,7 +296,7 @@ class Index:
|
|
|
]
|
|
|
)
|
|
|
|
|
|
- with self.connection.cursor() as cursor:
|
|
|
+ with self.write_connection.cursor() as cursor:
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
INSERT INTO %s (content_type_id, object_id, title, autocomplete, body, title_norm)
|
|
@@ -304,7 +314,7 @@ class Index:
|
|
|
self._refresh_title_norms()
|
|
|
|
|
|
def delete_item(self, item):
|
|
|
- item.index_entries.all()._raw_delete(using=self.db_alias)
|
|
|
+ item.index_entries.all()._raw_delete(using=self.write_connection.alias)
|
|
|
|
|
|
def __str__(self):
|
|
|
return self.name
|
|
@@ -660,7 +670,7 @@ class PostgresSearchRebuilder:
|
|
|
class PostgresSearchAtomicRebuilder(PostgresSearchRebuilder):
|
|
|
def __init__(self, index):
|
|
|
super().__init__(index)
|
|
|
- self.transaction = transaction.atomic(using=index.db_alias)
|
|
|
+ self.transaction = transaction.atomic(using=index.write_connection.alias)
|
|
|
self.transaction_opened = False
|
|
|
|
|
|
def start(self):
|
|
@@ -701,11 +711,11 @@ class PostgresSearchBackend(BaseSearchBackend):
|
|
|
if params.get("ATOMIC_REBUILD", False):
|
|
|
self.rebuilder_class = self.atomic_rebuilder_class
|
|
|
|
|
|
- def get_index_for_model(self, model, db_alias=None):
|
|
|
- return Index(self, db_alias)
|
|
|
+ def get_index_for_model(self, model):
|
|
|
+ return Index(self)
|
|
|
|
|
|
def get_index_for_object(self, obj):
|
|
|
- return self.get_index_for_model(obj._meta.model, obj._state.db)
|
|
|
+ return self.get_index_for_model(obj._meta.model)
|
|
|
|
|
|
def reset_index(self):
|
|
|
for connection in [
|