view_restrictions.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """
  2. Base model definitions for validating front-end user access to resources such as pages and
  3. documents. These may be subclassed to accommodate specific models such as Page or Collection,
  4. but the definitions here should remain generic and not depend on the base wagtail.models
  5. module or specific models defined there.
  6. """
  7. from django.conf import settings
  8. from django.contrib.auth.models import Group
  9. from django.db import models
  10. from django.utils.translation import gettext_lazy as _
  11. class BaseViewRestriction(models.Model):
  12. NONE = "none"
  13. PASSWORD = "password"
  14. GROUPS = "groups"
  15. LOGIN = "login"
  16. RESTRICTION_CHOICES = (
  17. (NONE, _("Public")),
  18. (PASSWORD, _("Private, accessible with a shared password")),
  19. (LOGIN, _("Private, accessible to any logged-in users")),
  20. (GROUPS, _("Private, accessible to users in specific groups")),
  21. )
  22. restriction_type = models.CharField(max_length=20, choices=RESTRICTION_CHOICES)
  23. password = models.CharField(
  24. verbose_name=_("shared password"),
  25. max_length=255,
  26. blank=True,
  27. help_text=_(
  28. "Shared passwords should not be used to protect sensitive content. Anyone who has this password will be able to view the content."
  29. ),
  30. )
  31. groups = models.ManyToManyField(Group, verbose_name=_("groups"), blank=True)
  32. def accept_request(self, request):
  33. if self.restriction_type == BaseViewRestriction.PASSWORD:
  34. passed_restrictions = request.session.get(
  35. self.passed_view_restrictions_session_key, []
  36. )
  37. if self.id not in passed_restrictions:
  38. return False
  39. elif self.restriction_type == BaseViewRestriction.LOGIN:
  40. if not request.user.is_authenticated:
  41. return False
  42. elif self.restriction_type == BaseViewRestriction.GROUPS:
  43. if not request.user.is_superuser:
  44. current_user_groups = request.user.groups.all()
  45. if not any(group in current_user_groups for group in self.groups.all()):
  46. return False
  47. return True
  48. def mark_as_passed(self, request):
  49. """
  50. Update the session data in the request to mark the user as having passed this
  51. view restriction
  52. """
  53. has_existing_session = settings.SESSION_COOKIE_NAME in request.COOKIES
  54. passed_restrictions = request.session.setdefault(
  55. self.passed_view_restrictions_session_key, []
  56. )
  57. if self.id not in passed_restrictions:
  58. passed_restrictions.append(self.id)
  59. request.session[
  60. self.passed_view_restrictions_session_key
  61. ] = passed_restrictions
  62. if not has_existing_session:
  63. # if this is a session we've created, set it to expire at the end
  64. # of the browser session
  65. request.session.set_expiry(0)
  66. class Meta:
  67. abstract = True
  68. verbose_name = _("view restriction")
  69. verbose_name_plural = _("view restrictions")