removed remaining remnants of the old connection pooling

pull/156/head
Chris Veilleux 2019-05-22 19:56:44 -05:00
parent e6f546788c
commit ac6608ceca
6 changed files with 45 additions and 45 deletions

View File

@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none
from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@given('a device pairing code')
@ -57,9 +57,9 @@ def validate_pairing_code_removal(context):
@then('the device is added to the database')
def validate_response(context):
device_id = context.response.data.decode()
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
device_repository = DeviceRepository(db)
device = device_repository.get_device_by_id(device_id)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
device_repository = DeviceRepository(db)
device = device_repository.get_device_by_id(device_id)
assert_that(device, not_none())
assert_that(device.name, equal_to('home'))

View File

@ -90,10 +90,10 @@ def check_db_for_account(context, membership_option):
@when('the account is deleted')
def account_deleted(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
context.stripe_id = account.membership.payment_id
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
context.stripe_id = account.membership.payment_id
context.response = context.client.delete('/api/account')

View File

@ -13,7 +13,7 @@ from selene.data.account import (
AgreementRepository,
PRIVACY_POLICY
)
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@fixture
@ -32,9 +32,9 @@ def before_feature(context, _):
def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_add_agreement(context, db)
_add_account(context, db)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_add_agreement(context, db)
_add_account(context, db)
def _add_agreement(context, db):
@ -72,8 +72,8 @@ def _add_account(context, db):
def after_scenario(context, _):
with get_db_connection(context.db_pool) as db:
acct_repository = AccountRepository(db)
acct_repository.remove(context.account)
agreement_repository = AgreementRepository(db)
agreement_repository.remove(context.agreement, testing=True)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
acct_repository.remove(context.account)
agreement_repository = AgreementRepository(db)
agreement_repository.remove(context.agreement, testing=True)

View File

@ -3,7 +3,7 @@ from dataclasses import asdict
from http import HTTPStatus
from selene.data.account import AgreementRepository
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
from ..base_endpoint import SeleneEndpoint
@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint):
def get(self, agreement_type):
"""Process HTTP GET request for an agreement."""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
self.response = agreement, HTTPStatus.OK
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
self.response = agreement, HTTPStatus.OK
return self.response

View File

@ -3,7 +3,7 @@ import string
from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
def device_etag_key(device_id: str):
@ -29,7 +29,7 @@ class ETagManager(object):
def __init__(self, cache: SeleneCache, config: dict):
self.cache: SeleneCache = cache
self.db_connection_pool = config['DB_CONNECTION_POOL']
self.db_connection_config = config['DB_CONNECTION_CONFIG']
def get(self, key: str) -> str:
"""Generate a etag with 32 random chars and store it into a given key
@ -60,10 +60,10 @@ class ETagManager(object):
def expire_device_setting_etag_by_account_id(self, account_id: str):
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated
at account level"""
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_setting_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_setting_etag_by_device_id(device.id)
def expire_device_location_etag_by_device_id(self, device_id: str):
"""Expire the etag associate with the device's location entity
@ -73,10 +73,10 @@ class ETagManager(object):
def expire_device_location_etag_by_account_id(self, account_id: str):
"""Expire the locations' etag fpr açç device for a given acccount
:param account_id: account uuid"""
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_location_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_location_etag_by_device_id(device.id)
def expire_skill_etag_by_device_id(self, device_id):
"""Expire the locations' etag for a given device
@ -84,7 +84,7 @@ class ETagManager(object):
self._expire(device_skill_etag_key(device_id))
def expire_skill_etag_by_account_id(self, account_id):
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_skill_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_skill_etag_by_device_id(device.id)

View File

@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item
from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ONE_MINUTE = 60
@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict:
def get_account(context) -> Account:
with get_db_connection(context.db_pool) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
return account