filters.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. from django.conf import settings
  2. from django.db import models
  3. from django.shortcuts import get_object_or_404
  4. from rest_framework.filters import BaseFilterBackend
  5. from taggit.managers import TaggableManager
  6. from wagtail.core.models import Locale, Page
  7. from wagtail.search.backends import get_search_backend
  8. from wagtail.search.backends.base import FilterFieldError, OrderByFieldError
  9. from .utils import BadRequestError, parse_boolean
  10. class FieldsFilter(BaseFilterBackend):
  11. def filter_queryset(self, request, queryset, view):
  12. """
  13. This performs field level filtering on the result set
  14. Eg: ?title=James Joyce
  15. """
  16. fields = set(view.get_available_fields(queryset.model, db_fields_only=True))
  17. # Locale is a database field, but we provide a separate filter for it
  18. if 'locale' in fields:
  19. fields.remove('locale')
  20. for field_name, value in request.GET.items():
  21. if field_name in fields:
  22. try:
  23. field = queryset.model._meta.get_field(field_name)
  24. except LookupError:
  25. field = None
  26. # Convert value into python
  27. try:
  28. if isinstance(field, (models.BooleanField, models.NullBooleanField)):
  29. value = parse_boolean(value)
  30. elif isinstance(field, (models.IntegerField, models.AutoField)):
  31. value = int(value)
  32. elif isinstance(field, models.ForeignKey):
  33. value = field.target_field.get_prep_value(value)
  34. except ValueError as e:
  35. raise BadRequestError("field filter error. '%s' is not a valid value for %s (%s)" % (
  36. value,
  37. field_name,
  38. str(e)
  39. ))
  40. if isinstance(field, TaggableManager):
  41. for tag in value.split(','):
  42. queryset = queryset.filter(**{field_name + '__name': tag})
  43. # Stick a message on the queryset to indicate that tag filtering has been performed
  44. # This will let the do_search method know that it must raise an error as searching
  45. # and tag filtering at the same time is not supported
  46. queryset._filtered_by_tag = True
  47. else:
  48. queryset = queryset.filter(**{field_name: value})
  49. return queryset
  50. class OrderingFilter(BaseFilterBackend):
  51. def filter_queryset(self, request, queryset, view):
  52. """
  53. This applies ordering to the result set
  54. Eg: ?order=title
  55. It also supports reverse ordering
  56. Eg: ?order=-title
  57. And random ordering
  58. Eg: ?order=random
  59. """
  60. if 'order' in request.GET:
  61. order_by = request.GET['order']
  62. # Random ordering
  63. if order_by == 'random':
  64. # Prevent ordering by random with offset
  65. if 'offset' in request.GET:
  66. raise BadRequestError("random ordering with offset is not supported")
  67. return queryset.order_by('?')
  68. # Check if reverse ordering is set
  69. if order_by.startswith('-'):
  70. reverse_order = True
  71. order_by = order_by[1:]
  72. else:
  73. reverse_order = False
  74. # Add ordering
  75. if order_by in view.get_available_fields(queryset.model):
  76. queryset = queryset.order_by(order_by)
  77. else:
  78. # Unknown field
  79. raise BadRequestError("cannot order by '%s' (unknown field)" % order_by)
  80. # Reverse order
  81. if reverse_order:
  82. queryset = queryset.reverse()
  83. return queryset
  84. class SearchFilter(BaseFilterBackend):
  85. def filter_queryset(self, request, queryset, view):
  86. """
  87. This performs a full-text search on the result set
  88. Eg: ?search=James Joyce
  89. """
  90. search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True)
  91. if 'search' in request.GET:
  92. if not search_enabled:
  93. raise BadRequestError("search is disabled")
  94. # Searching and filtering by tag at the same time is not supported
  95. if getattr(queryset, '_filtered_by_tag', False):
  96. raise BadRequestError("filtering by tag with a search query is not supported")
  97. search_query = request.GET['search']
  98. search_operator = request.GET.get('search_operator', None)
  99. order_by_relevance = 'order' not in request.GET
  100. sb = get_search_backend()
  101. try:
  102. queryset = sb.search(search_query, queryset, operator=search_operator, order_by_relevance=order_by_relevance)
  103. except FilterFieldError as e:
  104. raise BadRequestError("cannot filter by '{}' while searching (field is not indexed)".format(e.field_name))
  105. except OrderByFieldError as e:
  106. raise BadRequestError("cannot order by '{}' while searching (field is not indexed)".format(e.field_name))
  107. return queryset
  108. class ChildOfFilter(BaseFilterBackend):
  109. """
  110. Implements the ?child_of filter used to filter the results to only contain
  111. pages that are direct children of the specified page.
  112. """
  113. def filter_queryset(self, request, queryset, view):
  114. if 'child_of' in request.GET:
  115. try:
  116. parent_page_id = int(request.GET['child_of'])
  117. if parent_page_id < 0:
  118. raise ValueError()
  119. parent_page = view.get_base_queryset().get(id=parent_page_id)
  120. except ValueError:
  121. if request.GET['child_of'] == 'root':
  122. parent_page = view.get_root_page()
  123. else:
  124. raise BadRequestError("child_of must be a positive integer")
  125. except Page.DoesNotExist:
  126. raise BadRequestError("parent page doesn't exist")
  127. queryset = queryset.child_of(parent_page)
  128. # Save the parent page on the queryset. This is required for the page
  129. # explorer, which needs to pass the parent page into
  130. # `construct_explorer_page_queryset` hook functions
  131. queryset._filtered_by_child_of = parent_page
  132. return queryset
  133. class DescendantOfFilter(BaseFilterBackend):
  134. """
  135. Implements the ?decendant_of filter which limits the set of pages to a
  136. particular branch of the page tree.
  137. """
  138. def filter_queryset(self, request, queryset, view):
  139. if 'descendant_of' in request.GET:
  140. if hasattr(queryset, '_filtered_by_child_of'):
  141. raise BadRequestError("filtering by descendant_of with child_of is not supported")
  142. try:
  143. parent_page_id = int(request.GET['descendant_of'])
  144. if parent_page_id < 0:
  145. raise ValueError()
  146. parent_page = view.get_base_queryset().get(id=parent_page_id)
  147. except ValueError:
  148. if request.GET['descendant_of'] == 'root':
  149. parent_page = view.get_root_page()
  150. else:
  151. raise BadRequestError("descendant_of must be a positive integer")
  152. except Page.DoesNotExist:
  153. raise BadRequestError("ancestor page doesn't exist")
  154. queryset = queryset.descendant_of(parent_page)
  155. return queryset
  156. class TranslationOfFilter(BaseFilterBackend):
  157. """
  158. Implements the ?translation_of filter which limits the set of pages to translations
  159. of a page.
  160. """
  161. def filter_queryset(self, request, queryset, view):
  162. if 'translation_of' in request.GET:
  163. try:
  164. page_id = int(request.GET['translation_of'])
  165. if page_id < 0:
  166. raise ValueError()
  167. page = view.get_base_queryset().get(id=page_id)
  168. except ValueError:
  169. if request.GET['translation_of'] == 'root':
  170. page = view.get_root_page()
  171. else:
  172. raise BadRequestError("translation_of must be a positive integer")
  173. except Page.DoesNotExist:
  174. raise BadRequestError("translation_of page doesn't exist")
  175. _filtered_by_child_of = getattr(queryset, '_filtered_by_child_of', None)
  176. queryset = queryset.translation_of(page)
  177. if _filtered_by_child_of:
  178. queryset._filtered_by_child_of = _filtered_by_child_of
  179. return queryset
  180. class LocaleFilter(BaseFilterBackend):
  181. """
  182. Implements the ?locale filter which limits the set of pages to a
  183. particular locale.
  184. """
  185. def filter_queryset(self, request, queryset, view):
  186. if 'locale' in request.GET:
  187. _filtered_by_child_of = getattr(queryset, '_filtered_by_child_of', None)
  188. locale = get_object_or_404(Locale, language_code=request.GET['locale'])
  189. queryset = queryset.filter(locale=locale)
  190. if _filtered_by_child_of:
  191. queryset._filtered_by_child_of = _filtered_by_child_of
  192. return queryset