185 lines
6.2 KiB
Python
185 lines
6.2 KiB
Python
"""Base class for Flask API endpoints"""
|
|
from logging import getLogger
|
|
|
|
from flask import after_this_request, current_app, request, g as global_context
|
|
from flask.views import MethodView
|
|
|
|
from selene.data.account import Account, AccountRepository
|
|
from selene.util.auth import AuthenticationError, AuthenticationToken
|
|
from selene.util.db import connect_to_db
|
|
|
|
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
|
FIFTEEN_MINUTES = 900
|
|
ONE_MONTH = 2628000
|
|
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
|
|
|
|
_log = getLogger(__package__)
|
|
|
|
|
|
class APIError(Exception):
|
|
"""Raise this exception whenever a non-successful response is built"""
|
|
pass
|
|
|
|
|
|
class SeleneEndpoint(MethodView):
|
|
"""
|
|
Abstract base class for Selene Flask Restful API calls.
|
|
|
|
Subclasses must do the following:
|
|
- override the allowed_methods class attribute to a list of all allowed
|
|
HTTP methods. Each list member must be a HTTPMethod enum
|
|
- override the _build_response_data method
|
|
"""
|
|
def __init__(self):
|
|
self.config: dict = current_app.config
|
|
self.request = request
|
|
self.response: tuple = None
|
|
global_context.url = request.url
|
|
self.account: Account = None
|
|
self.access_token = self._init_access_token()
|
|
self.refresh_token = self._init_refresh_token()
|
|
|
|
@property
|
|
def db(self):
|
|
if 'db' not in global_context:
|
|
global_context.db = connect_to_db(
|
|
current_app.config['DB_CONNECTION_CONFIG']
|
|
)
|
|
|
|
return global_context.db
|
|
|
|
def _init_access_token(self):
|
|
return AuthenticationToken(
|
|
self.config['ACCESS_SECRET'],
|
|
FIFTEEN_MINUTES
|
|
)
|
|
|
|
def _init_refresh_token(self):
|
|
return AuthenticationToken(
|
|
self.config['REFRESH_SECRET'],
|
|
ONE_MONTH
|
|
)
|
|
|
|
def _authenticate(self):
|
|
"""
|
|
Authenticate the user using tokens passed via cookies.
|
|
|
|
:raises: APIError()
|
|
"""
|
|
try:
|
|
account_id = self._validate_auth_tokens()
|
|
self._get_account(account_id)
|
|
self._validate_account(account_id)
|
|
if self.access_token.is_expired:
|
|
self._refresh_auth_tokens()
|
|
except Exception:
|
|
_log.exception('an exception occurred during authentication')
|
|
raise
|
|
|
|
def _validate_auth_tokens(self):
|
|
"""Ensure the tokens are passed in request and are well formed."""
|
|
self._get_auth_tokens()
|
|
self._decode_access_token()
|
|
if self.access_token.is_expired:
|
|
self._decode_refresh_token()
|
|
account_not_found = (
|
|
self.access_token.account_id is None and
|
|
self.refresh_token.account_id is None
|
|
)
|
|
if account_not_found:
|
|
raise AuthenticationError(
|
|
'failed to retrieve account ID from authentication tokens'
|
|
)
|
|
|
|
return self.access_token.account_id or self.refresh_token.account_id
|
|
|
|
def _get_auth_tokens(self):
|
|
self.access_token.jwt = self.request.cookies.get(
|
|
ACCESS_TOKEN_COOKIE_NAME
|
|
)
|
|
self.refresh_token.jwt = self.request.cookies.get(
|
|
REFRESH_TOKEN_COOKIE_NAME
|
|
)
|
|
if self.access_token.jwt is None and self.refresh_token.jwt is None:
|
|
raise AuthenticationError('no authentication tokens found')
|
|
|
|
def _decode_access_token(self):
|
|
"""Decode the JWT to get the account ID and check for errors."""
|
|
self.access_token.validate()
|
|
|
|
if not self.access_token.is_valid:
|
|
raise AuthenticationError('invalid access token')
|
|
|
|
def _decode_refresh_token(self):
|
|
"""Decode the JWT to get the account ID and check for errors."""
|
|
self.refresh_token.validate()
|
|
|
|
if not self.refresh_token.is_valid:
|
|
raise AuthenticationError('invalid refresh token')
|
|
|
|
if self.refresh_token.is_expired:
|
|
raise AuthenticationError('authentication tokens expired')
|
|
|
|
def _get_account(self, account_id):
|
|
"""Use account ID from decoded authentication token to get account."""
|
|
account_repository = AccountRepository(self.db)
|
|
self.account = account_repository.get_account_by_id(account_id)
|
|
|
|
def _validate_account(self, account_id: str):
|
|
"""Account must exist and contain have a refresh token matching request.
|
|
|
|
:raises: AuthenticationError
|
|
"""
|
|
if self.account is None:
|
|
_log.error('account ID {} not on database'.format(account_id))
|
|
raise AuthenticationError('account not found')
|
|
else:
|
|
global_context.account_id = self.account.id
|
|
|
|
def _refresh_auth_tokens(self):
|
|
"""Steps necessary to refresh the tokens used for authentication."""
|
|
self._generate_tokens()
|
|
self._set_token_cookies()
|
|
|
|
def _generate_tokens(self):
|
|
"""Generate an access token and refresh token."""
|
|
self.access_token = self._init_access_token()
|
|
self.refresh_token = self._init_refresh_token()
|
|
self.access_token.generate(self.account.id)
|
|
self.refresh_token.generate(self.account.id)
|
|
|
|
def _set_token_cookies(self, expire=False):
|
|
"""Set the cookies that contain the authentication token.
|
|
|
|
This method should be called when a user logs in, logs out, or when
|
|
their access token expires.
|
|
|
|
:param expire: generate tokens that immediately expire, effectively
|
|
logging a user out of the system.
|
|
:return:
|
|
"""
|
|
access_token_cookie = dict(
|
|
key='seleneAccess',
|
|
value=str(self.access_token.jwt),
|
|
domain=self.config['DOMAIN'],
|
|
max_age=FIFTEEN_MINUTES,
|
|
)
|
|
refresh_token_cookie = dict(
|
|
key='seleneRefresh',
|
|
value=str(self.refresh_token.jwt),
|
|
domain=self.config['DOMAIN'],
|
|
max_age=ONE_MONTH,
|
|
)
|
|
|
|
if expire:
|
|
for cookie in (access_token_cookie, refresh_token_cookie):
|
|
cookie.update(value='', max_age=0)
|
|
|
|
@after_this_request
|
|
def set_cookies(response):
|
|
"""Use Flask after request hook to reset token cookies"""
|
|
response.set_cookie(**access_token_cookie)
|
|
response.set_cookie(**refresh_token_cookie)
|
|
|
|
return response
|