Merge remote-tracking branch 'origin/dev' into dev

pull/157/head
Chris Veilleux 2019-05-22 20:05:54 -05:00
commit 06b8bd4eb1
29 changed files with 453 additions and 255 deletions

View File

@ -15,7 +15,7 @@ from selene.data.account import (
)
from selene.data.device import Geography, GeographyRepository
from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@fixture
@ -32,10 +32,10 @@ def before_feature(context, _):
def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_add_agreements(context, db)
_add_account(context, db)
_add_geography(context, db)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_add_agreements(context, db)
_add_account(context, db)
_add_geography(context, db)
def _add_agreements(context, db):
@ -91,9 +91,9 @@ def _add_geography(context, db):
def after_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_delete_account(context, db)
_delete_agreements(context, db)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_delete_account(context, db)
_delete_agreements(context, db)
_clean_cache()

View File

@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none
from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@given('a device pairing code')
@ -57,9 +57,9 @@ def validate_pairing_code_removal(context):
@then('the device is added to the database')
def validate_response(context):
device_id = context.response.data.decode()
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
device_repository = DeviceRepository(db)
device = device_repository.get_device_by_id(device_id)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
device_repository = DeviceRepository(db)
device = device_repository.get_device_by_id(device_id)
assert_that(device, not_none())
assert_that(device.name, equal_to('home'))

View File

@ -8,7 +8,7 @@ from selene.api.testing import (
)
from selene.data.account import AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@given('an authenticated user with an expired access token')
@ -37,9 +37,9 @@ def check_for_new_cookies(context):
context.refresh_token,
is_not(equal_to(context.old_refresh_token))
)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],

View File

