convert from flask-restful to flask
parent
517177078c
commit
2b996a99ad
|
@ -1,11 +1,9 @@
|
|||
"""Define the API that will support Mycroft single sign on (SSO)."""
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from flask import Flask, request
|
||||
from flask_restful import Api
|
||||
|
||||
from selene.api.base_config import get_base_config
|
||||
from selene.api import get_base_config, selene_api, SeleneResponse
|
||||
from selene.util.log import configure_logger
|
||||
|
||||
from .endpoints import (
|
||||
AuthenticateInternalEndpoint,
|
||||
|
@ -13,18 +11,30 @@ from .endpoints import (
|
|||
ValidateFederatedEndpoint
|
||||
)
|
||||
|
||||
_log = getLogger('sso_api')
|
||||
_log = configure_logger('sso_api')
|
||||
|
||||
# Initialize the Flask application and the Flask Restful API
|
||||
# Define the Flask application
|
||||
sso = Flask(__name__)
|
||||
sso.config.from_object(get_base_config())
|
||||
sso.response_class = SeleneResponse
|
||||
sso.register_blueprint(selene_api)
|
||||
|
||||
# Initialize the REST API and define the endpoints
|
||||
sso_api = Api(sso, catch_all_404s=True)
|
||||
sso_api.add_resource(AuthenticateInternalEndpoint, '/api/internal-login')
|
||||
sso_api.add_resource(ValidateFederatedEndpoint, '/api/validate-federated')
|
||||
|
||||
sso_api.add_resource(LogoutEndpoint, '/api/logout')
|
||||
# Define the endpoints
|
||||
sso.add_url_rule(
|
||||
'/api/internal-login',
|
||||
view_func=AuthenticateInternalEndpoint.as_view('internal_login'),
|
||||
methods=['GET']
|
||||
)
|
||||
sso.add_url_rule(
|
||||
'/api/validate-federated',
|
||||
view_func=ValidateFederatedEndpoint.as_view('federated_login'),
|
||||
methods=['POST']
|
||||
)
|
||||
sso.add_url_rule(
|
||||
'/api/logout',
|
||||
view_func=LogoutEndpoint.as_view('logout'),
|
||||
methods=['GET']
|
||||
)
|
||||
|
||||
|
||||
def add_cors_headers(response):
|
||||
|
|
|
@ -18,20 +18,16 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
|
|||
"""Sign in a user with an email address and password."""
|
||||
def __init__(self):
|
||||
super(AuthenticateInternalEndpoint, self).__init__()
|
||||
self.response_status_code = HTTPStatus.OK
|
||||
self.account: Account = None
|
||||
|
||||
def get(self):
|
||||
"""Process HTTP GET request."""
|
||||
try:
|
||||
self._authenticate_credentials()
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._add_refresh_token_to_db(refresh_token)
|
||||
self._set_token_cookies(access_token, refresh_token)
|
||||
except AuthenticationError as ae:
|
||||
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
|
||||
else:
|
||||
self.response = ({}, HTTPStatus.OK)
|
||||
self._authenticate_credentials()
|
||||
self._generate_tokens()
|
||||
self._add_refresh_token_to_db()
|
||||
self._set_token_cookies()
|
||||
|
||||
self.response = dict(result='user authenticated'), HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
|
@ -52,15 +48,15 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
|
|||
)
|
||||
if self.account is None:
|
||||
raise AuthenticationError('provided credentials not found')
|
||||
self.access_token.account_id = self.account.id
|
||||
self.refresh_token.account_id = self.account.id
|
||||
|
||||
def _add_refresh_token_to_db(self, refresh_token: str):
|
||||
def _add_refresh_token_to_db(self):
|
||||
"""Track refresh tokens in the database.
|
||||
|
||||
We need to store the value of the refresh token in the database so
|
||||
that we can validate it when it is used to request new tokens.
|
||||
|
||||
:param refresh_token: the token to install into the database.
|
||||
"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repo = RefreshTokenRepository(db, self.account.id)
|
||||
token_repo.add_refresh_token(refresh_token)
|
||||
token_repo.add_refresh_token(self.refresh_token.jwt)
|
||||
|
|
|
@ -13,8 +13,7 @@ _log = getLogger(__package__)
|
|||
class LogoutEndpoint(SeleneEndpoint):
|
||||
def get(self):
|
||||
self._authenticate()
|
||||
if self.authenticated or self.refresh_token_expired:
|
||||
self._logout()
|
||||
self._logout()
|
||||
|
||||
return self.response
|
||||
|
||||
|
@ -26,9 +25,9 @@ class LogoutEndpoint(SeleneEndpoint):
|
|||
"""
|
||||
request_refresh_token = self.request.cookies['seleneRefresh']
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, self.account)
|
||||
token_repository = RefreshTokenRepository(db, self.account.id)
|
||||
token_repository.delete_refresh_token(request_refresh_token)
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._set_token_cookies(access_token, refresh_token, expire=True)
|
||||
self._generate_tokens()
|
||||
self._set_token_cookies(expire=True)
|
||||
|
||||
self.response = ('logged out', HTTPStatus.OK)
|
||||
|
|
|
@ -16,19 +16,15 @@ from selene.util.db import get_db_connection
|
|||
class ValidateFederatedEndpoint(SeleneEndpoint):
|
||||
def post(self):
|
||||
"""Process a HTTP POST request."""
|
||||
try:
|
||||
self._get_account()
|
||||
except AuthenticationError as ae:
|
||||
self.response = str(ae), HTTPStatus.UNAUTHORIZED
|
||||
else:
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._set_token_cookies(access_token, refresh_token)
|
||||
self._add_refresh_token_to_db(refresh_token)
|
||||
self.response = 'account validated', HTTPStatus.OK
|
||||
self._get_account_by_email()
|
||||
self._generate_tokens()
|
||||
self._set_token_cookies()
|
||||
self._add_refresh_token_to_db()
|
||||
self.response = dict(result='account validated'), HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
def _get_account(self):
|
||||
def _get_account_by_email(self):
|
||||
"""Use email returned by the authentication platform for validation"""
|
||||
email_address = self.request.form['email']
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
|
@ -38,14 +34,13 @@ class ValidateFederatedEndpoint(SeleneEndpoint):
|
|||
if self.account is None:
|
||||
raise AuthenticationError('account not found')
|
||||
|
||||
def _add_refresh_token_to_db(self, refresh_token):
|
||||
def _add_refresh_token_to_db(self):
|
||||
"""Track refresh tokens in the database.
|
||||
|
||||
We need to store the value of the refresh token in the database so
|
||||
that we can validate it when it is used to request new tokens.
|
||||
|
||||
:param refresh_token: the token to install into the database.
|
||||
"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repo = RefreshTokenRepository(db, self.account)
|
||||
token_repo.add_refresh_token(refresh_token)
|
||||
token_repo = RefreshTokenRepository(db, self.account.id)
|
||||
token_repo.add_refresh_token(self.refresh_token.jwt)
|
||||
|
|
|
@ -38,29 +38,28 @@ def before_scenario(context, _):
|
|||
|
||||
|
||||
def _add_agreement(context, db):
|
||||
context.agreement = Agreement(
|
||||
agreement = Agreement(
|
||||
type='Privacy Policy',
|
||||
version='1',
|
||||
content='this is Privacy Policy version 1',
|
||||
version='999',
|
||||
content='this is Privacy Policy version 999',
|
||||
effective_date=date.today() - timedelta(days=5)
|
||||
)
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement_repository.add(context.agreement)
|
||||
agreement_repository.add(agreement)
|
||||
context.agreement = agreement_repository.get_active_for_type(PRIVACY_POLICY)
|
||||
|
||||
|
||||
def _add_account(context, db):
|
||||
test_account = Account(
|
||||
id=None,
|
||||
email_address='foo@mycroft.ai',
|
||||
username='foobar',
|
||||
refresh_tokens=None,
|
||||
display_name='foobar',
|
||||
subscription=AccountSubscription(
|
||||
type='monthly supporter',
|
||||
start_date=None,
|
||||
type='Monthly Supporter',
|
||||
start_date=date.today(),
|
||||
stripe_customer_id='foo'
|
||||
),
|
||||
agreements=[
|
||||
AccountAgreement(name=PRIVACY_POLICY, accept_date=None)
|
||||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
|
||||
]
|
||||
)
|
||||
acct_repository = AccountRepository(db)
|
||||
|
|
|
@ -62,4 +62,5 @@ def check_for_login_fail(context, error_message):
|
|||
equal_to('*')
|
||||
)
|
||||
assert_that(context.response.is_json, equal_to(True))
|
||||
assert_that(context.response.get_json(), equal_to(error_message))
|
||||
response_json = context.response.get_json()
|
||||
assert_that(response_json['error'], equal_to(error_message))
|
||||
|
|
|
@ -3,7 +3,8 @@ from behave import given, then, when
|
|||
from hamcrest import assert_that, equal_to, has_item, is_not
|
||||
|
||||
from selene.api.testing import (
|
||||
generate_auth_tokens,
|
||||
generate_access_token,
|
||||
generate_refresh_token,
|
||||
get_account,
|
||||
validate_token_cookies
|
||||
)
|
||||
|
@ -16,7 +17,8 @@ def save_email(context, email):
|
|||
|
||||
@when('user attempts to logout')
|
||||
def call_logout_endpoint(context):
|
||||
generate_auth_tokens(context)
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
context.response = context.client.get('/api/logout')
|
||||
|
||||
|
||||
|
@ -39,7 +41,7 @@ def check_refresh_token_removed(context):
|
|||
account = get_account(context)
|
||||
assert_that(
|
||||
account.refresh_tokens,
|
||||
is_not(has_item(context.request_refresh_token))
|
||||
is_not(has_item(context.refresh_token))
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue