diff --git a/shared/selene/api/base_endpoint.py b/shared/selene/api/base_endpoint.py index 608761ba..fa07fb60 100644 --- a/shared/selene/api/base_endpoint.py +++ b/shared/selene/api/base_endpoint.py @@ -1,7 +1,5 @@ """Base class for Flask API endpoints""" - -from http import HTTPStatus - +from logging import getLogger from flask import after_this_request, current_app, request from flask.views import MethodView @@ -10,15 +8,16 @@ from selene.data.account import ( AccountRepository, RefreshTokenRepository ) -from selene.util.auth import ( - AuthenticationError, - AuthenticationTokenGenerator, - AuthenticationTokenValidator, - FIFTEEN_MINUTES, - ONE_MONTH -) +from selene.util.auth import AuthenticationError, AuthenticationToken from selene.util.db import get_db_connection +ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess' +FIFTEEN_MINUTES = 900 +ONE_MONTH = 2628000 +REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh' + +_log = getLogger() + class APIError(Exception): """Raise this exception whenever a non-successful response is built""" @@ -34,16 +33,19 @@ class SeleneEndpoint(MethodView): HTTP methods. Each list member must be a HTTPMethod enum - override the _build_response_data method """ - authentication_required: bool = True - def __init__(self): self.config: dict = current_app.config - self.authenticated = False self.request = request self.response: tuple = None - self.access_token_expired: bool = False - self.refresh_token_expired: bool = False self.account: Account = None + self.access_token = AuthenticationToken( + self.config['ACCESS_SECRET'], + FIFTEEN_MINUTES + ) + self.refresh_token = AuthenticationToken( + self.config['REFRESH_SECRET'], + ONE_MONTH + ) def _authenticate(self): """ @@ -51,90 +53,92 @@ class SeleneEndpoint(MethodView): :raises: APIError() """ - try: - account_id = self._validate_auth_tokens() - self._validate_account(account_id) - except AuthenticationError as ae: - if self.authentication_required: - self.response = (str(ae), HTTPStatus.UNAUTHORIZED) - else: - self.authenticated = True + self._validate_auth_tokens() + account_id = self._get_account_id_from_tokens() + self._get_account(account_id) + self._validate_account(account_id) + if self.access_token.is_expired: + self._refresh_auth_tokens() - def _validate_auth_tokens(self) -> str: - self.access_token_expired, account_id = self._validate_token( - 'seleneAccess', - self.config['ACCESS_SECRET'] + def _validate_auth_tokens(self): + """Ensure the tokens are passed in request and are well formed.""" + self.access_token.jwt = self.request.cookies.get( + ACCESS_TOKEN_COOKIE_NAME ) - if self.access_token_expired: - self.refresh_token_expired, account_id = self._validate_token( - 'seleneRefresh', - self.config['REFRESH_SECRET'] - ) + self.access_token.validate() + self.refresh_token.jwt = self.request.cookies.get( + REFRESH_TOKEN_COOKIE_NAME + ) + self.refresh_token.validate() + + if self.access_token.jwt is None and self.refresh_token.jwt is None: + raise AuthenticationError('no authentication tokens found') + + if self.access_token.is_expired and self.refresh_token.is_expired: + raise AuthenticationError('authentication tokens expired') + + def _get_account_id_from_tokens(self): + """Extract the account ID, which is encoded within the tokens""" + if self.access_token.is_expired: + account_id = self.refresh_token.account_id + else: + account_id = self.access_token.account_id return account_id - def _validate_token(self, cookie_key, jwt_secret): - """Validate the access token is well-formed and not expired - - :raises: AuthenticationError - """ - account_id = None - token_expired = False - - try: - token = self.request.cookies[cookie_key] - except KeyError: - error_msg = 'no {} token found in request' - raise AuthenticationError(error_msg.format(cookie_key)) - - validator = AuthenticationTokenValidator(token, jwt_secret) - validator.validate_token() - if validator.token_is_expired: - token_expired = True - elif validator.token_is_invalid: - raise AuthenticationError('access token is invalid') - else: - account_id = validator.account_id - - return token_expired, account_id - - def _validate_account(self, account_id): - """The refresh token in the request must match the database value. - - :raises: AuthenticationError - """ + def _get_account(self, account_id): + """Use account ID from decoded authentication token to get account.""" with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: account_repository = AccountRepository(db) self.account = account_repository.get_account_by_id(account_id) + def _validate_account(self, account_id: str): + """Account must exist and contain have a refresh token matching request. + + :raises: AuthenticationError + """ if self.account is None: + _log.error('account ID {} not on database'.format(account_id)) raise AuthenticationError('account not found') - if self.access_token_expired: - if self.refresh_token not in self.account.refresh_tokens: - raise AuthenticationError('refresh token not found') + if self.refresh_token.jwt not in self.account.refresh_tokens: + log_msg = 'account ID {} does not have token {}' + _log.error(log_msg.format(account_id, self.refresh_token.jwt)) + raise AuthenticationError( + 'refresh token does not exist for this account' + ) + + def _refresh_auth_tokens(self): + """Steps necessary to refresh the tokens used for authentication.""" + old_refresh_token = self.refresh_token + self._generate_tokens() + self._update_refresh_token_on_db(old_refresh_token) + self._set_token_cookies() def _generate_tokens(self): - token_generator = AuthenticationTokenGenerator( - self.account.id, - self.config['ACCESS_SECRET'], - self.config['REFRESH_SECRET'] - ) - access_token = token_generator.access_token - refresh_token = token_generator.refresh_token + """Generate an access token and refresh token.""" + self.access_token.generate() + self.refresh_token.generate() - return access_token, refresh_token + def _set_token_cookies(self, expire=False): + """Set the cookies that contain the authentication token. - def _set_token_cookies(self, access_token, refresh_token, expire=False): + This method should be called when a user logs in, logs out, or when + their access token expires. + + :param expire: generate tokens that immediately expire, effectively + logging a user out of the system. + :return: + """ access_token_cookie = dict( key='seleneAccess', - value=str(access_token), + value=str(self.access_token.jwt), domain=self.config['DOMAIN'], max_age=FIFTEEN_MINUTES, ) refresh_token_cookie = dict( key='seleneRefresh', - value=str(refresh_token), + value=str(self.refresh_token.jwt), domain=self.config['DOMAIN'], max_age=ONE_MONTH, ) @@ -145,20 +149,21 @@ class SeleneEndpoint(MethodView): @after_this_request def set_cookies(response): + """Use Flask after request hook to reset token cookies""" response.set_cookie(**access_token_cookie) response.set_cookie(**refresh_token_cookie) return response - def _update_refresh_token_on_db(self, new_refresh_token): - old_refresh_token = self.request.cookies['seleneRefresh'] + def _update_refresh_token_on_db(self, old_refresh_token): + """Replace the refresh token on the request with the newly minted one""" with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: - token_repository = RefreshTokenRepository(db, self.account) - if self.refresh_token_expired: + token_repository = RefreshTokenRepository(db, self.account.id) + if old_refresh_token.is_expired: token_repository.delete_refresh_token(old_refresh_token) raise AuthenticationError('refresh token expired') else: token_repository.update_refresh_token( - new_refresh_token, - old_refresh_token + self.refresh_token.jwt, + old_refresh_token.jwt ) diff --git a/shared/selene/api/endpoints/account.py b/shared/selene/api/endpoints/account.py index ecadf3d3..4b6b4921 100644 --- a/shared/selene/api/endpoints/account.py +++ b/shared/selene/api/endpoints/account.py @@ -83,10 +83,9 @@ class AccountEndpoint(SeleneEndpoint): def get(self): """Process HTTP GET request for an account.""" self._authenticate() - if self.authenticated: - response_data = asdict(self.account) - del (response_data['refresh_tokens']) - self.response = response_data, HTTPStatus.OK + response_data = asdict(self.account) + del (response_data['refresh_tokens']) + self.response = response_data, HTTPStatus.OK return self.response diff --git a/shared/selene/util/auth.py b/shared/selene/util/auth.py index ea14f394..e1748947 100644 --- a/shared/selene/util/auth.py +++ b/shared/selene/util/auth.py @@ -1,86 +1,50 @@ +"""Logic for generating and validating JWT authentication tokens.""" from datetime import datetime from time import time import jwt -FIFTEEN_MINUTES = 900 -ONE_MONTH = 2628000 - class AuthenticationError(Exception): pass -class AuthenticationTokenGenerator(object): - _access_token = None - _refresh_token = None +class AuthenticationToken(object): + def __init__(self, secret: str, duration: int): + self.secret = secret + self.duration = duration + self.jwt: str = '' + self.is_valid: bool = None + self.is_expired: bool = None + self.account_id: str = None - def __init__(self, account_id: str, access_secret, refresh_secret): - self.account_id = account_id - self.access_secret = access_secret - self.refresh_secret = refresh_secret - - def _generate_token(self, token_duration: int): + def generate(self): """ Generates a JWT token """ - token_expiration = time() + token_duration payload = dict( iat=datetime.utcnow(), - exp=token_expiration, + exp=time() + self.duration, sub=self.account_id ) - - if token_duration == FIFTEEN_MINUTES: - secret = self.access_secret - else: - secret = self.refresh_secret - - token = jwt.encode( - payload, - secret, - algorithm='HS256' - ) + token = jwt.encode(payload, self.secret, algorithm='HS256') # convert the token from byte-array to string so that # it can be included in a JSON response object - return token.decode() + self.jwt = token.decode() - @property - def access_token(self): - """ - Generates a JWT access token - """ - if self._access_token is None: - self._access_token = self._generate_token(FIFTEEN_MINUTES) + def validate(self): + """Decodes the auth token and performs some preliminary validation.""" + self.is_expired = False + self.is_valid = True - return self._access_token - - @property - def refresh_token(self): - """ - Generates a JWT access token - """ - if self._refresh_token is None: - self._refresh_token = self._generate_token(ONE_MONTH) - - return self._refresh_token - - -class AuthenticationTokenValidator(object): - def __init__(self, token: str, secret: str): - self.token = token - self.secret = secret - self.account_id = None - self.token_is_expired = False - self.token_is_invalid = False - - def validate_token(self): - """Decodes the auth token""" - try: - payload = jwt.decode(self.token, self.secret) - self.account_id = payload['sub'] - except jwt.ExpiredSignatureError: - self.token_is_expired = True - except jwt.InvalidTokenError: - self.token_is_invalid = True + if self.jwt is None: + self.is_expired = True + else: + try: + payload = jwt.decode(self.jwt, self.secret) + self.account_id = payload['sub'] + except jwt.ExpiredSignatureError: + self.is_expired = True + except jwt.InvalidTokenError: + self.is_valid = False