@ -1,5 +1,6 @@
import os
from binascii import b2a_base64
from datetime import date
import os
import stripe
from behave import given, then, when
@ -8,7 +9,7 @@ from hamcrest import assert_that, equal_to, is_in, none, not_none, starts_with
from stripe.error import InvalidRequestError
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
new_account_request = dict(
username='barfoo',
@ -17,8 +18,8 @@ new_account_request = dict(
login=dict(
federatedPlatform=None,
federatedToken=None,
userEnteredEmail='bar@mycroft.ai',
password='bar'
userEnteredEmail=b2a_base64(b'bar@mycroft.ai').decode(),
password=b2a_base64(b'bar').decode()
),
support=dict(openDataset=True)
)
@ -58,41 +59,41 @@ def call_add_account_endpoint(context):
context.response = context.client.post(
'/api/account',
data=json.dumps(context.new_account_request),
content_type='application_json'
content_type='application/json'
)
@then('the account will be added to the system {membership_option}')
def check_db_for_account(context, membership_option):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
assert_that(account, not_none())
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
assert_that(account, not_none())
assert_that(
account.email_address, equal_to('bar@mycroft.ai')
)
assert_that(account.username, equal_to('barfoo'))
if membership_option == 'with a membership':
assert_that(account.membership.type, equal_to('Monthly Membership'))
assert_that(
account.email_address, equal_to('bar@mycroft.ai')
account.membership.payment_account_id,
starts_with('cus')
)
assert_that(account.username, equal_to('barfoo'))
if membership_option == 'with a membership':
assert_that(account.membership.type, equal_to('Monthly Membership'))
assert_that(
account.membership.payment_account_id,
starts_with('cus')
)
elif membership_option == 'without a membership':
assert_that(account.membership, none())
elif membership_option == 'without a membership':
assert_that(account.membership, none())
assert_that(len(account.agreements), equal_to(2))
for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
assert_that(agreement.accept_date, equal_to(str(date.today())))
assert_that(len(account.agreements), equal_to(2))
for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
assert_that(agreement.accept_date, equal_to(str(date.today())))
@when('the account is deleted')
def account_deleted(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
context.stripe_id = account.membership.payment_id
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_email('bar@mycroft.ai')
context.stripe_id = account.membership.payment_id
context.response = context.client.delete('/api/account')

View File

@ -1,5 +1,6 @@
import json
from binascii import b2a_base64
from datetime import date
import json
from behave import given, when, then
from hamcrest import assert_that, equal_to, starts_with, none
@ -11,7 +12,9 @@ from selene.data.account import (
AccountAgreement,
PRIVACY_POLICY
)
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
TEST_EMAIL_ADDRESS = 'test@mycroft.ai'
new_account_request = dict(
username='test',
@ -20,8 +23,8 @@ new_account_request = dict(
login=dict(
federatedPlatform=None,
federatedToken=None,
userEnteredEmail='test@mycroft.ai',
password='12345678'
email=b2a_base64(b'test@mycroft.ai').decode(),
password=b2a_base64(b'12345678').decode()
),
support=dict(
openDataset=True,
@ -47,12 +50,12 @@ def create_account(context):
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
]
)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
acct_repository = AccountRepository(db)
account_id = acct_repository.add(context.account, 'foo')
context.account.id = account_id
generate_access_token(context)
generate_refresh_token(context)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account_id = acct_repository.add(context.account, 'foo')
context.account.id = account_id
generate_access_token(context)
generate_refresh_token(context)
@when('a monthly membership is added')
@ -66,16 +69,16 @@ def update_membership(context):
context.response = context.client.patch(
'/api/account',
data=json.dumps(dict(membership=membership_data)),
content_type='application_json'
content_type='application/json'
)
@when('the account is requested')
def request_account(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
context.response_account = AccountRepository(db).get_account_by_email(
'test@mycroft.ai'
)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
context.response_account = AccountRepository(db).get_account_by_email(
TEST_EMAIL_ADDRESS
)
@then('the account should have a monthly membership')
@ -98,16 +101,14 @@ def create_monthly_account(context):
context.client.post(
'/api/account',
data=json.dumps(new_account_request),
content_type='application_json'
content_type='application/json'
)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
account_repository = AccountRepository(db)
account = account_repository.get_account_by_email(
new_account_request['login']['userEnteredEmail']
)
context.account = account
generate_access_token(context)
generate_refresh_token(context)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
account_repository = AccountRepository(db)
account = account_repository.get_account_by_email(TEST_EMAIL_ADDRESS)
context.account = account
generate_access_token(context)
generate_refresh_token(context)
@when('the membership is cancelled')
@ -119,7 +120,7 @@ def cancel_membership(context):
context.client.patch(
'/api/account',
data=json.dumps(dict(membership=membership_data)),
content_type='application_json'
content_type='application/json'
)
@ -138,7 +139,7 @@ def change_to_yearly_account(context):
context.client.patch(
'/api/account',
data=json.dumps(dict(membership=membership_data)),
content_type='application_json'
content_type='application/json'
)

View File

@ -1,3 +1,9 @@
"""
Marketplace endpoint to add or remove a skill
This endpoint configures the install skill on a user's device(s) to add or
remove the skill.
"""
from http import HTTPStatus
from logging import getLogger
from typing import List
@ -6,7 +12,11 @@ from schematics import Model
from schematics.types import StringType
from selene.api import SeleneEndpoint
from selene.data.skill import AccountSkillSetting, SkillSettingRepository
from selene.data.skill import (
AccountSkillSetting,
SkillDisplayRepository,
SkillSettingRepository
)
INSTALL_SECTION = 'to_install'
UNINSTALL_SECTION = 'to_remove'
@ -15,17 +25,16 @@ _log = getLogger(__package__)
class InstallRequest(Model):
"""Defines the expected state of the request JSON data"""
setting_section = StringType(
required=True,
choices=[INSTALL_SECTION, UNINSTALL_SECTION]
)
skill_name = StringType(required=True)
skill_display_id = StringType(required=True)
class SkillInstallEndpoint(SeleneEndpoint):
"""
Install a skill on user device(s).
"""
"""Install a skill on user device(s)."""
def __init__(self):
super(SkillInstallEndpoint, self).__init__()
self.device_uuid: str = None
@ -34,43 +43,70 @@ class SkillInstallEndpoint(SeleneEndpoint):
self.installer_update_response = None
def put(self):
"""Handle an HTTP PUT request"""
self._authenticate()
self._validate_request()
skill_install_name = self._get_install_name()
self._get_installer_settings()
self._apply_update()
self._apply_update(skill_install_name)
self.response = (self.installer_update_response, HTTPStatus.OK)
return self.response
def _validate_request(self):
"""Ensure the data passed in the request is as expected.
:raises schematics.exceptions.ValidationError if the validation fails
"""
install_request = InstallRequest()
install_request.setting_section = self.request.json['section']
install_request.skill_name = self.request.json['skillName']
install_request.skill_display_id = self.request.json['skillDisplayId']
install_request.validate()
def _get_install_name(self) -> str:
"""Get the skill name used by the installer skill from the DB
The installer skill expects the skill name found in the "name" field
of the skill display JSON.
"""
display_repo = SkillDisplayRepository(self.db)
skill_display = display_repo.get_display_data_for_skill(
self.request.json['skillDisplayId']
)
return skill_display.display_data['name']
def _get_installer_settings(self):
"""Get the current value of the installer skill's settings"""
settings_repo = SkillSettingRepository(self.db)
self.installer_settings = settings_repo.get_installer_settings(
self.account.id
)
def _apply_update(self):
def _apply_update(self, skill_install_name: str):
"""Add the skill in the request to the installer skill settings.
This is designed to change the installer skill settings for all
devices associated with an account. It will be updated in the
future to target specific devices.
"""
for settings in self.installer_settings:
if self.request.json['section'] == INSTALL_SECTION:
to_install = settings.settings_values.get(INSTALL_SECTION, [])
to_install.append(
dict(name=self.request.json['skill_name'])
dict(name=skill_install_name)
)
settings.settings_values[INSTALL_SECTION] = to_install
else:
to_remove = settings.settings_values.get(UNINSTALL_SECTION, [])
to_remove.append(
dict(name=self.request.json['skill_name'])
dict(name=skill_install_name)
)
settings.settings_values[UNINSTALL_SECTION] = to_remove
self._update_skill_settings(settings)
def _update_skill_settings(self, settings):
"""Update the DB with the new installer skill settings."""
settings_repo = SkillSettingRepository(self.db)
settings_repo.update_device_skill_settings(
self.account.id,

View File

@ -28,6 +28,7 @@ class DeviceEndpoint(PublicEndpoint):
if device is not None:
response_data = dict(
uuid=device.id,
name=device.name,
description=device.placement,
coreVersion=device.core_version,

View File

@ -18,7 +18,7 @@ class MetricsService(object):
deviceUuid=device_id,
data=data
)
url = '{host}/{metric}'.format(host=self.metrics_service_host, metric=metric)
url = '{host}/metric/{metric}'.format(host=self.metrics_service_host, metric=metric)
requests.post(url, body)

View File

@ -4,7 +4,6 @@ import requests
from selene.api import PublicEndpoint
from selene.data.account import AccountRepository
from selene.util.db import get_db_connection
class OauthServiceEndpoint(PublicEndpoint):
@ -14,8 +13,7 @@ class OauthServiceEndpoint(PublicEndpoint):
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
def get(self, device_id, credentials, oauth_path):
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
account = AccountRepository(db).get_account_by_device_id(device_id)
account = AccountRepository(self.db).get_account_by_device_id(device_id)
uuid = account.id
url = '{host}/auth/{credentials}/{oauth_path}'.format(
host=self.oauth_service_host,

View File

@ -22,10 +22,20 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
if token_header.startswith('Bearer '):
refresh = token_header[len('Bearer '):]
session = self._refresh_session_token(refresh)
# Trying to fetch a session using the refresh token
if session:
response = session, HTTPStatus.OK
else:
response = '', HTTPStatus.UNAUTHORIZED
device = self.request.headers.get('Device')
if device:
# trying to fetch a session using the device uuid
session = self._refresh_session_token_device(device)
if session:
response = session, HTTPStatus.OK
else:
response = '', HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
return response
@ -38,3 +48,12 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
device_id = old_login['uuid']
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)
def _refresh_session_token_device(self, device: str):
refresh_key = 'device.session:{}'.format(device)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)

View File

@ -27,6 +27,7 @@ class SkillManifest(Model):
installed = DateTimeType()
updated = DateTimeType()
update = DateTimeType()
skill_gid = StringType()
class SkillJson(Model):

View File

@ -21,7 +21,7 @@ from selene.data.device import (
Geography)
from selene.data.device.entity.text_to_speech import TextToSpeech
from selene.data.device.entity.wake_word import WakeWord
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
account = Account(
email_address='test@test.com',
@ -63,22 +63,22 @@ def before_feature(context, _):
def before_scenario(context, _):
cache = context.client_config['SELENE_CACHE']
context.etag_manager = ETagManager(cache, context.client_config)
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
try:
_add_agreements(context, db)
_add_account(context, db)
_add_account_preference(context, db)
_add_geography(context, db)
_add_device(context, db)
except Exception as e:
import traceback
print(traceback.print_exc())
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
try:
_add_agreements(context, db)
_add_account(context, db)
_add_account_preference(context, db)
_add_geography(context, db)
_add_device(context, db)
except Exception as e:
import traceback
print(traceback.print_exc())
def after_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_remove_account(context, db)
_remove_agreements(context, db)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_remove_account(context, db)
_remove_agreements(context, db)
def _add_agreements(context, db):

View File

@ -17,7 +17,8 @@ skill_manifest = {
"installed": datetime.now().timestamp(),
"updated": datetime.now().timestamp(),
"installation": "installed",
"update": 0
"update": 0,
"skill_gid": "fallback-wolfram-alpha|19.02",
}
]
}

View File

@ -6,7 +6,7 @@ from hamcrest import assert_that, equal_to, not_none, is_not
from selene.api.etag import ETagManager, device_skill_etag_key
from selene.data.skill import SkillSettingRepository
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
skill = {
'skill_gid': 'wolfram-alpha|19.02',
@ -106,8 +106,8 @@ def update_skill(context):
}]
response = json.loads(context.upload_device_response.data)
skill_id = response['uuid']
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
@when('the skill settings is fetched')

View File

@ -31,6 +31,7 @@ def validate_response(context):
response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data)
assert_that(device, has_key('uuid'))
assert_that(device, has_key('name'))
assert_that(device, has_key('description'))
assert_that(device, has_key('coreVersion'))

View File

@ -7,7 +7,7 @@ from behave import when, then
from hamcrest import assert_that, has_entry, equal_to
from selene.data.account import AccountRepository, AccountMembership
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@when('the subscription endpoint is called')
@ -42,9 +42,9 @@ def get_device_subscription(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
AccountRepository(db).add_membership(context.account.id, membership)
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
AccountRepository(db).add_membership(context.account.id, membership)
context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers

View File

@ -13,7 +13,7 @@ from selene.data.account import (
AgreementRepository,
PRIVACY_POLICY
)
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
@fixture
@ -32,9 +32,9 @@ def before_feature(context, _):
def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_add_agreement(context, db)
_add_account(context, db)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
_add_agreement(context, db)
_add_account(context, db)
def _add_agreement(context, db):
@ -72,8 +72,8 @@ def _add_account(context, db):
def after_scenario(context, _):
with get_db_connection(context.db_pool) as db:
acct_repository = AccountRepository(db)
acct_repository.remove(context.account)
agreement_repository = AgreementRepository(db)
agreement_repository.remove(context.agreement, testing=True)
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
acct_repository.remove(context.account)
agreement_repository = AgreementRepository(db)
agreement_repository.remove(context.agreement, testing=True)

View File

@ -0,0 +1,11 @@
CREATE TABLE metrics.job (
id uuid PRIMARY KEY
DEFAULT gen_random_uuid(),
job_name text NOT NULL,
batch_date date NOT NULL,
start_ts TIMESTAMP NOT NULL,
end_ts TIMESTAMP NOT NULL,
command text NOT NULL,
success BOOLEAN NOT NULL,
UNIQUE (job_name, start_ts)
)

View File

@ -1,6 +1,7 @@
from glob import glob
from os import path
from os import environ, path, remove
from markdown import markdown
from psycopg2 import connect
MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft')
@ -8,10 +9,8 @@ SCHEMAS = ('account', 'skill', 'device', 'geography', 'metrics')
DB_DESTROY_FILES = (
'drop_mycroft_db.sql',
'drop_template_db.sql',
# 'drop_roles.sql'
)
DB_CREATE_FILES = (
# 'create_roles.sql',
'create_template_db.sql',
)
ACCOUNT_TABLE_ORDER = (
@ -48,6 +47,7 @@ GEOGRAPHY_TABLE_ORDER = (
METRICS_TABLE_ORDER = (
'api',
'job'
)
schema_directory = '{}_schema'
@ -61,32 +61,40 @@ def get_sql_from_file(file_path: str) -> str:
class PostgresDB(object):
def __init__(self, dbname, user, password=None):
self.db = connect(dbname=dbname, user=user, host='127.0.0.1')
# self.db = connect(
# dbname=dbname,
# user=user,
# password=password,
# host='selene-test-db-do-user-1412453-0.db.ondigitalocean.com',
# port=25060,
# sslmode='require'
# )
def __init__(self, db_name, user=None):
db_host = environ['DB_HOST']
db_port = environ['DB_PORT']
db_ssl_mode = environ.get('DB_SSL_MODE')
if db_name in ('postgres', 'defaultdb'):
db_user = environ['POSTGRES_DB_USER']
db_password = environ.get('POSTGRES_DB_PASSWORD')
else:
db_user = environ['MYCROFT_DB_USER']
db_password = environ['MYCROFT_DB_PASSWORD']
if user is not None:
db_user = user
self.db = connect(
dbname=db_name,
user=db_user,
password=db_password,
host=db_host,
port=db_port,
sslmode=db_ssl_mode
)
self.db.autocommit = True
def close_db(self):
self.db.close()
def execute_sql(self, sql: str):
def execute_sql(self, sql: str, args=None):
cursor = self.db.cursor()
cursor.execute(sql)
cursor.execute(sql, args)
return cursor
postgres_db = PostgresDB(dbname='postgres', user='postgres')
# postgres_db = PostgresDB(
# dbname='defaultdb',
# user='doadmin',
# password='l06tn0qi2bjhgcki'
# )
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
print('Destroying any objects we will be creating later.')
for db_destroy_file in DB_DESTROY_FILES:
@ -94,7 +102,7 @@ for db_destroy_file in DB_DESTROY_FILES:
get_sql_from_file(db_destroy_file)
)
print('Creating the extensions, mycroft database, and selene roles')
print('Creating the mycroft database')
for db_setup_file in DB_CREATE_FILES:
postgres_db.execute_sql(
get_sql_from_file(db_setup_file)
@ -102,13 +110,10 @@ for db_setup_file in DB_CREATE_FILES:
postgres_db.close_db()
template_db = PostgresDB(dbname='mycroft_template', user='mycroft')
# template_db = PostgresDB(
# dbname='mycroft_template',
# user='selene',
# password='ubhemhx1dikmqc5f'
# )
template_db = PostgresDB(db_name='mycroft_template')
print('Creating the extensions')
template_db.execute_sql(
get_sql_from_file(path.join('create_extensions.sql'))
)
@ -193,22 +198,14 @@ for schema in SCHEMAS:
template_db.close_db()
print('Copying template to new database.')
postgres_db = PostgresDB(dbname='postgres', user='mycroft')
# postgres_db = PostgresDB(
# dbname='defaultdb',
# user='doadmin',
# password='l06tn0qi2bjhgcki'
# )
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql'))
postgres_db.close_db()
mycroft_db = PostgresDB(dbname='mycroft', user='mycroft')
# mycroft_db = PostgresDB(
# dbname='mycroft_template',
# user='selene',
# password='ubhemhx1dikmqc5f'
# )
mycroft_db = PostgresDB(db_name=environ['MYCROFT_DB_NAME'])
insert_files = [
dict(schema_dir='account_schema', file_name='membership.sql'),
dict(schema_dir='device_schema', file_name='text_to_speech.sql'),
@ -226,3 +223,162 @@ for insert_file in insert_files:
)
except FileNotFoundError:
pass
print('Building account.agreement table')
mycroft_db.db.autocommit = False
insert_sql = (
"insert into account.agreement VALUES (default, '{}', '1', '[today,]', {})"
)
doc_dir = '/Users/chrisveilleux/Mycroft/github/documentation/_pages/'
docs = {
'Privacy Policy': doc_dir + 'embed-privacy-policy.md',
'Terms of Use': doc_dir + 'embed-terms-of-use.md'
}
try:
for agrmt_type, doc_path in docs.items():
lobj = mycroft_db.db.lobject(0, 'b')
with open(doc_path) as doc:
header_delimiter_count = 0
while True:
rec = doc.readline()
if rec == '---\n':
header_delimiter_count += 1
if header_delimiter_count == 2:
break
doc_html = markdown(
doc.read(),
output_format='html5'
)
lobj.write(doc_html)
mycroft_db.execute_sql(
insert_sql.format(agrmt_type, lobj.oid)
)
mycroft_db.execute_sql(
"grant select on large object {} to selene".format(lobj.oid)
)
mycroft_db.execute_sql(
insert_sql.format('Open Dataset', 'null')
)
except:
mycroft_db.db.rollback()
raise
else:
mycroft_db.db.commit()
mycroft_db.db.autocommit = True
reference_file_dir = '/Users/chrisveilleux/Mycroft'
print('Building geography.country table')
country_file = 'country.txt'
country_insert = """
INSERT INTO
geography.country (iso_code, name)
VALUES
('{iso_code}', '{country_name}')
"""
with open(path.join(reference_file_dir, country_file)) as countries:
while True:
rec = countries.readline()
if rec.startswith('#ISO'):
break
for country in countries.readlines():
country_fields = country.split('\t')
insert_args = dict(
iso_code=country_fields[0],
country_name=country_fields[4]
)
mycroft_db.execute_sql(country_insert.format(**insert_args))
print('Building geography.region table')
region_file = 'regions.txt'
region_insert = """
INSERT INTO
geography.region (country_id, region_code, name)
VALUES
(
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
%(region_code)s,
%(region_name)s)
"""
with open(path.join(reference_file_dir, region_file)) as regions:
for region in regions.readlines():
region_fields = region.split('\t')
country_iso_code = region_fields[0][:2]
insert_args = dict(
iso_code=country_iso_code,
region_code=region_fields[0],
region_name=region_fields[1]
)
mycroft_db.execute_sql(region_insert, insert_args)
print('Building geography.timezone table')
timezone_file = 'timezones.txt'
timezone_insert = """
INSERT INTO
geography.timezone (country_id, name, gmt_offset, dst_offset)
VALUES
(
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
%(timezone_name)s,
%(gmt_offset)s,
%(dst_offset)s
)
"""
with open(path.join(reference_file_dir, timezone_file)) as timezones:
timezones.readline()
for timezone in timezones.readlines():
timezone_fields = timezone.split('\t')
insert_args = dict(
iso_code=timezone_fields[0],
timezone_name=timezone_fields[1],
gmt_offset=timezone_fields[2],
dst_offset=timezone_fields[3]
)
mycroft_db.execute_sql(timezone_insert, insert_args)
print('Building geography.city table')
cities_file = 'cities500.txt'
region_query = "SELECT id, region_code FROM geography.region"
query_result = mycroft_db.execute_sql(region_query)
region_lookup = dict()
for row in query_result.fetchall():
region_lookup[row[1]] = row[0]
timezone_query = "SELECT id, name FROM geography.timezone"
query_result = mycroft_db.execute_sql(timezone_query)
timezone_lookup = dict()
for row in query_result.fetchall():
timezone_lookup[row[1]] = row[0]
# city_insert = """
# INSERT INTO
# geography.city (region_id, timezone_id, name, latitude, longitude)
# VALUES
# (%(region_id)s, %(timezone_id)s, %(city_name)s, %(latitude)s, %(longitude)s)
# """
with open(path.join(reference_file_dir, cities_file)) as cities:
with open(path.join(reference_file_dir, 'city.dump'), 'w') as dump_file:
for city in cities.readlines():
city_fields = city.split('\t')
city_region = city_fields[8] + '.' + city_fields[10]
region_id = region_lookup.get(city_region)
timezone_id = timezone_lookup[city_fields[17]]
if region_id is not None:
dump_file.write('\t'.join([
region_id,
timezone_id,
city_fields[1],
city_fields[4],
city_fields[5]
]) + '\n')
# mycroft_db.execute_sql(city_insert, insert_args)
with open(path.join(reference_file_dir, 'city.dump')) as dump_file:
cursor = mycroft_db.db.cursor()
cursor.copy_from(dump_file, 'geography.city', columns=(
'region_id', 'timezone_id', 'name', 'latitude', 'longitude')
)
remove(path.join(reference_file_dir, 'city.dump'))
mycroft_db.close_db()

View File

@ -22,15 +22,6 @@ import os
from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig
db_connection_config = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
)
class APIConfigError(Exception):
pass
@ -43,6 +34,14 @@ class BaseConfig(object):
DEBUG = False
ENV = os.environ['SELENE_ENVIRONMENT']
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
)
class DevelopmentConfig(BaseConfig):
@ -80,10 +79,4 @@ def get_base_config():
error_msg = 'no configuration defined for the "{}" environment'
raise APIConfigError(error_msg.format(environment_name))
max_db_connections = os.environ.get('MAX_DB_CONNECTIONS', 20)
app_config.DB_CONNECTION_POOL = allocate_db_connection_pool(
db_connection_config,
max_db_connections
)
return app_config

View File

@ -6,7 +6,7 @@ from flask.views import MethodView
from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationError, AuthenticationToken
from selene.util.db import get_db_connection_from_pool
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
FIFTEEN_MINUTES = 900
@ -42,8 +42,8 @@ class SeleneEndpoint(MethodView):
@property
def db(self):
if 'db' not in global_context:
global_context.db = get_db_connection_from_pool(
current_app.config['DB_CONNECTION_POOL']
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
)
return global_context.db

View File

@ -1,5 +1,5 @@
from datetime import datetime
import json
from datetime import datetime
from http import HTTPStatus
from flask import current_app, Blueprint, g as global_context
@ -7,10 +7,7 @@ from schematics.exceptions import DataError
from selene.data.metrics import ApiMetric, ApiMetricsRepository
from selene.util.auth import AuthenticationError
from selene.util.db import (
get_db_connection_from_pool,
return_db_connection_to_pool
)
from selene.util.db import connect_to_db
from selene.util.not_modified import NotModifiedError
selene_api = Blueprint('selene_api', __name__)
@ -39,7 +36,6 @@ def setup_request():
@selene_api.after_app_request
def teardown_request(response):
add_api_metric(response.status_code)
release_db_connection()
return response
@ -54,8 +50,8 @@ def add_api_metric(http_status):
if api is not None and int(http_status) != 304:
if 'db' not in global_context:
global_context.db = get_db_connection_from_pool(
current_app.config['DB_CONNECTION_POOL']
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
)
if 'account_id' in global_context:
account_id = global_context.account_id
@ -78,12 +74,3 @@ def add_api_metric(http_status):
)
metric_repository = ApiMetricsRepository(global_context.db)
metric_repository.add(api_metric)
def release_db_connection():
db = global_context.pop('db', None)
if db is not None:
return_db_connection_to_pool(
current_app.config['DB_CONNECTION_POOL'],
db
)

View File

@ -3,7 +3,7 @@ from dataclasses import asdict
from http import HTTPStatus
from selene.data.account import AgreementRepository
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
from ..base_endpoint import SeleneEndpoint
@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint):
def get(self, agreement_type):
"""Process HTTP GET request for an agreement."""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
self.response = agreement, HTTPStatus.OK
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
self.response = agreement, HTTPStatus.OK
return self.response

