fixing token authentication bugs

pull/53/head
Chris Veilleux 2019-02-19 13:33:06 -06:00
parent b98f667f30
commit e8a0a161d3
3 changed files with 117 additions and 149 deletions

View File

@ -1,7 +1,5 @@
"""Base class for Flask API endpoints"""
from http import HTTPStatus
from logging import getLogger
from flask import after_this_request, current_app, request
from flask.views import MethodView
@ -10,15 +8,16 @@ from selene.data.account import (
AccountRepository,
RefreshTokenRepository
)
from selene.util.auth import (
AuthenticationError,
AuthenticationTokenGenerator,
AuthenticationTokenValidator,
FIFTEEN_MINUTES,
ONE_MONTH
)
from selene.util.auth import AuthenticationError, AuthenticationToken
from selene.util.db import get_db_connection
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
_log = getLogger()
class APIError(Exception):
"""Raise this exception whenever a non-successful response is built"""
@ -34,16 +33,19 @@ class SeleneEndpoint(MethodView):
HTTP methods. Each list member must be a HTTPMethod enum
- override the _build_response_data method
"""
authentication_required: bool = True
def __init__(self):
self.config: dict = current_app.config
self.authenticated = False
self.request = request
self.response: tuple = None
self.access_token_expired: bool = False
self.refresh_token_expired: bool = False
self.account: Account = None
self.access_token = AuthenticationToken(
self.config['ACCESS_SECRET'],
FIFTEEN_MINUTES
)
self.refresh_token = AuthenticationToken(
self.config['REFRESH_SECRET'],
ONE_MONTH
)
def _authenticate(self):
"""
@ -51,90 +53,92 @@ class SeleneEndpoint(MethodView):
:raises: APIError()
"""
try:
account_id = self._validate_auth_tokens()
self._validate_account(account_id)
except AuthenticationError as ae:
if self.authentication_required:
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
else:
self.authenticated = True
self._validate_auth_tokens()
account_id = self._get_account_id_from_tokens()
self._get_account(account_id)
self._validate_account(account_id)
if self.access_token.is_expired:
self._refresh_auth_tokens()
def _validate_auth_tokens(self) -> str:
self.access_token_expired, account_id = self._validate_token(
'seleneAccess',
self.config['ACCESS_SECRET']
def _validate_auth_tokens(self):
"""Ensure the tokens are passed in request and are well formed."""
self.access_token.jwt = self.request.cookies.get(
ACCESS_TOKEN_COOKIE_NAME
)
if self.access_token_expired:
self.refresh_token_expired, account_id = self._validate_token(
'seleneRefresh',
self.config['REFRESH_SECRET']
)
self.access_token.validate()
self.refresh_token.jwt = self.request.cookies.get(
REFRESH_TOKEN_COOKIE_NAME
)
self.refresh_token.validate()
if self.access_token.jwt is None and self.refresh_token.jwt is None:
raise AuthenticationError('no authentication tokens found')
if self.access_token.is_expired and self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
def _get_account_id_from_tokens(self):
"""Extract the account ID, which is encoded within the tokens"""
if self.access_token.is_expired:
account_id = self.refresh_token.account_id
else:
account_id = self.access_token.account_id
return account_id
def _validate_token(self, cookie_key, jwt_secret):
"""Validate the access token is well-formed and not expired
:raises: AuthenticationError
"""
account_id = None
token_expired = False
try:
token = self.request.cookies[cookie_key]
except KeyError:
error_msg = 'no {} token found in request'
raise AuthenticationError(error_msg.format(cookie_key))
validator = AuthenticationTokenValidator(token, jwt_secret)
validator.validate_token()
if validator.token_is_expired:
token_expired = True
elif validator.token_is_invalid:
raise AuthenticationError('access token is invalid')
else:
account_id = validator.account_id
return token_expired, account_id
def _validate_account(self, account_id):
"""The refresh token in the request must match the database value.
:raises: AuthenticationError
"""
def _get_account(self, account_id):
"""Use account ID from decoded authentication token to get account."""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
account_repository = AccountRepository(db)
self.account = account_repository.get_account_by_id(account_id)
def _validate_account(self, account_id: str):
"""Account must exist and contain have a refresh token matching request.
:raises: AuthenticationError
"""
if self.account is None:
_log.error('account ID {} not on database'.format(account_id))
raise AuthenticationError('account not found')
if self.access_token_expired:
if self.refresh_token not in self.account.refresh_tokens:
raise AuthenticationError('refresh token not found')
if self.refresh_token.jwt not in self.account.refresh_tokens:
log_msg = 'account ID {} does not have token {}'
_log.error(log_msg.format(account_id, self.refresh_token.jwt))
raise AuthenticationError(
'refresh token does not exist for this account'
)
def _refresh_auth_tokens(self):
"""Steps necessary to refresh the tokens used for authentication."""
old_refresh_token = self.refresh_token
self._generate_tokens()
self._update_refresh_token_on_db(old_refresh_token)
self._set_token_cookies()
def _generate_tokens(self):
token_generator = AuthenticationTokenGenerator(
self.account.id,
self.config['ACCESS_SECRET'],
self.config['REFRESH_SECRET']
)
access_token = token_generator.access_token
refresh_token = token_generator.refresh_token
"""Generate an access token and refresh token."""
self.access_token.generate()
self.refresh_token.generate()
return access_token, refresh_token
def _set_token_cookies(self, expire=False):
"""Set the cookies that contain the authentication token.
def _set_token_cookies(self, access_token, refresh_token, expire=False):
This method should be called when a user logs in, logs out, or when
their access token expires.
:param expire: generate tokens that immediately expire, effectively
logging a user out of the system.
:return:
"""
access_token_cookie = dict(
key='seleneAccess',
value=str(access_token),
value=str(self.access_token.jwt),
domain=self.config['DOMAIN'],
max_age=FIFTEEN_MINUTES,
)
refresh_token_cookie = dict(
key='seleneRefresh',
value=str(refresh_token),
value=str(self.refresh_token.jwt),
domain=self.config['DOMAIN'],
max_age=ONE_MONTH,
)
@ -145,20 +149,21 @@ class SeleneEndpoint(MethodView):
@after_this_request
def set_cookies(response):
"""Use Flask after request hook to reset token cookies"""
response.set_cookie(**access_token_cookie)
response.set_cookie(**refresh_token_cookie)
return response
def _update_refresh_token_on_db(self, new_refresh_token):
old_refresh_token = self.request.cookies['seleneRefresh']
def _update_refresh_token_on_db(self, old_refresh_token):
"""Replace the refresh token on the request with the newly minted one"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account)
if self.refresh_token_expired:
token_repository = RefreshTokenRepository(db, self.account.id)
if old_refresh_token.is_expired:
token_repository.delete_refresh_token(old_refresh_token)
raise AuthenticationError('refresh token expired')
else:
token_repository.update_refresh_token(
new_refresh_token,
old_refresh_token
self.refresh_token.jwt,
old_refresh_token.jwt
)

