From ac6608ceca6cfc0b38541d0dd614268cd0a31123 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 19:56:44 -0500 Subject: [PATCH] removed remaining remnants of the old connection pooling --- .../tests/features/steps/add_device.py | 8 +++--- .../tests/features/steps/new_account.py | 8 +++--- api/sso/tests/features/environment.py | 18 ++++++------ shared/selene/api/endpoints/agreements.py | 20 ++++++------- shared/selene/api/etag.py | 28 +++++++++---------- shared/selene/api/testing/authentication.py | 8 +++--- 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/api/account/tests/features/steps/add_device.py b/api/account/tests/features/steps/add_device.py index fef8b1c2..5d2f3641 100644 --- a/api/account/tests/features/steps/add_device.py +++ b/api/account/tests/features/steps/add_device.py @@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none from selene.data.device import DeviceRepository from selene.util.cache import SeleneCache -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @given('a device pairing code') @@ -57,9 +57,9 @@ def validate_pairing_code_removal(context): @then('the device is added to the database') def validate_response(context): device_id = context.response.data.decode() - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - device_repository = DeviceRepository(db) - device = device_repository.get_device_by_id(device_id) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + device_repository = DeviceRepository(db) + device = device_repository.get_device_by_id(device_id) assert_that(device, not_none()) assert_that(device.name, equal_to('home')) diff --git a/api/account/tests/features/steps/new_account.py b/api/account/tests/features/steps/new_account.py index 6fe07aa8..12a2ecd9 100644 --- a/api/account/tests/features/steps/new_account.py +++ b/api/account/tests/features/steps/new_account.py @@ -90,10 +90,10 @@ def check_db_for_account(context, membership_option): @when('the account is deleted') def account_deleted(context): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_email('bar@mycroft.ai') - context.stripe_id = account.membership.payment_id + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_email('bar@mycroft.ai') + context.stripe_id = account.membership.payment_id context.response = context.client.delete('/api/account') diff --git a/api/sso/tests/features/environment.py b/api/sso/tests/features/environment.py index 7350eaa5..eafbed7f 100644 --- a/api/sso/tests/features/environment.py +++ b/api/sso/tests/features/environment.py @@ -13,7 +13,7 @@ from selene.data.account import ( AgreementRepository, PRIVACY_POLICY ) -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @fixture @@ -32,9 +32,9 @@ def before_feature(context, _): def before_scenario(context, _): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - _add_agreement(context, db) - _add_account(context, db) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + _add_agreement(context, db) + _add_account(context, db) def _add_agreement(context, db): @@ -72,8 +72,8 @@ def _add_account(context, db): def after_scenario(context, _): - with get_db_connection(context.db_pool) as db: - acct_repository = AccountRepository(db) - acct_repository.remove(context.account) - agreement_repository = AgreementRepository(db) - agreement_repository.remove(context.agreement, testing=True) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + acct_repository.remove(context.account) + agreement_repository = AgreementRepository(db) + agreement_repository.remove(context.agreement, testing=True) diff --git a/shared/selene/api/endpoints/agreements.py b/shared/selene/api/endpoints/agreements.py index c85cb202..99960568 100644 --- a/shared/selene/api/endpoints/agreements.py +++ b/shared/selene/api/endpoints/agreements.py @@ -3,7 +3,7 @@ from dataclasses import asdict from http import HTTPStatus from selene.data.account import AgreementRepository -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db from ..base_endpoint import SeleneEndpoint @@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint): def get(self, agreement_type): """Process HTTP GET request for an agreement.""" - with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: - agreement_repository = AgreementRepository(db) - agreement = agreement_repository.get_active_for_type( - self.agreement_types[agreement_type] - ) - if agreement is not None: - agreement = asdict(agreement) - del(agreement['effective_date']) - self.response = agreement, HTTPStatus.OK + db = connect_to_db(self.config['DB_CONNECTION_CONFIG']) + agreement_repository = AgreementRepository(db) + agreement = agreement_repository.get_active_for_type( + self.agreement_types[agreement_type] + ) + if agreement is not None: + agreement = asdict(agreement) + del(agreement['effective_date']) + self.response = agreement, HTTPStatus.OK return self.response diff --git a/shared/selene/api/etag.py b/shared/selene/api/etag.py index 9cf4c3e8..8ade0fb4 100644 --- a/shared/selene/api/etag.py +++ b/shared/selene/api/etag.py @@ -3,7 +3,7 @@ import string from selene.data.device import DeviceRepository from selene.util.cache import SeleneCache -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db def device_etag_key(device_id: str): @@ -29,7 +29,7 @@ class ETagManager(object): def __init__(self, cache: SeleneCache, config: dict): self.cache: SeleneCache = cache - self.db_connection_pool = config['DB_CONNECTION_POOL'] + self.db_connection_config = config['DB_CONNECTION_CONFIG'] def get(self, key: str) -> str: """Generate a etag with 32 random chars and store it into a given key @@ -60,10 +60,10 @@ class ETagManager(object): def expire_device_setting_etag_by_account_id(self, account_id: str): """Expire the settings' etags for all devices from a given account. Used when the settings are updated at account level""" - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_device_setting_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_device_setting_etag_by_device_id(device.id) def expire_device_location_etag_by_device_id(self, device_id: str): """Expire the etag associate with the device's location entity @@ -73,10 +73,10 @@ class ETagManager(object): def expire_device_location_etag_by_account_id(self, account_id: str): """Expire the locations' etag fpr açç device for a given acccount :param account_id: account uuid""" - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_device_location_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_device_location_etag_by_device_id(device.id) def expire_skill_etag_by_device_id(self, device_id): """Expire the locations' etag for a given device @@ -84,7 +84,7 @@ class ETagManager(object): self._expire(device_skill_etag_key(device_id)) def expire_skill_etag_by_account_id(self, account_id): - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_skill_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_skill_etag_by_device_id(device.id) diff --git a/shared/selene/api/testing/authentication.py b/shared/selene/api/testing/authentication.py index 9be2283a..0b70f5db 100644 --- a/shared/selene/api/testing/authentication.py +++ b/shared/selene/api/testing/authentication.py @@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item from selene.data.account import Account, AccountRepository from selene.util.auth import AuthenticationToken -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess' ONE_MINUTE = 60 @@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict: def get_account(context) -> Account: - with get_db_connection(context.db_pool) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_id(context.account.id) + db = connect_to_db(context.client['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_id(context.account.id) return account