View File

@ -3,7 +3,7 @@ import string
from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
def device_etag_key(device_id: str):
@ -29,7 +29,7 @@ class ETagManager(object):
def __init__(self, cache: SeleneCache, config: dict):
self.cache: SeleneCache = cache
self.db_connection_pool = config['DB_CONNECTION_POOL']
self.db_connection_config = config['DB_CONNECTION_CONFIG']
def get(self, key: str) -> str:
"""Generate a etag with 32 random chars and store it into a given key
@ -60,10 +60,10 @@ class ETagManager(object):
def expire_device_setting_etag_by_account_id(self, account_id: str):
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated
at account level"""
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_setting_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_setting_etag_by_device_id(device.id)
def expire_device_location_etag_by_device_id(self, device_id: str):
"""Expire the etag associate with the device's location entity
@ -73,10 +73,10 @@ class ETagManager(object):
def expire_device_location_etag_by_account_id(self, account_id: str):
"""Expire the locations' etag fpr açç device for a given acccount
:param account_id: account uuid"""
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_location_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_device_location_etag_by_device_id(device.id)
def expire_skill_etag_by_device_id(self, device_id):
"""Expire the locations' etag for a given device
@ -84,7 +84,7 @@ class ETagManager(object):
self._expire(device_skill_etag_key(device_id))
def expire_skill_etag_by_account_id(self, account_id):
with get_db_connection(self.db_connection_pool) as db:
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_skill_etag_by_device_id(device.id)
db = connect_to_db(self.db_connection_config)
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
for device in devices:
self.expire_skill_etag_by_device_id(device.id)

