|
@@ -14,6 +14,20 @@ class FilterError(Exception):
|
|
|
|
|
|
|
|
|
class FieldError(Exception):
|
|
|
+ def __init__(self, *args, field_name=None, **kwargs):
|
|
|
+ self.field_name = field_name
|
|
|
+ super(FieldError, self).__init__(*args, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+class SearchFieldError(FieldError):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class FilterFieldError(FieldError):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class OrderByFieldError(FieldError):
|
|
|
pass
|
|
|
|
|
|
|
|
@@ -48,18 +62,20 @@ class BaseSearchQueryCompiler:
|
|
|
def _connect_filters(self, filters, connector, negated):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
- def _process_filter(self, field_attname, lookup, value):
|
|
|
+ def _process_filter(self, field_attname, lookup, value, check_only=False):
|
|
|
# Get the field
|
|
|
field = self._get_filterable_field(field_attname)
|
|
|
|
|
|
if field is None:
|
|
|
- raise FieldError(
|
|
|
+ raise FilterFieldError(
|
|
|
'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' +
|
|
|
- field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
|
|
|
+ field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
|
|
+ field_name=field_attname
|
|
|
)
|
|
|
|
|
|
# Process the lookup
|
|
|
- result = self._process_lookup(field, lookup, value)
|
|
|
+ if not check_only:
|
|
|
+ result = self._process_lookup(field, lookup, value)
|
|
|
|
|
|
if result is None:
|
|
|
raise FilterError(
|
|
@@ -69,7 +85,7 @@ class BaseSearchQueryCompiler:
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def _get_filters_from_where_node(self, where_node):
|
|
|
+ def _get_filters_from_where_node(self, where_node, check_only=False):
|
|
|
# Check if this is a leaf node
|
|
|
if isinstance(where_node, Lookup):
|
|
|
field_attname = where_node.lhs.target.attname
|
|
@@ -81,7 +97,7 @@ class BaseSearchQueryCompiler:
|
|
|
return
|
|
|
|
|
|
# Process the filter
|
|
|
- return self._process_filter(field_attname, lookup, value)
|
|
|
+ return self._process_filter(field_attname, lookup, value, check_only=check_only)
|
|
|
|
|
|
elif isinstance(where_node, SubqueryConstraint):
|
|
|
raise FilterError('Could not apply filter on search results: Subqueries are not allowed.')
|
|
@@ -90,9 +106,10 @@ class BaseSearchQueryCompiler:
|
|
|
# Get child filters
|
|
|
connector = where_node.connector
|
|
|
child_filters = [self._get_filters_from_where_node(child) for child in where_node.children]
|
|
|
- child_filters = [child_filter for child_filter in child_filters if child_filter]
|
|
|
|
|
|
- return self._connect_filters(child_filters, connector, where_node.negated)
|
|
|
+ if not check_only:
|
|
|
+ child_filters = [child_filter for child_filter in child_filters if child_filter]
|
|
|
+ return self._connect_filters(child_filters, connector, where_node.negated)
|
|
|
|
|
|
else:
|
|
|
raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node)))
|
|
@@ -114,13 +131,35 @@ class BaseSearchQueryCompiler:
|
|
|
field = self._get_filterable_field(field_name)
|
|
|
|
|
|
if field is None:
|
|
|
- raise FieldError(
|
|
|
+ raise OrderByFieldError(
|
|
|
'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' +
|
|
|
- field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
|
|
|
+ field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
|
|
+ field_name=field_name
|
|
|
)
|
|
|
|
|
|
yield reverse, field
|
|
|
|
|
|
+ def check(self):
|
|
|
+ # Check search fields
|
|
|
+ if self.fields:
|
|
|
+ allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()}
|
|
|
+
|
|
|
+ for field_name in self.fields:
|
|
|
+ if field_name not in allowed_fields:
|
|
|
+ raise SearchFieldError(
|
|
|
+ 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
|
|
|
+ field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
|
|
+ field_name=field_name
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check where clause
|
|
|
+ # Raises FilterFieldError if an unindexed field is being filtered on
|
|
|
+ self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
|
|
|
+
|
|
|
+ # Check order by
|
|
|
+ # Raises OrderByFieldError if an unindexed field is being used to order by
|
|
|
+ list(self._get_order_by())
|
|
|
+
|
|
|
|
|
|
class BaseSearchResults:
|
|
|
def __init__(self, backend, query_compiler, prefetch_related=None):
|
|
@@ -278,17 +317,6 @@ class BaseSearchBackend:
|
|
|
if query == "":
|
|
|
return EmptySearchResults()
|
|
|
|
|
|
- # Only fields that are indexed as a SearchField can be passed in fields
|
|
|
- if fields:
|
|
|
- allowed_fields = {field.field_name for field in model.get_searchable_search_fields()}
|
|
|
-
|
|
|
- for field_name in fields:
|
|
|
- if field_name not in allowed_fields:
|
|
|
- raise FieldError(
|
|
|
- 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
|
|
|
- field_name + '\') to ' + model.__name__ + '.search_fields.'
|
|
|
- )
|
|
|
-
|
|
|
# Apply filters to queryset
|
|
|
if filters:
|
|
|
queryset = queryset.filter(**filters)
|
|
@@ -302,4 +330,8 @@ class BaseSearchBackend:
|
|
|
search_query = self.query_compiler_class(
|
|
|
queryset, query, fields=fields, operator=operator, order_by_relevance=order_by_relevance
|
|
|
)
|
|
|
+
|
|
|
+ # Check the query
|
|
|
+ search_query.check()
|
|
|
+
|
|
|
return self.results_class(self, search_query)
|