From fe2e8bc7e799d4d11d927ec7681c0c3117357bbe Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Tue, 21 May 2019 14:18:36 -0500 Subject: [PATCH 01/15] moved population of account.agreement table and geography tables from python notebooks into the bootstrap script --- db/mycroft/metrics_schema/tables/job.sql | 11 ++ db/scripts/bootstrap_mycroft_db.py | 236 +++++++++++++++++++---- 2 files changed, 207 insertions(+), 40 deletions(-) create mode 100644 db/mycroft/metrics_schema/tables/job.sql diff --git a/db/mycroft/metrics_schema/tables/job.sql b/db/mycroft/metrics_schema/tables/job.sql new file mode 100644 index 00000000..8dbd8654 --- /dev/null +++ b/db/mycroft/metrics_schema/tables/job.sql @@ -0,0 +1,11 @@ +CREATE TABLE metrics.job ( + id uuid PRIMARY KEY + DEFAULT gen_random_uuid(), + job_name text NOT NULL, + batch_date date NOT NULL, + start_ts TIMESTAMP NOT NULL, + end_ts TIMESTAMP NOT NULL, + command text NOT NULL, + success BOOLEAN NOT NULL, + UNIQUE (job_name, start_ts) +) diff --git a/db/scripts/bootstrap_mycroft_db.py b/db/scripts/bootstrap_mycroft_db.py index 9136dd53..0b94f5bb 100644 --- a/db/scripts/bootstrap_mycroft_db.py +++ b/db/scripts/bootstrap_mycroft_db.py @@ -1,6 +1,7 @@ from glob import glob -from os import path +from os import environ, path, remove +from markdown import markdown from psycopg2 import connect MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft') @@ -8,10 +9,8 @@ SCHEMAS = ('account', 'skill', 'device', 'geography', 'metrics') DB_DESTROY_FILES = ( 'drop_mycroft_db.sql', 'drop_template_db.sql', - # 'drop_roles.sql' ) DB_CREATE_FILES = ( - # 'create_roles.sql', 'create_template_db.sql', ) ACCOUNT_TABLE_ORDER = ( @@ -48,6 +47,7 @@ GEOGRAPHY_TABLE_ORDER = ( METRICS_TABLE_ORDER = ( 'api', + 'job' ) schema_directory = '{}_schema' @@ -61,32 +61,40 @@ def get_sql_from_file(file_path: str) -> str: class PostgresDB(object): - def __init__(self, dbname, user, password=None): - self.db = connect(dbname=dbname, user=user, host='127.0.0.1') - # self.db = connect( - # dbname=dbname, - # user=user, - # password=password, - # host='selene-test-db-do-user-1412453-0.db.ondigitalocean.com', - # port=25060, - # sslmode='require' - # ) + def __init__(self, db_name, user=None): + db_host = environ['DB_HOST'] + db_port = environ['DB_PORT'] + db_ssl_mode = environ.get('DB_SSL_MODE') + if db_name in ('postgres', 'defaultdb'): + db_user = environ['POSTGRES_DB_USER'] + db_password = environ.get('POSTGRES_DB_PASSWORD') + else: + db_user = environ['MYCROFT_DB_USER'] + db_password = environ['MYCROFT_DB_PASSWORD'] + + if user is not None: + db_user = user + + self.db = connect( + dbname=db_name, + user=db_user, + password=db_password, + host=db_host, + port=db_port, + sslmode=db_ssl_mode + ) self.db.autocommit = True def close_db(self): self.db.close() - def execute_sql(self, sql: str): + def execute_sql(self, sql: str, args=None): cursor = self.db.cursor() - cursor.execute(sql) + cursor.execute(sql, args) + return cursor -postgres_db = PostgresDB(dbname='postgres', user='postgres') -# postgres_db = PostgresDB( -# dbname='defaultdb', -# user='doadmin', -# password='l06tn0qi2bjhgcki' -# ) +postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME']) print('Destroying any objects we will be creating later.') for db_destroy_file in DB_DESTROY_FILES: @@ -94,7 +102,7 @@ for db_destroy_file in DB_DESTROY_FILES: get_sql_from_file(db_destroy_file) ) -print('Creating the extensions, mycroft database, and selene roles') +print('Creating the mycroft database') for db_setup_file in DB_CREATE_FILES: postgres_db.execute_sql( get_sql_from_file(db_setup_file) @@ -102,13 +110,10 @@ for db_setup_file in DB_CREATE_FILES: postgres_db.close_db() -template_db = PostgresDB(dbname='mycroft_template', user='mycroft') -# template_db = PostgresDB( -# dbname='mycroft_template', -# user='selene', -# password='ubhemhx1dikmqc5f' -# ) +template_db = PostgresDB(db_name='mycroft_template') + +print('Creating the extensions') template_db.execute_sql( get_sql_from_file(path.join('create_extensions.sql')) ) @@ -193,22 +198,14 @@ for schema in SCHEMAS: template_db.close_db() + print('Copying template to new database.') -postgres_db = PostgresDB(dbname='postgres', user='mycroft') -# postgres_db = PostgresDB( -# dbname='defaultdb', -# user='doadmin', -# password='l06tn0qi2bjhgcki' -# ) +postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME']) postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql')) postgres_db.close_db() -mycroft_db = PostgresDB(dbname='mycroft', user='mycroft') -# mycroft_db = PostgresDB( -# dbname='mycroft_template', -# user='selene', -# password='ubhemhx1dikmqc5f' -# ) + +mycroft_db = PostgresDB(db_name=environ['MYCROFT_DB_NAME']) insert_files = [ dict(schema_dir='account_schema', file_name='membership.sql'), dict(schema_dir='device_schema', file_name='text_to_speech.sql'), @@ -226,3 +223,162 @@ for insert_file in insert_files: ) except FileNotFoundError: pass + +print('Building account.agreement table') +mycroft_db.db.autocommit = False +insert_sql = ( + "insert into account.agreement VALUES (default, '{}', '1', '[today,]', {})" +) +doc_dir = '/Users/chrisveilleux/Mycroft/github/documentation/_pages/' +docs = { + 'Privacy Policy': doc_dir + 'embed-privacy-policy.md', + 'Terms of Use': doc_dir + 'embed-terms-of-use.md' +} +try: + for agrmt_type, doc_path in docs.items(): + lobj = mycroft_db.db.lobject(0, 'b') + with open(doc_path) as doc: + header_delimiter_count = 0 + while True: + rec = doc.readline() + if rec == '---\n': + header_delimiter_count += 1 + if header_delimiter_count == 2: + break + doc_html = markdown( + doc.read(), + output_format='html5' + ) + lobj.write(doc_html) + mycroft_db.execute_sql( + insert_sql.format(agrmt_type, lobj.oid) + ) + mycroft_db.execute_sql( + "grant select on large object {} to selene".format(lobj.oid) + ) + mycroft_db.execute_sql( + insert_sql.format('Open Dataset', 'null') + ) +except: + mycroft_db.db.rollback() + raise +else: + mycroft_db.db.commit() + +mycroft_db.db.autocommit = True + +reference_file_dir = '/Users/chrisveilleux/Mycroft' + +print('Building geography.country table') +country_file = 'country.txt' +country_insert = """ +INSERT INTO + geography.country (iso_code, name) +VALUES + ('{iso_code}', '{country_name}') +""" + +with open(path.join(reference_file_dir, country_file)) as countries: + while True: + rec = countries.readline() + if rec.startswith('#ISO'): + break + + for country in countries.readlines(): + country_fields = country.split('\t') + insert_args = dict( + iso_code=country_fields[0], + country_name=country_fields[4] + ) + mycroft_db.execute_sql(country_insert.format(**insert_args)) + +print('Building geography.region table') +region_file = 'regions.txt' +region_insert = """ +INSERT INTO + geography.region (country_id, region_code, name) +VALUES + ( + (SELECT id FROM geography.country WHERE iso_code = %(iso_code)s), + %(region_code)s, + %(region_name)s) +""" +with open(path.join(reference_file_dir, region_file)) as regions: + for region in regions.readlines(): + region_fields = region.split('\t') + country_iso_code = region_fields[0][:2] + insert_args = dict( + iso_code=country_iso_code, + region_code=region_fields[0], + region_name=region_fields[1] + ) + mycroft_db.execute_sql(region_insert, insert_args) + +print('Building geography.timezone table') +timezone_file = 'timezones.txt' +timezone_insert = """ +INSERT INTO + geography.timezone (country_id, name, gmt_offset, dst_offset) +VALUES + ( + (SELECT id FROM geography.country WHERE iso_code = %(iso_code)s), + %(timezone_name)s, + %(gmt_offset)s, + %(dst_offset)s + ) +""" +with open(path.join(reference_file_dir, timezone_file)) as timezones: + timezones.readline() + for timezone in timezones.readlines(): + timezone_fields = timezone.split('\t') + insert_args = dict( + iso_code=timezone_fields[0], + timezone_name=timezone_fields[1], + gmt_offset=timezone_fields[2], + dst_offset=timezone_fields[3] + ) + mycroft_db.execute_sql(timezone_insert, insert_args) + +print('Building geography.city table') +cities_file = 'cities500.txt' +region_query = "SELECT id, region_code FROM geography.region" +query_result = mycroft_db.execute_sql(region_query) +region_lookup = dict() +for row in query_result.fetchall(): + region_lookup[row[1]] = row[0] + +timezone_query = "SELECT id, name FROM geography.timezone" +query_result = mycroft_db.execute_sql(timezone_query) +timezone_lookup = dict() +for row in query_result.fetchall(): + timezone_lookup[row[1]] = row[0] +# city_insert = """ +# INSERT INTO +# geography.city (region_id, timezone_id, name, latitude, longitude) +# VALUES +# (%(region_id)s, %(timezone_id)s, %(city_name)s, %(latitude)s, %(longitude)s) +# """ +with open(path.join(reference_file_dir, cities_file)) as cities: + with open(path.join(reference_file_dir, 'city.dump'), 'w') as dump_file: + for city in cities.readlines(): + city_fields = city.split('\t') + city_region = city_fields[8] + '.' + city_fields[10] + region_id = region_lookup.get(city_region) + timezone_id = timezone_lookup[city_fields[17]] + if region_id is not None: + dump_file.write('\t'.join([ + region_id, + timezone_id, + city_fields[1], + city_fields[4], + city_fields[5] + ]) + '\n') + # mycroft_db.execute_sql(city_insert, insert_args) +with open(path.join(reference_file_dir, 'city.dump')) as dump_file: + cursor = mycroft_db.db.cursor() + cursor.copy_from(dump_file, 'geography.city', columns=( + 'region_id', 'timezone_id', 'name', 'latitude', 'longitude') + ) +remove(path.join(reference_file_dir, 'city.dump')) + +mycroft_db.close_db() From c3a1d74fc3770c0570dc98feac9183c8d14e1408 Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Tue, 21 May 2019 20:39:03 -0300 Subject: [PATCH 02/15] Adding uuid in the response returned by the endpoint to get the device --- api/public/public_api/endpoints/device.py | 1 + api/public/tests/features/steps/get_device.py | 1 + 2 files changed, 2 insertions(+) diff --git a/api/public/public_api/endpoints/device.py b/api/public/public_api/endpoints/device.py index 4c26ea5f..2b1f087b 100644 --- a/api/public/public_api/endpoints/device.py +++ b/api/public/public_api/endpoints/device.py @@ -28,6 +28,7 @@ class DeviceEndpoint(PublicEndpoint): if device is not None: response_data = dict( + uuid=device.id, name=device.name, description=device.placement, coreVersion=device.core_version, diff --git a/api/public/tests/features/steps/get_device.py b/api/public/tests/features/steps/get_device.py index e44f1750..15cffda5 100644 --- a/api/public/tests/features/steps/get_device.py +++ b/api/public/tests/features/steps/get_device.py @@ -31,6 +31,7 @@ def validate_response(context): response = context.get_device_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) device = json.loads(response.data) + assert_that(device, has_key('uuid')) assert_that(device, has_key('name')) assert_that(device, has_key('description')) assert_that(device, has_key('coreVersion')) From 1092f8f1eda31edb7a98f3935ec20056ae38c88d Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Wed, 22 May 2019 05:22:08 -0300 Subject: [PATCH 03/15] Allowing skill_gid in the skill manifest --- api/public/public_api/endpoints/device_skill_manifest.py | 1 + api/public/tests/features/steps/device_skill_manifest.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/api/public/public_api/endpoints/device_skill_manifest.py b/api/public/public_api/endpoints/device_skill_manifest.py index 866b899e..e6ef7133 100644 --- a/api/public/public_api/endpoints/device_skill_manifest.py +++ b/api/public/public_api/endpoints/device_skill_manifest.py @@ -27,6 +27,7 @@ class SkillManifest(Model): installed = DateTimeType() updated = DateTimeType() update = DateTimeType() + skill_gid = StringType() class SkillJson(Model): diff --git a/api/public/tests/features/steps/device_skill_manifest.py b/api/public/tests/features/steps/device_skill_manifest.py index 121fa73a..83a37aea 100644 --- a/api/public/tests/features/steps/device_skill_manifest.py +++ b/api/public/tests/features/steps/device_skill_manifest.py @@ -17,7 +17,8 @@ skill_manifest = { "installed": datetime.now().timestamp(), "updated": datetime.now().timestamp(), "installation": "installed", - "update": 0 + "update": 0, + "skill_gid": "fallback-wolfram-alpha|19.02", } ] } From de0af15ae72ae8578c90e4b0fc6e21278cf01f52 Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Wed, 22 May 2019 12:27:35 -0300 Subject: [PATCH 04/15] Releasing the db connection when an exception happens --- shared/selene/api/blueprint.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/shared/selene/api/blueprint.py b/shared/selene/api/blueprint.py index 37fd13d3..846684ef 100644 --- a/shared/selene/api/blueprint.py +++ b/shared/selene/api/blueprint.py @@ -1,5 +1,5 @@ -from datetime import datetime import json +from datetime import datetime from http import HTTPStatus from flask import current_app, Blueprint, g as global_context @@ -31,6 +31,11 @@ def handle_not_modified(error): return '', HTTPStatus.NOT_MODIFIED +@selene_api.app_errorhandler(Exception) +def release_connection_after_error(error): + release_db_connection() + + @selene_api.before_app_request def setup_request(): global_context.start_ts = datetime.utcnow() From 43ef249d009905ff5dfbe8b2389a7b77a955685a Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 13:12:23 -0500 Subject: [PATCH 05/15] change public API to use a different connection pooling mechanism --- .../public_api/endpoints/device_oauth.py | 4 +- api/public/tests/features/environment.py | 28 +++++----- shared/selene/api/base_config.py | 8 +++ shared/selene/api/blueprint.py | 24 ++------- shared/selene/api/public_endpoint.py | 6 +-- shared/selene/util/db/connection.py | 54 ++++++++----------- 6 files changed, 52 insertions(+), 72 deletions(-) diff --git a/api/public/public_api/endpoints/device_oauth.py b/api/public/public_api/endpoints/device_oauth.py index d7b7c73d..ab61055e 100644 --- a/api/public/public_api/endpoints/device_oauth.py +++ b/api/public/public_api/endpoints/device_oauth.py @@ -4,7 +4,6 @@ import requests from selene.api import PublicEndpoint from selene.data.account import AccountRepository -from selene.util.db import get_db_connection class OauthServiceEndpoint(PublicEndpoint): @@ -14,8 +13,7 @@ class OauthServiceEndpoint(PublicEndpoint): self.oauth_service_host = os.environ['OAUTH_BASE_URL'] def get(self, device_id, credentials, oauth_path): - with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: - account = AccountRepository(db).get_account_by_device_id(device_id) + account = AccountRepository(self.db).get_account_by_device_id(device_id) uuid = account.id url = '{host}/auth/{credentials}/{oauth_path}'.format( host=self.oauth_service_host, diff --git a/api/public/tests/features/environment.py b/api/public/tests/features/environment.py index 4a997aef..5b99de3a 100644 --- a/api/public/tests/features/environment.py +++ b/api/public/tests/features/environment.py @@ -21,7 +21,7 @@ from selene.data.device import ( Geography) from selene.data.device.entity.text_to_speech import TextToSpeech from selene.data.device.entity.wake_word import WakeWord -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db account = Account( email_address='test@test.com', @@ -63,22 +63,22 @@ def before_feature(context, _): def before_scenario(context, _): cache = context.client_config['SELENE_CACHE'] context.etag_manager = ETagManager(cache, context.client_config) - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - try: - _add_agreements(context, db) - _add_account(context, db) - _add_account_preference(context, db) - _add_geography(context, db) - _add_device(context, db) - except Exception as e: - import traceback - print(traceback.print_exc()) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + try: + _add_agreements(context, db) + _add_account(context, db) + _add_account_preference(context, db) + _add_geography(context, db) + _add_device(context, db) + except Exception as e: + import traceback + print(traceback.print_exc()) def after_scenario(context, _): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - _remove_account(context, db) - _remove_agreements(context, db) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + _remove_account(context, db) + _remove_agreements(context, db) def _add_agreements(context, db): diff --git a/shared/selene/api/base_config.py b/shared/selene/api/base_config.py index 82484cd7..b5a99bd8 100644 --- a/shared/selene/api/base_config.py +++ b/shared/selene/api/base_config.py @@ -43,6 +43,14 @@ class BaseConfig(object): DEBUG = False ENV = os.environ['SELENE_ENVIRONMENT'] REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET'] + DB_CONNECTION_CONFIG = DatabaseConnectionConfig( + host=os.environ['DB_HOST'], + db_name=os.environ['DB_NAME'], + password=os.environ['DB_PASSWORD'], + port=os.environ.get('DB_PORT', 5432), + user=os.environ['DB_USER'], + sslmode=os.environ.get('DB_SSLMODE') + ) class DevelopmentConfig(BaseConfig): diff --git a/shared/selene/api/blueprint.py b/shared/selene/api/blueprint.py index 846684ef..1aa25bfb 100644 --- a/shared/selene/api/blueprint.py +++ b/shared/selene/api/blueprint.py @@ -7,10 +7,7 @@ from schematics.exceptions import DataError from selene.data.metrics import ApiMetric, ApiMetricsRepository from selene.util.auth import AuthenticationError -from selene.util.db import ( - get_db_connection_from_pool, - return_db_connection_to_pool -) +from selene.util.db import connect_to_db from selene.util.not_modified import NotModifiedError selene_api = Blueprint('selene_api', __name__) @@ -31,11 +28,6 @@ def handle_not_modified(error): return '', HTTPStatus.NOT_MODIFIED -@selene_api.app_errorhandler(Exception) -def release_connection_after_error(error): - release_db_connection() - - @selene_api.before_app_request def setup_request(): global_context.start_ts = datetime.utcnow() @@ -44,7 +36,6 @@ def setup_request(): @selene_api.after_app_request def teardown_request(response): add_api_metric(response.status_code) - release_db_connection() return response @@ -59,8 +50,8 @@ def add_api_metric(http_status): if api is not None and int(http_status) != 304: if 'db' not in global_context: - global_context.db = get_db_connection_from_pool( - current_app.config['DB_CONNECTION_POOL'] + global_context.db = connect_to_db( + current_app.config['DB_CONNECTION_CONFIG'] ) if 'account_id' in global_context: account_id = global_context.account_id @@ -83,12 +74,3 @@ def add_api_metric(http_status): ) metric_repository = ApiMetricsRepository(global_context.db) metric_repository.add(api_metric) - - -def release_db_connection(): - db = global_context.pop('db', None) - if db is not None: - return_db_connection_to_pool( - current_app.config['DB_CONNECTION_POOL'], - db - ) diff --git a/shared/selene/api/public_endpoint.py b/shared/selene/api/public_endpoint.py index c3314384..8683fa0c 100644 --- a/shared/selene/api/public_endpoint.py +++ b/shared/selene/api/public_endpoint.py @@ -13,7 +13,7 @@ from flask.views import MethodView from selene.api.etag import ETagManager from selene.util.auth import AuthenticationError -from selene.util.db import get_db_connection_from_pool +from selene.util.db import connect_to_db from selene.util.not_modified import NotModifiedError from ..util.cache import SeleneCache @@ -91,8 +91,8 @@ class PublicEndpoint(MethodView): @property def db(self): if 'db' not in global_context: - global_context.db = get_db_connection_from_pool( - current_app.config['DB_CONNECTION_POOL'] + global_context.db = connect_to_db( + current_app.config['DB_CONNECTION_CONFIG'] ) return global_context.db diff --git a/shared/selene/util/db/connection.py b/shared/selene/util/db/connection.py index 031b00d6..37b06a19 100644 --- a/shared/selene/util/db/connection.py +++ b/shared/selene/util/db/connection.py @@ -6,12 +6,12 @@ Example Usage: query_result = mycroft_db_ro.execute_sql(sql) """ -from contextlib import contextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass, field, InitVar from logging import getLogger from psycopg2 import connect -from psycopg2.extras import RealDictCursor +from psycopg2.extras import RealDictCursor, NamedTupleCursor +from psycopg2.extensions import cursor _log = getLogger(__package__) @@ -29,10 +29,16 @@ class DatabaseConnectionConfig(object): password: str port: int = field(default=5432) sslmode: str = None + autocommit: str = True + cursor_factory = RealDictCursor + use_namedtuple_cursor: InitVar[bool] = False + + def __post_init__(self, use_namedtuple_cursor: bool): + if use_namedtuple_cursor: + self.cursor_factory = NamedTupleCursor -@contextmanager -def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True): +def connect_to_db(connection_config: DatabaseConnectionConfig): """ Return a connection to the mycroft database for the specified user. @@ -41,33 +47,19 @@ def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True): python notebook) :param connection_config: data needed to establish a connection - :param autocommit: indicated if transactions should commit automatically :return: database connection """ - db = None log_msg = 'establishing connection to the {db_name} database' _log.info(log_msg.format(db_name=connection_config.db_name)) - try: - if connection_config.sslmode is None: - db = connect( - host=connection_config.host, - dbname=connection_config.db_name, - user=connection_config.user, - port=connection_config.port, - cursor_factory=RealDictCursor, - ) - else: - db = connect( - host=connection_config.host, - dbname=connection_config.db_name, - user=connection_config.user, - password=connection_config.password, - port=connection_config.port, - cursor_factory=RealDictCursor, - sslmode=connection_config.sslmode - ) - db.autocommit = autocommit - yield db - finally: - if db is not None: - db.close() + db = connect( + host=connection_config.host, + dbname=connection_config.db_name, + user=connection_config.user, + password=connection_config.password, + port=connection_config.port, + cursor_factory=connection_config.cursor_factory, + sslmode=connection_config.sslmode + ) + db.autocommit = connection_config.autocommit + + return db From 00be679ad4a7ae528b41fc1f9b9632f7b63ad714 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 13:19:06 -0500 Subject: [PATCH 06/15] removed code that allocated a connection pool using the old method --- shared/selene/api/base_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/shared/selene/api/base_config.py b/shared/selene/api/base_config.py index b5a99bd8..c7ca06f2 100644 --- a/shared/selene/api/base_config.py +++ b/shared/selene/api/base_config.py @@ -88,10 +88,4 @@ def get_base_config(): error_msg = 'no configuration defined for the "{}" environment' raise APIConfigError(error_msg.format(environment_name)) - max_db_connections = os.environ.get('MAX_DB_CONNECTIONS', 20) - app_config.DB_CONNECTION_POOL = allocate_db_connection_pool( - db_connection_config, - max_db_connections - ) - return app_config From 7c96bde026c74b541ce6fc60dc2683c654fc18bd Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 13:22:45 -0500 Subject: [PATCH 07/15] fixed tests still using the old connection pool --- api/public/tests/features/steps/device_skills.py | 6 +++--- .../tests/features/steps/get_device_subscription.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/public/tests/features/steps/device_skills.py b/api/public/tests/features/steps/device_skills.py index 6a0c013e..63846c87 100644 --- a/api/public/tests/features/steps/device_skills.py +++ b/api/public/tests/features/steps/device_skills.py @@ -6,7 +6,7 @@ from hamcrest import assert_that, equal_to, not_none, is_not from selene.api.etag import ETagManager, device_skill_etag_key from selene.data.skill import SkillSettingRepository -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db skill = { 'skill_gid': 'wolfram-alpha|19.02', @@ -106,8 +106,8 @@ def update_skill(context): }] response = json.loads(context.upload_device_response.data) skill_id = response['uuid'] - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings) @when('the skill settings is fetched') diff --git a/api/public/tests/features/steps/get_device_subscription.py b/api/public/tests/features/steps/get_device_subscription.py index 18c36f42..6543976f 100644 --- a/api/public/tests/features/steps/get_device_subscription.py +++ b/api/public/tests/features/steps/get_device_subscription.py @@ -7,7 +7,7 @@ from behave import when, then from hamcrest import assert_that, has_entry, equal_to from selene.data.account import AccountRepository, AccountMembership -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @when('the subscription endpoint is called') @@ -42,9 +42,9 @@ def get_device_subscription(context): login = context.device_login device_id = login['uuid'] access_token = login['accessToken'] - headers=dict(Authorization='Bearer {token}'.format(token=access_token)) - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - AccountRepository(db).add_membership(context.account.id, membership) + headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + AccountRepository(db).add_membership(context.account.id, membership) context.subscription_response = context.client.get( '/v1/device/{uuid}/subscription'.format(uuid=device_id), headers=headers From 9d87bc0170f50bc533a769f43682dc3f37248b43 Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Wed, 22 May 2019 17:04:14 -0300 Subject: [PATCH 08/15] Fixing path in the metrics service --- api/public/public_api/endpoints/device_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/public/public_api/endpoints/device_metrics.py b/api/public/public_api/endpoints/device_metrics.py index a46425c8..69afe3b1 100644 --- a/api/public/public_api/endpoints/device_metrics.py +++ b/api/public/public_api/endpoints/device_metrics.py @@ -18,7 +18,7 @@ class MetricsService(object): deviceUuid=device_id, data=data ) - url = '{host}/{metric}'.format(host=self.metrics_service_host, metric=metric) + url = '{host}/metrics/{metric}'.format(host=self.metrics_service_host, metric=metric) requests.post(url, body) From dab7301556ff2f31fb876c3cf00fa39b78ee64d2 Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Wed, 22 May 2019 17:15:12 -0300 Subject: [PATCH 09/15] Fix typo --- api/public/public_api/endpoints/device_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/public/public_api/endpoints/device_metrics.py b/api/public/public_api/endpoints/device_metrics.py index 69afe3b1..ba7f38f5 100644 --- a/api/public/public_api/endpoints/device_metrics.py +++ b/api/public/public_api/endpoints/device_metrics.py @@ -18,7 +18,7 @@ class MetricsService(object): deviceUuid=device_id, data=data ) - url = '{host}/metrics/{metric}'.format(host=self.metrics_service_host, metric=metric) + url = '{host}/metric/{metric}'.format(host=self.metrics_service_host, metric=metric) requests.post(url, body) From 71df9c93130124af8d5e8ffcad4204708063e0f1 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 15:41:30 -0500 Subject: [PATCH 10/15] modified base endpoint class and tests to use new connection pooling mechanism --- api/account/tests/features/environment.py | 16 +++--- .../tests/features/steps/authentication.py | 8 +-- .../tests/features/steps/new_account.py | 47 ++++++++-------- .../tests/features/steps/update_membership.py | 53 ++++++++++--------- shared/selene/api/base_config.py | 9 ---- shared/selene/api/base_endpoint.py | 6 +-- 6 files changed, 66 insertions(+), 73 deletions(-) diff --git a/api/account/tests/features/environment.py b/api/account/tests/features/environment.py index 5a36200f..155544a8 100644 --- a/api/account/tests/features/environment.py +++ b/api/account/tests/features/environment.py @@ -15,7 +15,7 @@ from selene.data.account import ( ) from selene.data.device import Geography, GeographyRepository from selene.util.cache import SeleneCache -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @fixture @@ -32,10 +32,10 @@ def before_feature(context, _): def before_scenario(context, _): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - _add_agreements(context, db) - _add_account(context, db) - _add_geography(context, db) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + _add_agreements(context, db) + _add_account(context, db) + _add_geography(context, db) def _add_agreements(context, db): @@ -91,9 +91,9 @@ def _add_geography(context, db): def after_scenario(context, _): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - _delete_account(context, db) - _delete_agreements(context, db) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + _delete_account(context, db) + _delete_agreements(context, db) _clean_cache() diff --git a/api/account/tests/features/steps/authentication.py b/api/account/tests/features/steps/authentication.py index 3ca6e8df..db44ef1a 100644 --- a/api/account/tests/features/steps/authentication.py +++ b/api/account/tests/features/steps/authentication.py @@ -8,7 +8,7 @@ from selene.api.testing import ( ) from selene.data.account import AccountRepository from selene.util.auth import AuthenticationToken -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @given('an authenticated user with an expired access token') @@ -37,9 +37,9 @@ def check_for_new_cookies(context): context.refresh_token, is_not(equal_to(context.old_refresh_token)) ) - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_id(context.account.id) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_id(context.account.id) refresh_token = AuthenticationToken( context.client_config['REFRESH_SECRET'], diff --git a/api/account/tests/features/steps/new_account.py b/api/account/tests/features/steps/new_account.py index 3b03f354..6fe07aa8 100644 --- a/api/account/tests/features/steps/new_account.py +++ b/api/account/tests/features/steps/new_account.py @@ -1,5 +1,6 @@ -import os +from binascii import b2a_base64 from datetime import date +import os import stripe from behave import given, then, when @@ -8,7 +9,7 @@ from hamcrest import assert_that, equal_to, is_in, none, not_none, starts_with from stripe.error import InvalidRequestError from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db new_account_request = dict( username='barfoo', @@ -17,8 +18,8 @@ new_account_request = dict( login=dict( federatedPlatform=None, federatedToken=None, - userEnteredEmail='bar@mycroft.ai', - password='bar' + userEnteredEmail=b2a_base64(b'bar@mycroft.ai').decode(), + password=b2a_base64(b'bar').decode() ), support=dict(openDataset=True) ) @@ -58,33 +59,33 @@ def call_add_account_endpoint(context): context.response = context.client.post( '/api/account', data=json.dumps(context.new_account_request), - content_type='application_json' + content_type='application/json' ) @then('the account will be added to the system {membership_option}') def check_db_for_account(context, membership_option): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_email('bar@mycroft.ai') - assert_that(account, not_none()) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_email('bar@mycroft.ai') + assert_that(account, not_none()) + assert_that( + account.email_address, equal_to('bar@mycroft.ai') + ) + assert_that(account.username, equal_to('barfoo')) + if membership_option == 'with a membership': + assert_that(account.membership.type, equal_to('Monthly Membership')) assert_that( - account.email_address, equal_to('bar@mycroft.ai') + account.membership.payment_account_id, + starts_with('cus') ) - assert_that(account.username, equal_to('barfoo')) - if membership_option == 'with a membership': - assert_that(account.membership.type, equal_to('Monthly Membership')) - assert_that( - account.membership.payment_account_id, - starts_with('cus') - ) - elif membership_option == 'without a membership': - assert_that(account.membership, none()) + elif membership_option == 'without a membership': + assert_that(account.membership, none()) - assert_that(len(account.agreements), equal_to(2)) - for agreement in account.agreements: - assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE))) - assert_that(agreement.accept_date, equal_to(str(date.today()))) + assert_that(len(account.agreements), equal_to(2)) + for agreement in account.agreements: + assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE))) + assert_that(agreement.accept_date, equal_to(str(date.today()))) @when('the account is deleted') diff --git a/api/account/tests/features/steps/update_membership.py b/api/account/tests/features/steps/update_membership.py index a719287e..a1ddd75d 100644 --- a/api/account/tests/features/steps/update_membership.py +++ b/api/account/tests/features/steps/update_membership.py @@ -1,5 +1,6 @@ -import json +from binascii import b2a_base64 from datetime import date +import json from behave import given, when, then from hamcrest import assert_that, equal_to, starts_with, none @@ -11,7 +12,9 @@ from selene.data.account import ( AccountAgreement, PRIVACY_POLICY ) -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db + +TEST_EMAIL_ADDRESS = 'test@mycroft.ai' new_account_request = dict( username='test', @@ -20,8 +23,8 @@ new_account_request = dict( login=dict( federatedPlatform=None, federatedToken=None, - userEnteredEmail='test@mycroft.ai', - password='12345678' + email=b2a_base64(b'test@mycroft.ai').decode(), + password=b2a_base64(b'12345678').decode() ), support=dict( openDataset=True, @@ -47,12 +50,12 @@ def create_account(context): AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()) ] ) - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - acct_repository = AccountRepository(db) - account_id = acct_repository.add(context.account, 'foo') - context.account.id = account_id - generate_access_token(context) - generate_refresh_token(context) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account_id = acct_repository.add(context.account, 'foo') + context.account.id = account_id + generate_access_token(context) + generate_refresh_token(context) @when('a monthly membership is added') @@ -66,16 +69,16 @@ def update_membership(context): context.response = context.client.patch( '/api/account', data=json.dumps(dict(membership=membership_data)), - content_type='application_json' + content_type='application/json' ) @when('the account is requested') def request_account(context): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - context.response_account = AccountRepository(db).get_account_by_email( - 'test@mycroft.ai' - ) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + context.response_account = AccountRepository(db).get_account_by_email( + TEST_EMAIL_ADDRESS + ) @then('the account should have a monthly membership') @@ -98,16 +101,14 @@ def create_monthly_account(context): context.client.post( '/api/account', data=json.dumps(new_account_request), - content_type='application_json' + content_type='application/json' ) - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - account_repository = AccountRepository(db) - account = account_repository.get_account_by_email( - new_account_request['login']['userEnteredEmail'] - ) - context.account = account - generate_access_token(context) - generate_refresh_token(context) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + account_repository = AccountRepository(db) + account = account_repository.get_account_by_email(TEST_EMAIL_ADDRESS) + context.account = account + generate_access_token(context) + generate_refresh_token(context) @when('the membership is cancelled') @@ -119,7 +120,7 @@ def cancel_membership(context): context.client.patch( '/api/account', data=json.dumps(dict(membership=membership_data)), - content_type='application_json' + content_type='application/json' ) @@ -138,7 +139,7 @@ def change_to_yearly_account(context): context.client.patch( '/api/account', data=json.dumps(dict(membership=membership_data)), - content_type='application_json' + content_type='application/json' ) diff --git a/shared/selene/api/base_config.py b/shared/selene/api/base_config.py index c7ca06f2..125117fc 100644 --- a/shared/selene/api/base_config.py +++ b/shared/selene/api/base_config.py @@ -22,15 +22,6 @@ import os from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig -db_connection_config = DatabaseConnectionConfig( - host=os.environ['DB_HOST'], - db_name=os.environ['DB_NAME'], - password=os.environ['DB_PASSWORD'], - port=os.environ.get('DB_PORT', 5432), - user=os.environ['DB_USER'], - sslmode=os.environ.get('DB_SSLMODE') -) - class APIConfigError(Exception): pass diff --git a/shared/selene/api/base_endpoint.py b/shared/selene/api/base_endpoint.py index 6b5a4e68..779e420a 100644 --- a/shared/selene/api/base_endpoint.py +++ b/shared/selene/api/base_endpoint.py @@ -6,7 +6,7 @@ from flask.views import MethodView from selene.data.account import Account, AccountRepository from selene.util.auth import AuthenticationError, AuthenticationToken -from selene.util.db import get_db_connection_from_pool +from selene.util.db import connect_to_db ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess' FIFTEEN_MINUTES = 900 @@ -42,8 +42,8 @@ class SeleneEndpoint(MethodView): @property def db(self): if 'db' not in global_context: - global_context.db = get_db_connection_from_pool( - current_app.config['DB_CONNECTION_POOL'] + global_context.db = connect_to_db( + current_app.config['DB_CONNECTION_CONFIG'] ) return global_context.db From 2069df6332e630f8738a5798e63d5f9a619b200d Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 17:06:35 -0500 Subject: [PATCH 11/15] fixed a typo in one of the request field names --- api/market/market_api/endpoints/skill_install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/market/market_api/endpoints/skill_install.py b/api/market/market_api/endpoints/skill_install.py index bdf37f5f..5f55e20c 100644 --- a/api/market/market_api/endpoints/skill_install.py +++ b/api/market/market_api/endpoints/skill_install.py @@ -45,7 +45,7 @@ class SkillInstallEndpoint(SeleneEndpoint): def _validate_request(self): install_request = InstallRequest() install_request.setting_section = self.request.json['section'] - install_request.skill_name = self.request.json['skillName'] + install_request.skill_name = self.request.json['skill_name'] install_request.validate() def _get_installer_settings(self): From 00ecba1166a591f96e407b7dcfc26323ec32a5c4 Mon Sep 17 00:00:00 2001 From: Matheus Lima Date: Wed, 22 May 2019 20:05:31 -0300 Subject: [PATCH 12/15] Refreshing device session using either refresh token or device id --- .../endpoints/device_refresh_token.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/api/public/public_api/endpoints/device_refresh_token.py b/api/public/public_api/endpoints/device_refresh_token.py index e42f9667..af76fb57 100644 --- a/api/public/public_api/endpoints/device_refresh_token.py +++ b/api/public/public_api/endpoints/device_refresh_token.py @@ -22,10 +22,20 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint): if token_header.startswith('Bearer '): refresh = token_header[len('Bearer '):] session = self._refresh_session_token(refresh) + # Trying to fetch a session using the refresh token if session: response = session, HTTPStatus.OK else: - response = '', HTTPStatus.UNAUTHORIZED + device = self.request.headers.get('Device') + if device: + # trying to fetch a session using the device uuid + session = self._refresh_session_token_device(device) + if session: + response = session, HTTPStatus.OK + else: + response = '', HTTPStatus.UNAUTHORIZED + else: + response = '', HTTPStatus.UNAUTHORIZED else: response = '', HTTPStatus.UNAUTHORIZED return response @@ -38,3 +48,12 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint): device_id = old_login['uuid'] self.cache.delete(refresh_key) return generate_device_login(device_id, self.cache) + + def _refresh_session_token_device(self, device: str): + refresh_key = 'device.session:{}'.format(device) + session = self.cache.get(refresh_key) + if session: + old_login = json.loads(session) + device_id = old_login['uuid'] + self.cache.delete(refresh_key) + return generate_device_login(device_id, self.cache) From 2fbdc1e9c88060fa1e178f086010c4bddf23fb16 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 18:53:11 -0500 Subject: [PATCH 13/15] fixed a bug in how the installer skill settings were being populated and added some comments --- .../market_api/endpoints/skill_install.py | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/api/market/market_api/endpoints/skill_install.py b/api/market/market_api/endpoints/skill_install.py index 5f55e20c..fb389518 100644 --- a/api/market/market_api/endpoints/skill_install.py +++ b/api/market/market_api/endpoints/skill_install.py @@ -1,3 +1,9 @@ +""" +Marketplace endpoint to add or remove a skill + +This endpoint configures the install skill on a user's device(s) to add or +remove the skill. +""" from http import HTTPStatus from logging import getLogger from typing import List @@ -6,7 +12,11 @@ from schematics import Model from schematics.types import StringType from selene.api import SeleneEndpoint -from selene.data.skill import AccountSkillSetting, SkillSettingRepository +from selene.data.skill import ( + AccountSkillSetting, + SkillDisplayRepository, + SkillSettingRepository +) INSTALL_SECTION = 'to_install' UNINSTALL_SECTION = 'to_remove' @@ -15,17 +25,16 @@ _log = getLogger(__package__) class InstallRequest(Model): + """Defines the expected state of the request JSON data""" setting_section = StringType( required=True, choices=[INSTALL_SECTION, UNINSTALL_SECTION] ) - skill_name = StringType(required=True) + skill_display_id = StringType(required=True) class SkillInstallEndpoint(SeleneEndpoint): - """ - Install a skill on user device(s). - """ + """Install a skill on user device(s).""" def __init__(self): super(SkillInstallEndpoint, self).__init__() self.device_uuid: str = None @@ -34,43 +43,70 @@ class SkillInstallEndpoint(SeleneEndpoint): self.installer_update_response = None def put(self): + """Handle an HTTP PUT request""" self._authenticate() self._validate_request() + skill_install_name = self._get_install_name() self._get_installer_settings() - self._apply_update() + self._apply_update(skill_install_name) self.response = (self.installer_update_response, HTTPStatus.OK) return self.response def _validate_request(self): + """Ensure the data passed in the request is as expected. + + :raises schematics.exceptions.ValidationError if the validation fails + """ install_request = InstallRequest() install_request.setting_section = self.request.json['section'] - install_request.skill_name = self.request.json['skill_name'] + install_request.skill_display_id = self.request.json['skillDisplayId'] install_request.validate() + def _get_install_name(self) -> str: + """Get the skill name used by the installer skill from the DB + + The installer skill expects the skill name found in the "name" field + of the skill display JSON. + """ + display_repo = SkillDisplayRepository(self.db) + skill_display = display_repo.get_display_data_for_skill( + self.request.json['skillDisplayId'] + ) + + return skill_display.display_data['name'] + def _get_installer_settings(self): + """Get the current value of the installer skill's settings""" settings_repo = SkillSettingRepository(self.db) self.installer_settings = settings_repo.get_installer_settings( self.account.id ) - def _apply_update(self): + def _apply_update(self, skill_install_name: str): + """Add the skill in the request to the installer skill settings. + + This is designed to change the installer skill settings for all + devices associated with an account. It will be updated in the + future to target specific devices. + """ for settings in self.installer_settings: if self.request.json['section'] == INSTALL_SECTION: to_install = settings.settings_values.get(INSTALL_SECTION, []) to_install.append( - dict(name=self.request.json['skill_name']) + dict(name=skill_install_name) ) settings.settings_values[INSTALL_SECTION] = to_install else: to_remove = settings.settings_values.get(UNINSTALL_SECTION, []) to_remove.append( - dict(name=self.request.json['skill_name']) + dict(name=skill_install_name) ) settings.settings_values[UNINSTALL_SECTION] = to_remove self._update_skill_settings(settings) def _update_skill_settings(self, settings): + """Update the DB with the new installer skill settings.""" settings_repo = SkillSettingRepository(self.db) settings_repo.update_device_skill_settings( self.account.id, From 26e7b083bb53b7333a20924b73f4070c17a225e3 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 19:15:32 -0500 Subject: [PATCH 14/15] fixed a bug in the skill repository that was using the wrong field name for skill name --- shared/selene/data/skill/repository/display.py | 2 +- shared/selene/data/skill/repository/setting.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/shared/selene/data/skill/repository/display.py b/shared/selene/data/skill/repository/display.py index 598cb415..51cec92e 100644 --- a/shared/selene/data/skill/repository/display.py +++ b/shared/selene/data/skill/repository/display.py @@ -12,7 +12,7 @@ class SkillDisplayRepository(RepositoryBase): sql_file_name='get_display_data_for_skills.sql' ) - def get_display_data_for_skill(self, skill_display_id): + def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay: return self._select_one_into_dataclass( dataclass=SkillDisplay, sql_file_name='get_display_data_for_skill.sql', diff --git a/shared/selene/data/skill/repository/setting.py b/shared/selene/data/skill/repository/setting.py index f0f08a22..23c10f5b 100644 --- a/shared/selene/data/skill/repository/setting.py +++ b/shared/selene/data/skill/repository/setting.py @@ -34,12 +34,12 @@ class SkillSettingRepository(RepositoryBase): return skill_settings - def get_installer_settings(self, account_id: str): + def get_installer_settings(self, account_id: str) -> List[AccountSkillSetting]: skill_repo = SkillRepository(self.db) skills = skill_repo.get_skills_for_account(account_id) installer_skill_id = None for skill in skills: - if skill.name == 'mycroft_installer': + if skill.display_name == 'Installer': installer_skill_id = skill.id skill_settings = None From ac6608ceca6cfc0b38541d0dd614268cd0a31123 Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Wed, 22 May 2019 19:56:44 -0500 Subject: [PATCH 15/15] removed remaining remnants of the old connection pooling --- .../tests/features/steps/add_device.py | 8 +++--- .../tests/features/steps/new_account.py | 8 +++--- api/sso/tests/features/environment.py | 18 ++++++------ shared/selene/api/endpoints/agreements.py | 20 ++++++------- shared/selene/api/etag.py | 28 +++++++++---------- shared/selene/api/testing/authentication.py | 8 +++--- 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/api/account/tests/features/steps/add_device.py b/api/account/tests/features/steps/add_device.py index fef8b1c2..5d2f3641 100644 --- a/api/account/tests/features/steps/add_device.py +++ b/api/account/tests/features/steps/add_device.py @@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none from selene.data.device import DeviceRepository from selene.util.cache import SeleneCache -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @given('a device pairing code') @@ -57,9 +57,9 @@ def validate_pairing_code_removal(context): @then('the device is added to the database') def validate_response(context): device_id = context.response.data.decode() - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - device_repository = DeviceRepository(db) - device = device_repository.get_device_by_id(device_id) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + device_repository = DeviceRepository(db) + device = device_repository.get_device_by_id(device_id) assert_that(device, not_none()) assert_that(device.name, equal_to('home')) diff --git a/api/account/tests/features/steps/new_account.py b/api/account/tests/features/steps/new_account.py index 6fe07aa8..12a2ecd9 100644 --- a/api/account/tests/features/steps/new_account.py +++ b/api/account/tests/features/steps/new_account.py @@ -90,10 +90,10 @@ def check_db_for_account(context, membership_option): @when('the account is deleted') def account_deleted(context): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_email('bar@mycroft.ai') - context.stripe_id = account.membership.payment_id + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_email('bar@mycroft.ai') + context.stripe_id = account.membership.payment_id context.response = context.client.delete('/api/account') diff --git a/api/sso/tests/features/environment.py b/api/sso/tests/features/environment.py index 7350eaa5..eafbed7f 100644 --- a/api/sso/tests/features/environment.py +++ b/api/sso/tests/features/environment.py @@ -13,7 +13,7 @@ from selene.data.account import ( AgreementRepository, PRIVACY_POLICY ) -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db @fixture @@ -32,9 +32,9 @@ def before_feature(context, _): def before_scenario(context, _): - with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: - _add_agreement(context, db) - _add_account(context, db) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + _add_agreement(context, db) + _add_account(context, db) def _add_agreement(context, db): @@ -72,8 +72,8 @@ def _add_account(context, db): def after_scenario(context, _): - with get_db_connection(context.db_pool) as db: - acct_repository = AccountRepository(db) - acct_repository.remove(context.account) - agreement_repository = AgreementRepository(db) - agreement_repository.remove(context.agreement, testing=True) + db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + acct_repository.remove(context.account) + agreement_repository = AgreementRepository(db) + agreement_repository.remove(context.agreement, testing=True) diff --git a/shared/selene/api/endpoints/agreements.py b/shared/selene/api/endpoints/agreements.py index c85cb202..99960568 100644 --- a/shared/selene/api/endpoints/agreements.py +++ b/shared/selene/api/endpoints/agreements.py @@ -3,7 +3,7 @@ from dataclasses import asdict from http import HTTPStatus from selene.data.account import AgreementRepository -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db from ..base_endpoint import SeleneEndpoint @@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint): def get(self, agreement_type): """Process HTTP GET request for an agreement.""" - with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: - agreement_repository = AgreementRepository(db) - agreement = agreement_repository.get_active_for_type( - self.agreement_types[agreement_type] - ) - if agreement is not None: - agreement = asdict(agreement) - del(agreement['effective_date']) - self.response = agreement, HTTPStatus.OK + db = connect_to_db(self.config['DB_CONNECTION_CONFIG']) + agreement_repository = AgreementRepository(db) + agreement = agreement_repository.get_active_for_type( + self.agreement_types[agreement_type] + ) + if agreement is not None: + agreement = asdict(agreement) + del(agreement['effective_date']) + self.response = agreement, HTTPStatus.OK return self.response diff --git a/shared/selene/api/etag.py b/shared/selene/api/etag.py index 9cf4c3e8..8ade0fb4 100644 --- a/shared/selene/api/etag.py +++ b/shared/selene/api/etag.py @@ -3,7 +3,7 @@ import string from selene.data.device import DeviceRepository from selene.util.cache import SeleneCache -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db def device_etag_key(device_id: str): @@ -29,7 +29,7 @@ class ETagManager(object): def __init__(self, cache: SeleneCache, config: dict): self.cache: SeleneCache = cache - self.db_connection_pool = config['DB_CONNECTION_POOL'] + self.db_connection_config = config['DB_CONNECTION_CONFIG'] def get(self, key: str) -> str: """Generate a etag with 32 random chars and store it into a given key @@ -60,10 +60,10 @@ class ETagManager(object): def expire_device_setting_etag_by_account_id(self, account_id: str): """Expire the settings' etags for all devices from a given account. Used when the settings are updated at account level""" - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_device_setting_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_device_setting_etag_by_device_id(device.id) def expire_device_location_etag_by_device_id(self, device_id: str): """Expire the etag associate with the device's location entity @@ -73,10 +73,10 @@ class ETagManager(object): def expire_device_location_etag_by_account_id(self, account_id: str): """Expire the locations' etag fpr açç device for a given acccount :param account_id: account uuid""" - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_device_location_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_device_location_etag_by_device_id(device.id) def expire_skill_etag_by_device_id(self, device_id): """Expire the locations' etag for a given device @@ -84,7 +84,7 @@ class ETagManager(object): self._expire(device_skill_etag_key(device_id)) def expire_skill_etag_by_account_id(self, account_id): - with get_db_connection(self.db_connection_pool) as db: - devices = DeviceRepository(db).get_devices_by_account_id(account_id) - for device in devices: - self.expire_skill_etag_by_device_id(device.id) + db = connect_to_db(self.db_connection_config) + devices = DeviceRepository(db).get_devices_by_account_id(account_id) + for device in devices: + self.expire_skill_etag_by_device_id(device.id) diff --git a/shared/selene/api/testing/authentication.py b/shared/selene/api/testing/authentication.py index 9be2283a..0b70f5db 100644 --- a/shared/selene/api/testing/authentication.py +++ b/shared/selene/api/testing/authentication.py @@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item from selene.data.account import Account, AccountRepository from selene.util.auth import AuthenticationToken -from selene.util.db import get_db_connection +from selene.util.db import connect_to_db ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess' ONE_MINUTE = 60 @@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict: def get_account(context) -> Account: - with get_db_connection(context.db_pool) as db: - acct_repository = AccountRepository(db) - account = acct_repository.get_account_by_id(context.account.id) + db = connect_to_db(context.client['DB_CONNECTION_CONFIG']) + acct_repository = AccountRepository(db) + account = acct_repository.get_account_by_id(context.account.id) return account