fixing token authentication bugs
parent
b98f667f30
commit
e8a0a161d3
|
@ -1,7 +1,5 @@
|
|||
"""Base class for Flask API endpoints"""
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from logging import getLogger
|
||||
from flask import after_this_request, current_app, request
|
||||
from flask.views import MethodView
|
||||
|
||||
|
@ -10,15 +8,16 @@ from selene.data.account import (
|
|||
AccountRepository,
|
||||
RefreshTokenRepository
|
||||
)
|
||||
from selene.util.auth import (
|
||||
AuthenticationError,
|
||||
AuthenticationTokenGenerator,
|
||||
AuthenticationTokenValidator,
|
||||
FIFTEEN_MINUTES,
|
||||
ONE_MONTH
|
||||
)
|
||||
from selene.util.auth import AuthenticationError, AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
|
||||
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
||||
FIFTEEN_MINUTES = 900
|
||||
ONE_MONTH = 2628000
|
||||
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
|
||||
|
||||
_log = getLogger()
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Raise this exception whenever a non-successful response is built"""
|
||||
|
@ -34,16 +33,19 @@ class SeleneEndpoint(MethodView):
|
|||
HTTP methods. Each list member must be a HTTPMethod enum
|
||||
- override the _build_response_data method
|
||||
"""
|
||||
authentication_required: bool = True
|
||||
|
||||
def __init__(self):
|
||||
self.config: dict = current_app.config
|
||||
self.authenticated = False
|
||||
self.request = request
|
||||
self.response: tuple = None
|
||||
self.access_token_expired: bool = False
|
||||
self.refresh_token_expired: bool = False
|
||||
self.account: Account = None
|
||||
self.access_token = AuthenticationToken(
|
||||
self.config['ACCESS_SECRET'],
|
||||
FIFTEEN_MINUTES
|
||||
)
|
||||
self.refresh_token = AuthenticationToken(
|
||||
self.config['REFRESH_SECRET'],
|
||||
ONE_MONTH
|
||||
)
|
||||
|
||||
def _authenticate(self):
|
||||
"""
|
||||
|
@ -51,90 +53,92 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
:raises: APIError()
|
||||
"""
|
||||
try:
|
||||
account_id = self._validate_auth_tokens()
|
||||
self._validate_account(account_id)
|
||||
except AuthenticationError as ae:
|
||||
if self.authentication_required:
|
||||
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
|
||||
else:
|
||||
self.authenticated = True
|
||||
self._validate_auth_tokens()
|
||||
account_id = self._get_account_id_from_tokens()
|
||||
self._get_account(account_id)
|
||||
self._validate_account(account_id)
|
||||
if self.access_token.is_expired:
|
||||
self._refresh_auth_tokens()
|
||||
|
||||
def _validate_auth_tokens(self) -> str:
|
||||
self.access_token_expired, account_id = self._validate_token(
|
||||
'seleneAccess',
|
||||
self.config['ACCESS_SECRET']
|
||||
def _validate_auth_tokens(self):
|
||||
"""Ensure the tokens are passed in request and are well formed."""
|
||||
self.access_token.jwt = self.request.cookies.get(
|
||||
ACCESS_TOKEN_COOKIE_NAME
|
||||
)
|
||||
if self.access_token_expired:
|
||||
self.refresh_token_expired, account_id = self._validate_token(
|
||||
'seleneRefresh',
|
||||
self.config['REFRESH_SECRET']
|
||||
)
|
||||
self.access_token.validate()
|
||||
self.refresh_token.jwt = self.request.cookies.get(
|
||||
REFRESH_TOKEN_COOKIE_NAME
|
||||
)
|
||||
self.refresh_token.validate()
|
||||
|
||||
if self.access_token.jwt is None and self.refresh_token.jwt is None:
|
||||
raise AuthenticationError('no authentication tokens found')
|
||||
|
||||
if self.access_token.is_expired and self.refresh_token.is_expired:
|
||||
raise AuthenticationError('authentication tokens expired')
|
||||
|
||||
def _get_account_id_from_tokens(self):
|
||||
"""Extract the account ID, which is encoded within the tokens"""
|
||||
if self.access_token.is_expired:
|
||||
account_id = self.refresh_token.account_id
|
||||
else:
|
||||
account_id = self.access_token.account_id
|
||||
|
||||
return account_id
|
||||
|
||||
def _validate_token(self, cookie_key, jwt_secret):
|
||||
"""Validate the access token is well-formed and not expired
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
account_id = None
|
||||
token_expired = False
|
||||
|
||||
try:
|
||||
token = self.request.cookies[cookie_key]
|
||||
except KeyError:
|
||||
error_msg = 'no {} token found in request'
|
||||
raise AuthenticationError(error_msg.format(cookie_key))
|
||||
|
||||
validator = AuthenticationTokenValidator(token, jwt_secret)
|
||||
validator.validate_token()
|
||||
if validator.token_is_expired:
|
||||
token_expired = True
|
||||
elif validator.token_is_invalid:
|
||||
raise AuthenticationError('access token is invalid')
|
||||
else:
|
||||
account_id = validator.account_id
|
||||
|
||||
return token_expired, account_id
|
||||
|
||||
def _validate_account(self, account_id):
|
||||
"""The refresh token in the request must match the database value.
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
def _get_account(self, account_id):
|
||||
"""Use account ID from decoded authentication token to get account."""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
account_repository = AccountRepository(db)
|
||||
self.account = account_repository.get_account_by_id(account_id)
|
||||
|
||||
def _validate_account(self, account_id: str):
|
||||
"""Account must exist and contain have a refresh token matching request.
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
if self.account is None:
|
||||
_log.error('account ID {} not on database'.format(account_id))
|
||||
raise AuthenticationError('account not found')
|
||||
|
||||
if self.access_token_expired:
|
||||
if self.refresh_token not in self.account.refresh_tokens:
|
||||
raise AuthenticationError('refresh token not found')
|
||||
if self.refresh_token.jwt not in self.account.refresh_tokens:
|
||||
log_msg = 'account ID {} does not have token {}'
|
||||
_log.error(log_msg.format(account_id, self.refresh_token.jwt))
|
||||
raise AuthenticationError(
|
||||
'refresh token does not exist for this account'
|
||||
)
|
||||
|
||||
def _refresh_auth_tokens(self):
|
||||
"""Steps necessary to refresh the tokens used for authentication."""
|
||||
old_refresh_token = self.refresh_token
|
||||
self._generate_tokens()
|
||||
self._update_refresh_token_on_db(old_refresh_token)
|
||||
self._set_token_cookies()
|
||||
|
||||
def _generate_tokens(self):
|
||||
token_generator = AuthenticationTokenGenerator(
|
||||
self.account.id,
|
||||
self.config['ACCESS_SECRET'],
|
||||
self.config['REFRESH_SECRET']
|
||||
)
|
||||
access_token = token_generator.access_token
|
||||
refresh_token = token_generator.refresh_token
|
||||
"""Generate an access token and refresh token."""
|
||||
self.access_token.generate()
|
||||
self.refresh_token.generate()
|
||||
|
||||
return access_token, refresh_token
|
||||
def _set_token_cookies(self, expire=False):
|
||||
"""Set the cookies that contain the authentication token.
|
||||
|
||||
def _set_token_cookies(self, access_token, refresh_token, expire=False):
|
||||
This method should be called when a user logs in, logs out, or when
|
||||
their access token expires.
|
||||
|
||||
:param expire: generate tokens that immediately expire, effectively
|
||||
logging a user out of the system.
|
||||
:return:
|
||||
"""
|
||||
access_token_cookie = dict(
|
||||
key='seleneAccess',
|
||||
value=str(access_token),
|
||||
value=str(self.access_token.jwt),
|
||||
domain=self.config['DOMAIN'],
|
||||
max_age=FIFTEEN_MINUTES,
|
||||
)
|
||||
refresh_token_cookie = dict(
|
||||
key='seleneRefresh',
|
||||
value=str(refresh_token),
|
||||
value=str(self.refresh_token.jwt),
|
||||
domain=self.config['DOMAIN'],
|
||||
max_age=ONE_MONTH,
|
||||
)
|
||||
|
@ -145,20 +149,21 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
@after_this_request
|
||||
def set_cookies(response):
|
||||
"""Use Flask after request hook to reset token cookies"""
|
||||
response.set_cookie(**access_token_cookie)
|
||||
response.set_cookie(**refresh_token_cookie)
|
||||
|
||||
return response
|
||||
|
||||
def _update_refresh_token_on_db(self, new_refresh_token):
|
||||
old_refresh_token = self.request.cookies['seleneRefresh']
|
||||
def _update_refresh_token_on_db(self, old_refresh_token):
|
||||
"""Replace the refresh token on the request with the newly minted one"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, self.account)
|
||||
if self.refresh_token_expired:
|
||||
token_repository = RefreshTokenRepository(db, self.account.id)
|
||||
if old_refresh_token.is_expired:
|
||||
token_repository.delete_refresh_token(old_refresh_token)
|
||||
raise AuthenticationError('refresh token expired')
|
||||
else:
|
||||
token_repository.update_refresh_token(
|
||||
new_refresh_token,
|
||||
old_refresh_token
|
||||
self.refresh_token.jwt,
|
||||
old_refresh_token.jwt
|
||||
)
|
||||
|
|
|
@ -83,10 +83,9 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
def get(self):
|
||||
"""Process HTTP GET request for an account."""
|
||||
self._authenticate()
|
||||
if self.authenticated:
|
||||
response_data = asdict(self.account)
|
||||
del (response_data['refresh_tokens'])
|
||||
self.response = response_data, HTTPStatus.OK
|
||||
response_data = asdict(self.account)
|
||||
del (response_data['refresh_tokens'])
|
||||
self.response = response_data, HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
|
|
|
@ -1,86 +1,50 @@
|
|||
"""Logic for generating and validating JWT authentication tokens."""
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
import jwt
|
||||
|
||||
FIFTEEN_MINUTES = 900
|
||||
ONE_MONTH = 2628000
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationTokenGenerator(object):
|
||||
_access_token = None
|
||||
_refresh_token = None
|
||||
class AuthenticationToken(object):
|
||||
def __init__(self, secret: str, duration: int):
|
||||
self.secret = secret
|
||||
self.duration = duration
|
||||
self.jwt: str = ''
|
||||
self.is_valid: bool = None
|
||||
self.is_expired: bool = None
|
||||
self.account_id: str = None
|
||||
|
||||
def __init__(self, account_id: str, access_secret, refresh_secret):
|
||||
self.account_id = account_id
|
||||
self.access_secret = access_secret
|
||||
self.refresh_secret = refresh_secret
|
||||
|
||||
def _generate_token(self, token_duration: int):
|
||||
def generate(self):
|
||||
"""
|
||||
Generates a JWT token
|
||||
"""
|
||||
token_expiration = time() + token_duration
|
||||
payload = dict(
|
||||
iat=datetime.utcnow(),
|
||||
exp=token_expiration,
|
||||
exp=time() + self.duration,
|
||||
sub=self.account_id
|
||||
)
|
||||
|
||||
if token_duration == FIFTEEN_MINUTES:
|
||||
secret = self.access_secret
|
||||
else:
|
||||
secret = self.refresh_secret
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm='HS256'
|
||||
)
|
||||
token = jwt.encode(payload, self.secret, algorithm='HS256')
|
||||
|
||||
# convert the token from byte-array to string so that
|
||||
# it can be included in a JSON response object
|
||||
return token.decode()
|
||||
self.jwt = token.decode()
|
||||
|
||||
@property
|
||||
def access_token(self):
|
||||
"""
|
||||
Generates a JWT access token
|
||||
"""
|
||||
if self._access_token is None:
|
||||
self._access_token = self._generate_token(FIFTEEN_MINUTES)
|
||||
def validate(self):
|
||||
"""Decodes the auth token and performs some preliminary validation."""
|
||||
self.is_expired = False
|
||||
self.is_valid = True
|
||||
|
||||
return self._access_token
|
||||
|
||||
@property
|
||||
def refresh_token(self):
|
||||
"""
|
||||
Generates a JWT access token
|
||||
"""
|
||||
if self._refresh_token is None:
|
||||
self._refresh_token = self._generate_token(ONE_MONTH)
|
||||
|
||||
return self._refresh_token
|
||||
|
||||
|
||||
class AuthenticationTokenValidator(object):
|
||||
def __init__(self, token: str, secret: str):
|
||||
self.token = token
|
||||
self.secret = secret
|
||||
self.account_id = None
|
||||
self.token_is_expired = False
|
||||
self.token_is_invalid = False
|
||||
|
||||
def validate_token(self):
|
||||
"""Decodes the auth token"""
|
||||
try:
|
||||
payload = jwt.decode(self.token, self.secret)
|
||||
self.account_id = payload['sub']
|
||||
except jwt.ExpiredSignatureError:
|
||||
self.token_is_expired = True
|
||||
except jwt.InvalidTokenError:
|
||||
self.token_is_invalid = True
|
||||
if self.jwt is None:
|
||||
self.is_expired = True
|
||||
else:
|
||||
try:
|
||||
payload = jwt.decode(self.jwt, self.secret)
|
||||
self.account_id = payload['sub']
|
||||
except jwt.ExpiredSignatureError:
|
||||
self.is_expired = True
|
||||
except jwt.InvalidTokenError:
|
||||
self.is_valid = False
|
||||
|
|
Loading…
Reference in New Issue