Merge remote-tracking branch 'origin/dev' into dev

pull/157/head
Chris Veilleux 2019-05-22 20:05:54 -05:00
commit 06b8bd4eb1
29 changed files with 453 additions and 255 deletions

View File

@ -15,7 +15,7 @@ from selene.data.account import (
) )
from selene.data.device import Geography, GeographyRepository from selene.data.device import Geography, GeographyRepository
from selene.util.cache import SeleneCache from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
@fixture @fixture
@ -32,10 +32,10 @@ def before_feature(context, _):
def before_scenario(context, _): def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_add_agreements(context, db) _add_agreements(context, db)
_add_account(context, db) _add_account(context, db)
_add_geography(context, db) _add_geography(context, db)
def _add_agreements(context, db): def _add_agreements(context, db):
@ -91,9 +91,9 @@ def _add_geography(context, db):
def after_scenario(context, _): def after_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_delete_account(context, db) _delete_account(context, db)
_delete_agreements(context, db) _delete_agreements(context, db)
_clean_cache() _clean_cache()

View File

@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none
from selene.data.device import DeviceRepository from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
@given('a device pairing code') @given('a device pairing code')
@ -57,9 +57,9 @@ def validate_pairing_code_removal(context):
@then('the device is added to the database') @then('the device is added to the database')
def validate_response(context): def validate_response(context):
device_id = context.response.data.decode() device_id = context.response.data.decode()
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
device_repository = DeviceRepository(db) device_repository = DeviceRepository(db)
device = device_repository.get_device_by_id(device_id) device = device_repository.get_device_by_id(device_id)
assert_that(device, not_none()) assert_that(device, not_none())
assert_that(device.name, equal_to('home')) assert_that(device.name, equal_to('home'))

View File

@ -8,7 +8,7 @@ from selene.api.testing import (
) )
from selene.data.account import AccountRepository from selene.data.account import AccountRepository
from selene.util.auth import AuthenticationToken from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
@given('an authenticated user with an expired access token') @given('an authenticated user with an expired access token')
@ -37,9 +37,9 @@ def check_for_new_cookies(context):
context.refresh_token, context.refresh_token,
is_not(equal_to(context.old_refresh_token)) is_not(equal_to(context.old_refresh_token))
) )
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id) account = acct_repository.get_account_by_id(context.account.id)
refresh_token = AuthenticationToken( refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'], context.client_config['REFRESH_SECRET'],

View File

@ -1,5 +1,6 @@
import os from binascii import b2a_base64
from datetime import date from datetime import date
import os
import stripe import stripe
from behave import given, then, when from behave import given, then, when
@ -8,7 +9,7 @@ from hamcrest import assert_that, equal_to, is_in, none, not_none, starts_with
from stripe.error import InvalidRequestError from stripe.error import InvalidRequestError
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
new_account_request = dict( new_account_request = dict(
username='barfoo', username='barfoo',
@ -17,8 +18,8 @@ new_account_request = dict(
login=dict( login=dict(
federatedPlatform=None, federatedPlatform=None,
federatedToken=None, federatedToken=None,
userEnteredEmail='bar@mycroft.ai', userEnteredEmail=b2a_base64(b'bar@mycroft.ai').decode(),
password='bar' password=b2a_base64(b'bar').decode()
), ),
support=dict(openDataset=True) support=dict(openDataset=True)
) )
@ -58,41 +59,41 @@ def call_add_account_endpoint(context):
context.response = context.client.post( context.response = context.client.post(
'/api/account', '/api/account',
data=json.dumps(context.new_account_request), data=json.dumps(context.new_account_request),
content_type='application_json' content_type='application/json'
) )
@then('the account will be added to the system {membership_option}') @then('the account will be added to the system {membership_option}')
def check_db_for_account(context, membership_option): def check_db_for_account(context, membership_option):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai') account = acct_repository.get_account_by_email('bar@mycroft.ai')
assert_that(account, not_none()) assert_that(account, not_none())
assert_that(
account.email_address, equal_to('bar@mycroft.ai')
)
assert_that(account.username, equal_to('barfoo'))
if membership_option == 'with a membership':
assert_that(account.membership.type, equal_to('Monthly Membership'))
assert_that( assert_that(
account.email_address, equal_to('bar@mycroft.ai') account.membership.payment_account_id,
starts_with('cus')
) )
assert_that(account.username, equal_to('barfoo')) elif membership_option == 'without a membership':
if membership_option == 'with a membership': assert_that(account.membership, none())
assert_that(account.membership.type, equal_to('Monthly Membership'))
assert_that(
account.membership.payment_account_id,
starts_with('cus')
)
elif membership_option == 'without a membership':
assert_that(account.membership, none())
assert_that(len(account.agreements), equal_to(2)) assert_that(len(account.agreements), equal_to(2))
for agreement in account.agreements: for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE))) assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
assert_that(agreement.accept_date, equal_to(str(date.today()))) assert_that(agreement.accept_date, equal_to(str(date.today())))
@when('the account is deleted') @when('the account is deleted')
def account_deleted(context): def account_deleted(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai') account = acct_repository.get_account_by_email('bar@mycroft.ai')
context.stripe_id = account.membership.payment_id context.stripe_id = account.membership.payment_id
context.response = context.client.delete('/api/account') context.response = context.client.delete('/api/account')

View File

@ -1,5 +1,6 @@
import json from binascii import b2a_base64
from datetime import date from datetime import date
import json
from behave import given, when, then from behave import given, when, then
from hamcrest import assert_that, equal_to, starts_with, none from hamcrest import assert_that, equal_to, starts_with, none
@ -11,7 +12,9 @@ from selene.data.account import (
AccountAgreement, AccountAgreement,
PRIVACY_POLICY PRIVACY_POLICY
) )
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
TEST_EMAIL_ADDRESS = 'test@mycroft.ai'
new_account_request = dict( new_account_request = dict(
username='test', username='test',
@ -20,8 +23,8 @@ new_account_request = dict(
login=dict( login=dict(
federatedPlatform=None, federatedPlatform=None,
federatedToken=None, federatedToken=None,
userEnteredEmail='test@mycroft.ai', email=b2a_base64(b'test@mycroft.ai').decode(),
password='12345678' password=b2a_base64(b'12345678').decode()
), ),
support=dict( support=dict(
openDataset=True, openDataset=True,
@ -47,12 +50,12 @@ def create_account(context):
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()) AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
] ]
) )
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
account_id = acct_repository.add(context.account, 'foo') account_id = acct_repository.add(context.account, 'foo')
context.account.id = account_id context.account.id = account_id
generate_access_token(context) generate_access_token(context)
generate_refresh_token(context) generate_refresh_token(context)
@when('a monthly membership is added') @when('a monthly membership is added')
@ -66,16 +69,16 @@ def update_membership(context):
context.response = context.client.patch( context.response = context.client.patch(
'/api/account', '/api/account',
data=json.dumps(dict(membership=membership_data)), data=json.dumps(dict(membership=membership_data)),
content_type='application_json' content_type='application/json'
) )
@when('the account is requested') @when('the account is requested')
def request_account(context): def request_account(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
context.response_account = AccountRepository(db).get_account_by_email( context.response_account = AccountRepository(db).get_account_by_email(
'test@mycroft.ai' TEST_EMAIL_ADDRESS
) )
@then('the account should have a monthly membership') @then('the account should have a monthly membership')
@ -98,16 +101,14 @@ def create_monthly_account(context):
context.client.post( context.client.post(
'/api/account', '/api/account',
data=json.dumps(new_account_request), data=json.dumps(new_account_request),
content_type='application_json' content_type='application/json'
) )
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
account_repository = AccountRepository(db) account_repository = AccountRepository(db)
account = account_repository.get_account_by_email( account = account_repository.get_account_by_email(TEST_EMAIL_ADDRESS)
new_account_request['login']['userEnteredEmail'] context.account = account
) generate_access_token(context)
context.account = account generate_refresh_token(context)
generate_access_token(context)
generate_refresh_token(context)
@when('the membership is cancelled') @when('the membership is cancelled')
@ -119,7 +120,7 @@ def cancel_membership(context):
context.client.patch( context.client.patch(
'/api/account', '/api/account',
data=json.dumps(dict(membership=membership_data)), data=json.dumps(dict(membership=membership_data)),
content_type='application_json' content_type='application/json'
) )
@ -138,7 +139,7 @@ def change_to_yearly_account(context):
context.client.patch( context.client.patch(
'/api/account', '/api/account',
data=json.dumps(dict(membership=membership_data)), data=json.dumps(dict(membership=membership_data)),
content_type='application_json' content_type='application/json'
) )

View File

@ -1,3 +1,9 @@
"""
Marketplace endpoint to add or remove a skill
This endpoint configures the install skill on a user's device(s) to add or
remove the skill.
"""
from http import HTTPStatus from http import HTTPStatus
from logging import getLogger from logging import getLogger
from typing import List from typing import List
@ -6,7 +12,11 @@ from schematics import Model
from schematics.types import StringType from schematics.types import StringType
from selene.api import SeleneEndpoint from selene.api import SeleneEndpoint
from selene.data.skill import AccountSkillSetting, SkillSettingRepository from selene.data.skill import (
AccountSkillSetting,
SkillDisplayRepository,
SkillSettingRepository
)
INSTALL_SECTION = 'to_install' INSTALL_SECTION = 'to_install'
UNINSTALL_SECTION = 'to_remove' UNINSTALL_SECTION = 'to_remove'
@ -15,17 +25,16 @@ _log = getLogger(__package__)
class InstallRequest(Model): class InstallRequest(Model):
"""Defines the expected state of the request JSON data"""
setting_section = StringType( setting_section = StringType(
required=True, required=True,
choices=[INSTALL_SECTION, UNINSTALL_SECTION] choices=[INSTALL_SECTION, UNINSTALL_SECTION]
) )
skill_name = StringType(required=True) skill_display_id = StringType(required=True)
class SkillInstallEndpoint(SeleneEndpoint): class SkillInstallEndpoint(SeleneEndpoint):
""" """Install a skill on user device(s)."""
Install a skill on user device(s).
"""
def __init__(self): def __init__(self):
super(SkillInstallEndpoint, self).__init__() super(SkillInstallEndpoint, self).__init__()
self.device_uuid: str = None self.device_uuid: str = None
@ -34,43 +43,70 @@ class SkillInstallEndpoint(SeleneEndpoint):
self.installer_update_response = None self.installer_update_response = None
def put(self): def put(self):
"""Handle an HTTP PUT request"""
self._authenticate() self._authenticate()
self._validate_request() self._validate_request()
skill_install_name = self._get_install_name()
self._get_installer_settings() self._get_installer_settings()
self._apply_update() self._apply_update(skill_install_name)
self.response = (self.installer_update_response, HTTPStatus.OK) self.response = (self.installer_update_response, HTTPStatus.OK)
return self.response return self.response
def _validate_request(self): def _validate_request(self):
"""Ensure the data passed in the request is as expected.
:raises schematics.exceptions.ValidationError if the validation fails
"""
install_request = InstallRequest() install_request = InstallRequest()
install_request.setting_section = self.request.json['section'] install_request.setting_section = self.request.json['section']
install_request.skill_name = self.request.json['skillName'] install_request.skill_display_id = self.request.json['skillDisplayId']
install_request.validate() install_request.validate()
def _get_install_name(self) -> str:
"""Get the skill name used by the installer skill from the DB
The installer skill expects the skill name found in the "name" field
of the skill display JSON.
"""
display_repo = SkillDisplayRepository(self.db)
skill_display = display_repo.get_display_data_for_skill(
self.request.json['skillDisplayId']
)
return skill_display.display_data['name']
def _get_installer_settings(self): def _get_installer_settings(self):
"""Get the current value of the installer skill's settings"""
settings_repo = SkillSettingRepository(self.db) settings_repo = SkillSettingRepository(self.db)
self.installer_settings = settings_repo.get_installer_settings( self.installer_settings = settings_repo.get_installer_settings(
self.account.id self.account.id
) )
def _apply_update(self): def _apply_update(self, skill_install_name: str):
"""Add the skill in the request to the installer skill settings.
This is designed to change the installer skill settings for all
devices associated with an account. It will be updated in the
future to target specific devices.
"""
for settings in self.installer_settings: for settings in self.installer_settings:
if self.request.json['section'] == INSTALL_SECTION: if self.request.json['section'] == INSTALL_SECTION:
to_install = settings.settings_values.get(INSTALL_SECTION, []) to_install = settings.settings_values.get(INSTALL_SECTION, [])
to_install.append( to_install.append(
dict(name=self.request.json['skill_name']) dict(name=skill_install_name)
) )
settings.settings_values[INSTALL_SECTION] = to_install settings.settings_values[INSTALL_SECTION] = to_install
else: else:
to_remove = settings.settings_values.get(UNINSTALL_SECTION, []) to_remove = settings.settings_values.get(UNINSTALL_SECTION, [])
to_remove.append( to_remove.append(
dict(name=self.request.json['skill_name']) dict(name=skill_install_name)
) )
settings.settings_values[UNINSTALL_SECTION] = to_remove settings.settings_values[UNINSTALL_SECTION] = to_remove
self._update_skill_settings(settings) self._update_skill_settings(settings)
def _update_skill_settings(self, settings): def _update_skill_settings(self, settings):
"""Update the DB with the new installer skill settings."""
settings_repo = SkillSettingRepository(self.db) settings_repo = SkillSettingRepository(self.db)
settings_repo.update_device_skill_settings( settings_repo.update_device_skill_settings(
self.account.id, self.account.id,

View File

@ -28,6 +28,7 @@ class DeviceEndpoint(PublicEndpoint):
if device is not None: if device is not None:
response_data = dict( response_data = dict(
uuid=device.id,
name=device.name, name=device.name,
description=device.placement, description=device.placement,
coreVersion=device.core_version, coreVersion=device.core_version,

View File

@ -18,7 +18,7 @@ class MetricsService(object):
deviceUuid=device_id, deviceUuid=device_id,
data=data data=data
) )
url = '{host}/{metric}'.format(host=self.metrics_service_host, metric=metric) url = '{host}/metric/{metric}'.format(host=self.metrics_service_host, metric=metric)
requests.post(url, body) requests.post(url, body)

View File

@ -4,7 +4,6 @@ import requests
from selene.api import PublicEndpoint from selene.api import PublicEndpoint
from selene.data.account import AccountRepository from selene.data.account import AccountRepository
from selene.util.db import get_db_connection
class OauthServiceEndpoint(PublicEndpoint): class OauthServiceEndpoint(PublicEndpoint):
@ -14,8 +13,7 @@ class OauthServiceEndpoint(PublicEndpoint):
self.oauth_service_host = os.environ['OAUTH_BASE_URL'] self.oauth_service_host = os.environ['OAUTH_BASE_URL']
def get(self, device_id, credentials, oauth_path): def get(self, device_id, credentials, oauth_path):
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: account = AccountRepository(self.db).get_account_by_device_id(device_id)
account = AccountRepository(db).get_account_by_device_id(device_id)
uuid = account.id uuid = account.id
url = '{host}/auth/{credentials}/{oauth_path}'.format( url = '{host}/auth/{credentials}/{oauth_path}'.format(
host=self.oauth_service_host, host=self.oauth_service_host,

View File

@ -22,10 +22,20 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
if token_header.startswith('Bearer '): if token_header.startswith('Bearer '):
refresh = token_header[len('Bearer '):] refresh = token_header[len('Bearer '):]
session = self._refresh_session_token(refresh) session = self._refresh_session_token(refresh)
# Trying to fetch a session using the refresh token
if session: if session:
response = session, HTTPStatus.OK response = session, HTTPStatus.OK
else: else:
response = '', HTTPStatus.UNAUTHORIZED device = self.request.headers.get('Device')
if device:
# trying to fetch a session using the device uuid
session = self._refresh_session_token_device(device)
if session:
response = session, HTTPStatus.OK
else:
response = '', HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
else: else:
response = '', HTTPStatus.UNAUTHORIZED response = '', HTTPStatus.UNAUTHORIZED
return response return response
@ -38,3 +48,12 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
device_id = old_login['uuid'] device_id = old_login['uuid']
self.cache.delete(refresh_key) self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache) return generate_device_login(device_id, self.cache)
def _refresh_session_token_device(self, device: str):
refresh_key = 'device.session:{}'.format(device)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)

View File

@ -27,6 +27,7 @@ class SkillManifest(Model):
installed = DateTimeType() installed = DateTimeType()
updated = DateTimeType() updated = DateTimeType()
update = DateTimeType() update = DateTimeType()
skill_gid = StringType()
class SkillJson(Model): class SkillJson(Model):

View File

@ -21,7 +21,7 @@ from selene.data.device import (
Geography) Geography)
from selene.data.device.entity.text_to_speech import TextToSpeech from selene.data.device.entity.text_to_speech import TextToSpeech
from selene.data.device.entity.wake_word import WakeWord from selene.data.device.entity.wake_word import WakeWord
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
account = Account( account = Account(
email_address='test@test.com', email_address='test@test.com',
@ -63,22 +63,22 @@ def before_feature(context, _):
def before_scenario(context, _): def before_scenario(context, _):
cache = context.client_config['SELENE_CACHE'] cache = context.client_config['SELENE_CACHE']
context.etag_manager = ETagManager(cache, context.client_config) context.etag_manager = ETagManager(cache, context.client_config)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
try: try:
_add_agreements(context, db) _add_agreements(context, db)
_add_account(context, db) _add_account(context, db)
_add_account_preference(context, db) _add_account_preference(context, db)
_add_geography(context, db) _add_geography(context, db)
_add_device(context, db) _add_device(context, db)
except Exception as e: except Exception as e:
import traceback import traceback
print(traceback.print_exc()) print(traceback.print_exc())
def after_scenario(context, _): def after_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_remove_account(context, db) _remove_account(context, db)
_remove_agreements(context, db) _remove_agreements(context, db)
def _add_agreements(context, db): def _add_agreements(context, db):

View File

@ -17,7 +17,8 @@ skill_manifest = {
"installed": datetime.now().timestamp(), "installed": datetime.now().timestamp(),
"updated": datetime.now().timestamp(), "updated": datetime.now().timestamp(),
"installation": "installed", "installation": "installed",
"update": 0 "update": 0,
"skill_gid": "fallback-wolfram-alpha|19.02",
} }
] ]
} }

View File

@ -6,7 +6,7 @@ from hamcrest import assert_that, equal_to, not_none, is_not
from selene.api.etag import ETagManager, device_skill_etag_key from selene.api.etag import ETagManager, device_skill_etag_key
from selene.data.skill import SkillSettingRepository from selene.data.skill import SkillSettingRepository
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
skill = { skill = {
'skill_gid': 'wolfram-alpha|19.02', 'skill_gid': 'wolfram-alpha|19.02',
@ -106,8 +106,8 @@ def update_skill(context):
}] }]
response = json.loads(context.upload_device_response.data) response = json.loads(context.upload_device_response.data)
skill_id = response['uuid'] skill_id = response['uuid']
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings) SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
@when('the skill settings is fetched') @when('the skill settings is fetched')

View File

@ -31,6 +31,7 @@ def validate_response(context):
response = context.get_device_response response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK)) assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data) device = json.loads(response.data)
assert_that(device, has_key('uuid'))
assert_that(device, has_key('name')) assert_that(device, has_key('name'))
assert_that(device, has_key('description')) assert_that(device, has_key('description'))
assert_that(device, has_key('coreVersion')) assert_that(device, has_key('coreVersion'))

View File

@ -7,7 +7,7 @@ from behave import when, then
from hamcrest import assert_that, has_entry, equal_to from hamcrest import assert_that, has_entry, equal_to
from selene.data.account import AccountRepository, AccountMembership from selene.data.account import AccountRepository, AccountMembership
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
@when('the subscription endpoint is called') @when('the subscription endpoint is called')
@ -42,9 +42,9 @@ def get_device_subscription(context):
login = context.device_login login = context.device_login
device_id = login['uuid'] device_id = login['uuid']
access_token = login['accessToken'] access_token = login['accessToken']
headers=dict(Authorization='Bearer {token}'.format(token=access_token)) headers = dict(Authorization='Bearer {token}'.format(token=access_token))
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
AccountRepository(db).add_membership(context.account.id, membership) AccountRepository(db).add_membership(context.account.id, membership)
context.subscription_response = context.client.get( context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id), '/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers headers=headers

View File

@ -13,7 +13,7 @@ from selene.data.account import (
AgreementRepository, AgreementRepository,
PRIVACY_POLICY PRIVACY_POLICY
) )
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
@fixture @fixture
@ -32,9 +32,9 @@ def before_feature(context, _):
def before_scenario(context, _): def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_add_agreement(context, db) _add_agreement(context, db)
_add_account(context, db) _add_account(context, db)
def _add_agreement(context, db): def _add_agreement(context, db):
@ -72,8 +72,8 @@ def _add_account(context, db):
def after_scenario(context, _): def after_scenario(context, _):
with get_db_connection(context.db_pool) as db: db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
acct_repository.remove(context.account) acct_repository.remove(context.account)
agreement_repository = AgreementRepository(db) agreement_repository = AgreementRepository(db)
agreement_repository.remove(context.agreement, testing=True) agreement_repository.remove(context.agreement, testing=True)

View File

@ -0,0 +1,11 @@
CREATE TABLE metrics.job (
id uuid PRIMARY KEY
DEFAULT gen_random_uuid(),
job_name text NOT NULL,
batch_date date NOT NULL,
start_ts TIMESTAMP NOT NULL,
end_ts TIMESTAMP NOT NULL,
command text NOT NULL,
success BOOLEAN NOT NULL,
UNIQUE (job_name, start_ts)
)

View File

@ -1,6 +1,7 @@
from glob import glob from glob import glob
from os import path from os import environ, path, remove
from markdown import markdown
from psycopg2 import connect from psycopg2 import connect
MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft') MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft')
@ -8,10 +9,8 @@ SCHEMAS = ('account', 'skill', 'device', 'geography', 'metrics')
DB_DESTROY_FILES = ( DB_DESTROY_FILES = (
'drop_mycroft_db.sql', 'drop_mycroft_db.sql',
'drop_template_db.sql', 'drop_template_db.sql',
# 'drop_roles.sql'
) )
DB_CREATE_FILES = ( DB_CREATE_FILES = (
# 'create_roles.sql',
'create_template_db.sql', 'create_template_db.sql',
) )
ACCOUNT_TABLE_ORDER = ( ACCOUNT_TABLE_ORDER = (
@ -48,6 +47,7 @@ GEOGRAPHY_TABLE_ORDER = (
METRICS_TABLE_ORDER = ( METRICS_TABLE_ORDER = (
'api', 'api',
'job'
) )
schema_directory = '{}_schema' schema_directory = '{}_schema'
@ -61,32 +61,40 @@ def get_sql_from_file(file_path: str) -> str:
class PostgresDB(object): class PostgresDB(object):
def __init__(self, dbname, user, password=None): def __init__(self, db_name, user=None):
self.db = connect(dbname=dbname, user=user, host='127.0.0.1') db_host = environ['DB_HOST']
# self.db = connect( db_port = environ['DB_PORT']
# dbname=dbname, db_ssl_mode = environ.get('DB_SSL_MODE')
# user=user, if db_name in ('postgres', 'defaultdb'):
# password=password, db_user = environ['POSTGRES_DB_USER']
# host='selene-test-db-do-user-1412453-0.db.ondigitalocean.com', db_password = environ.get('POSTGRES_DB_PASSWORD')
# port=25060, else:
# sslmode='require' db_user = environ['MYCROFT_DB_USER']
# ) db_password = environ['MYCROFT_DB_PASSWORD']
if user is not None:
db_user = user
self.db = connect(
dbname=db_name,
user=db_user,
password=db_password,
host=db_host,
port=db_port,
sslmode=db_ssl_mode
)
self.db.autocommit = True self.db.autocommit = True
def close_db(self): def close_db(self):
self.db.close() self.db.close()
def execute_sql(self, sql: str): def execute_sql(self, sql: str, args=None):
cursor = self.db.cursor() cursor = self.db.cursor()
cursor.execute(sql) cursor.execute(sql, args)
return cursor
postgres_db = PostgresDB(dbname='postgres', user='postgres') postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
# postgres_db = PostgresDB(
# dbname='defaultdb',
# user='doadmin',
# password='l06tn0qi2bjhgcki'
# )
print('Destroying any objects we will be creating later.') print('Destroying any objects we will be creating later.')
for db_destroy_file in DB_DESTROY_FILES: for db_destroy_file in DB_DESTROY_FILES:
@ -94,7 +102,7 @@ for db_destroy_file in DB_DESTROY_FILES:
get_sql_from_file(db_destroy_file) get_sql_from_file(db_destroy_file)
) )
print('Creating the extensions, mycroft database, and selene roles') print('Creating the mycroft database')
for db_setup_file in DB_CREATE_FILES: for db_setup_file in DB_CREATE_FILES:
postgres_db.execute_sql( postgres_db.execute_sql(
get_sql_from_file(db_setup_file) get_sql_from_file(db_setup_file)
@ -102,13 +110,10 @@ for db_setup_file in DB_CREATE_FILES:
postgres_db.close_db() postgres_db.close_db()
template_db = PostgresDB(dbname='mycroft_template', user='mycroft')
# template_db = PostgresDB(
# dbname='mycroft_template',
# user='selene',
# password='ubhemhx1dikmqc5f'
# )
template_db = PostgresDB(db_name='mycroft_template')
print('Creating the extensions')
template_db.execute_sql( template_db.execute_sql(
get_sql_from_file(path.join('create_extensions.sql')) get_sql_from_file(path.join('create_extensions.sql'))
) )
@ -193,22 +198,14 @@ for schema in SCHEMAS:
template_db.close_db() template_db.close_db()
print('Copying template to new database.') print('Copying template to new database.')
postgres_db = PostgresDB(dbname='postgres', user='mycroft') postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
# postgres_db = PostgresDB(
# dbname='defaultdb',
# user='doadmin',
# password='l06tn0qi2bjhgcki'
# )
postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql')) postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql'))
postgres_db.close_db() postgres_db.close_db()
mycroft_db = PostgresDB(dbname='mycroft', user='mycroft')
# mycroft_db = PostgresDB( mycroft_db = PostgresDB(db_name=environ['MYCROFT_DB_NAME'])
# dbname='mycroft_template',
# user='selene',
# password='ubhemhx1dikmqc5f'
# )
insert_files = [ insert_files = [
dict(schema_dir='account_schema', file_name='membership.sql'), dict(schema_dir='account_schema', file_name='membership.sql'),
dict(schema_dir='device_schema', file_name='text_to_speech.sql'), dict(schema_dir='device_schema', file_name='text_to_speech.sql'),
@ -226,3 +223,162 @@ for insert_file in insert_files:
) )
except FileNotFoundError: except FileNotFoundError:
pass pass
print('Building account.agreement table')
mycroft_db.db.autocommit = False
insert_sql = (
"insert into account.agreement VALUES (default, '{}', '1', '[today,]', {})"
)
doc_dir = '/Users/chrisveilleux/Mycroft/github/documentation/_pages/'
docs = {
'Privacy Policy': doc_dir + 'embed-privacy-policy.md',
'Terms of Use': doc_dir + 'embed-terms-of-use.md'
}
try:
for agrmt_type, doc_path in docs.items():
lobj = mycroft_db.db.lobject(0, 'b')
with open(doc_path) as doc:
header_delimiter_count = 0
while True:
rec = doc.readline()
if rec == '---\n':
header_delimiter_count += 1
if header_delimiter_count == 2:
break
doc_html = markdown(
doc.read(),
output_format='html5'
)
lobj.write(doc_html)
mycroft_db.execute_sql(
insert_sql.format(agrmt_type, lobj.oid)
)
mycroft_db.execute_sql(
"grant select on large object {} to selene".format(lobj.oid)
)
mycroft_db.execute_sql(
insert_sql.format('Open Dataset', 'null')
)
except:
mycroft_db.db.rollback()
raise
else:
mycroft_db.db.commit()
mycroft_db.db.autocommit = True
reference_file_dir = '/Users/chrisveilleux/Mycroft'
print('Building geography.country table')
country_file = 'country.txt'
country_insert = """
INSERT INTO
geography.country (iso_code, name)
VALUES
('{iso_code}', '{country_name}')
"""
with open(path.join(reference_file_dir, country_file)) as countries:
while True:
rec = countries.readline()
if rec.startswith('#ISO'):
break
for country in countries.readlines():
country_fields = country.split('\t')
insert_args = dict(
iso_code=country_fields[0],
country_name=country_fields[4]
)
mycroft_db.execute_sql(country_insert.format(**insert_args))
print('Building geography.region table')
region_file = 'regions.txt'
region_insert = """
INSERT INTO
geography.region (country_id, region_code, name)
VALUES
(
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
%(region_code)s,
%(region_name)s)
"""
with open(path.join(reference_file_dir, region_file)) as regions:
for region in regions.readlines():
region_fields = region.split('\t')
country_iso_code = region_fields[0][:2]
insert_args = dict(
iso_code=country_iso_code,
region_code=region_fields[0],
region_name=region_fields[1]
)
mycroft_db.execute_sql(region_insert, insert_args)
print('Building geography.timezone table')
timezone_file = 'timezones.txt'
timezone_insert = """
INSERT INTO
geography.timezone (country_id, name, gmt_offset, dst_offset)
VALUES
(
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
%(timezone_name)s,
%(gmt_offset)s,
%(dst_offset)s
)
"""
with open(path.join(reference_file_dir, timezone_file)) as timezones:
timezones.readline()
for timezone in timezones.readlines():
timezone_fields = timezone.split('\t')
insert_args = dict(
iso_code=timezone_fields[0],
timezone_name=timezone_fields[1],
gmt_offset=timezone_fields[2],
dst_offset=timezone_fields[3]
)
mycroft_db.execute_sql(timezone_insert, insert_args)
print('Building geography.city table')
cities_file = 'cities500.txt'
region_query = "SELECT id, region_code FROM geography.region"
query_result = mycroft_db.execute_sql(region_query)
region_lookup = dict()
for row in query_result.fetchall():
region_lookup[row[1]] = row[0]
timezone_query = "SELECT id, name FROM geography.timezone"
query_result = mycroft_db.execute_sql(timezone_query)
timezone_lookup = dict()
for row in query_result.fetchall():
timezone_lookup[row[1]] = row[0]
# city_insert = """
# INSERT INTO
# geography.city (region_id, timezone_id, name, latitude, longitude)
# VALUES
# (%(region_id)s, %(timezone_id)s, %(city_name)s, %(latitude)s, %(longitude)s)
# """
with open(path.join(reference_file_dir, cities_file)) as cities:
with open(path.join(reference_file_dir, 'city.dump'), 'w') as dump_file:
for city in cities.readlines():
city_fields = city.split('\t')
city_region = city_fields[8] + '.' + city_fields[10]
region_id = region_lookup.get(city_region)
timezone_id = timezone_lookup[city_fields[17]]
if region_id is not None:
dump_file.write('\t'.join([
region_id,
timezone_id,
city_fields[1],
city_fields[4],
city_fields[5]
]) + '\n')
# mycroft_db.execute_sql(city_insert, insert_args)
with open(path.join(reference_file_dir, 'city.dump')) as dump_file:
cursor = mycroft_db.db.cursor()
cursor.copy_from(dump_file, 'geography.city', columns=(
'region_id', 'timezone_id', 'name', 'latitude', 'longitude')
)
remove(path.join(reference_file_dir, 'city.dump'))
mycroft_db.close_db()

View File

@ -22,15 +22,6 @@ import os
from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig
db_connection_config = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
)
class APIConfigError(Exception): class APIConfigError(Exception):
pass pass
@ -43,6 +34,14 @@ class BaseConfig(object):
DEBUG = False DEBUG = False
ENV = os.environ['SELENE_ENVIRONMENT'] ENV = os.environ['SELENE_ENVIRONMENT']
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET'] REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
)
class DevelopmentConfig(BaseConfig): class DevelopmentConfig(BaseConfig):
@ -80,10 +79,4 @@ def get_base_config():
error_msg = 'no configuration defined for the "{}" environment' error_msg = 'no configuration defined for the "{}" environment'
raise APIConfigError(error_msg.format(environment_name)) raise APIConfigError(error_msg.format(environment_name))
max_db_connections = os.environ.get('MAX_DB_CONNECTIONS', 20)
app_config.DB_CONNECTION_POOL = allocate_db_connection_pool(
db_connection_config,
max_db_connections
)
return app_config return app_config

View File

@ -6,7 +6,7 @@ from flask.views import MethodView
from selene.data.account import Account, AccountRepository from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationError, AuthenticationToken from selene.util.auth import AuthenticationError, AuthenticationToken
from selene.util.db import get_db_connection_from_pool from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess' ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
FIFTEEN_MINUTES = 900 FIFTEEN_MINUTES = 900
@ -42,8 +42,8 @@ class SeleneEndpoint(MethodView):
@property @property
def db(self): def db(self):
if 'db' not in global_context: if 'db' not in global_context:
global_context.db = get_db_connection_from_pool( global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_POOL'] current_app.config['DB_CONNECTION_CONFIG']
) )
return global_context.db return global_context.db

View File

@ -1,5 +1,5 @@
from datetime import datetime
import json import json
from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from flask import current_app, Blueprint, g as global_context from flask import current_app, Blueprint, g as global_context
@ -7,10 +7,7 @@ from schematics.exceptions import DataError
from selene.data.metrics import ApiMetric, ApiMetricsRepository from selene.data.metrics import ApiMetric, ApiMetricsRepository
from selene.util.auth import AuthenticationError from selene.util.auth import AuthenticationError
from selene.util.db import ( from selene.util.db import connect_to_db
get_db_connection_from_pool,
return_db_connection_to_pool
)
from selene.util.not_modified import NotModifiedError from selene.util.not_modified import NotModifiedError
selene_api = Blueprint('selene_api', __name__) selene_api = Blueprint('selene_api', __name__)
@ -39,7 +36,6 @@ def setup_request():
@selene_api.after_app_request @selene_api.after_app_request
def teardown_request(response): def teardown_request(response):
add_api_metric(response.status_code) add_api_metric(response.status_code)
release_db_connection()
return response return response
@ -54,8 +50,8 @@ def add_api_metric(http_status):
if api is not None and int(http_status) != 304: if api is not None and int(http_status) != 304:
if 'db' not in global_context: if 'db' not in global_context:
global_context.db = get_db_connection_from_pool( global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_POOL'] current_app.config['DB_CONNECTION_CONFIG']
) )
if 'account_id' in global_context: if 'account_id' in global_context:
account_id = global_context.account_id account_id = global_context.account_id
@ -78,12 +74,3 @@ def add_api_metric(http_status):
) )
metric_repository = ApiMetricsRepository(global_context.db) metric_repository = ApiMetricsRepository(global_context.db)
metric_repository.add(api_metric) metric_repository.add(api_metric)
def release_db_connection():
db = global_context.pop('db', None)
if db is not None:
return_db_connection_to_pool(
current_app.config['DB_CONNECTION_POOL'],
db
)

View File

@ -3,7 +3,7 @@ from dataclasses import asdict
from http import HTTPStatus from http import HTTPStatus
from selene.data.account import AgreementRepository from selene.data.account import AgreementRepository
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
from ..base_endpoint import SeleneEndpoint from ..base_endpoint import SeleneEndpoint
@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint):
def get(self, agreement_type): def get(self, agreement_type):
"""Process HTTP GET request for an agreement.""" """Process HTTP GET request for an agreement."""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
agreement_repository = AgreementRepository(db) agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type( agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type] self.agreement_types[agreement_type]
) )
if agreement is not None: if agreement is not None:
agreement = asdict(agreement) agreement = asdict(agreement)
del(agreement['effective_date']) del(agreement['effective_date'])
self.response = agreement, HTTPStatus.OK self.response = agreement, HTTPStatus.OK
return self.response return self.response

View File

@ -3,7 +3,7 @@ import string
from selene.data.device import DeviceRepository from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
def device_etag_key(device_id: str): def device_etag_key(device_id: str):
@ -29,7 +29,7 @@ class ETagManager(object):
def __init__(self, cache: SeleneCache, config: dict): def __init__(self, cache: SeleneCache, config: dict):
self.cache: SeleneCache = cache self.cache: SeleneCache = cache
self.db_connection_pool = config['DB_CONNECTION_POOL'] self.db_connection_config = config['DB_CONNECTION_CONFIG']
def get(self, key: str) -> str: def get(self, key: str) -> str:
"""Generate a etag with 32 random chars and store it into a given key """Generate a etag with 32 random chars and store it into a given key
@ -60,10 +60,10 @@ class ETagManager(object):
def expire_device_setting_etag_by_account_id(self, account_id: str): def expire_device_setting_etag_by_account_id(self, account_id: str):
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated """Expire the settings' etags for all devices from a given account. Used when the settings are updated
at account level""" at account level"""
with get_db_connection(self.db_connection_pool) as db: db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id) devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices: for device in devices:
self.expire_device_setting_etag_by_device_id(device.id) self.expire_device_setting_etag_by_device_id(device.id)
def expire_device_location_etag_by_device_id(self, device_id: str): def expire_device_location_etag_by_device_id(self, device_id: str):
"""Expire the etag associate with the device's location entity """Expire the etag associate with the device's location entity
@ -73,10 +73,10 @@ class ETagManager(object):
def expire_device_location_etag_by_account_id(self, account_id: str): def expire_device_location_etag_by_account_id(self, account_id: str):
"""Expire the locations' etag fpr açç device for a given acccount """Expire the locations' etag fpr açç device for a given acccount
:param account_id: account uuid""" :param account_id: account uuid"""
with get_db_connection(self.db_connection_pool) as db: db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id) devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices: for device in devices:
self.expire_device_location_etag_by_device_id(device.id) self.expire_device_location_etag_by_device_id(device.id)
def expire_skill_etag_by_device_id(self, device_id): def expire_skill_etag_by_device_id(self, device_id):
"""Expire the locations' etag for a given device """Expire the locations' etag for a given device
@ -84,7 +84,7 @@ class ETagManager(object):
self._expire(device_skill_etag_key(device_id)) self._expire(device_skill_etag_key(device_id))
def expire_skill_etag_by_account_id(self, account_id): def expire_skill_etag_by_account_id(self, account_id):
with get_db_connection(self.db_connection_pool) as db: db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id) devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices: for device in devices:
self.expire_skill_etag_by_device_id(device.id) self.expire_skill_etag_by_device_id(device.id)