View File

@ -13,7 +13,7 @@ from flask.views import MethodView
from selene.api.etag import ETagManager
from selene.util.auth import AuthenticationError
from selene.util.db import get_db_connection_from_pool
from selene.util.db import connect_to_db
from selene.util.not_modified import NotModifiedError
from ..util.cache import SeleneCache
@ -91,8 +91,8 @@ class PublicEndpoint(MethodView):
@property
def db(self):
if 'db' not in global_context:
global_context.db = get_db_connection_from_pool(
current_app.config['DB_CONNECTION_POOL']
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
)
return global_context.db

View File

@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item
from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ONE_MINUTE = 60
@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict:
def get_account(context) -> Account:
with get_db_connection(context.db_pool) as db:
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
return account

View File

@ -12,7 +12,7 @@ class SkillDisplayRepository(RepositoryBase):
sql_file_name='get_display_data_for_skills.sql'
)
def get_display_data_for_skill(self, skill_display_id):
def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
return self._select_one_into_dataclass(
dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skill.sql',

View File

@ -34,12 +34,12 @@ class SkillSettingRepository(RepositoryBase):
return skill_settings
def get_installer_settings(self, account_id: str):
def get_installer_settings(self, account_id: str) -> List[AccountSkillSetting]:
skill_repo = SkillRepository(self.db)
skills = skill_repo.get_skills_for_account(account_id)
installer_skill_id = None
for skill in skills:
if skill.name == 'mycroft_installer':
if skill.display_name == 'Installer':
installer_skill_id = skill.id
skill_settings = None

View File

@ -6,12 +6,12 @@ Example Usage:
query_result = mycroft_db_ro.execute_sql(sql)
"""
from contextlib import contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass, field, InitVar
from logging import getLogger
from psycopg2 import connect
from psycopg2.extras import RealDictCursor
from psycopg2.extras import RealDictCursor, NamedTupleCursor
from psycopg2.extensions import cursor
_log = getLogger(__package__)
@ -29,10 +29,16 @@ class DatabaseConnectionConfig(object):
password: str
port: int = field(default=5432)
sslmode: str = None
autocommit: str = True
cursor_factory = RealDictCursor
use_namedtuple_cursor: InitVar[bool] = False
def __post_init__(self, use_namedtuple_cursor: bool):
if use_namedtuple_cursor:
self.cursor_factory = NamedTupleCursor
@contextmanager
def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
def connect_to_db(connection_config: DatabaseConnectionConfig):
"""
Return a connection to the mycroft database for the specified user.
@ -41,33 +47,19 @@ def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
python notebook)
:param connection_config: data needed to establish a connection
:param autocommit: indicated if transactions should commit automatically
:return: database connection
"""
db = None
log_msg = 'establishing connection to the {db_name} database'
_log.info(log_msg.format(db_name=connection_config.db_name))
try:
if connection_config.sslmode is None:
db = connect(
host=connection_config.host,
dbname=connection_config.db_name,
user=connection_config.user,
port=connection_config.port,
cursor_factory=RealDictCursor,
)
else:
db = connect(
host=connection_config.host,
dbname=connection_config.db_name,
user=connection_config.user,
password=connection_config.password,
port=connection_config.port,
cursor_factory=RealDictCursor,
sslmode=connection_config.sslmode
)
db.autocommit = autocommit
yield db
finally:
if db is not None:
db.close()
db = connect(
host=connection_config.host,
dbname=connection_config.db_name,
user=connection_config.user,
password=connection_config.password,
port=connection_config.port,
cursor_factory=connection_config.cursor_factory,
sslmode=connection_config.sslmode
)
db.autocommit = connection_config.autocommit
return db