Merge pull request #54 from MycroftAI/account-api
fixed a bug with new account logic when a user opts out of membershippull/49/head
commit
cf39ba2aca
|
@ -1,11 +1,22 @@
|
|||
Feature: Add a new account
|
||||
Test the API call to add an account to the database.
|
||||
|
||||
Scenario: Successful account addition
|
||||
When a valid new account request is submitted
|
||||
Then the request will be successful
|
||||
And the account will be added to the system
|
||||
Scenario: Successful account addition with membership
|
||||
Given a user completes on-boarding
|
||||
And user opts into a membership
|
||||
When the new account request is submitted
|
||||
Then the request will be successful
|
||||
And the account will be added to the system with a membership
|
||||
|
||||
Scenario: Successful account addition without membership
|
||||
Given a user completes on-boarding
|
||||
And user opts out of membership
|
||||
When the new account request is submitted
|
||||
Then the account will be added to the system without a membership
|
||||
|
||||
Scenario: Request missing a required field
|
||||
When a request is sent without an email address
|
||||
Given a user completes on-boarding
|
||||
And user does not specify an email address
|
||||
When the new account request is submitted
|
||||
Then the request will fail with a bad request error
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from behave import given, then
|
||||
from hamcrest import assert_that, equal_to, is_not
|
||||
from hamcrest import assert_that, equal_to, has_item, is_not
|
||||
|
||||
from selene.api.testing import (
|
||||
generate_access_token,
|
||||
generate_refresh_token,
|
||||
validate_token_cookies
|
||||
)
|
||||
from selene.data.account import AccountRepository
|
||||
from selene.util.auth import AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
|
||||
|
||||
@given('an authenticated user with an expired access token')
|
||||
|
@ -34,3 +37,18 @@ def check_for_new_cookies(context):
|
|||
context.refresh_token,
|
||||
is_not(equal_to(context.old_refresh_token))
|
||||
)
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_id(context.account.id)
|
||||
|
||||
assert_that(account.refresh_tokens, has_item(context.refresh_token))
|
||||
|
||||
refresh_token = AuthenticationToken(
|
||||
context.client_config['REFRESH_SECRET'],
|
||||
0
|
||||
)
|
||||
refresh_token.jwt = context.refresh_token
|
||||
account_id = refresh_token.validate()
|
||||
assert_that(refresh_token.is_valid, equal_to(True))
|
||||
assert_that(refresh_token.is_expired, equal_to(False))
|
||||
assert_that(account_id, equal_to(account.id))
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from datetime import date
|
||||
from http import HTTPStatus
|
||||
|
||||
from behave import then, when
|
||||
from behave import given, then, when
|
||||
from flask import json
|
||||
from hamcrest import assert_that, equal_to, is_in, not_none
|
||||
from hamcrest import assert_that, equal_to, is_in, none, not_none
|
||||
|
||||
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
|
||||
from selene.util.db import get_db_connection
|
||||
|
@ -17,17 +16,38 @@ new_account_request = dict(
|
|||
userEnteredEmail='bar@mycroft.ai',
|
||||
password='bar'
|
||||
),
|
||||
support=dict(
|
||||
openDataset=True,
|
||||
membership='MONTHLY SUPPORTER',
|
||||
stripeCustomerId='barstripe'
|
||||
)
|
||||
support=dict(openDataset=True)
|
||||
)
|
||||
|
||||
|
||||
@when('a valid new account request is submitted')
|
||||
@given('a user completes on-boarding')
|
||||
def build_new_account_request(context):
|
||||
context.new_account_request = new_account_request
|
||||
|
||||
|
||||
@given('user opts out of membership')
|
||||
def add_maybe_later_membership(context):
|
||||
context.new_account_request['support'].update(
|
||||
membership='Maybe Later',
|
||||
stripeCustomerId=None
|
||||
)
|
||||
|
||||
|
||||
@given('user opts into a membership')
|
||||
def change_membership_option(context):
|
||||
context.new_account_request['support'].update(
|
||||
membership='Monthly Supporter',
|
||||
stripeCustomerId='barstripe'
|
||||
)
|
||||
|
||||
|
||||
@given('user does not specify an email address')
|
||||
def remove_email_from_request(context):
|
||||
del(context.new_account_request['login']['userEnteredEmail'])
|
||||
|
||||
|
||||
@when('the new account request is submitted')
|
||||
def call_add_account_endpoint(context):
|
||||
context.new_account_request = new_account_request
|
||||
context.client.content_type = 'application/json'
|
||||
context.response = context.client.post(
|
||||
'/api/account',
|
||||
|
@ -36,22 +56,8 @@ def call_add_account_endpoint(context):
|
|||
)
|
||||
|
||||
|
||||
@when('a request is sent without an email address')
|
||||
def create_account_without_email(context):
|
||||
context.new_account_request = new_account_request
|
||||
login_data = context.new_account_request['login']
|
||||
del(login_data['userEnteredEmail'])
|
||||
context.new_account_request['login'] = login_data
|
||||
context.client.content_type = 'application/json'
|
||||
context.response = context.client.post(
|
||||
'/api/account',
|
||||
data=json.dumps(context.new_account_request),
|
||||
content_type='application_json'
|
||||
)
|
||||
|
||||
|
||||
@then('the account will be added to the system')
|
||||
def check_db_for_account(context):
|
||||
@then('the account will be added to the system {membership_option}')
|
||||
def check_db_for_account(context, membership_option):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||
|
@ -60,11 +66,15 @@ def check_db_for_account(context):
|
|||
account.email_address, equal_to('bar@mycroft.ai')
|
||||
)
|
||||
assert_that(account.username, equal_to('barfoo'))
|
||||
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
|
||||
assert_that(
|
||||
account.subscription.stripe_customer_id,
|
||||
equal_to('barstripe')
|
||||
)
|
||||
if membership_option == 'with a membership':
|
||||
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
|
||||
assert_that(
|
||||
account.subscription.stripe_customer_id,
|
||||
equal_to('barstripe')
|
||||
)
|
||||
elif membership_option == 'without a membership':
|
||||
assert_that(account.subscription, none())
|
||||
|
||||
assert_that(len(account.agreements), equal_to(2))
|
||||
for agreement in account.agreements:
|
||||
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Base class for Flask API endpoints"""
|
||||
from copy import deepcopy
|
||||
from logging import getLogger
|
||||
from flask import after_this_request, current_app, request
|
||||
from flask.views import MethodView
|
||||
|
@ -16,7 +17,7 @@ FIFTEEN_MINUTES = 900
|
|||
ONE_MONTH = 2628000
|
||||
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
|
||||
|
||||
_log = getLogger()
|
||||
_log = getLogger(__package__)
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
|
@ -38,11 +39,17 @@ class SeleneEndpoint(MethodView):
|
|||
self.request = request
|
||||
self.response: tuple = None
|
||||
self.account: Account = None
|
||||
self.access_token = AuthenticationToken(
|
||||
self.access_token = self._init_access_token()
|
||||
self.refresh_token = self._init_refresh_token()
|
||||
|
||||
def _init_access_token(self):
|
||||
return AuthenticationToken(
|
||||
self.config['ACCESS_SECRET'],
|
||||
FIFTEEN_MINUTES
|
||||
)
|
||||
self.refresh_token = AuthenticationToken(
|
||||
|
||||
def _init_refresh_token(self):
|
||||
return AuthenticationToken(
|
||||
self.config['REFRESH_SECRET'],
|
||||
ONE_MONTH
|
||||
)
|
||||
|
@ -53,36 +60,58 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
:raises: APIError()
|
||||
"""
|
||||
self._validate_auth_tokens()
|
||||
account_id = self._get_account_id_from_tokens()
|
||||
self._get_account(account_id)
|
||||
self._validate_account(account_id)
|
||||
if self.access_token.is_expired:
|
||||
self._refresh_auth_tokens()
|
||||
try:
|
||||
account_id = self._validate_auth_tokens()
|
||||
self._get_account(account_id)
|
||||
self._validate_account(account_id)
|
||||
if self.access_token.is_expired:
|
||||
self._refresh_auth_tokens()
|
||||
except Exception:
|
||||
_log.exception('an exception occurred during authentication')
|
||||
raise
|
||||
|
||||
def _validate_auth_tokens(self):
|
||||
"""Ensure the tokens are passed in request and are well formed."""
|
||||
self._get_auth_tokens()
|
||||
account_id = self._decode_access_token()
|
||||
if self.access_token.is_expired:
|
||||
account_id = self._decode_refresh_token()
|
||||
if account_id is None:
|
||||
raise AuthenticationError(
|
||||
'failed to retrieve account ID from authentication tokens'
|
||||
)
|
||||
|
||||
return account_id
|
||||
|
||||
def _get_auth_tokens(self):
|
||||
self.access_token.jwt = self.request.cookies.get(
|
||||
ACCESS_TOKEN_COOKIE_NAME
|
||||
)
|
||||
self.access_token.validate()
|
||||
self.refresh_token.jwt = self.request.cookies.get(
|
||||
REFRESH_TOKEN_COOKIE_NAME
|
||||
)
|
||||
self.refresh_token.validate()
|
||||
|
||||
if self.access_token.jwt is None and self.refresh_token.jwt is None:
|
||||
raise AuthenticationError('no authentication tokens found')
|
||||
|
||||
if self.access_token.is_expired and self.refresh_token.is_expired:
|
||||
raise AuthenticationError('authentication tokens expired')
|
||||
def _decode_access_token(self):
|
||||
"""Decode the JWT to get the account ID and check for errors."""
|
||||
account_id = self.access_token.validate()
|
||||
|
||||
def _get_account_id_from_tokens(self):
|
||||
"""Extract the account ID, which is encoded within the tokens"""
|
||||
if self.access_token.is_expired:
|
||||
account_id = self.refresh_token.account_id
|
||||
else:
|
||||
account_id = self.access_token.account_id
|
||||
if not self.access_token.is_valid:
|
||||
raise AuthenticationError('invalid access token')
|
||||
|
||||
return account_id
|
||||
|
||||
def _decode_refresh_token(self):
|
||||
"""Decode the JWT to get the account ID and check for errors."""
|
||||
account_id = self.refresh_token.validate()
|
||||
|
||||
if not self.access_token.is_valid:
|
||||
raise AuthenticationError('invalid refresh token')
|
||||
|
||||
if self.refresh_token.is_expired:
|
||||
raise AuthenticationError('authentication tokens expired')
|
||||
|
||||
return account_id
|
||||
|
||||
|
@ -110,15 +139,17 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
def _refresh_auth_tokens(self):
|
||||
"""Steps necessary to refresh the tokens used for authentication."""
|
||||
old_refresh_token = self.refresh_token
|
||||
old_refresh_token = deepcopy(self.refresh_token)
|
||||
self._generate_tokens()
|
||||
self._update_refresh_token_on_db(old_refresh_token)
|
||||
self._set_token_cookies()
|
||||
|
||||
def _generate_tokens(self):
|
||||
"""Generate an access token and refresh token."""
|
||||
self.access_token.generate()
|
||||
self.refresh_token.generate()
|
||||
self.access_token = self._init_access_token()
|
||||
self.refresh_token = self._init_refresh_token()
|
||||
self.access_token.generate(self.account.id)
|
||||
self.refresh_token.generate(self.account.id)
|
||||
|
||||
def _set_token_cookies(self, expire=False):
|
||||
"""Set the cookies that contain the authentication token.
|
||||
|
@ -159,11 +190,6 @@ class SeleneEndpoint(MethodView):
|
|||
"""Replace the refresh token on the request with the newly minted one"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, self.account.id)
|
||||
if old_refresh_token.is_expired:
|
||||
token_repository.delete_refresh_token(old_refresh_token)
|
||||
raise AuthenticationError('refresh token expired')
|
||||
else:
|
||||
token_repository.update_refresh_token(
|
||||
self.refresh_token.jwt,
|
||||
old_refresh_token.jwt
|
||||
)
|
||||
token_repository.delete_refresh_token(old_refresh_token.jwt)
|
||||
if not old_refresh_token.is_expired:
|
||||
token_repository.add_refresh_token(self.refresh_token.jwt)
|
||||
|
|
|
@ -19,11 +19,9 @@ from selene.data.account import (
|
|||
from selene.util.db import get_db_connection
|
||||
from ..base_endpoint import SeleneEndpoint
|
||||
|
||||
membeship_types = {
|
||||
'MONTHLY SUPPORTER': 'Monthly Supporter',
|
||||
'YEARLY SUPPORTER': 'Yearly Supporter',
|
||||
'MAYBE LATER': 'Maybe Later'
|
||||
}
|
||||
MONTHLY_MEMBERSHIP = 'Monthly Supporter'
|
||||
YEARLY_MEMBERSHIP = 'Yearly Supporter'
|
||||
NO_MEMBERSHIP = 'Maybe Later'
|
||||
|
||||
|
||||
def agreement_accepted(value):
|
||||
|
@ -55,15 +53,15 @@ class Support(Model):
|
|||
open_dataset = BooleanType(required=True)
|
||||
membership = StringType(
|
||||
required=True,
|
||||
choices=('MONTHLY SUPPORTER', 'YEARLY SUPPORTER', 'MAYBE LATER')
|
||||
choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP)
|
||||
)
|
||||
stripe_customer_id = StringType()
|
||||
|
||||
# def validate_stripe_customer_id(self, data, value):
|
||||
# if data['membership'] != 'Maybe Later':
|
||||
# if not data['stripe_customer_id']:
|
||||
# raise ValidationError('Membership requires a stripe ID')
|
||||
# return value
|
||||
def validate_stripe_customer_id(self, data, value):
|
||||
if data['membership'] != NO_MEMBERSHIP:
|
||||
if not data['stripe_customer_id']:
|
||||
raise ValidationError('Membership requires a stripe ID')
|
||||
return value
|
||||
|
||||
|
||||
class AddAccountRequest(Model):
|
||||
|
@ -143,9 +141,15 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
return email_address, password
|
||||
|
||||
def _add_account(self, email_address, password):
|
||||
membership_type = membeship_types[
|
||||
self.request_data['support']['membership']
|
||||
]
|
||||
membership_type = self.request_data['support']['membership']
|
||||
subscription = None
|
||||
if membership_type != NO_MEMBERSHIP:
|
||||
stripe_id = self.request_data['support']['stripeCustomerId']
|
||||
subscription = AccountSubscription(
|
||||
type=membership_type,
|
||||
start_date=date.today(),
|
||||
stripe_customer_id=stripe_id
|
||||
)
|
||||
account = Account(
|
||||
email_address=email_address,
|
||||
username=self.request_data['username'],
|
||||
|
@ -153,11 +157,7 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
|
||||
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today())
|
||||
],
|
||||
subscription=AccountSubscription(
|
||||
type=membership_type,
|
||||
start_date=date.today(),
|
||||
stripe_customer_id=self.request_data['support']['stripeCustomerId']
|
||||
)
|
||||
subscription=subscription
|
||||
)
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
|
|
|
@ -19,9 +19,8 @@ def generate_access_token(context, expire=False):
|
|||
context.client_config['ACCESS_SECRET'],
|
||||
ONE_MINUTE
|
||||
)
|
||||
access_token.account_id = context.account.id
|
||||
if not expire:
|
||||
access_token.generate()
|
||||
access_token.generate(context.account.id)
|
||||
context.access_token = access_token
|
||||
|
||||
context.client.set_cookie(
|
||||
|
@ -38,9 +37,8 @@ def generate_refresh_token(context, expire=False):
|
|||
context.client_config['REFRESH_SECRET'],
|
||||
TWO_MINUTES
|
||||
)
|
||||
refresh_token.account_id = account_id
|
||||
if not expire:
|
||||
refresh_token.generate()
|
||||
refresh_token.generate(context.account.id)
|
||||
context.refresh_token = refresh_token
|
||||
|
||||
context.client.set_cookie(
|
||||
|
|
|
@ -26,6 +26,6 @@ class Account(object):
|
|||
email_address: str
|
||||
username: str
|
||||
agreements: List[AccountAgreement]
|
||||
subscription: AccountSubscription
|
||||
subscription: AccountSubscription = None
|
||||
id: str = None
|
||||
refresh_tokens: List[str] = None
|
||||
|
|
|
@ -157,9 +157,10 @@ class AccountRepository(object):
|
|||
for agreement in result['account']['agreements']:
|
||||
account_agreements.append(AccountAgreement(**agreement))
|
||||
result['account']['agreements'] = account_agreements
|
||||
result['account']['subscription'] = AccountSubscription(
|
||||
**result['account']['subscription']
|
||||
)
|
||||
if result['account']['subscription'] is not None:
|
||||
result['account']['subscription'] = AccountSubscription(
|
||||
**result['account']['subscription']
|
||||
)
|
||||
account = Account(**result['account'])
|
||||
|
||||
return account
|
||||
|
|
|
@ -16,16 +16,15 @@ class AuthenticationToken(object):
|
|||
self.jwt: str = ''
|
||||
self.is_valid: bool = None
|
||||
self.is_expired: bool = None
|
||||
self.account_id: str = None
|
||||
|
||||
def generate(self):
|
||||
def generate(self, account_id):
|
||||
"""
|
||||
Generates a JWT token
|
||||
"""
|
||||
payload = dict(
|
||||
iat=datetime.utcnow(),
|
||||
exp=time() + self.duration,
|
||||
sub=self.account_id
|
||||
sub=account_id
|
||||
)
|
||||
token = jwt.encode(payload, self.secret, algorithm='HS256')
|
||||
|
||||
|
@ -37,14 +36,17 @@ class AuthenticationToken(object):
|
|||
"""Decodes the auth token and performs some preliminary validation."""
|
||||
self.is_expired = False
|
||||
self.is_valid = True
|
||||
account_id = None
|
||||
|
||||
if self.jwt is None:
|
||||
self.is_expired = True
|
||||
else:
|
||||
try:
|
||||
payload = jwt.decode(self.jwt, self.secret)
|
||||
self.account_id = payload['sub']
|
||||
account_id = payload['sub']
|
||||
except jwt.ExpiredSignatureError:
|
||||
self.is_expired = True
|
||||
except jwt.InvalidTokenError:
|
||||
self.is_valid = False
|
||||
|
||||
return account_id
|
||||
|
|
|
@ -16,8 +16,7 @@ setup(
|
|||
'pyhamcrest',
|
||||
'pyjwt',
|
||||
'psycopg2-binary',
|
||||
'schematics',
|
||||
'validator-collection',
|
||||
'redis'
|
||||
'redis',
|
||||
'schematics'
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue