Merge remote-tracking branch 'origin/dev' into dev
commit
06b8bd4eb1
|
@ -15,7 +15,7 @@ from selene.data.account import (
|
|||
)
|
||||
from selene.data.device import Geography, GeographyRepository
|
||||
from selene.util.cache import SeleneCache
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
@fixture
|
||||
|
@ -32,10 +32,10 @@ def before_feature(context, _):
|
|||
|
||||
|
||||
def before_scenario(context, _):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
_add_agreements(context, db)
|
||||
_add_account(context, db)
|
||||
_add_geography(context, db)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
_add_agreements(context, db)
|
||||
_add_account(context, db)
|
||||
_add_geography(context, db)
|
||||
|
||||
|
||||
def _add_agreements(context, db):
|
||||
|
@ -91,9 +91,9 @@ def _add_geography(context, db):
|
|||
|
||||
|
||||
def after_scenario(context, _):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
_delete_account(context, db)
|
||||
_delete_agreements(context, db)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
_delete_account(context, db)
|
||||
_delete_agreements(context, db)
|
||||
_clean_cache()
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none
|
|||
|
||||
from selene.data.device import DeviceRepository
|
||||
from selene.util.cache import SeleneCache
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
@given('a device pairing code')
|
||||
|
@ -57,9 +57,9 @@ def validate_pairing_code_removal(context):
|
|||
@then('the device is added to the database')
|
||||
def validate_response(context):
|
||||
device_id = context.response.data.decode()
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
device_repository = DeviceRepository(db)
|
||||
device = device_repository.get_device_by_id(device_id)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
device_repository = DeviceRepository(db)
|
||||
device = device_repository.get_device_by_id(device_id)
|
||||
|
||||
assert_that(device, not_none())
|
||||
assert_that(device.name, equal_to('home'))
|
||||
|
|
|
@ -8,7 +8,7 @@ from selene.api.testing import (
|
|||
)
|
||||
from selene.data.account import AccountRepository
|
||||
from selene.util.auth import AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
@given('an authenticated user with an expired access token')
|
||||
|
@ -37,9 +37,9 @@ def check_for_new_cookies(context):
|
|||
context.refresh_token,
|
||||
is_not(equal_to(context.old_refresh_token))
|
||||
)
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_id(context.account.id)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_id(context.account.id)
|
||||
|
||||
refresh_token = AuthenticationToken(
|
||||
context.client_config['REFRESH_SECRET'],
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from binascii import b2a_base64
|
||||
from datetime import date
|
||||
import os
|
||||
|
||||
import stripe
|
||||
from behave import given, then, when
|
||||
|
@ -8,7 +9,7 @@ from hamcrest import assert_that, equal_to, is_in, none, not_none, starts_with
|
|||
from stripe.error import InvalidRequestError
|
||||
|
||||
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
new_account_request = dict(
|
||||
username='barfoo',
|
||||
|
@ -17,8 +18,8 @@ new_account_request = dict(
|
|||
login=dict(
|
||||
federatedPlatform=None,
|
||||
federatedToken=None,
|
||||
userEnteredEmail='bar@mycroft.ai',
|
||||
password='bar'
|
||||
userEnteredEmail=b2a_base64(b'bar@mycroft.ai').decode(),
|
||||
password=b2a_base64(b'bar').decode()
|
||||
),
|
||||
support=dict(openDataset=True)
|
||||
)
|
||||
|
@ -58,41 +59,41 @@ def call_add_account_endpoint(context):
|
|||
context.response = context.client.post(
|
||||
'/api/account',
|
||||
data=json.dumps(context.new_account_request),
|
||||
content_type='application_json'
|
||||
content_type='application/json'
|
||||
)
|
||||
|
||||
|
||||
@then('the account will be added to the system {membership_option}')
|
||||
def check_db_for_account(context, membership_option):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||
assert_that(account, not_none())
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||
assert_that(account, not_none())
|
||||
assert_that(
|
||||
account.email_address, equal_to('bar@mycroft.ai')
|
||||
)
|
||||
assert_that(account.username, equal_to('barfoo'))
|
||||
if membership_option == 'with a membership':
|
||||
assert_that(account.membership.type, equal_to('Monthly Membership'))
|
||||
assert_that(
|
||||
account.email_address, equal_to('bar@mycroft.ai')
|
||||
account.membership.payment_account_id,
|
||||
starts_with('cus')
|
||||
)
|
||||
assert_that(account.username, equal_to('barfoo'))
|
||||
if membership_option == 'with a membership':
|
||||
assert_that(account.membership.type, equal_to('Monthly Membership'))
|
||||
assert_that(
|
||||
account.membership.payment_account_id,
|
||||
starts_with('cus')
|
||||
)
|
||||
elif membership_option == 'without a membership':
|
||||
assert_that(account.membership, none())
|
||||
elif membership_option == 'without a membership':
|
||||
assert_that(account.membership, none())
|
||||
|
||||
assert_that(len(account.agreements), equal_to(2))
|
||||
for agreement in account.agreements:
|
||||
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
||||
assert_that(agreement.accept_date, equal_to(str(date.today())))
|
||||
assert_that(len(account.agreements), equal_to(2))
|
||||
for agreement in account.agreements:
|
||||
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
||||
assert_that(agreement.accept_date, equal_to(str(date.today())))
|
||||
|
||||
|
||||
@when('the account is deleted')
|
||||
def account_deleted(context):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||
context.stripe_id = account.membership.payment_id
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||
context.stripe_id = account.membership.payment_id
|
||||
context.response = context.client.delete('/api/account')
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from binascii import b2a_base64
|
||||
from datetime import date
|
||||
import json
|
||||
|
||||
from behave import given, when, then
|
||||
from hamcrest import assert_that, equal_to, starts_with, none
|
||||
|
@ -11,7 +12,9 @@ from selene.data.account import (
|
|||
AccountAgreement,
|
||||
PRIVACY_POLICY
|
||||
)
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
TEST_EMAIL_ADDRESS = 'test@mycroft.ai'
|
||||
|
||||
new_account_request = dict(
|
||||
username='test',
|
||||
|
@ -20,8 +23,8 @@ new_account_request = dict(
|
|||
login=dict(
|
||||
federatedPlatform=None,
|
||||
federatedToken=None,
|
||||
userEnteredEmail='test@mycroft.ai',
|
||||
password='12345678'
|
||||
email=b2a_base64(b'test@mycroft.ai').decode(),
|
||||
password=b2a_base64(b'12345678').decode()
|
||||
),
|
||||
support=dict(
|
||||
openDataset=True,
|
||||
|
@ -47,12 +50,12 @@ def create_account(context):
|
|||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
|
||||
]
|
||||
)
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account_id = acct_repository.add(context.account, 'foo')
|
||||
context.account.id = account_id
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
account_id = acct_repository.add(context.account, 'foo')
|
||||
context.account.id = account_id
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
|
||||
|
||||
@when('a monthly membership is added')
|
||||
|
@ -66,16 +69,16 @@ def update_membership(context):
|
|||
context.response = context.client.patch(
|
||||
'/api/account',
|
||||
data=json.dumps(dict(membership=membership_data)),
|
||||
content_type='application_json'
|
||||
content_type='application/json'
|
||||
)
|
||||
|
||||
|
||||
@when('the account is requested')
|
||||
def request_account(context):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
context.response_account = AccountRepository(db).get_account_by_email(
|
||||
'test@mycroft.ai'
|
||||
)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
context.response_account = AccountRepository(db).get_account_by_email(
|
||||
TEST_EMAIL_ADDRESS
|
||||
)
|
||||
|
||||
|
||||
@then('the account should have a monthly membership')
|
||||
|
@ -98,16 +101,14 @@ def create_monthly_account(context):
|
|||
context.client.post(
|
||||
'/api/account',
|
||||
data=json.dumps(new_account_request),
|
||||
content_type='application_json'
|
||||
content_type='application/json'
|
||||
)
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
account_repository = AccountRepository(db)
|
||||
account = account_repository.get_account_by_email(
|
||||
new_account_request['login']['userEnteredEmail']
|
||||
)
|
||||
context.account = account
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
account_repository = AccountRepository(db)
|
||||
account = account_repository.get_account_by_email(TEST_EMAIL_ADDRESS)
|
||||
context.account = account
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
|
||||
|
||||
@when('the membership is cancelled')
|
||||
|
@ -119,7 +120,7 @@ def cancel_membership(context):
|
|||
context.client.patch(
|
||||
'/api/account',
|
||||
data=json.dumps(dict(membership=membership_data)),
|
||||
content_type='application_json'
|
||||
content_type='application/json'
|
||||
)
|
||||
|
||||
|
||||
|
@ -138,7 +139,7 @@ def change_to_yearly_account(context):
|
|||
context.client.patch(
|
||||
'/api/account',
|
||||
data=json.dumps(dict(membership=membership_data)),
|
||||
content_type='application_json'
|
||||
content_type='application/json'
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
"""
|
||||
Marketplace endpoint to add or remove a skill
|
||||
|
||||
This endpoint configures the install skill on a user's device(s) to add or
|
||||
remove the skill.
|
||||
"""
|
||||
from http import HTTPStatus
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
|
@ -6,7 +12,11 @@ from schematics import Model
|
|||
from schematics.types import StringType
|
||||
|
||||
from selene.api import SeleneEndpoint
|
||||
from selene.data.skill import AccountSkillSetting, SkillSettingRepository
|
||||
from selene.data.skill import (
|
||||
AccountSkillSetting,
|
||||
SkillDisplayRepository,
|
||||
SkillSettingRepository
|
||||
)
|
||||
|
||||
INSTALL_SECTION = 'to_install'
|
||||
UNINSTALL_SECTION = 'to_remove'
|
||||
|
@ -15,17 +25,16 @@ _log = getLogger(__package__)
|
|||
|
||||
|
||||
class InstallRequest(Model):
|
||||
"""Defines the expected state of the request JSON data"""
|
||||
setting_section = StringType(
|
||||
required=True,
|
||||
choices=[INSTALL_SECTION, UNINSTALL_SECTION]
|
||||
)
|
||||
skill_name = StringType(required=True)
|
||||
skill_display_id = StringType(required=True)
|
||||
|
||||
|
||||
class SkillInstallEndpoint(SeleneEndpoint):
|
||||
"""
|
||||
Install a skill on user device(s).
|
||||
"""
|
||||
"""Install a skill on user device(s)."""
|
||||
def __init__(self):
|
||||
super(SkillInstallEndpoint, self).__init__()
|
||||
self.device_uuid: str = None
|
||||
|
@ -34,43 +43,70 @@ class SkillInstallEndpoint(SeleneEndpoint):
|
|||
self.installer_update_response = None
|
||||
|
||||
def put(self):
|
||||
"""Handle an HTTP PUT request"""
|
||||
self._authenticate()
|
||||
self._validate_request()
|
||||
skill_install_name = self._get_install_name()
|
||||
self._get_installer_settings()
|
||||
self._apply_update()
|
||||
self._apply_update(skill_install_name)
|
||||
self.response = (self.installer_update_response, HTTPStatus.OK)
|
||||
|
||||
return self.response
|
||||
|
||||
def _validate_request(self):
|
||||
"""Ensure the data passed in the request is as expected.
|
||||
|
||||
:raises schematics.exceptions.ValidationError if the validation fails
|
||||
"""
|
||||
install_request = InstallRequest()
|
||||
install_request.setting_section = self.request.json['section']
|
||||
install_request.skill_name = self.request.json['skillName']
|
||||
install_request.skill_display_id = self.request.json['skillDisplayId']
|
||||
install_request.validate()
|
||||
|
||||
def _get_install_name(self) -> str:
|
||||
"""Get the skill name used by the installer skill from the DB
|
||||
|
||||
The installer skill expects the skill name found in the "name" field
|
||||
of the skill display JSON.
|
||||
"""
|
||||
display_repo = SkillDisplayRepository(self.db)
|
||||
skill_display = display_repo.get_display_data_for_skill(
|
||||
self.request.json['skillDisplayId']
|
||||
)
|
||||
|
||||
return skill_display.display_data['name']
|
||||
|
||||
def _get_installer_settings(self):
|
||||
"""Get the current value of the installer skill's settings"""
|
||||
settings_repo = SkillSettingRepository(self.db)
|
||||
self.installer_settings = settings_repo.get_installer_settings(
|
||||
self.account.id
|
||||
)
|
||||
|
||||
def _apply_update(self):
|
||||
def _apply_update(self, skill_install_name: str):
|
||||
"""Add the skill in the request to the installer skill settings.
|
||||
|
||||
This is designed to change the installer skill settings for all
|
||||
devices associated with an account. It will be updated in the
|
||||
future to target specific devices.
|
||||
"""
|
||||
for settings in self.installer_settings:
|
||||
if self.request.json['section'] == INSTALL_SECTION:
|
||||
to_install = settings.settings_values.get(INSTALL_SECTION, [])
|
||||
to_install.append(
|
||||
dict(name=self.request.json['skill_name'])
|
||||
dict(name=skill_install_name)
|
||||
)
|
||||
settings.settings_values[INSTALL_SECTION] = to_install
|
||||
else:
|
||||
to_remove = settings.settings_values.get(UNINSTALL_SECTION, [])
|
||||
to_remove.append(
|
||||
dict(name=self.request.json['skill_name'])
|
||||
dict(name=skill_install_name)
|
||||
)
|
||||
settings.settings_values[UNINSTALL_SECTION] = to_remove
|
||||
self._update_skill_settings(settings)
|
||||
|
||||
def _update_skill_settings(self, settings):
|
||||
"""Update the DB with the new installer skill settings."""
|
||||
settings_repo = SkillSettingRepository(self.db)
|
||||
settings_repo.update_device_skill_settings(
|
||||
self.account.id,
|
||||
|
|
|
@ -28,6 +28,7 @@ class DeviceEndpoint(PublicEndpoint):
|
|||
|
||||
if device is not None:
|
||||
response_data = dict(
|
||||
uuid=device.id,
|
||||
name=device.name,
|
||||
description=device.placement,
|
||||
coreVersion=device.core_version,
|
||||
|
|
|
@ -18,7 +18,7 @@ class MetricsService(object):
|
|||
deviceUuid=device_id,
|
||||
data=data
|
||||
)
|
||||
url = '{host}/{metric}'.format(host=self.metrics_service_host, metric=metric)
|
||||
url = '{host}/metric/{metric}'.format(host=self.metrics_service_host, metric=metric)
|
||||
requests.post(url, body)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import requests
|
|||
|
||||
from selene.api import PublicEndpoint
|
||||
from selene.data.account import AccountRepository
|
||||
from selene.util.db import get_db_connection
|
||||
|
||||
|
||||
class OauthServiceEndpoint(PublicEndpoint):
|
||||
|
@ -14,8 +13,7 @@ class OauthServiceEndpoint(PublicEndpoint):
|
|||
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
|
||||
|
||||
def get(self, device_id, credentials, oauth_path):
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
account = AccountRepository(db).get_account_by_device_id(device_id)
|
||||
account = AccountRepository(self.db).get_account_by_device_id(device_id)
|
||||
uuid = account.id
|
||||
url = '{host}/auth/{credentials}/{oauth_path}'.format(
|
||||
host=self.oauth_service_host,
|
||||
|
|
|
@ -22,10 +22,20 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
|
|||
if token_header.startswith('Bearer '):
|
||||
refresh = token_header[len('Bearer '):]
|
||||
session = self._refresh_session_token(refresh)
|
||||
# Trying to fetch a session using the refresh token
|
||||
if session:
|
||||
response = session, HTTPStatus.OK
|
||||
else:
|
||||
response = '', HTTPStatus.UNAUTHORIZED
|
||||
device = self.request.headers.get('Device')
|
||||
if device:
|
||||
# trying to fetch a session using the device uuid
|
||||
session = self._refresh_session_token_device(device)
|
||||
if session:
|
||||
response = session, HTTPStatus.OK
|
||||
else:
|
||||
response = '', HTTPStatus.UNAUTHORIZED
|
||||
else:
|
||||
response = '', HTTPStatus.UNAUTHORIZED
|
||||
else:
|
||||
response = '', HTTPStatus.UNAUTHORIZED
|
||||
return response
|
||||
|
@ -38,3 +48,12 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
|
|||
device_id = old_login['uuid']
|
||||
self.cache.delete(refresh_key)
|
||||
return generate_device_login(device_id, self.cache)
|
||||
|
||||
def _refresh_session_token_device(self, device: str):
|
||||
refresh_key = 'device.session:{}'.format(device)
|
||||
session = self.cache.get(refresh_key)
|
||||
if session:
|
||||
old_login = json.loads(session)
|
||||
device_id = old_login['uuid']
|
||||
self.cache.delete(refresh_key)
|
||||
return generate_device_login(device_id, self.cache)
|
||||
|
|
|
@ -27,6 +27,7 @@ class SkillManifest(Model):
|
|||
installed = DateTimeType()
|
||||
updated = DateTimeType()
|
||||
update = DateTimeType()
|
||||
skill_gid = StringType()
|
||||
|
||||
|
||||
class SkillJson(Model):
|
||||
|
|
|
@ -21,7 +21,7 @@ from selene.data.device import (
|
|||
Geography)
|
||||
from selene.data.device.entity.text_to_speech import TextToSpeech
|
||||
from selene.data.device.entity.wake_word import WakeWord
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
account = Account(
|
||||
email_address='test@test.com',
|
||||
|
@ -63,22 +63,22 @@ def before_feature(context, _):
|
|||
def before_scenario(context, _):
|
||||
cache = context.client_config['SELENE_CACHE']
|
||||
context.etag_manager = ETagManager(cache, context.client_config)
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
try:
|
||||
_add_agreements(context, db)
|
||||
_add_account(context, db)
|
||||
_add_account_preference(context, db)
|
||||
_add_geography(context, db)
|
||||
_add_device(context, db)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.print_exc())
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
try:
|
||||
_add_agreements(context, db)
|
||||
_add_account(context, db)
|
||||
_add_account_preference(context, db)
|
||||
_add_geography(context, db)
|
||||
_add_device(context, db)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.print_exc())
|
||||
|
||||
|
||||
def after_scenario(context, _):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
_remove_account(context, db)
|
||||
_remove_agreements(context, db)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
_remove_account(context, db)
|
||||
_remove_agreements(context, db)
|
||||
|
||||
|
||||
def _add_agreements(context, db):
|
||||
|
|
|
@ -17,7 +17,8 @@ skill_manifest = {
|
|||
"installed": datetime.now().timestamp(),
|
||||
"updated": datetime.now().timestamp(),
|
||||
"installation": "installed",
|
||||
"update": 0
|
||||
"update": 0,
|
||||
"skill_gid": "fallback-wolfram-alpha|19.02",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ from hamcrest import assert_that, equal_to, not_none, is_not
|
|||
|
||||
from selene.api.etag import ETagManager, device_skill_etag_key
|
||||
from selene.data.skill import SkillSettingRepository
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
skill = {
|
||||
'skill_gid': 'wolfram-alpha|19.02',
|
||||
|
@ -106,8 +106,8 @@ def update_skill(context):
|
|||
}]
|
||||
response = json.loads(context.upload_device_response.data)
|
||||
skill_id = response['uuid']
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
|
||||
|
||||
|
||||
@when('the skill settings is fetched')
|
||||
|
|
|
@ -31,6 +31,7 @@ def validate_response(context):
|
|||
response = context.get_device_response
|
||||
assert_that(response.status_code, equal_to(HTTPStatus.OK))
|
||||
device = json.loads(response.data)
|
||||
assert_that(device, has_key('uuid'))
|
||||
assert_that(device, has_key('name'))
|
||||
assert_that(device, has_key('description'))
|
||||
assert_that(device, has_key('coreVersion'))
|
||||
|
|
|
@ -7,7 +7,7 @@ from behave import when, then
|
|||
from hamcrest import assert_that, has_entry, equal_to
|
||||
|
||||
from selene.data.account import AccountRepository, AccountMembership
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
@when('the subscription endpoint is called')
|
||||
|
@ -42,9 +42,9 @@ def get_device_subscription(context):
|
|||
login = context.device_login
|
||||
device_id = login['uuid']
|
||||
access_token = login['accessToken']
|
||||
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
AccountRepository(db).add_membership(context.account.id, membership)
|
||||
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
AccountRepository(db).add_membership(context.account.id, membership)
|
||||
context.subscription_response = context.client.get(
|
||||
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
|
||||
headers=headers
|
||||
|
|
|
@ -13,7 +13,7 @@ from selene.data.account import (
|
|||
AgreementRepository,
|
||||
PRIVACY_POLICY
|
||||
)
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
@fixture
|
||||
|
@ -32,9 +32,9 @@ def before_feature(context, _):
|
|||
|
||||
|
||||
def before_scenario(context, _):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
_add_agreement(context, db)
|
||||
_add_account(context, db)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
_add_agreement(context, db)
|
||||
_add_account(context, db)
|
||||
|
||||
|
||||
def _add_agreement(context, db):
|
||||
|
@ -72,8 +72,8 @@ def _add_account(context, db):
|
|||
|
||||
|
||||
def after_scenario(context, _):
|
||||
with get_db_connection(context.db_pool) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
acct_repository.remove(context.account)
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement_repository.remove(context.agreement, testing=True)
|
||||
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
acct_repository.remove(context.account)
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement_repository.remove(context.agreement, testing=True)
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
CREATE TABLE metrics.job (
|
||||
id uuid PRIMARY KEY
|
||||
DEFAULT gen_random_uuid(),
|
||||
job_name text NOT NULL,
|
||||
batch_date date NOT NULL,
|
||||
start_ts TIMESTAMP NOT NULL,
|
||||
end_ts TIMESTAMP NOT NULL,
|
||||
command text NOT NULL,
|
||||
success BOOLEAN NOT NULL,
|
||||
UNIQUE (job_name, start_ts)
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
from glob import glob
|
||||
from os import path
|
||||
from os import environ, path, remove
|
||||
|
||||
from markdown import markdown
|
||||
from psycopg2 import connect
|
||||
|
||||
MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft')
|
||||
|
@ -8,10 +9,8 @@ SCHEMAS = ('account', 'skill', 'device', 'geography', 'metrics')
|
|||
DB_DESTROY_FILES = (
|
||||
'drop_mycroft_db.sql',
|
||||
'drop_template_db.sql',
|
||||
# 'drop_roles.sql'
|
||||
)
|
||||
DB_CREATE_FILES = (
|
||||
# 'create_roles.sql',
|
||||
'create_template_db.sql',
|
||||
)
|
||||
ACCOUNT_TABLE_ORDER = (
|
||||
|
@ -48,6 +47,7 @@ GEOGRAPHY_TABLE_ORDER = (
|
|||
|
||||
METRICS_TABLE_ORDER = (
|
||||
'api',
|
||||
'job'
|
||||
)
|
||||
|
||||
schema_directory = '{}_schema'
|
||||
|
@ -61,32 +61,40 @@ def get_sql_from_file(file_path: str) -> str:
|
|||
|
||||
|
||||
class PostgresDB(object):
|
||||
def __init__(self, dbname, user, password=None):
|
||||
self.db = connect(dbname=dbname, user=user, host='127.0.0.1')
|
||||
# self.db = connect(
|
||||
# dbname=dbname,
|
||||
# user=user,
|
||||
# password=password,
|
||||
# host='selene-test-db-do-user-1412453-0.db.ondigitalocean.com',
|
||||
# port=25060,
|
||||
# sslmode='require'
|
||||
# )
|
||||
def __init__(self, db_name, user=None):
|
||||
db_host = environ['DB_HOST']
|
||||
db_port = environ['DB_PORT']
|
||||
db_ssl_mode = environ.get('DB_SSL_MODE')
|
||||
if db_name in ('postgres', 'defaultdb'):
|
||||
db_user = environ['POSTGRES_DB_USER']
|
||||
db_password = environ.get('POSTGRES_DB_PASSWORD')
|
||||
else:
|
||||
db_user = environ['MYCROFT_DB_USER']
|
||||
db_password = environ['MYCROFT_DB_PASSWORD']
|
||||
|
||||
if user is not None:
|
||||
db_user = user
|
||||
|
||||
self.db = connect(
|
||||
dbname=db_name,
|
||||
user=db_user,
|
||||
password=db_password,
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
sslmode=db_ssl_mode
|
||||
)
|
||||
self.db.autocommit = True
|
||||
|
||||
def close_db(self):
|
||||
self.db.close()
|
||||
|
||||
def execute_sql(self, sql: str):
|
||||
def execute_sql(self, sql: str, args=None):
|
||||
cursor = self.db.cursor()
|
||||
cursor.execute(sql)
|
||||
cursor.execute(sql, args)
|
||||
return cursor
|
||||
|
||||
|
||||
postgres_db = PostgresDB(dbname='postgres', user='postgres')
|
||||
# postgres_db = PostgresDB(
|
||||
# dbname='defaultdb',
|
||||
# user='doadmin',
|
||||
# password='l06tn0qi2bjhgcki'
|
||||
# )
|
||||
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
|
||||
|
||||
print('Destroying any objects we will be creating later.')
|
||||
for db_destroy_file in DB_DESTROY_FILES:
|
||||
|
@ -94,7 +102,7 @@ for db_destroy_file in DB_DESTROY_FILES:
|
|||
get_sql_from_file(db_destroy_file)
|
||||
)
|
||||
|
||||
print('Creating the extensions, mycroft database, and selene roles')
|
||||
print('Creating the mycroft database')
|
||||
for db_setup_file in DB_CREATE_FILES:
|
||||
postgres_db.execute_sql(
|
||||
get_sql_from_file(db_setup_file)
|
||||
|
@ -102,13 +110,10 @@ for db_setup_file in DB_CREATE_FILES:
|
|||
|
||||
postgres_db.close_db()
|
||||
|
||||
template_db = PostgresDB(dbname='mycroft_template', user='mycroft')
|
||||
# template_db = PostgresDB(
|
||||
# dbname='mycroft_template',
|
||||
# user='selene',
|
||||
# password='ubhemhx1dikmqc5f'
|
||||
# )
|
||||
|
||||
template_db = PostgresDB(db_name='mycroft_template')
|
||||
|
||||
print('Creating the extensions')
|
||||
template_db.execute_sql(
|
||||
get_sql_from_file(path.join('create_extensions.sql'))
|
||||
)
|
||||
|
@ -193,22 +198,14 @@ for schema in SCHEMAS:
|
|||
|
||||
template_db.close_db()
|
||||
|
||||
|
||||
print('Copying template to new database.')
|
||||
postgres_db = PostgresDB(dbname='postgres', user='mycroft')
|
||||
# postgres_db = PostgresDB(
|
||||
# dbname='defaultdb',
|
||||
# user='doadmin',
|
||||
# password='l06tn0qi2bjhgcki'
|
||||
# )
|
||||
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
|
||||
postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql'))
|
||||
postgres_db.close_db()
|
||||
|
||||
mycroft_db = PostgresDB(dbname='mycroft', user='mycroft')
|
||||
# mycroft_db = PostgresDB(
|
||||
# dbname='mycroft_template',
|
||||
# user='selene',
|
||||
# password='ubhemhx1dikmqc5f'
|
||||
# )
|
||||
|
||||
mycroft_db = PostgresDB(db_name=environ['MYCROFT_DB_NAME'])
|
||||
insert_files = [
|
||||
dict(schema_dir='account_schema', file_name='membership.sql'),
|
||||
dict(schema_dir='device_schema', file_name='text_to_speech.sql'),
|
||||
|
@ -226,3 +223,162 @@ for insert_file in insert_files:
|
|||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print('Building account.agreement table')
|
||||
mycroft_db.db.autocommit = False
|
||||
insert_sql = (
|
||||
"insert into account.agreement VALUES (default, '{}', '1', '[today,]', {})"
|
||||
)
|
||||
doc_dir = '/Users/chrisveilleux/Mycroft/github/documentation/_pages/'
|
||||
docs = {
|
||||
'Privacy Policy': doc_dir + 'embed-privacy-policy.md',
|
||||
'Terms of Use': doc_dir + 'embed-terms-of-use.md'
|
||||
}
|
||||
try:
|
||||
for agrmt_type, doc_path in docs.items():
|
||||
lobj = mycroft_db.db.lobject(0, 'b')
|
||||
with open(doc_path) as doc:
|
||||
header_delimiter_count = 0
|
||||
while True:
|
||||
rec = doc.readline()
|
||||
if rec == '---\n':
|
||||
header_delimiter_count += 1
|
||||
if header_delimiter_count == 2:
|
||||
break
|
||||
doc_html = markdown(
|
||||
doc.read(),
|
||||
output_format='html5'
|
||||
)
|
||||
lobj.write(doc_html)
|
||||
mycroft_db.execute_sql(
|
||||
insert_sql.format(agrmt_type, lobj.oid)
|
||||
)
|
||||
mycroft_db.execute_sql(
|
||||
"grant select on large object {} to selene".format(lobj.oid)
|
||||
)
|
||||
mycroft_db.execute_sql(
|
||||
insert_sql.format('Open Dataset', 'null')
|
||||
)
|
||||
except:
|
||||
mycroft_db.db.rollback()
|
||||
raise
|
||||
else:
|
||||
mycroft_db.db.commit()
|
||||
|
||||
mycroft_db.db.autocommit = True
|
||||
|
||||
reference_file_dir = '/Users/chrisveilleux/Mycroft'
|
||||
|
||||
print('Building geography.country table')
|
||||
country_file = 'country.txt'
|
||||
country_insert = """
|
||||
INSERT INTO
|
||||
geography.country (iso_code, name)
|
||||
VALUES
|
||||
('{iso_code}', '{country_name}')
|
||||
"""
|
||||
|
||||
with open(path.join(reference_file_dir, country_file)) as countries:
|
||||
while True:
|
||||
rec = countries.readline()
|
||||
if rec.startswith('#ISO'):
|
||||
break
|
||||
|
||||
for country in countries.readlines():
|
||||
country_fields = country.split('\t')
|
||||
insert_args = dict(
|
||||
iso_code=country_fields[0],
|
||||
country_name=country_fields[4]
|
||||
)
|
||||
mycroft_db.execute_sql(country_insert.format(**insert_args))
|
||||
|
||||
print('Building geography.region table')
|
||||
region_file = 'regions.txt'
|
||||
region_insert = """
|
||||
INSERT INTO
|
||||
geography.region (country_id, region_code, name)
|
||||
VALUES
|
||||
(
|
||||
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
|
||||
%(region_code)s,
|
||||
%(region_name)s)
|
||||
"""
|
||||
with open(path.join(reference_file_dir, region_file)) as regions:
|
||||
for region in regions.readlines():
|
||||
region_fields = region.split('\t')
|
||||
country_iso_code = region_fields[0][:2]
|
||||
insert_args = dict(
|
||||
iso_code=country_iso_code,
|
||||
region_code=region_fields[0],
|
||||
region_name=region_fields[1]
|
||||
)
|
||||
mycroft_db.execute_sql(region_insert, insert_args)
|
||||
|
||||
print('Building geography.timezone table')
|
||||
timezone_file = 'timezones.txt'
|
||||
timezone_insert = """
|
||||
INSERT INTO
|
||||
geography.timezone (country_id, name, gmt_offset, dst_offset)
|
||||
VALUES
|
||||
(
|
||||
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
|
||||
%(timezone_name)s,
|
||||
%(gmt_offset)s,
|
||||
%(dst_offset)s
|
||||
)
|
||||
"""
|
||||
with open(path.join(reference_file_dir, timezone_file)) as timezones:
|
||||
timezones.readline()
|
||||
for timezone in timezones.readlines():
|
||||
timezone_fields = timezone.split('\t')
|
||||
insert_args = dict(
|
||||
iso_code=timezone_fields[0],
|
||||
timezone_name=timezone_fields[1],
|
||||
gmt_offset=timezone_fields[2],
|
||||
dst_offset=timezone_fields[3]
|
||||
)
|
||||
mycroft_db.execute_sql(timezone_insert, insert_args)
|
||||
|
||||
print('Building geography.city table')
|
||||
cities_file = 'cities500.txt'
|
||||
region_query = "SELECT id, region_code FROM geography.region"
|
||||
query_result = mycroft_db.execute_sql(region_query)
|
||||
region_lookup = dict()
|
||||
for row in query_result.fetchall():
|
||||
region_lookup[row[1]] = row[0]
|
||||
|
||||
timezone_query = "SELECT id, name FROM geography.timezone"
|
||||
query_result = mycroft_db.execute_sql(timezone_query)
|
||||
timezone_lookup = dict()
|
||||
for row in query_result.fetchall():
|
||||
timezone_lookup[row[1]] = row[0]
|
||||
# city_insert = """
|
||||
# INSERT INTO
|
||||
# geography.city (region_id, timezone_id, name, latitude, longitude)
|
||||
# VALUES
|
||||
# (%(region_id)s, %(timezone_id)s, %(city_name)s, %(latitude)s, %(longitude)s)
|
||||
# """
|
||||
with open(path.join(reference_file_dir, cities_file)) as cities:
|
||||
with open(path.join(reference_file_dir, 'city.dump'), 'w') as dump_file:
|
||||
for city in cities.readlines():
|
||||
city_fields = city.split('\t')
|
||||
city_region = city_fields[8] + '.' + city_fields[10]
|
||||
region_id = region_lookup.get(city_region)
|
||||
timezone_id = timezone_lookup[city_fields[17]]
|
||||
if region_id is not None:
|
||||
dump_file.write('\t'.join([
|
||||
region_id,
|
||||
timezone_id,
|
||||
city_fields[1],
|
||||
city_fields[4],
|
||||
city_fields[5]
|
||||
]) + '\n')
|
||||
# mycroft_db.execute_sql(city_insert, insert_args)
|
||||
with open(path.join(reference_file_dir, 'city.dump')) as dump_file:
|
||||
cursor = mycroft_db.db.cursor()
|
||||
cursor.copy_from(dump_file, 'geography.city', columns=(
|
||||
'region_id', 'timezone_id', 'name', 'latitude', 'longitude')
|
||||
)
|
||||
remove(path.join(reference_file_dir, 'city.dump'))
|
||||
|
||||
mycroft_db.close_db()
|
||||
|
|
|
@ -22,15 +22,6 @@ import os
|
|||
|
||||
from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig
|
||||
|
||||
db_connection_config = DatabaseConnectionConfig(
|
||||
host=os.environ['DB_HOST'],
|
||||
db_name=os.environ['DB_NAME'],
|
||||
password=os.environ['DB_PASSWORD'],
|
||||
port=os.environ.get('DB_PORT', 5432),
|
||||
user=os.environ['DB_USER'],
|
||||
sslmode=os.environ.get('DB_SSLMODE')
|
||||
)
|
||||
|
||||
|
||||
class APIConfigError(Exception):
|
||||
pass
|
||||
|
@ -43,6 +34,14 @@ class BaseConfig(object):
|
|||
DEBUG = False
|
||||
ENV = os.environ['SELENE_ENVIRONMENT']
|
||||
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
|
||||
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
|
||||
host=os.environ['DB_HOST'],
|
||||
db_name=os.environ['DB_NAME'],
|
||||
password=os.environ['DB_PASSWORD'],
|
||||
port=os.environ.get('DB_PORT', 5432),
|
||||
user=os.environ['DB_USER'],
|
||||
sslmode=os.environ.get('DB_SSLMODE')
|
||||
)
|
||||
|
||||
|
||||
class DevelopmentConfig(BaseConfig):
|
||||
|
@ -80,10 +79,4 @@ def get_base_config():
|
|||
error_msg = 'no configuration defined for the "{}" environment'
|
||||
raise APIConfigError(error_msg.format(environment_name))
|
||||
|
||||
max_db_connections = os.environ.get('MAX_DB_CONNECTIONS', 20)
|
||||
app_config.DB_CONNECTION_POOL = allocate_db_connection_pool(
|
||||
db_connection_config,
|
||||
max_db_connections
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
|
|
@ -6,7 +6,7 @@ from flask.views import MethodView
|
|||
|
||||
from selene.data.account import Account, AccountRepository
|
||||
from selene.util.auth import AuthenticationError, AuthenticationToken
|
||||
from selene.util.db import get_db_connection_from_pool
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
||||
FIFTEEN_MINUTES = 900
|
||||
|
@ -42,8 +42,8 @@ class SeleneEndpoint(MethodView):
|
|||
@property
|
||||
def db(self):
|
||||
if 'db' not in global_context:
|
||||
global_context.db = get_db_connection_from_pool(
|
||||
current_app.config['DB_CONNECTION_POOL']
|
||||
global_context.db = connect_to_db(
|
||||
current_app.config['DB_CONNECTION_CONFIG']
|
||||
)
|
||||
|
||||
return global_context.db
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from datetime import datetime
|
||||
import json
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from flask import current_app, Blueprint, g as global_context
|
||||
|
@ -7,10 +7,7 @@ from schematics.exceptions import DataError
|
|||
|
||||
from selene.data.metrics import ApiMetric, ApiMetricsRepository
|
||||
from selene.util.auth import AuthenticationError
|
||||
from selene.util.db import (
|
||||
get_db_connection_from_pool,
|
||||
return_db_connection_to_pool
|
||||
)
|
||||
from selene.util.db import connect_to_db
|
||||
from selene.util.not_modified import NotModifiedError
|
||||
|
||||
selene_api = Blueprint('selene_api', __name__)
|
||||
|
@ -39,7 +36,6 @@ def setup_request():
|
|||
@selene_api.after_app_request
|
||||
def teardown_request(response):
|
||||
add_api_metric(response.status_code)
|
||||
release_db_connection()
|
||||
|
||||
return response
|
||||
|
||||
|
@ -54,8 +50,8 @@ def add_api_metric(http_status):
|
|||
|
||||
if api is not None and int(http_status) != 304:
|
||||
if 'db' not in global_context:
|
||||
global_context.db = get_db_connection_from_pool(
|
||||
current_app.config['DB_CONNECTION_POOL']
|
||||
global_context.db = connect_to_db(
|
||||
current_app.config['DB_CONNECTION_CONFIG']
|
||||
)
|
||||
if 'account_id' in global_context:
|
||||
account_id = global_context.account_id
|
||||
|
@ -78,12 +74,3 @@ def add_api_metric(http_status):
|
|||
)
|
||||
metric_repository = ApiMetricsRepository(global_context.db)
|
||||
metric_repository.add(api_metric)
|
||||
|
||||
|
||||
def release_db_connection():
|
||||
db = global_context.pop('db', None)
|
||||
if db is not None:
|
||||
return_db_connection_to_pool(
|
||||
current_app.config['DB_CONNECTION_POOL'],
|
||||
db
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import asdict
|
|||
from http import HTTPStatus
|
||||
|
||||
from selene.data.account import AgreementRepository
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
from ..base_endpoint import SeleneEndpoint
|
||||
|
||||
|
||||
|
@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint):
|
|||
|
||||
def get(self, agreement_type):
|
||||
"""Process HTTP GET request for an agreement."""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement = agreement_repository.get_active_for_type(
|
||||
self.agreement_types[agreement_type]
|
||||
)
|
||||
if agreement is not None:
|
||||
agreement = asdict(agreement)
|
||||
del(agreement['effective_date'])
|
||||
self.response = agreement, HTTPStatus.OK
|
||||
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement = agreement_repository.get_active_for_type(
|
||||
self.agreement_types[agreement_type]
|
||||
)
|
||||
if agreement is not None:
|
||||
agreement = asdict(agreement)
|
||||
del(agreement['effective_date'])
|
||||
self.response = agreement, HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
|
|
@ -3,7 +3,7 @@ import string
|
|||
|
||||
from selene.data.device import DeviceRepository
|
||||
from selene.util.cache import SeleneCache
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
|
||||
def device_etag_key(device_id: str):
|
||||
|
@ -29,7 +29,7 @@ class ETagManager(object):
|
|||
|
||||
def __init__(self, cache: SeleneCache, config: dict):
|
||||
self.cache: SeleneCache = cache
|
||||
self.db_connection_pool = config['DB_CONNECTION_POOL']
|
||||
self.db_connection_config = config['DB_CONNECTION_CONFIG']
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
"""Generate a etag with 32 random chars and store it into a given key
|
||||
|
@ -60,10 +60,10 @@ class ETagManager(object):
|
|||
def expire_device_setting_etag_by_account_id(self, account_id: str):
|
||||
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated
|
||||
at account level"""
|
||||
with get_db_connection(self.db_connection_pool) as db:
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_device_setting_etag_by_device_id(device.id)
|
||||
db = connect_to_db(self.db_connection_config)
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_device_setting_etag_by_device_id(device.id)
|
||||
|
||||
def expire_device_location_etag_by_device_id(self, device_id: str):
|
||||
"""Expire the etag associate with the device's location entity
|
||||
|
@ -73,10 +73,10 @@ class ETagManager(object):
|
|||
def expire_device_location_etag_by_account_id(self, account_id: str):
|
||||
"""Expire the locations' etag fpr açç device for a given acccount
|
||||
:param account_id: account uuid"""
|
||||
with get_db_connection(self.db_connection_pool) as db:
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_device_location_etag_by_device_id(device.id)
|
||||
db = connect_to_db(self.db_connection_config)
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_device_location_etag_by_device_id(device.id)
|
||||
|
||||
def expire_skill_etag_by_device_id(self, device_id):
|
||||
"""Expire the locations' etag for a given device
|
||||
|
@ -84,7 +84,7 @@ class ETagManager(object):
|
|||
self._expire(device_skill_etag_key(device_id))
|
||||
|
||||
def expire_skill_etag_by_account_id(self, account_id):
|
||||
with get_db_connection(self.db_connection_pool) as db:
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_skill_etag_by_device_id(device.id)
|
||||
db = connect_to_db(self.db_connection_config)
|
||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||
for device in devices:
|
||||
self.expire_skill_etag_by_device_id(device.id)
|
||||
|
|
|
@ -13,7 +13,7 @@ from flask.views import MethodView
|
|||
|
||||
from selene.api.etag import ETagManager
|
||||
from selene.util.auth import AuthenticationError
|
||||
from selene.util.db import get_db_connection_from_pool
|
||||
from selene.util.db import connect_to_db
|
||||
from selene.util.not_modified import NotModifiedError
|
||||
from ..util.cache import SeleneCache
|
||||
|
||||
|
@ -91,8 +91,8 @@ class PublicEndpoint(MethodView):
|
|||
@property
|
||||
def db(self):
|
||||
if 'db' not in global_context:
|
||||
global_context.db = get_db_connection_from_pool(
|
||||
current_app.config['DB_CONNECTION_POOL']
|
||||
global_context.db = connect_to_db(
|
||||
current_app.config['DB_CONNECTION_CONFIG']
|
||||
)
|
||||
|
||||
return global_context.db
|
||||
|
|
|
@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item
|
|||
|
||||
from selene.data.account import Account, AccountRepository
|
||||
from selene.util.auth import AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
from selene.util.db import connect_to_db
|
||||
|
||||
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
|
||||
ONE_MINUTE = 60
|
||||
|
@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict:
|
|||
|
||||
|
||||
def get_account(context) -> Account:
|
||||
with get_db_connection(context.db_pool) as db:
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_id(context.account.id)
|
||||
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
|
||||
acct_repository = AccountRepository(db)
|
||||
account = acct_repository.get_account_by_id(context.account.id)
|
||||
|
||||
return account
|
||||
|
|
|
@ -12,7 +12,7 @@ class SkillDisplayRepository(RepositoryBase):
|
|||
sql_file_name='get_display_data_for_skills.sql'
|
||||
)
|
||||
|
||||
def get_display_data_for_skill(self, skill_display_id):
|
||||
def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
|
||||
return self._select_one_into_dataclass(
|
||||
dataclass=SkillDisplay,
|
||||
sql_file_name='get_display_data_for_skill.sql',
|
||||
|
|
|
@ -34,12 +34,12 @@ class SkillSettingRepository(RepositoryBase):
|
|||
|
||||
return skill_settings
|
||||
|
||||
def get_installer_settings(self, account_id: str):
|
||||
def get_installer_settings(self, account_id: str) -> List[AccountSkillSetting]:
|
||||
skill_repo = SkillRepository(self.db)
|
||||
skills = skill_repo.get_skills_for_account(account_id)
|
||||
installer_skill_id = None
|
||||
for skill in skills:
|
||||
if skill.name == 'mycroft_installer':
|
||||
if skill.display_name == 'Installer':
|
||||
installer_skill_id = skill.id
|
||||
|
||||
skill_settings = None
|
||||
|
|
|
@ -6,12 +6,12 @@ Example Usage:
|
|||
query_result = mycroft_db_ro.execute_sql(sql)
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from logging import getLogger
|
||||
|
||||
from psycopg2 import connect
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from psycopg2.extras import RealDictCursor, NamedTupleCursor
|
||||
from psycopg2.extensions import cursor
|
||||
|
||||
_log = getLogger(__package__)
|
||||
|
||||
|
@ -29,10 +29,16 @@ class DatabaseConnectionConfig(object):
|
|||
password: str
|
||||
port: int = field(default=5432)
|
||||
sslmode: str = None
|
||||
autocommit: str = True
|
||||
cursor_factory = RealDictCursor
|
||||
use_namedtuple_cursor: InitVar[bool] = False
|
||||
|
||||
def __post_init__(self, use_namedtuple_cursor: bool):
|
||||
if use_namedtuple_cursor:
|
||||
self.cursor_factory = NamedTupleCursor
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
|
||||
def connect_to_db(connection_config: DatabaseConnectionConfig):
|
||||
"""
|
||||
Return a connection to the mycroft database for the specified user.
|
||||
|
||||
|
@ -41,33 +47,19 @@ def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
|
|||
python notebook)
|
||||
|
||||
:param connection_config: data needed to establish a connection
|
||||
:param autocommit: indicated if transactions should commit automatically
|
||||
:return: database connection
|
||||
"""
|
||||
db = None
|
||||
log_msg = 'establishing connection to the {db_name} database'
|
||||
_log.info(log_msg.format(db_name=connection_config.db_name))
|
||||
try:
|
||||
if connection_config.sslmode is None:
|
||||
db = connect(
|
||||
host=connection_config.host,
|
||||
dbname=connection_config.db_name,
|
||||
user=connection_config.user,
|
||||
port=connection_config.port,
|
||||
cursor_factory=RealDictCursor,
|
||||
)
|
||||
else:
|
||||
db = connect(
|
||||
host=connection_config.host,
|
||||
dbname=connection_config.db_name,
|
||||
user=connection_config.user,
|
||||
password=connection_config.password,
|
||||
port=connection_config.port,
|
||||
cursor_factory=RealDictCursor,
|
||||
sslmode=connection_config.sslmode
|
||||
)
|
||||
db.autocommit = autocommit
|
||||
yield db
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
db = connect(
|
||||
host=connection_config.host,
|
||||
dbname=connection_config.db_name,
|
||||
user=connection_config.user,
|
||||
password=connection_config.password,
|
||||
port=connection_config.port,
|
||||
cursor_factory=connection_config.cursor_factory,
|
||||
sslmode=connection_config.sslmode
|
||||
)
|
||||
db.autocommit = connection_config.autocommit
|
||||
|
||||
return db
|
||||
|
|
Loading…
Reference in New Issue