diff --git a/api/account/account_api/api.py b/api/account/account_api/api.py index 9aef7864..f72cd9a7 100644 --- a/api/account/account_api/api.py +++ b/api/account/account_api/api.py @@ -9,8 +9,8 @@ from .endpoints.device import DeviceEndpoint from .endpoints.device_count import DeviceCountEndpoint from .endpoints.geography import GeographyEndpoint from .endpoints.membership import MembershipEndpoint -from .endpoints.skills import SkillsEndpoint from .endpoints.skill_settings import SkillSettingsEndpoint +from .endpoints.skills import SkillsEndpoint from .endpoints.voice_endpoint import VoiceEndpoint from .endpoints.wake_word_endpoint import WakeWordEndpoint @@ -26,7 +26,7 @@ acct.register_blueprint(selene_api) acct.add_url_rule( '/api/account', view_func=AccountEndpoint.as_view('account_api'), - methods=['GET', 'POST'] + methods=['GET', 'POST', 'PATCH'] ) acct.add_url_rule( '/api/agreement/', diff --git a/api/account/tests/features/update_membership.feature b/api/account/tests/features/update_membership.feature new file mode 100644 index 00000000..1d1704dd --- /dev/null +++ b/api/account/tests/features/update_membership.feature @@ -0,0 +1,5 @@ +Feature: Test the API call to update a membership + + Scenario: user with free account opts into a membership + Given a user with a free account + When a monthly membership is added \ No newline at end of file diff --git a/api/public/tests/features/steps/get_device_subscription.py b/api/public/tests/features/steps/get_device_subscription.py index 9249b7f2..4704346f 100644 --- a/api/public/tests/features/steps/get_device_subscription.py +++ b/api/public/tests/features/steps/get_device_subscription.py @@ -43,7 +43,7 @@ def get_device_subscription(context): access_token = login['accessToken'] headers=dict(Authorization='Bearer {token}'.format(token=access_token)) with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - AccountRepository(db)._add_membership(context.account.id, membership) + AccountRepository(db).add_membership(context.account.id, membership) context.subscription_response = context.client.get( '/v1/device/{uuid}/subscription'.format(uuid=device_id), headers=headers diff --git a/shared/selene/api/endpoints/account.py b/shared/selene/api/endpoints/account.py index 0b17d6ee..f1165ac6 100644 --- a/shared/selene/api/endpoints/account.py +++ b/shared/selene/api/endpoints/account.py @@ -63,13 +63,21 @@ class Support(Model): payment_method = StringType(choices=[STRIPE_PAYMENT]) payment_token = StringType() - def validate_payment_account_id(self, data, value): - if data['membership'] != NO_MEMBERSHIP: - if not data['payment_account_id']: - raise ValidationError( - 'Membership requires a payment account ID' - ) - return value + +class AddMembership(Model): + membership = StringType( + required=True, + choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP) + ) + payment_method = StringType(required=True, choices=[STRIPE_PAYMENT]) + payment_token = StringType(required=True) + + +class UpdateMembership(Model): + membership = StringType( + required=True, + choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP) + ) class AddAccountRequest(Model): @@ -143,6 +151,11 @@ class AccountEndpoint(SeleneEndpoint): return jsonify('Account added successfully'), HTTPStatus.OK + def patch(self): + self._authenticate() + self.request_data = json.loads(self.request.data) + self._update_support() + def _validate_request(self): add_request = AddAccountRequest(dict( username=self.request_data.get('username'), @@ -219,11 +232,58 @@ class AccountEndpoint(SeleneEndpoint): password=password ) - def _create_stripe_subscription(self, token, user_email, plan): - customer = stripe.Customer.create(source=token, email=user_email) - subscription = stripe.Subscription.create(customer=customer.id, items=[{'plan': plan}]) - return customer.id, subscription.current_period_start + def _create_stripe_subscription(self, customer_id, token, user_email, plan): + if customer_id is None: + customer = stripe.Customer.create(source=token, email=user_email) + customer_id = customer.id + subscription = stripe.Subscription.create(customer=customer_id, items=[{'plan': plan}]) + + # TODO: store subscription.id + start = subscription.current_period_start + start = date.fromtimestamp(start) + return customer_id, start def _get_plan(self, plan): with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: return MembershipRepository(db).get_membership_by_type(plan) + + def _update_support(self): + with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: + membership_repository = MembershipRepository(db) + active_membership = membership_repository.get_active_membership_by_account_id(self.account.id) + if active_membership: + active_membership.end_date = datetime.now() + # TODO: use the subscription id to delete the membership on stripe + membership_repository.finish_membership(active_membership) + add_membership = UpdateMembership(self.request_data.get('support')) + add_membership.validate() + support = self.request_data['support'] + membership = support['membership'] + stripe_plan = self._get_plan(membership) + stripe_id, start_date = self._create_stripe_subscription( + active_membership.payment_account_id, + None, + self.account.email_address, + stripe_plan + ) + else: + add_membership = AddMembership(self.request_data.get('support')) + add_membership.validate() + support = self.request_data['support'] + membership = support['membership'] + token = support['payment'] + stripe_plan = self._get_plan(membership) + stripe_id, start_date = self._create_stripe_subscription( + None, + token, + self.account.email_address, + stripe_plan + ) + + new_membership = AccountMembership( + start_date=start_date, + payment_method=STRIPE_PAYMENT, + payment_account_id=stripe_id, + type=membership + ) + AccountRepository(db).add_membership(self.account.id, new_membership) diff --git a/shared/selene/data/account/entity/account.py b/shared/selene/data/account/entity/account.py index 3017e576..6f0cd515 100644 --- a/shared/selene/data/account/entity/account.py +++ b/shared/selene/data/account/entity/account.py @@ -1,5 +1,5 @@ -from datetime import date from dataclasses import dataclass +from datetime import date from typing import List @@ -19,6 +19,8 @@ class AccountMembership(object): payment_method: str payment_account_id: str id: str = None + account_id: str = None + end_date: date = None @dataclass diff --git a/shared/selene/data/account/repository/account.py b/shared/selene/data/account/repository/account.py index db0e4adb..cbf9753b 100644 --- a/shared/selene/data/account/repository/account.py +++ b/shared/selene/data/account/repository/account.py @@ -35,7 +35,7 @@ class AccountRepository(object): account_id = self._add_account(account, password) self._add_agreements(account_id, account.agreements) if account.membership is not None: - self._add_membership(account_id, account.membership) + self.add_membership(account_id, account.membership) _log.info('Added account {}'.format(account.email_address)) @@ -73,7 +73,7 @@ class AccountRepository(object): ) self.cursor.insert(request) - def _add_membership(self, acct_id: str, membership: AccountMembership): + def add_membership(self, acct_id: str, membership: AccountMembership): """A membership is optional, add it if one was selected""" request = DatabaseRequest( sql=get_sql_from_file( diff --git a/shared/selene/data/account/repository/membership.py b/shared/selene/data/account/repository/membership.py index 92b23d5e..27f32e69 100644 --- a/shared/selene/data/account/repository/membership.py +++ b/shared/selene/data/account/repository/membership.py @@ -1,3 +1,4 @@ +from selene.data.account import AccountMembership from ..entity.membership import Membership from ...repository_base import RepositoryBase @@ -22,6 +23,15 @@ class MembershipRepository(RepositoryBase): db_result = self.cursor.select_one(db_request) return Membership(**db_result) + def get_active_membership_by_account_id(self, account_id) -> AccountMembership: + db_request = self._build_db_request( + sql_file_name='get_active_membership_by_account_id.sql', + args=dict(account_id=account_id) + ) + db_result = self.cursor.select_one(db_request) + if db_result: + return AccountMembership(**db_result) + def add(self, membership: Membership): db_request = self._build_db_request( 'add_membership.sql', @@ -37,7 +47,20 @@ class MembershipRepository(RepositoryBase): def remove(self, membership: Membership): db_request = self._build_db_request( - 'delete_membership.sql', + sql_file_name='delete_membership.sql', args=dict(membership_id=membership.id) ) self.cursor.delete(db_request) + + def finish_membership(self, membership: AccountMembership): + db_request = self._build_db_request( + sql_file_name='finish_membership.sql', + args=dict( + id=membership.id, + membership_ts_range='[{start},{end}]'.format( + start=membership.start_date, + end=membership.end_date + ) + ) + ) + self.cursor.update(db_request) diff --git a/shared/selene/data/account/repository/sql/finish_membership.sql b/shared/selene/data/account/repository/sql/finish_membership.sql new file mode 100644 index 00000000..9f374ffa --- /dev/null +++ b/shared/selene/data/account/repository/sql/finish_membership.sql @@ -0,0 +1,6 @@ +UPDATE + account.account_membership +SET + membership_ts_range = %(membership_ts_range)s +WHERE + id = %(id)s \ No newline at end of file diff --git a/shared/selene/data/account/repository/sql/get_active_membership_by_account_id.sql b/shared/selene/data/account/repository/sql/get_active_membership_by_account_id.sql new file mode 100644 index 00000000..44950199 --- /dev/null +++ b/shared/selene/data/account/repository/sql/get_active_membership_by_account_id.sql @@ -0,0 +1,6 @@ +SELECT + * +FROM + account.account_membership +WHERE + account_id = %(account_id)s and membership_ts_range @> '[now,)' \ No newline at end of file