Merge pull request #54 from MycroftAI/account-api

fixed a bug with new account logic when a user opts out of membership
pull/49/head
Chris Veilleux 2019-02-20 17:37:03 -06:00 committed by GitHub
commit cf39ba2aca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 101 deletions

View File

@ -1,11 +1,22 @@
Feature: Add a new account
Test the API call to add an account to the database.
Scenario: Successful account addition
When a valid new account request is submitted
Then the request will be successful
And the account will be added to the system
Scenario: Successful account addition with membership
Given a user completes on-boarding
And user opts into a membership
When the new account request is submitted
Then the request will be successful
And the account will be added to the system with a membership
Scenario: Successful account addition without membership
Given a user completes on-boarding
And user opts out of membership
When the new account request is submitted
Then the account will be added to the system without a membership
Scenario: Request missing a required field
When a request is sent without an email address
Given a user completes on-boarding
And user does not specify an email address
When the new account request is submitted
Then the request will fail with a bad request error

View File

@ -1,11 +1,14 @@
from behave import given, then
from hamcrest import assert_that, equal_to, is_not
from hamcrest import assert_that, equal_to, has_item, is_not
from selene.api.testing import (
generate_access_token,
generate_refresh_token,
validate_token_cookies
)
from selene.data.account import AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection
@given('an authenticated user with an expired access token')
@ -34,3 +37,18 @@ def check_for_new_cookies(context):
context.refresh_token,
is_not(equal_to(context.old_refresh_token))
)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
assert_that(account.refresh_tokens, has_item(context.refresh_token))
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],
0
)
refresh_token.jwt = context.refresh_token
account_id = refresh_token.validate()
assert_that(refresh_token.is_valid, equal_to(True))
assert_that(refresh_token.is_expired, equal_to(False))
assert_that(account_id, equal_to(account.id))

View File

@ -1,9 +1,8 @@
from datetime import date
from http import HTTPStatus
from behave import then, when
from behave import given, then, when
from flask import json
from hamcrest import assert_that, equal_to, is_in, not_none
from hamcrest import assert_that, equal_to, is_in, none, not_none
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
from selene.util.db import get_db_connection
@ -17,17 +16,38 @@ new_account_request = dict(
userEnteredEmail='bar@mycroft.ai',
password='bar'
),
support=dict(
openDataset=True,
membership='MONTHLY SUPPORTER',
stripeCustomerId='barstripe'
)
support=dict(openDataset=True)
)
@when('a valid new account request is submitted')
@given('a user completes on-boarding')
def build_new_account_request(context):
context.new_account_request = new_account_request
@given('user opts out of membership')
def add_maybe_later_membership(context):
context.new_account_request['support'].update(
membership='Maybe Later',
stripeCustomerId=None
)
@given('user opts into a membership')
def change_membership_option(context):
context.new_account_request['support'].update(
membership='Monthly Supporter',
stripeCustomerId='barstripe'
)
@given('user does not specify an email address')
def remove_email_from_request(context):
del(context.new_account_request['login']['userEnteredEmail'])
@when('the new account request is submitted')
def call_add_account_endpoint(context):
context.new_account_request = new_account_request
context.client.content_type = 'application/json'
context.response = context.client.post(
'/api/account',
@ -36,22 +56,8 @@ def call_add_account_endpoint(context):
)
@when('a request is sent without an email address')
def create_account_without_email(context):
context.new_account_request = new_account_request
login_data = context.new_account_request['login']
del(login_data['userEnteredEmail'])
context.new_account_request['login'] = login_data
context.client.content_type = 'application/json'
context.response = context.client.post(
'/api/account',
data=json.dumps(context.new_account_request),
content_type='application_json'
)
@then('the account will be added to the system')
def check_db_for_account(context):
@then('the account will be added to the system {membership_option}')
def check_db_for_account(context, membership_option):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
@ -60,11 +66,15 @@ def check_db_for_account(context):
account.email_address, equal_to('bar@mycroft.ai')
)
assert_that(account.username, equal_to('barfoo'))
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
assert_that(
account.subscription.stripe_customer_id,
equal_to('barstripe')
)
if membership_option == 'with a membership':
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
assert_that(
account.subscription.stripe_customer_id,
equal_to('barstripe')
)
elif membership_option == 'without a membership':
assert_that(account.subscription, none())
assert_that(len(account.agreements), equal_to(2))
for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))

View File

