Created PATCH request to update membership
parent
f02d1f151e
commit
7fc2a54ab0
|
@ -9,8 +9,8 @@ from .endpoints.device import DeviceEndpoint
|
|||
from .endpoints.device_count import DeviceCountEndpoint
|
||||
from .endpoints.geography import GeographyEndpoint
|
||||
from .endpoints.membership import MembershipEndpoint
|
||||
from .endpoints.skills import SkillsEndpoint
|
||||
from .endpoints.skill_settings import SkillSettingsEndpoint
|
||||
from .endpoints.skills import SkillsEndpoint
|
||||
from .endpoints.voice_endpoint import VoiceEndpoint
|
||||
from .endpoints.wake_word_endpoint import WakeWordEndpoint
|
||||
|
||||
|
@ -26,7 +26,7 @@ acct.register_blueprint(selene_api)
|
|||
acct.add_url_rule(
|
||||
'/api/account',
|
||||
view_func=AccountEndpoint.as_view('account_api'),
|
||||
methods=['GET', 'POST']
|
||||
methods=['GET', 'POST', 'PATCH']
|
||||
)
|
||||
acct.add_url_rule(
|
||||
'/api/agreement/<string:agreement_type>',
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
Feature: Test the API call to update a membership
|
||||
|
||||
Scenario: user with free account opts into a membership
|
||||
Given a user with a free account
|
||||
When a monthly membership is added
|
|
@ -43,7 +43,7 @@ def get_device_subscription(context):
|
|||
access_token = login['accessToken']
|
||||
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
AccountRepository(db)._add_membership(context.account.id, membership)
|
||||
AccountRepository(db).add_membership(context.account.id, membership)
|
||||
context.subscription_response = context.client.get(
|
||||
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
|
||||
headers=headers
|
||||
|
|
|
@ -63,13 +63,21 @@ class Support(Model):
|
|||
payment_method = StringType(choices=[STRIPE_PAYMENT])
|
||||
payment_token = StringType()
|
||||
|
||||
def validate_payment_account_id(self, data, value):
|
||||
if data['membership'] != NO_MEMBERSHIP:
|
||||
if not data['payment_account_id']:
|
||||
raise ValidationError(
|
||||
'Membership requires a payment account ID'
|
||||
)
|
||||
return value
|
||||
|
||||
class AddMembership(Model):
|
||||
membership = StringType(
|
||||
required=True,
|
||||
choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP)
|
||||
)
|
||||
payment_method = StringType(required=True, choices=[STRIPE_PAYMENT])
|
||||
payment_token = StringType(required=True)
|
||||
|
||||
|
||||
class UpdateMembership(Model):
|
||||
membership = StringType(
|
||||
required=True,
|
||||
choices=(MONTHLY_MEMBERSHIP, YEARLY_MEMBERSHIP, NO_MEMBERSHIP)
|
||||
)
|
||||
|
||||
|
||||
class AddAccountRequest(Model):
|
||||
|
@ -143,6 +151,11 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
|
||||
return jsonify('Account added successfully'), HTTPStatus.OK
|
||||
|
||||
def patch(self):
|
||||
self._authenticate()
|
||||
self.request_data = json.loads(self.request.data)
|
||||
self._update_support()
|
||||
|
||||
def _validate_request(self):
|
||||
add_request = AddAccountRequest(dict(
|
||||
username=self.request_data.get('username'),
|
||||
|
@ -219,11 +232,58 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
password=password
|
||||
)
|
||||
|
||||
def _create_stripe_subscription(self, token, user_email, plan):
|
||||
customer = stripe.Customer.create(source=token, email=user_email)
|
||||
subscription = stripe.Subscription.create(customer=customer.id, items=[{'plan': plan}])
|
||||
return customer.id, subscription.current_period_start
|
||||
def _create_stripe_subscription(self, customer_id, token, user_email, plan):
|
||||
if customer_id is None:
|
||||
customer = stripe.Customer.create(source=token, email=user_email)
|
||||
customer_id = customer.id
|
||||
subscription = stripe.Subscription.create(customer=customer_id, items=[{'plan': plan}])
|
||||
|
||||
# TODO: store subscription.id
|
||||
start = subscription.current_period_start
|
||||
start = date.fromtimestamp(start)
|
||||
return customer_id, start
|
||||
|
||||
def _get_plan(self, plan):
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
return MembershipRepository(db).get_membership_by_type(plan)
|
||||
|
||||
def _update_support(self):
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
membership_repository = MembershipRepository(db)
|
||||
active_membership = membership_repository.get_active_membership_by_account_id(self.account.id)
|
||||
if active_membership:
|
||||
active_membership.end_date = datetime.now()
|
||||
# TODO: use the subscription id to delete the membership on stripe
|
||||
membership_repository.finish_membership(active_membership)
|
||||
add_membership = UpdateMembership(self.request_data.get('support'))
|
||||
add_membership.validate()
|
||||
support = self.request_data['support']
|
||||
membership = support['membership']
|
||||
stripe_plan = self._get_plan(membership)
|
||||
stripe_id, start_date = self._create_stripe_subscription(
|
||||
active_membership.payment_account_id,
|
||||
None,
|
||||
self.account.email_address,
|
||||
stripe_plan
|
||||
)
|
||||
else:
|
||||
add_membership = AddMembership(self.request_data.get('support'))
|
||||
add_membership.validate()
|
||||
support = self.request_data['support']
|
||||
membership = support['membership']
|
||||
token = support['payment']
|
||||
stripe_plan = self._get_plan(membership)
|
||||
stripe_id, start_date = self._create_stripe_subscription(
|
||||
None,
|
||||
token,
|
||||
self.account.email_address,
|
||||
stripe_plan
|
||||
)
|
||||
|
||||
new_membership = AccountMembership(
|
||||
start_date=start_date,
|
||||
payment_method=STRIPE_PAYMENT,
|
||||
payment_account_id=stripe_id,
|
||||
type=membership
|
||||
)
|
||||
AccountRepository(db).add_membership(self.account.id, new_membership)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import date
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import List
|
||||
|
||||
|
||||
|
@ -19,6 +19,8 @@ class AccountMembership(object):
|
|||
payment_method: str
|
||||
payment_account_id: str
|
||||
id: str = None
|
||||
account_id: str = None
|
||||
end_date: date = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -35,7 +35,7 @@ class AccountRepository(object):
|
|||
account_id = self._add_account(account, password)
|
||||
self._add_agreements(account_id, account.agreements)
|
||||
if account.membership is not None:
|
||||
self._add_membership(account_id, account.membership)
|
||||
self.add_membership(account_id, account.membership)
|
||||
|
||||
_log.info('Added account {}'.format(account.email_address))
|
||||
|
||||
|
@ -73,7 +73,7 @@ class AccountRepository(object):
|
|||
)
|
||||
self.cursor.insert(request)
|
||||
|
||||
def _add_membership(self, acct_id: str, membership: AccountMembership):
|
||||
def add_membership(self, acct_id: str, membership: AccountMembership):
|
||||
"""A membership is optional, add it if one was selected"""
|
||||
request = DatabaseRequest(
|
||||
sql=get_sql_from_file(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from selene.data.account import AccountMembership
|
||||
from ..entity.membership import Membership
|
||||
from ...repository_base import RepositoryBase
|
||||
|
||||
|
@ -22,6 +23,15 @@ class MembershipRepository(RepositoryBase):
|
|||
db_result = self.cursor.select_one(db_request)
|
||||
return Membership(**db_result)
|
||||
|
||||
def get_active_membership_by_account_id(self, account_id) -> AccountMembership:
|
||||
db_request = self._build_db_request(
|
||||
sql_file_name='get_active_membership_by_account_id.sql',
|
||||
args=dict(account_id=account_id)
|
||||
)
|
||||
db_result = self.cursor.select_one(db_request)
|
||||
if db_result:
|
||||
return AccountMembership(**db_result)
|
||||
|
||||
def add(self, membership: Membership):
|
||||
db_request = self._build_db_request(
|
||||
'add_membership.sql',
|
||||
|
@ -37,7 +47,20 @@ class MembershipRepository(RepositoryBase):
|
|||
|
||||
def remove(self, membership: Membership):
|
||||
db_request = self._build_db_request(
|
||||
'delete_membership.sql',
|
||||
sql_file_name='delete_membership.sql',
|
||||
args=dict(membership_id=membership.id)
|
||||
)
|
||||
self.cursor.delete(db_request)
|
||||
|
||||
def finish_membership(self, membership: AccountMembership):
|
||||
db_request = self._build_db_request(
|
||||
sql_file_name='finish_membership.sql',
|
||||
args=dict(
|
||||
id=membership.id,
|
||||
membership_ts_range='[{start},{end}]'.format(
|
||||
start=membership.start_date,
|
||||
end=membership.end_date
|
||||
)
|
||||
)
|
||||
)
|
||||
self.cursor.update(db_request)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
UPDATE
|
||||
account.account_membership
|
||||
SET
|
||||
membership_ts_range = %(membership_ts_range)s
|
||||
WHERE
|
||||
id = %(id)s
|
|
@ -0,0 +1,6 @@
|
|||
SELECT
|
||||
*
|
||||
FROM
|
||||
account.account_membership
|
||||
WHERE
|
||||
account_id = %(account_id)s and membership_ts_range @> '[now,)'
|
Loading…
Reference in New Issue