convert from flask-restful to flask

pull/53/head
Chris Veilleux 2019-02-19 14:41:36 -06:00
parent 517177078c
commit 2b996a99ad
7 changed files with 61 additions and 59 deletions

View File

@ -1,11 +1,9 @@
"""Define the API that will support Mycroft single sign on (SSO)."""
from logging import getLogger
from flask import Flask, request
from flask_restful import Api
from selene.api.base_config import get_base_config
from selene.api import get_base_config, selene_api, SeleneResponse
from selene.util.log import configure_logger
from .endpoints import (
AuthenticateInternalEndpoint,
@ -13,18 +11,30 @@ from .endpoints import (
ValidateFederatedEndpoint
)
_log = getLogger('sso_api')
_log = configure_logger('sso_api')
# Initialize the Flask application and the Flask Restful API
# Define the Flask application
sso = Flask(__name__)
sso.config.from_object(get_base_config())
sso.response_class = SeleneResponse
sso.register_blueprint(selene_api)
# Initialize the REST API and define the endpoints
sso_api = Api(sso, catch_all_404s=True)
sso_api.add_resource(AuthenticateInternalEndpoint, '/api/internal-login')
sso_api.add_resource(ValidateFederatedEndpoint, '/api/validate-federated')
sso_api.add_resource(LogoutEndpoint, '/api/logout')
# Define the endpoints
sso.add_url_rule(
'/api/internal-login',
view_func=AuthenticateInternalEndpoint.as_view('internal_login'),
methods=['GET']
)
sso.add_url_rule(
'/api/validate-federated',
view_func=ValidateFederatedEndpoint.as_view('federated_login'),
methods=['POST']
)
sso.add_url_rule(
'/api/logout',
view_func=LogoutEndpoint.as_view('logout'),
methods=['GET']
)
def add_cors_headers(response):

View File

@ -18,20 +18,16 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
"""Sign in a user with an email address and password."""
def __init__(self):
super(AuthenticateInternalEndpoint, self).__init__()
self.response_status_code = HTTPStatus.OK
self.account: Account = None
def get(self):
"""Process HTTP GET request."""
try:
self._authenticate_credentials()
access_token, refresh_token = self._generate_tokens()
self._add_refresh_token_to_db(refresh_token)
self._set_token_cookies(access_token, refresh_token)
except AuthenticationError as ae:
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
else:
self.response = ({}, HTTPStatus.OK)
self._authenticate_credentials()
self._generate_tokens()
self._add_refresh_token_to_db()
self._set_token_cookies()
self.response = dict(result='user authenticated'), HTTPStatus.OK
return self.response
@ -52,15 +48,15 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
)
if self.account is None:
raise AuthenticationError('provided credentials not found')
self.access_token.account_id = self.account.id
self.refresh_token.account_id = self.account.id
def _add_refresh_token_to_db(self, refresh_token: str):
def _add_refresh_token_to_db(self):
"""Track refresh tokens in the database.
We need to store the value of the refresh token in the database so
that we can validate it when it is used to request new tokens.
:param refresh_token: the token to install into the database.
"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repo = RefreshTokenRepository(db, self.account.id)
token_repo.add_refresh_token(refresh_token)
token_repo.add_refresh_token(self.refresh_token.jwt)

View File

@ -13,8 +13,7 @@ _log = getLogger(__package__)
class LogoutEndpoint(SeleneEndpoint):
def get(self):
self._authenticate()
if self.authenticated or self.refresh_token_expired:
self._logout()
self._logout()
return self.response
@ -26,9 +25,9 @@ class LogoutEndpoint(SeleneEndpoint):
"""
request_refresh_token = self.request.cookies['seleneRefresh']
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account)
token_repository = RefreshTokenRepository(db, self.account.id)
token_repository.delete_refresh_token(request_refresh_token)
access_token, refresh_token = self._generate_tokens()
self._set_token_cookies(access_token, refresh_token, expire=True)
self._generate_tokens()
self._set_token_cookies(expire=True)
self.response = ('logged out', HTTPStatus.OK)

View File

@ -16,19 +16,15 @@ from selene.util.db import get_db_connection
class ValidateFederatedEndpoint(SeleneEndpoint):
def post(self):
"""Process a HTTP POST request."""
try:
self._get_account()
except AuthenticationError as ae:
self.response = str(ae), HTTPStatus.UNAUTHORIZED
else:
access_token, refresh_token = self._generate_tokens()
self._set_token_cookies(access_token, refresh_token)
self._add_refresh_token_to_db(refresh_token)
self.response = 'account validated', HTTPStatus.OK
self._get_account_by_email()
self._generate_tokens()
self._set_token_cookies()
self._add_refresh_token_to_db()
self.response = dict(result='account validated'), HTTPStatus.OK
return self.response
def _get_account(self):
def _get_account_by_email(self):
"""Use email returned by the authentication platform for validation"""
email_address = self.request.form['email']
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
@ -38,14 +34,13 @@ class ValidateFederatedEndpoint(SeleneEndpoint):
if self.account is None:
raise AuthenticationError('account not found')
def _add_refresh_token_to_db(self, refresh_token):
def _add_refresh_token_to_db(self):
"""Track refresh tokens in the database.
We need to store the value of the refresh token in the database so
that we can validate it when it is used to request new tokens.
:param refresh_token: the token to install into the database.
"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repo = RefreshTokenRepository(db, self.account)
token_repo.add_refresh_token(refresh_token)
token_repo = RefreshTokenRepository(db, self.account.id)
token_repo.add_refresh_token(self.refresh_token.jwt)

View File

@ -38,29 +38,28 @@ def before_scenario(context, _):
def _add_agreement(context, db):
context.agreement = Agreement(
agreement = Agreement(
type='Privacy Policy',
version='1',
content='this is Privacy Policy version 1',
version='999',
content='this is Privacy Policy version 999',
effective_date=date.today() - timedelta(days=5)
)
agreement_repository = AgreementRepository(db)
agreement_repository.add(context.agreement)
agreement_repository.add(agreement)
context.agreement = agreement_repository.get_active_for_type(PRIVACY_POLICY)
def _add_account(context, db):
test_account = Account(
id=None,
email_address='foo@mycroft.ai',
username='foobar',
refresh_tokens=None,
display_name='foobar',
subscription=AccountSubscription(
type='monthly supporter',
start_date=None,
type='Monthly Supporter',
start_date=date.today(),
stripe_customer_id='foo'
),
agreements=[
AccountAgreement(name=PRIVACY_POLICY, accept_date=None)
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
]
)
acct_repository = AccountRepository(db)

View File

@ -62,4 +62,5 @@ def check_for_login_fail(context, error_message):
equal_to('*')
)
assert_that(context.response.is_json, equal_to(True))
assert_that(context.response.get_json(), equal_to(error_message))
response_json = context.response.get_json()
assert_that(response_json['error'], equal_to(error_message))

View File

@ -3,7 +3,8 @@ from behave import given, then, when
from hamcrest import assert_that, equal_to, has_item, is_not
from selene.api.testing import (
generate_auth_tokens,
generate_access_token,
generate_refresh_token,
get_account,
validate_token_cookies
)
@ -16,7 +17,8 @@ def save_email(context, email):
@when('user attempts to logout')
def call_logout_endpoint(context):
generate_auth_tokens(context)
generate_access_token(context)
generate_refresh_token(context)
context.response = context.client.get('/api/logout')
@ -39,7 +41,7 @@ def check_refresh_token_removed(context):
account = get_account(context)
assert_that(
account.refresh_tokens,
is_not(has_item(context.request_refresh_token))
is_not(has_item(context.refresh_token))
)