selene-backend/shared/selene/api/base_endpoint.py

165 lines
5.3 KiB
Python

"""Base class for Flask API endpoints"""
from http import HTTPStatus
from flask import after_this_request, current_app, request
from flask.views import MethodView
from selene.data.account import (
Account,
AccountRepository,
RefreshTokenRepository
)
from selene.util.auth import (
AuthenticationError,
AuthenticationTokenGenerator,
AuthenticationTokenValidator,
FIFTEEN_MINUTES,
ONE_MONTH
)
from selene.util.db import get_db_connection
class APIError(Exception):
"""Raise this exception whenever a non-successful response is built"""
pass
class SeleneEndpoint(MethodView):
"""
Abstract base class for Selene Flask Restful API calls.
Subclasses must do the following:
- override the allowed_methods class attribute to a list of all allowed
HTTP methods. Each list member must be a HTTPMethod enum
- override the _build_response_data method
"""
authentication_required: bool = True
def __init__(self):
self.config: dict = current_app.config
self.authenticated = False
self.request = request
self.response: tuple = None
self.access_token_expired: bool = False
self.refresh_token_expired: bool = False
self.account: Account = None
def _authenticate(self):
"""
Authenticate the user using tokens passed via cookies.
:raises: APIError()
"""
try:
account_id = self._validate_auth_tokens()
self._validate_account(account_id)
except AuthenticationError as ae:
if self.authentication_required:
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
else:
self.authenticated = True
def _validate_auth_tokens(self) -> str:
self.access_token_expired, account_id = self._validate_token(
'seleneAccess',
self.config['ACCESS_SECRET']
)
if self.access_token_expired:
self.refresh_token_expired, account_id = self._validate_token(
'seleneRefresh',
self.config['REFRESH_SECRET']
)
return account_id
def _validate_token(self, cookie_key, jwt_secret):
"""Validate the access token is well-formed and not expired
:raises: AuthenticationError
"""
account_id = None
token_expired = False
try:
token = self.request.cookies[cookie_key]
except KeyError:
error_msg = 'no {} token found in request'
raise AuthenticationError(error_msg.format(cookie_key))
validator = AuthenticationTokenValidator(token, jwt_secret)
validator.validate_token()
if validator.token_is_expired:
token_expired = True
elif validator.token_is_invalid:
raise AuthenticationError('access token is invalid')
else:
account_id = validator.account_id
return token_expired, account_id
def _validate_account(self, account_id):
"""The refresh token in the request must match the database value.
:raises: AuthenticationError
"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
account_repository = AccountRepository(db)
self.account = account_repository.get_account_by_id(account_id)
if self.account is None:
raise AuthenticationError('account not found')
if self.access_token_expired:
if self.refresh_token not in self.account.refresh_tokens:
raise AuthenticationError('refresh token not found')
def _generate_tokens(self):
token_generator = AuthenticationTokenGenerator(
self.account.id,
self.config['ACCESS_SECRET'],
self.config['REFRESH_SECRET']
)
access_token = token_generator.access_token
refresh_token = token_generator.refresh_token
return access_token, refresh_token
def _set_token_cookies(self, access_token, refresh_token, expire=False):
access_token_cookie = dict(
key='seleneAccess',
value=str(access_token),
domain=self.config['DOMAIN'],
max_age=FIFTEEN_MINUTES,
)
refresh_token_cookie = dict(
key='seleneRefresh',
value=str(refresh_token),
domain=self.config['DOMAIN'],
max_age=ONE_MONTH,
)
if expire:
for cookie in (access_token_cookie, refresh_token_cookie):
cookie.update(value='', max_age=0)
@after_this_request
def set_cookies(response):
response.set_cookie(**access_token_cookie)
response.set_cookie(**refresh_token_cookie)
return response
def _update_refresh_token_on_db(self, new_refresh_token):
old_refresh_token = self.request.cookies['seleneRefresh']
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account)
if self.refresh_token_expired:
token_repository.delete_refresh_token(old_refresh_token)
raise AuthenticationError('refresh token expired')
else:
token_repository.update_refresh_token(
new_refresh_token,
old_refresh_token
)