Merge pull request #54 from MycroftAI/account-api

fixed a bug with new account logic when a user opts out of membership
pull/49/head
Chris Veilleux 2019-02-20 17:37:03 -06:00 committed by GitHub
commit cf39ba2aca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 101 deletions

View File

@ -1,11 +1,22 @@
Feature: Add a new account
Test the API call to add an account to the database.
Scenario: Successful account addition
When a valid new account request is submitted
Scenario: Successful account addition with membership
Given a user completes on-boarding
And user opts into a membership
When the new account request is submitted
Then the request will be successful
And the account will be added to the system
And the account will be added to the system with a membership
Scenario: Successful account addition without membership
Given a user completes on-boarding
And user opts out of membership
When the new account request is submitted
Then the account will be added to the system without a membership
Scenario: Request missing a required field
When a request is sent without an email address
Given a user completes on-boarding
And user does not specify an email address
When the new account request is submitted
Then the request will fail with a bad request error

View File

@ -1,11 +1,14 @@
from behave import given, then
from hamcrest import assert_that, equal_to, is_not
from hamcrest import assert_that, equal_to, has_item, is_not
from selene.api.testing import (
generate_access_token,
generate_refresh_token,
validate_token_cookies
)
from selene.data.account import AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection
@given('an authenticated user with an expired access token')
@ -34,3 +37,18 @@ def check_for_new_cookies(context):
context.refresh_token,
is_not(equal_to(context.old_refresh_token))
)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
assert_that(account.refresh_tokens, has_item(context.refresh_token))
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],
0
)
refresh_token.jwt = context.refresh_token
account_id = refresh_token.validate()
assert_that(refresh_token.is_valid, equal_to(True))
assert_that(refresh_token.is_expired, equal_to(False))
assert_that(account_id, equal_to(account.id))

View File

@ -1,9 +1,8 @@
from datetime import date
from http import HTTPStatus
from behave import then, when
from behave import given, then, when
from flask import json
from hamcrest import assert_that, equal_to, is_in, not_none
from hamcrest import assert_that, equal_to, is_in, none, not_none
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
from selene.util.db import get_db_connection
@ -17,17 +16,38 @@ new_account_request = dict(
userEnteredEmail='bar@mycroft.ai',
password='bar'
),
support=dict(
openDataset=True,
membership='MONTHLY SUPPORTER',
stripeCustomerId='barstripe'
)
support=dict(openDataset=True)
)
@when('a valid new account request is submitted')
@given('a user completes on-boarding')
def build_new_account_request(context):
context.new_account_request = new_account_request
@given('user opts out of membership')
def add_maybe_later_membership(context):
context.new_account_request['support'].update(
membership='Maybe Later',
stripeCustomerId=None
)
@given('user opts into a membership')
def change_membership_option(context):
context.new_account_request['support'].update(
membership='Monthly Supporter',
stripeCustomerId='barstripe'
)
@given('user does not specify an email address')
def remove_email_from_request(context):
del(context.new_account_request['login']['userEnteredEmail'])
@when('the new account request is submitted')
def call_add_account_endpoint(context):
context.new_account_request = new_account_request
context.client.content_type = 'application/json'
context.response = context.client.post(
'/api/account',
@ -36,22 +56,8 @@ def call_add_account_endpoint(context):
)
@when('a request is sent without an email address')
def create_account_without_email(context):
context.new_account_request = new_account_request
login_data = context.new_account_request['login']
del(login_data['userEnteredEmail'])
context.new_account_request['login'] = login_data
context.client.content_type = 'application/json'
context.response = context.client.post(
'/api/account',
data=json.dumps(context.new_account_request),
content_type='application_json'
)
@then('the account will be added to the system')
def check_db_for_account(context):
@then('the account will be added to the system {membership_option}')
def check_db_for_account(context, membership_option):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
@ -60,11 +66,15 @@ def check_db_for_account(context):
account.email_address, equal_to('bar@mycroft.ai')
)
assert_that(account.username, equal_to('barfoo'))
if membership_option == 'with a membership':
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
assert_that(
account.subscription.stripe_customer_id,
equal_to('barstripe')
)
elif membership_option == 'without a membership':
assert_that(account.subscription, none())
assert_that(len(account.agreements), equal_to(2))
for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))

View File

@ -1,4 +1,5 @@
"""Base class for Flask API endpoints"""
from copy import deepcopy
from logging import getLogger
from flask import after_this_request, current_app, request
from flask.views import MethodView
@ -16,7 +17,7 @@ FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
_log = getLogger()
_log = getLogger(__package__)
class APIError(Exception):
@ -38,11 +39,17 @@ class SeleneEndpoint(MethodView):
self.request = request
self.response: tuple = None
self.account: Account = None
self.access_token = AuthenticationToken(
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
def _init_access_token(self):
return AuthenticationToken(
self.config['ACCESS_SECRET'],
FIFTEEN_MINUTES
)
self.refresh_token = AuthenticationToken(
def _init_refresh_token(self):
return AuthenticationToken(
self.config['REFRESH_SECRET'],
ONE_MONTH
)
@ -53,36 +60,58 @@ class SeleneEndpoint(MethodView):
:raises: APIError()
"""
self._validate_auth_tokens()
account_id = self._get_account_id_from_tokens()
try:
account_id = self._validate_auth_tokens()
self._get_account(account_id)
self._validate_account(account_id)
if self.access_token.is_expired:
self._refresh_auth_tokens()
except Exception:
_log.exception('an exception occurred during authentication')
raise
def _validate_auth_tokens(self):
"""Ensure the tokens are passed in request and are well formed."""
self._get_auth_tokens()
account_id = self._decode_access_token()
if self.access_token.is_expired:
account_id = self._decode_refresh_token()
if account_id is None:
raise AuthenticationError(
'failed to retrieve account ID from authentication tokens'
)
return account_id
def _get_auth_tokens(self):
self.access_token.jwt = self.request.cookies.get(
ACCESS_TOKEN_COOKIE_NAME
)
self.access_token.validate()
self.refresh_token.jwt = self.request.cookies.get(
REFRESH_TOKEN_COOKIE_NAME
)
self.refresh_token.validate()
if self.access_token.jwt is None and self.refresh_token.jwt is None:
raise AuthenticationError('no authentication tokens found')
if self.access_token.is_expired and self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
def _decode_access_token(self):
"""Decode the JWT to get the account ID and check for errors."""
account_id = self.access_token.validate()
def _get_account_id_from_tokens(self):
"""Extract the account ID, which is encoded within the tokens"""
if self.access_token.is_expired:
account_id = self.refresh_token.account_id
else:
account_id = self.access_token.account_id
if not self.access_token.is_valid:
raise AuthenticationError('invalid access token')
return account_id
def _decode_refresh_token(self):
"""Decode the JWT to get the account ID and check for errors."""
account_id = self.refresh_token.validate()
if not self.access_token.is_valid:
raise AuthenticationError('invalid refresh token')
if self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
return account_id
@ -110,15 +139,17 @@ class SeleneEndpoint(MethodView):
def _refresh_auth_tokens(self):
"""Steps necessary to refresh the tokens used for authentication."""
old_refresh_token = self.refresh_token
old_refresh_token = deepcopy(self.refresh_token)
self._generate_tokens()
self._update_refresh_token_on_db(old_refresh_token)
self._set_token_cookies()
def _generate_tokens(self):
"""Generate an access token and refresh token."""
self.access_token.generate()
self.refresh_token.generate()
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
self.access_token.generate(self.account.id)
self.refresh_token.generate(self.account.id)
def _set_token_cookies(self, expire=False):
"""Set the cookies that contain the authentication token.
@ -159,11 +190,6 @@ class SeleneEndpoint(MethodView):
"""Replace the refresh token on the request with the newly minted one"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account.id)
if old_refresh_token.is_expired:
token_repository.delete_refresh_token(old_refresh_token)
raise AuthenticationError('refresh token expired')
else:
token_repository.update_refresh_token(
self.refresh_token.jwt,
old_refresh_token.jwt
)
token_repository.delete_refresh_token(old_refresh_token.jwt)
if not old_refresh_token.is_expired:
token_repository.add_refresh_token(self.refresh_token.jwt)

View File

@ -19,11 +19,9 @@ from selene.data.account import (
from selene.util.db import get_db_connection
from ..base_endpoint import SeleneEndpoint
membeship_types = {
'MONTHLY SUPPORTER': 'Monthly Supporter',
'YEARLY SUPPORTER': 'Yearly Supporter',
'MAYBE LATER': 'Maybe Later'
}
MONTHLY_MEMBERSHIP = 'Monthly Supporter'
YEARLY_MEMBERSHIP = 'Yearly Supporter'
NO_MEMBERSHIP = 'Maybe Later'
def agreement_accepted(value):
@ -55,15 +53,15 @@ class Support(Model):
open_dataset = BooleanType(required=True)
membership = StringType(
required=True,
choices=('MONTHLY SUPPORTER', 'YEARLY SUPPORTER', 'MAYBE LATER')
choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP)
)
stripe_customer_id = StringType()
# def validate_stripe_customer_id(self, data, value):
# if data['membership'] != 'Maybe Later':
# if not data['stripe_customer_id']:
# raise ValidationError('Membership requires a stripe ID')
# return value
def validate_stripe_customer_id(self, data, value):
if data['membership'] != NO_MEMBERSHIP:
if not data['stripe_customer_id']:
raise ValidationError('Membership requires a stripe ID')
return value
class AddAccountRequest(Model):
@ -143,9 +141,15 @@ class AccountEndpoint(SeleneEndpoint):
return email_address, password
def _add_account(self, email_address, password):
membership_type = membeship_types[
self.request_data['support']['membership']
]
membership_type = self.request_data['support']['membership']
subscription = None
if membership_type != NO_MEMBERSHIP:
stripe_id = self.request_data['support']['stripeCustomerId']
subscription = AccountSubscription(
type=membership_type,
start_date=date.today(),
stripe_customer_id=stripe_id
)
account = Account(
email_address=email_address,
username=self.request_data['username'],
@ -153,11 +157,7 @@ class AccountEndpoint(SeleneEndpoint):
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today())
],
subscription=AccountSubscription(
type=membership_type,
start_date=date.today(),
stripe_customer_id=self.request_data['support']['stripeCustomerId']
)
subscription=subscription
)
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)

View File

@ -19,9 +19,8 @@ def generate_access_token(context, expire=False):
context.client_config['ACCESS_SECRET'],
ONE_MINUTE
)
access_token.account_id = context.account.id
if not expire:
access_token.generate()
access_token.generate(context.account.id)
context.access_token = access_token
context.client.set_cookie(
@ -38,9 +37,8 @@ def generate_refresh_token(context, expire=False):
context.client_config['REFRESH_SECRET'],
TWO_MINUTES
)
refresh_token.account_id = account_id
if not expire:
refresh_token.generate()
refresh_token.generate(context.account.id)
context.refresh_token = refresh_token
context.client.set_cookie(

View File

@ -26,6 +26,6 @@ class Account(object):
email_address: str
username: str
agreements: List[AccountAgreement]
subscription: AccountSubscription
subscription: AccountSubscription = None
id: str = None
refresh_tokens: List[str] = None

View File

@ -157,6 +157,7 @@ class AccountRepository(object):
for agreement in result['account']['agreements']:
account_agreements.append(AccountAgreement(**agreement))
result['account']['agreements'] = account_agreements
if result['account']['subscription'] is not None:
result['account']['subscription'] = AccountSubscription(
**result['account']['subscription']
)

View File

@ -16,16 +16,15 @@ class AuthenticationToken(object):
self.jwt: str = ''
self.is_valid: bool = None
self.is_expired: bool = None
self.account_id: str = None
def generate(self):
def generate(self, account_id):
"""
Generates a JWT token
"""
payload = dict(
iat=datetime.utcnow(),
exp=time() + self.duration,
sub=self.account_id
sub=account_id
)
token = jwt.encode(payload, self.secret, algorithm='HS256')
@ -37,14 +36,17 @@ class AuthenticationToken(object):
"""Decodes the auth token and performs some preliminary validation."""
self.is_expired = False
self.is_valid = True
account_id = None
if self.jwt is None:
self.is_expired = True
else:
try:
payload = jwt.decode(self.jwt, self.secret)
self.account_id = payload['sub']
account_id = payload['sub']
except jwt.ExpiredSignatureError:
self.is_expired = True
except jwt.InvalidTokenError:
self.is_valid = False
return account_id

View File

@ -16,8 +16,7 @@ setup(
'pyhamcrest',
'pyjwt',
'psycopg2-binary',
'schematics',
'validator-collection',
'redis'
'redis',
'schematics'
]
)