View File

@ -83,10 +83,9 @@ class AccountEndpoint(SeleneEndpoint):
def get(self):
"""Process HTTP GET request for an account."""
self._authenticate()
if self.authenticated:
response_data = asdict(self.account)
del (response_data['refresh_tokens'])
self.response = response_data, HTTPStatus.OK
response_data = asdict(self.account)
del (response_data['refresh_tokens'])
self.response = response_data, HTTPStatus.OK
return self.response

View File

@ -1,86 +1,50 @@
"""Logic for generating and validating JWT authentication tokens."""
from datetime import datetime
from time import time
import jwt
FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
class AuthenticationError(Exception):
pass
class AuthenticationTokenGenerator(object):
_access_token = None
_refresh_token = None
class AuthenticationToken(object):
def __init__(self, secret: str, duration: int):
self.secret = secret
self.duration = duration
self.jwt: str = ''
self.is_valid: bool = None
self.is_expired: bool = None
self.account_id: str = None
def __init__(self, account_id: str, access_secret, refresh_secret):
self.account_id = account_id
self.access_secret = access_secret
self.refresh_secret = refresh_secret
def _generate_token(self, token_duration: int):
def generate(self):
"""
Generates a JWT token
"""
token_expiration = time() + token_duration
payload = dict(
iat=datetime.utcnow(),
exp=token_expiration,
exp=time() + self.duration,
sub=self.account_id
)
if token_duration == FIFTEEN_MINUTES:
secret = self.access_secret
else:
secret = self.refresh_secret
token = jwt.encode(
payload,
secret,
algorithm='HS256'
)
token = jwt.encode(payload, self.secret, algorithm='HS256')
# convert the token from byte-array to string so that
# it can be included in a JSON response object
return token.decode()
self.jwt = token.decode()
@property
def access_token(self):
"""
Generates a JWT access token
"""
if self._access_token is None:
self._access_token = self._generate_token(FIFTEEN_MINUTES)
def validate(self):
"""Decodes the auth token and performs some preliminary validation."""
self.is_expired = False
self.is_valid = True
return self._access_token
@property
def refresh_token(self):
"""
Generates a JWT access token
"""
if self._refresh_token is None:
self._refresh_token = self._generate_token(ONE_MONTH)
return self._refresh_token
class AuthenticationTokenValidator(object):
def __init__(self, token: str, secret: str):
self.token = token
self.secret = secret
self.account_id = None
self.token_is_expired = False
self.token_is_invalid = False
def validate_token(self):
"""Decodes the auth token"""
try:
payload = jwt.decode(self.token, self.secret)
self.account_id = payload['sub']
except jwt.ExpiredSignatureError:
self.token_is_expired = True
except jwt.InvalidTokenError:
self.token_is_invalid = True
if self.jwt is None:
self.is_expired = True
else:
try:
payload = jwt.decode(self.jwt, self.secret)
self.account_id = payload['sub']
except jwt.ExpiredSignatureError:
self.is_expired = True
except jwt.InvalidTokenError:
self.is_valid = False