View File

@ -13,7 +13,7 @@ from flask.views import MethodView
from selene.api.etag import ETagManager from selene.api.etag import ETagManager
from selene.util.auth import AuthenticationError from selene.util.auth import AuthenticationError
from selene.util.db import get_db_connection_from_pool from selene.util.db import connect_to_db
from selene.util.not_modified import NotModifiedError from selene.util.not_modified import NotModifiedError
from ..util.cache import SeleneCache from ..util.cache import SeleneCache
@ -91,8 +91,8 @@ class PublicEndpoint(MethodView):
@property @property
def db(self): def db(self):
if 'db' not in global_context: if 'db' not in global_context:
global_context.db = get_db_connection_from_pool( global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_POOL'] current_app.config['DB_CONNECTION_CONFIG']
) )
return global_context.db return global_context.db

View File

@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item
from selene.data.account import Account, AccountRepository from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationToken from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess' ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ONE_MINUTE = 60 ONE_MINUTE = 60
@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict:
def get_account(context) -> Account: def get_account(context) -> Account:
with get_db_connection(context.db_pool) as db: db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id) account = acct_repository.get_account_by_id(context.account.id)
return account return account

View File

@ -12,7 +12,7 @@ class SkillDisplayRepository(RepositoryBase):
sql_file_name='get_display_data_for_skills.sql' sql_file_name='get_display_data_for_skills.sql'
) )
def get_display_data_for_skill(self, skill_display_id): def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
return self._select_one_into_dataclass( return self._select_one_into_dataclass(
dataclass=SkillDisplay, dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skill.sql', sql_file_name='get_display_data_for_skill.sql',

View File

@ -34,12 +34,12 @@ class SkillSettingRepository(RepositoryBase):
return skill_settings return skill_settings
def get_installer_settings(self, account_id: str): def get_installer_settings(self, account_id: str) -> List[AccountSkillSetting]:
skill_repo = SkillRepository(self.db) skill_repo = SkillRepository(self.db)
skills = skill_repo.get_skills_for_account(account_id) skills = skill_repo.get_skills_for_account(account_id)
installer_skill_id = None installer_skill_id = None
for skill in skills: for skill in skills:
if skill.name == 'mycroft_installer': if skill.display_name == 'Installer':
installer_skill_id = skill.id installer_skill_id = skill.id
skill_settings = None skill_settings = None

View File

@ -6,12 +6,12 @@ Example Usage:
query_result = mycroft_db_ro.execute_sql(sql) query_result = mycroft_db_ro.execute_sql(sql)
""" """
from contextlib import contextmanager from dataclasses import dataclass, field, InitVar
from dataclasses import dataclass, field
from logging import getLogger from logging import getLogger
from psycopg2 import connect from psycopg2 import connect
from psycopg2.extras import RealDictCursor from psycopg2.extras import RealDictCursor, NamedTupleCursor
from psycopg2.extensions import cursor
_log = getLogger(__package__) _log = getLogger(__package__)
@ -29,10 +29,16 @@ class DatabaseConnectionConfig(object):
password: str password: str
port: int = field(default=5432) port: int = field(default=5432)
sslmode: str = None sslmode: str = None
autocommit: str = True
cursor_factory = RealDictCursor
use_namedtuple_cursor: InitVar[bool] = False
def __post_init__(self, use_namedtuple_cursor: bool):
if use_namedtuple_cursor:
self.cursor_factory = NamedTupleCursor
@contextmanager def connect_to_db(connection_config: DatabaseConnectionConfig):
def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
""" """
Return a connection to the mycroft database for the specified user. Return a connection to the mycroft database for the specified user.
@ -41,33 +47,19 @@ def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
python notebook) python notebook)
:param connection_config: data needed to establish a connection :param connection_config: data needed to establish a connection
:param autocommit: indicated if transactions should commit automatically
:return: database connection :return: database connection
""" """
db = None
log_msg = 'establishing connection to the {db_name} database' log_msg = 'establishing connection to the {db_name} database'
_log.info(log_msg.format(db_name=connection_config.db_name)) _log.info(log_msg.format(db_name=connection_config.db_name))
try: db = connect(
if connection_config.sslmode is None: host=connection_config.host,
db = connect( dbname=connection_config.db_name,
host=connection_config.host, user=connection_config.user,
dbname=connection_config.db_name, password=connection_config.password,
user=connection_config.user, port=connection_config.port,
port=connection_config.port, cursor_factory=connection_config.cursor_factory,
cursor_factory=RealDictCursor, sslmode=connection_config.sslmode
) )
else: db.autocommit = connection_config.autocommit
db = connect(
host=connection_config.host, return db
dbname=connection_config.db_name,
user=connection_config.user,
password=connection_config.password,
port=connection_config.port,
cursor_factory=RealDictCursor,
sslmode=connection_config.sslmode
)
db.autocommit = autocommit
yield db
finally:
if db is not None:
db.close()