selene-backend/shared/selene/api/base_endpoint.py

185 lines
6.2 KiB
Python

"""Base class for Flask API endpoints"""
from logging import getLogger
from flask import after_this_request, current_app, request, g as global_context
from flask.views import MethodView
from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationError, AuthenticationToken
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
_log = getLogger(__package__)
class APIError(Exception):
"""Raise this exception whenever a non-successful response is built"""
pass
class SeleneEndpoint(MethodView):
"""
Abstract base class for Selene Flask Restful API calls.
Subclasses must do the following:
- override the allowed_methods class attribute to a list of all allowed
HTTP methods. Each list member must be a HTTPMethod enum
- override the _build_response_data method
"""
def __init__(self):
self.config: dict = current_app.config
self.request = request
self.response: tuple = None
global_context.url = request.url
self.account: Account = None
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
@property
def db(self):
if 'db' not in global_context:
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
)
return global_context.db
def _init_access_token(self):
return AuthenticationToken(
self.config['ACCESS_SECRET'],
FIFTEEN_MINUTES
)
def _init_refresh_token(self):
return AuthenticationToken(
self.config['REFRESH_SECRET'],
ONE_MONTH
)
def _authenticate(self):
"""
Authenticate the user using tokens passed via cookies.
:raises: APIError()
"""
try:
account_id = self._validate_auth_tokens()
self._get_account(account_id)
self._validate_account(account_id)
if self.access_token.is_expired:
self._refresh_auth_tokens()
except Exception:
_log.exception('an exception occurred during authentication')
raise
def _validate_auth_tokens(self):
"""Ensure the tokens are passed in request and are well formed."""
self._get_auth_tokens()
self._decode_access_token()
if self.access_token.is_expired:
self._decode_refresh_token()
account_not_found = (
self.access_token.account_id is None and
self.refresh_token.account_id is None
)
if account_not_found:
raise AuthenticationError(
'failed to retrieve account ID from authentication tokens'
)
return self.access_token.account_id or self.refresh_token.account_id
def _get_auth_tokens(self):
self.access_token.jwt = self.request.cookies.get(
ACCESS_TOKEN_COOKIE_NAME
)
self.refresh_token.jwt = self.request.cookies.get(
REFRESH_TOKEN_COOKIE_NAME
)
if self.access_token.jwt is None and self.refresh_token.jwt is None:
raise AuthenticationError('no authentication tokens found')
def _decode_access_token(self):
"""Decode the JWT to get the account ID and check for errors."""
self.access_token.validate()
if not self.access_token.is_valid:
raise AuthenticationError('invalid access token')
def _decode_refresh_token(self):
"""Decode the JWT to get the account ID and check for errors."""
self.refresh_token.validate()
if not self.refresh_token.is_valid:
raise AuthenticationError('invalid refresh token')
if self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
def _get_account(self, account_id):
"""Use account ID from decoded authentication token to get account."""
account_repository = AccountRepository(self.db)
self.account = account_repository.get_account_by_id(account_id)
def _validate_account(self, account_id: str):
"""Account must exist and contain have a refresh token matching request.
:raises: AuthenticationError
"""
if self.account is None:
_log.error('account ID {} not on database'.format(account_id))
raise AuthenticationError('account not found')
else:
global_context.account_id = self.account.id
def _refresh_auth_tokens(self):
"""Steps necessary to refresh the tokens used for authentication."""
self._generate_tokens()
self._set_token_cookies()
def _generate_tokens(self):
"""Generate an access token and refresh token."""
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
self.access_token.generate(self.account.id)
self.refresh_token.generate(self.account.id)
def _set_token_cookies(self, expire=False):
"""Set the cookies that contain the authentication token.
This method should be called when a user logs in, logs out, or when
their access token expires.
:param expire: generate tokens that immediately expire, effectively
logging a user out of the system.
:return:
"""
access_token_cookie = dict(
key='seleneAccess',
value=str(self.access_token.jwt),
domain=self.config['DOMAIN'],
max_age=FIFTEEN_MINUTES,
)
refresh_token_cookie = dict(
key='seleneRefresh',
value=str(self.refresh_token.jwt),
domain=self.config['DOMAIN'],
max_age=ONE_MONTH,
)
if expire:
for cookie in (access_token_cookie, refresh_token_cookie):
cookie.update(value='', max_age=0)
@after_this_request
def set_cookies(response):
"""Use Flask after request hook to reset token cookies"""
response.set_cookie(**access_token_cookie)
response.set_cookie(**refresh_token_cookie)
return response