@ -1,4 +1,5 @@
"""Base class for Flask API endpoints"""
from copy import deepcopy
from logging import getLogger
from flask import after_this_request, current_app, request
from flask.views import MethodView
@ -16,7 +17,7 @@ FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
_log = getLogger()
_log = getLogger(__package__)
class APIError(Exception):
@ -38,11 +39,17 @@ class SeleneEndpoint(MethodView):
self.request = request
self.response: tuple = None
self.account: Account = None
self.access_token = AuthenticationToken(
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
def _init_access_token(self):
return AuthenticationToken(
self.config['ACCESS_SECRET'],
FIFTEEN_MINUTES
)
self.refresh_token = AuthenticationToken(
def _init_refresh_token(self):
return AuthenticationToken(
self.config['REFRESH_SECRET'],
ONE_MONTH
)
@ -53,36 +60,58 @@ class SeleneEndpoint(MethodView):
:raises: APIError()
"""
self._validate_auth_tokens()
account_id = self._get_account_id_from_tokens()
self._get_account(account_id)
self._validate_account(account_id)
if self.access_token.is_expired:
self._refresh_auth_tokens()
try:
account_id = self._validate_auth_tokens()
self._get_account(account_id)
self._validate_account(account_id)
if self.access_token.is_expired:
self._refresh_auth_tokens()
except Exception:
_log.exception('an exception occurred during authentication')
raise
def _validate_auth_tokens(self):
"""Ensure the tokens are passed in request and are well formed."""
self._get_auth_tokens()
account_id = self._decode_access_token()
if self.access_token.is_expired:
account_id = self._decode_refresh_token()
if account_id is None:
raise AuthenticationError(
'failed to retrieve account ID from authentication tokens'
)
return account_id
def _get_auth_tokens(self):
self.access_token.jwt = self.request.cookies.get(
ACCESS_TOKEN_COOKIE_NAME
)
self.access_token.validate()
self.refresh_token.jwt = self.request.cookies.get(
REFRESH_TOKEN_COOKIE_NAME
)
self.refresh_token.validate()
if self.access_token.jwt is None and self.refresh_token.jwt is None:
raise AuthenticationError('no authentication tokens found')
if self.access_token.is_expired and self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
def _decode_access_token(self):
"""Decode the JWT to get the account ID and check for errors."""
account_id = self.access_token.validate()
def _get_account_id_from_tokens(self):
"""Extract the account ID, which is encoded within the tokens"""
if self.access_token.is_expired:
account_id = self.refresh_token.account_id
else:
account_id = self.access_token.account_id
if not self.access_token.is_valid:
raise AuthenticationError('invalid access token')
return account_id
def _decode_refresh_token(self):
"""Decode the JWT to get the account ID and check for errors."""
account_id = self.refresh_token.validate()
if not self.access_token.is_valid:
raise AuthenticationError('invalid refresh token')
if self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
return account_id
@ -110,15 +139,17 @@ class SeleneEndpoint(MethodView):
def _refresh_auth_tokens(self):
"""Steps necessary to refresh the tokens used for authentication."""
old_refresh_token = self.refresh_token
old_refresh_token = deepcopy(self.refresh_token)
self._generate_tokens()
self._update_refresh_token_on_db(old_refresh_token)
self._set_token_cookies()
def _generate_tokens(self):
"""Generate an access token and refresh token."""
self.access_token.generate()
self.refresh_token.generate()
self.access_token = self._init_access_token()
self.refresh_token = self._init_refresh_token()
self.access_token.generate(self.account.id)
self.refresh_token.generate(self.account.id)
def _set_token_cookies(self, expire=False):
"""Set the cookies that contain the authentication token.
@ -159,11 +190,6 @@ class SeleneEndpoint(MethodView):
"""Replace the refresh token on the request with the newly minted one"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account.id)
if old_refresh_token.is_expired:
token_repository.delete_refresh_token(old_refresh_token)
raise AuthenticationError('refresh token expired')
else:
token_repository.update_refresh_token(
self.refresh_token.jwt,
old_refresh_token.jwt
)
token_repository.delete_refresh_token(old_refresh_token.jwt)
if not old_refresh_token.is_expired:
token_repository.add_refresh_token(self.refresh_token.jwt)

View File

@ -19,11 +19,9 @@ from selene.data.account import (
from selene.util.db import get_db_connection
from ..base_endpoint import SeleneEndpoint
membeship_types = {
'MONTHLY SUPPORTER': 'Monthly Supporter',
'YEARLY SUPPORTER': 'Yearly Supporter',
'MAYBE LATER': 'Maybe Later'
}
MONTHLY_MEMBERSHIP = 'Monthly Supporter'
YEARLY_MEMBERSHIP = 'Yearly Supporter'
NO_MEMBERSHIP = 'Maybe Later'
def agreement_accepted(value):
@ -55,15 +53,15 @@ class Support(Model):
open_dataset = BooleanType(required=True)
membership = StringType(
required=True,
choices=('MONTHLY SUPPORTER', 'YEARLY SUPPORTER', 'MAYBE LATER')
choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP)
)
stripe_customer_id = StringType()
# def validate_stripe_customer_id(self, data, value):
# if data['membership'] != 'Maybe Later':
# if not data['stripe_customer_id']:
# raise ValidationError('Membership requires a stripe ID')
# return value
def validate_stripe_customer_id(self, data, value):
if data['membership'] != NO_MEMBERSHIP:
if not data['stripe_customer_id']:
raise ValidationError('Membership requires a stripe ID')
return value
class AddAccountRequest(Model):
@ -143,9 +141,15 @@ class AccountEndpoint(SeleneEndpoint):
return email_address, password
def _add_account(self, email_address, password):
membership_type = membeship_types[
self.request_data['support']['membership']
]
membership_type = self.request_data['support']['membership']
subscription = None
if membership_type != NO_MEMBERSHIP:
stripe_id = self.request_data['support']['stripeCustomerId']
subscription = AccountSubscription(
type=membership_type,
start_date=date.today(),
stripe_customer_id=stripe_id
)
account = Account(
email_address=email_address,
username=self.request_data['username'],
@ -153,11 +157,7 @@ class AccountEndpoint(SeleneEndpoint):
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today())
],
subscription=AccountSubscription(
type=membership_type,
start_date=date.today(),
stripe_customer_id=self.request_data['support']['stripeCustomerId']
)
subscription=subscription
)
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)

View File

@ -19,9 +19,8 @@ def generate_access_token(context, expire=False):
context.client_config['ACCESS_SECRET'],
ONE_MINUTE
)
access_token.account_id = context.account.id
if not expire:
access_token.generate()
access_token.generate(context.account.id)
context.access_token = access_token
context.client.set_cookie(
@ -38,9 +37,8 @@ def generate_refresh_token(context, expire=False):
context.client_config['REFRESH_SECRET'],
TWO_MINUTES
)
refresh_token.account_id = account_id
if not expire:
refresh_token.generate()
refresh_token.generate(context.account.id)
context.refresh_token = refresh_token
context.client.set_cookie(

View File

@ -26,6 +26,6 @@ class Account(object):
email_address: str
username: str
agreements: List[AccountAgreement]
subscription: AccountSubscription
subscription: AccountSubscription = None
id: str = None
refresh_tokens: List[str] = None

View File

@ -157,9 +157,10 @@ class AccountRepository(object):
for agreement in result['account']['agreements']:
account_agreements.append(AccountAgreement(**agreement))
result['account']['agreements'] = account_agreements
result['account']['subscription'] = AccountSubscription(
**result['account']['subscription']
)
if result['account']['subscription'] is not None:
result['account']['subscription'] = AccountSubscription(
**result['account']['subscription']
)
account = Account(**result['account'])
return account

View File

@ -16,16 +16,15 @@ class AuthenticationToken(object):
self.jwt: str = ''
self.is_valid: bool = None
self.is_expired: bool = None
self.account_id: str = None
def generate(self):
def generate(self, account_id):
"""
Generates a JWT token
"""
payload = dict(
iat=datetime.utcnow(),
exp=time() + self.duration,
sub=self.account_id
sub=account_id
)
token = jwt.encode(payload, self.secret, algorithm='HS256')
@ -37,14 +36,17 @@ class AuthenticationToken(object):
"""Decodes the auth token and performs some preliminary validation."""
self.is_expired = False
self.is_valid = True
account_id = None
if self.jwt is None:
self.is_expired = True
else:
try:
payload = jwt.decode(self.jwt, self.secret)
self.account_id = payload['sub']
account_id = payload['sub']
except jwt.ExpiredSignatureError:
self.is_expired = True
except jwt.InvalidTokenError:
self.is_valid = False
return account_id

View File

@ -16,8 +16,7 @@ setup(
'pyhamcrest',
'pyjwt',
'psycopg2-binary',
'schematics',
'validator-collection',
'redis'
'redis',
'schematics'
]
)