Merge remote-tracking branch 'origin/dev' into dev
commit
06b8bd4eb1
|
@ -15,7 +15,7 @@ from selene.data.account import (
|
||||||
)
|
)
|
||||||
from selene.data.device import Geography, GeographyRepository
|
from selene.data.device import Geography, GeographyRepository
|
||||||
from selene.util.cache import SeleneCache
|
from selene.util.cache import SeleneCache
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
|
@ -32,10 +32,10 @@ def before_feature(context, _):
|
||||||
|
|
||||||
|
|
||||||
def before_scenario(context, _):
|
def before_scenario(context, _):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
_add_agreements(context, db)
|
_add_agreements(context, db)
|
||||||
_add_account(context, db)
|
_add_account(context, db)
|
||||||
_add_geography(context, db)
|
_add_geography(context, db)
|
||||||
|
|
||||||
|
|
||||||
def _add_agreements(context, db):
|
def _add_agreements(context, db):
|
||||||
|
@ -91,9 +91,9 @@ def _add_geography(context, db):
|
||||||
|
|
||||||
|
|
||||||
def after_scenario(context, _):
|
def after_scenario(context, _):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
_delete_account(context, db)
|
_delete_account(context, db)
|
||||||
_delete_agreements(context, db)
|
_delete_agreements(context, db)
|
||||||
_clean_cache()
|
_clean_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from hamcrest import assert_that, equal_to, has_key, none, not_none
|
||||||
|
|
||||||
from selene.data.device import DeviceRepository
|
from selene.data.device import DeviceRepository
|
||||||
from selene.util.cache import SeleneCache
|
from selene.util.cache import SeleneCache
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
@given('a device pairing code')
|
@given('a device pairing code')
|
||||||
|
@ -57,9 +57,9 @@ def validate_pairing_code_removal(context):
|
||||||
@then('the device is added to the database')
|
@then('the device is added to the database')
|
||||||
def validate_response(context):
|
def validate_response(context):
|
||||||
device_id = context.response.data.decode()
|
device_id = context.response.data.decode()
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
device_repository = DeviceRepository(db)
|
device_repository = DeviceRepository(db)
|
||||||
device = device_repository.get_device_by_id(device_id)
|
device = device_repository.get_device_by_id(device_id)
|
||||||
|
|
||||||
assert_that(device, not_none())
|
assert_that(device, not_none())
|
||||||
assert_that(device.name, equal_to('home'))
|
assert_that(device.name, equal_to('home'))
|
||||||
|
|
|
@ -8,7 +8,7 @@ from selene.api.testing import (
|
||||||
)
|
)
|
||||||
from selene.data.account import AccountRepository
|
from selene.data.account import AccountRepository
|
||||||
from selene.util.auth import AuthenticationToken
|
from selene.util.auth import AuthenticationToken
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
@given('an authenticated user with an expired access token')
|
@given('an authenticated user with an expired access token')
|
||||||
|
@ -37,9 +37,9 @@ def check_for_new_cookies(context):
|
||||||
context.refresh_token,
|
context.refresh_token,
|
||||||
is_not(equal_to(context.old_refresh_token))
|
is_not(equal_to(context.old_refresh_token))
|
||||||
)
|
)
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
account = acct_repository.get_account_by_id(context.account.id)
|
account = acct_repository.get_account_by_id(context.account.id)
|
||||||
|
|
||||||
refresh_token = AuthenticationToken(
|
refresh_token = AuthenticationToken(
|
||||||
context.client_config['REFRESH_SECRET'],
|
context.client_config['REFRESH_SECRET'],
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
from binascii import b2a_base64
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
import os
|
||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
from behave import given, then, when
|
from behave import given, then, when
|
||||||
|
@ -8,7 +9,7 @@ from hamcrest import assert_that, equal_to, is_in, none, not_none, starts_with
|
||||||
from stripe.error import InvalidRequestError
|
from stripe.error import InvalidRequestError
|
||||||
|
|
||||||
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
|
from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
new_account_request = dict(
|
new_account_request = dict(
|
||||||
username='barfoo',
|
username='barfoo',
|
||||||
|
@ -17,8 +18,8 @@ new_account_request = dict(
|
||||||
login=dict(
|
login=dict(
|
||||||
federatedPlatform=None,
|
federatedPlatform=None,
|
||||||
federatedToken=None,
|
federatedToken=None,
|
||||||
userEnteredEmail='bar@mycroft.ai',
|
userEnteredEmail=b2a_base64(b'bar@mycroft.ai').decode(),
|
||||||
password='bar'
|
password=b2a_base64(b'bar').decode()
|
||||||
),
|
),
|
||||||
support=dict(openDataset=True)
|
support=dict(openDataset=True)
|
||||||
)
|
)
|
||||||
|
@ -58,41 +59,41 @@ def call_add_account_endpoint(context):
|
||||||
context.response = context.client.post(
|
context.response = context.client.post(
|
||||||
'/api/account',
|
'/api/account',
|
||||||
data=json.dumps(context.new_account_request),
|
data=json.dumps(context.new_account_request),
|
||||||
content_type='application_json'
|
content_type='application/json'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@then('the account will be added to the system {membership_option}')
|
@then('the account will be added to the system {membership_option}')
|
||||||
def check_db_for_account(context, membership_option):
|
def check_db_for_account(context, membership_option):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||||
assert_that(account, not_none())
|
assert_that(account, not_none())
|
||||||
|
assert_that(
|
||||||
|
account.email_address, equal_to('bar@mycroft.ai')
|
||||||
|
)
|
||||||
|
assert_that(account.username, equal_to('barfoo'))
|
||||||
|
if membership_option == 'with a membership':
|
||||||
|
assert_that(account.membership.type, equal_to('Monthly Membership'))
|
||||||
assert_that(
|
assert_that(
|
||||||
account.email_address, equal_to('bar@mycroft.ai')
|
account.membership.payment_account_id,
|
||||||
|
starts_with('cus')
|
||||||
)
|
)
|
||||||
assert_that(account.username, equal_to('barfoo'))
|
elif membership_option == 'without a membership':
|
||||||
if membership_option == 'with a membership':
|
assert_that(account.membership, none())
|
||||||
assert_that(account.membership.type, equal_to('Monthly Membership'))
|
|
||||||
assert_that(
|
|
||||||
account.membership.payment_account_id,
|
|
||||||
starts_with('cus')
|
|
||||||
)
|
|
||||||
elif membership_option == 'without a membership':
|
|
||||||
assert_that(account.membership, none())
|
|
||||||
|
|
||||||
assert_that(len(account.agreements), equal_to(2))
|
assert_that(len(account.agreements), equal_to(2))
|
||||||
for agreement in account.agreements:
|
for agreement in account.agreements:
|
||||||
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
||||||
assert_that(agreement.accept_date, equal_to(str(date.today())))
|
assert_that(agreement.accept_date, equal_to(str(date.today())))
|
||||||
|
|
||||||
|
|
||||||
@when('the account is deleted')
|
@when('the account is deleted')
|
||||||
def account_deleted(context):
|
def account_deleted(context):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
account = acct_repository.get_account_by_email('bar@mycroft.ai')
|
||||||
context.stripe_id = account.membership.payment_id
|
context.stripe_id = account.membership.payment_id
|
||||||
context.response = context.client.delete('/api/account')
|
context.response = context.client.delete('/api/account')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
from binascii import b2a_base64
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
import json
|
||||||
|
|
||||||
from behave import given, when, then
|
from behave import given, when, then
|
||||||
from hamcrest import assert_that, equal_to, starts_with, none
|
from hamcrest import assert_that, equal_to, starts_with, none
|
||||||
|
@ -11,7 +12,9 @@ from selene.data.account import (
|
||||||
AccountAgreement,
|
AccountAgreement,
|
||||||
PRIVACY_POLICY
|
PRIVACY_POLICY
|
||||||
)
|
)
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
TEST_EMAIL_ADDRESS = 'test@mycroft.ai'
|
||||||
|
|
||||||
new_account_request = dict(
|
new_account_request = dict(
|
||||||
username='test',
|
username='test',
|
||||||
|
@ -20,8 +23,8 @@ new_account_request = dict(
|
||||||
login=dict(
|
login=dict(
|
||||||
federatedPlatform=None,
|
federatedPlatform=None,
|
||||||
federatedToken=None,
|
federatedToken=None,
|
||||||
userEnteredEmail='test@mycroft.ai',
|
email=b2a_base64(b'test@mycroft.ai').decode(),
|
||||||
password='12345678'
|
password=b2a_base64(b'12345678').decode()
|
||||||
),
|
),
|
||||||
support=dict(
|
support=dict(
|
||||||
openDataset=True,
|
openDataset=True,
|
||||||
|
@ -47,12 +50,12 @@ def create_account(context):
|
||||||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
|
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
account_id = acct_repository.add(context.account, 'foo')
|
account_id = acct_repository.add(context.account, 'foo')
|
||||||
context.account.id = account_id
|
context.account.id = account_id
|
||||||
generate_access_token(context)
|
generate_access_token(context)
|
||||||
generate_refresh_token(context)
|
generate_refresh_token(context)
|
||||||
|
|
||||||
|
|
||||||
@when('a monthly membership is added')
|
@when('a monthly membership is added')
|
||||||
|
@ -66,16 +69,16 @@ def update_membership(context):
|
||||||
context.response = context.client.patch(
|
context.response = context.client.patch(
|
||||||
'/api/account',
|
'/api/account',
|
||||||
data=json.dumps(dict(membership=membership_data)),
|
data=json.dumps(dict(membership=membership_data)),
|
||||||
content_type='application_json'
|
content_type='application/json'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@when('the account is requested')
|
@when('the account is requested')
|
||||||
def request_account(context):
|
def request_account(context):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
context.response_account = AccountRepository(db).get_account_by_email(
|
context.response_account = AccountRepository(db).get_account_by_email(
|
||||||
'test@mycroft.ai'
|
TEST_EMAIL_ADDRESS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@then('the account should have a monthly membership')
|
@then('the account should have a monthly membership')
|
||||||
|
@ -98,16 +101,14 @@ def create_monthly_account(context):
|
||||||
context.client.post(
|
context.client.post(
|
||||||
'/api/account',
|
'/api/account',
|
||||||
data=json.dumps(new_account_request),
|
data=json.dumps(new_account_request),
|
||||||
content_type='application_json'
|
content_type='application/json'
|
||||||
)
|
)
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
account_repository = AccountRepository(db)
|
account_repository = AccountRepository(db)
|
||||||
account = account_repository.get_account_by_email(
|
account = account_repository.get_account_by_email(TEST_EMAIL_ADDRESS)
|
||||||
new_account_request['login']['userEnteredEmail']
|
context.account = account
|
||||||
)
|
generate_access_token(context)
|
||||||
context.account = account
|
generate_refresh_token(context)
|
||||||
generate_access_token(context)
|
|
||||||
generate_refresh_token(context)
|
|
||||||
|
|
||||||
|
|
||||||
@when('the membership is cancelled')
|
@when('the membership is cancelled')
|
||||||
|
@ -119,7 +120,7 @@ def cancel_membership(context):
|
||||||
context.client.patch(
|
context.client.patch(
|
||||||
'/api/account',
|
'/api/account',
|
||||||
data=json.dumps(dict(membership=membership_data)),
|
data=json.dumps(dict(membership=membership_data)),
|
||||||
content_type='application_json'
|
content_type='application/json'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +139,7 @@ def change_to_yearly_account(context):
|
||||||
context.client.patch(
|
context.client.patch(
|
||||||
'/api/account',
|
'/api/account',
|
||||||
data=json.dumps(dict(membership=membership_data)),
|
data=json.dumps(dict(membership=membership_data)),
|
||||||
content_type='application_json'
|
content_type='application/json'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
"""
|
||||||
|
Marketplace endpoint to add or remove a skill
|
||||||
|
|
||||||
|
This endpoint configures the install skill on a user's device(s) to add or
|
||||||
|
remove the skill.
|
||||||
|
"""
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -6,7 +12,11 @@ from schematics import Model
|
||||||
from schematics.types import StringType
|
from schematics.types import StringType
|
||||||
|
|
||||||
from selene.api import SeleneEndpoint
|
from selene.api import SeleneEndpoint
|
||||||
from selene.data.skill import AccountSkillSetting, SkillSettingRepository
|
from selene.data.skill import (
|
||||||
|
AccountSkillSetting,
|
||||||
|
SkillDisplayRepository,
|
||||||
|
SkillSettingRepository
|
||||||
|
)
|
||||||
|
|
||||||
INSTALL_SECTION = 'to_install'
|
INSTALL_SECTION = 'to_install'
|
||||||
UNINSTALL_SECTION = 'to_remove'
|
UNINSTALL_SECTION = 'to_remove'
|
||||||
|
@ -15,17 +25,16 @@ _log = getLogger(__package__)
|
||||||
|
|
||||||
|
|
||||||
class InstallRequest(Model):
|
class InstallRequest(Model):
|
||||||
|
"""Defines the expected state of the request JSON data"""
|
||||||
setting_section = StringType(
|
setting_section = StringType(
|
||||||
required=True,
|
required=True,
|
||||||
choices=[INSTALL_SECTION, UNINSTALL_SECTION]
|
choices=[INSTALL_SECTION, UNINSTALL_SECTION]
|
||||||
)
|
)
|
||||||
skill_name = StringType(required=True)
|
skill_display_id = StringType(required=True)
|
||||||
|
|
||||||
|
|
||||||
class SkillInstallEndpoint(SeleneEndpoint):
|
class SkillInstallEndpoint(SeleneEndpoint):
|
||||||
"""
|
"""Install a skill on user device(s)."""
|
||||||
Install a skill on user device(s).
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SkillInstallEndpoint, self).__init__()
|
super(SkillInstallEndpoint, self).__init__()
|
||||||
self.device_uuid: str = None
|
self.device_uuid: str = None
|
||||||
|
@ -34,43 +43,70 @@ class SkillInstallEndpoint(SeleneEndpoint):
|
||||||
self.installer_update_response = None
|
self.installer_update_response = None
|
||||||
|
|
||||||
def put(self):
|
def put(self):
|
||||||
|
"""Handle an HTTP PUT request"""
|
||||||
self._authenticate()
|
self._authenticate()
|
||||||
self._validate_request()
|
self._validate_request()
|
||||||
|
skill_install_name = self._get_install_name()
|
||||||
self._get_installer_settings()
|
self._get_installer_settings()
|
||||||
self._apply_update()
|
self._apply_update(skill_install_name)
|
||||||
self.response = (self.installer_update_response, HTTPStatus.OK)
|
self.response = (self.installer_update_response, HTTPStatus.OK)
|
||||||
|
|
||||||
return self.response
|
return self.response
|
||||||
|
|
||||||
def _validate_request(self):
|
def _validate_request(self):
|
||||||
|
"""Ensure the data passed in the request is as expected.
|
||||||
|
|
||||||
|
:raises schematics.exceptions.ValidationError if the validation fails
|
||||||
|
"""
|
||||||
install_request = InstallRequest()
|
install_request = InstallRequest()
|
||||||
install_request.setting_section = self.request.json['section']
|
install_request.setting_section = self.request.json['section']
|
||||||
install_request.skill_name = self.request.json['skillName']
|
install_request.skill_display_id = self.request.json['skillDisplayId']
|
||||||
install_request.validate()
|
install_request.validate()
|
||||||
|
|
||||||
|
def _get_install_name(self) -> str:
|
||||||
|
"""Get the skill name used by the installer skill from the DB
|
||||||
|
|
||||||
|
The installer skill expects the skill name found in the "name" field
|
||||||
|
of the skill display JSON.
|
||||||
|
"""
|
||||||
|
display_repo = SkillDisplayRepository(self.db)
|
||||||
|
skill_display = display_repo.get_display_data_for_skill(
|
||||||
|
self.request.json['skillDisplayId']
|
||||||
|
)
|
||||||
|
|
||||||
|
return skill_display.display_data['name']
|
||||||
|
|
||||||
def _get_installer_settings(self):
|
def _get_installer_settings(self):
|
||||||
|
"""Get the current value of the installer skill's settings"""
|
||||||
settings_repo = SkillSettingRepository(self.db)
|
settings_repo = SkillSettingRepository(self.db)
|
||||||
self.installer_settings = settings_repo.get_installer_settings(
|
self.installer_settings = settings_repo.get_installer_settings(
|
||||||
self.account.id
|
self.account.id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_update(self):
|
def _apply_update(self, skill_install_name: str):
|
||||||
|
"""Add the skill in the request to the installer skill settings.
|
||||||
|
|
||||||
|
This is designed to change the installer skill settings for all
|
||||||
|
devices associated with an account. It will be updated in the
|
||||||
|
future to target specific devices.
|
||||||
|
"""
|
||||||
for settings in self.installer_settings:
|
for settings in self.installer_settings:
|
||||||
if self.request.json['section'] == INSTALL_SECTION:
|
if self.request.json['section'] == INSTALL_SECTION:
|
||||||
to_install = settings.settings_values.get(INSTALL_SECTION, [])
|
to_install = settings.settings_values.get(INSTALL_SECTION, [])
|
||||||
to_install.append(
|
to_install.append(
|
||||||
dict(name=self.request.json['skill_name'])
|
dict(name=skill_install_name)
|
||||||
)
|
)
|
||||||
settings.settings_values[INSTALL_SECTION] = to_install
|
settings.settings_values[INSTALL_SECTION] = to_install
|
||||||
else:
|
else:
|
||||||
to_remove = settings.settings_values.get(UNINSTALL_SECTION, [])
|
to_remove = settings.settings_values.get(UNINSTALL_SECTION, [])
|
||||||
to_remove.append(
|
to_remove.append(
|
||||||
dict(name=self.request.json['skill_name'])
|
dict(name=skill_install_name)
|
||||||
)
|
)
|
||||||
settings.settings_values[UNINSTALL_SECTION] = to_remove
|
settings.settings_values[UNINSTALL_SECTION] = to_remove
|
||||||
self._update_skill_settings(settings)
|
self._update_skill_settings(settings)
|
||||||
|
|
||||||
def _update_skill_settings(self, settings):
|
def _update_skill_settings(self, settings):
|
||||||
|
"""Update the DB with the new installer skill settings."""
|
||||||
settings_repo = SkillSettingRepository(self.db)
|
settings_repo = SkillSettingRepository(self.db)
|
||||||
settings_repo.update_device_skill_settings(
|
settings_repo.update_device_skill_settings(
|
||||||
self.account.id,
|
self.account.id,
|
||||||
|
|
|
@ -28,6 +28,7 @@ class DeviceEndpoint(PublicEndpoint):
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
response_data = dict(
|
response_data = dict(
|
||||||
|
uuid=device.id,
|
||||||
name=device.name,
|
name=device.name,
|
||||||
description=device.placement,
|
description=device.placement,
|
||||||
coreVersion=device.core_version,
|
coreVersion=device.core_version,
|
||||||
|
|
|
@ -18,7 +18,7 @@ class MetricsService(object):
|
||||||
deviceUuid=device_id,
|
deviceUuid=device_id,
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
url = '{host}/{metric}'.format(host=self.metrics_service_host, metric=metric)
|
url = '{host}/metric/{metric}'.format(host=self.metrics_service_host, metric=metric)
|
||||||
requests.post(url, body)
|
requests.post(url, body)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ import requests
|
||||||
|
|
||||||
from selene.api import PublicEndpoint
|
from selene.api import PublicEndpoint
|
||||||
from selene.data.account import AccountRepository
|
from selene.data.account import AccountRepository
|
||||||
from selene.util.db import get_db_connection
|
|
||||||
|
|
||||||
|
|
||||||
class OauthServiceEndpoint(PublicEndpoint):
|
class OauthServiceEndpoint(PublicEndpoint):
|
||||||
|
@ -14,8 +13,7 @@ class OauthServiceEndpoint(PublicEndpoint):
|
||||||
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
|
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
|
||||||
|
|
||||||
def get(self, device_id, credentials, oauth_path):
|
def get(self, device_id, credentials, oauth_path):
|
||||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
account = AccountRepository(self.db).get_account_by_device_id(device_id)
|
||||||
account = AccountRepository(db).get_account_by_device_id(device_id)
|
|
||||||
uuid = account.id
|
uuid = account.id
|
||||||
url = '{host}/auth/{credentials}/{oauth_path}'.format(
|
url = '{host}/auth/{credentials}/{oauth_path}'.format(
|
||||||
host=self.oauth_service_host,
|
host=self.oauth_service_host,
|
||||||
|
|
|
@ -22,10 +22,20 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
|
||||||
if token_header.startswith('Bearer '):
|
if token_header.startswith('Bearer '):
|
||||||
refresh = token_header[len('Bearer '):]
|
refresh = token_header[len('Bearer '):]
|
||||||
session = self._refresh_session_token(refresh)
|
session = self._refresh_session_token(refresh)
|
||||||
|
# Trying to fetch a session using the refresh token
|
||||||
if session:
|
if session:
|
||||||
response = session, HTTPStatus.OK
|
response = session, HTTPStatus.OK
|
||||||
else:
|
else:
|
||||||
response = '', HTTPStatus.UNAUTHORIZED
|
device = self.request.headers.get('Device')
|
||||||
|
if device:
|
||||||
|
# trying to fetch a session using the device uuid
|
||||||
|
session = self._refresh_session_token_device(device)
|
||||||
|
if session:
|
||||||
|
response = session, HTTPStatus.OK
|
||||||
|
else:
|
||||||
|
response = '', HTTPStatus.UNAUTHORIZED
|
||||||
|
else:
|
||||||
|
response = '', HTTPStatus.UNAUTHORIZED
|
||||||
else:
|
else:
|
||||||
response = '', HTTPStatus.UNAUTHORIZED
|
response = '', HTTPStatus.UNAUTHORIZED
|
||||||
return response
|
return response
|
||||||
|
@ -38,3 +48,12 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
|
||||||
device_id = old_login['uuid']
|
device_id = old_login['uuid']
|
||||||
self.cache.delete(refresh_key)
|
self.cache.delete(refresh_key)
|
||||||
return generate_device_login(device_id, self.cache)
|
return generate_device_login(device_id, self.cache)
|
||||||
|
|
||||||
|
def _refresh_session_token_device(self, device: str):
|
||||||
|
refresh_key = 'device.session:{}'.format(device)
|
||||||
|
session = self.cache.get(refresh_key)
|
||||||
|
if session:
|
||||||
|
old_login = json.loads(session)
|
||||||
|
device_id = old_login['uuid']
|
||||||
|
self.cache.delete(refresh_key)
|
||||||
|
return generate_device_login(device_id, self.cache)
|
||||||
|
|
|
@ -27,6 +27,7 @@ class SkillManifest(Model):
|
||||||
installed = DateTimeType()
|
installed = DateTimeType()
|
||||||
updated = DateTimeType()
|
updated = DateTimeType()
|
||||||
update = DateTimeType()
|
update = DateTimeType()
|
||||||
|
skill_gid = StringType()
|
||||||
|
|
||||||
|
|
||||||
class SkillJson(Model):
|
class SkillJson(Model):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from selene.data.device import (
|
||||||
Geography)
|
Geography)
|
||||||
from selene.data.device.entity.text_to_speech import TextToSpeech
|
from selene.data.device.entity.text_to_speech import TextToSpeech
|
||||||
from selene.data.device.entity.wake_word import WakeWord
|
from selene.data.device.entity.wake_word import WakeWord
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
account = Account(
|
account = Account(
|
||||||
email_address='test@test.com',
|
email_address='test@test.com',
|
||||||
|
@ -63,22 +63,22 @@ def before_feature(context, _):
|
||||||
def before_scenario(context, _):
|
def before_scenario(context, _):
|
||||||
cache = context.client_config['SELENE_CACHE']
|
cache = context.client_config['SELENE_CACHE']
|
||||||
context.etag_manager = ETagManager(cache, context.client_config)
|
context.etag_manager = ETagManager(cache, context.client_config)
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
try:
|
try:
|
||||||
_add_agreements(context, db)
|
_add_agreements(context, db)
|
||||||
_add_account(context, db)
|
_add_account(context, db)
|
||||||
_add_account_preference(context, db)
|
_add_account_preference(context, db)
|
||||||
_add_geography(context, db)
|
_add_geography(context, db)
|
||||||
_add_device(context, db)
|
_add_device(context, db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.print_exc())
|
print(traceback.print_exc())
|
||||||
|
|
||||||
|
|
||||||
def after_scenario(context, _):
|
def after_scenario(context, _):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
_remove_account(context, db)
|
_remove_account(context, db)
|
||||||
_remove_agreements(context, db)
|
_remove_agreements(context, db)
|
||||||
|
|
||||||
|
|
||||||
def _add_agreements(context, db):
|
def _add_agreements(context, db):
|
||||||
|
|
|
@ -17,7 +17,8 @@ skill_manifest = {
|
||||||
"installed": datetime.now().timestamp(),
|
"installed": datetime.now().timestamp(),
|
||||||
"updated": datetime.now().timestamp(),
|
"updated": datetime.now().timestamp(),
|
||||||
"installation": "installed",
|
"installation": "installed",
|
||||||
"update": 0
|
"update": 0,
|
||||||
|
"skill_gid": "fallback-wolfram-alpha|19.02",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ from hamcrest import assert_that, equal_to, not_none, is_not
|
||||||
|
|
||||||
from selene.api.etag import ETagManager, device_skill_etag_key
|
from selene.api.etag import ETagManager, device_skill_etag_key
|
||||||
from selene.data.skill import SkillSettingRepository
|
from selene.data.skill import SkillSettingRepository
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
skill = {
|
skill = {
|
||||||
'skill_gid': 'wolfram-alpha|19.02',
|
'skill_gid': 'wolfram-alpha|19.02',
|
||||||
|
@ -106,8 +106,8 @@ def update_skill(context):
|
||||||
}]
|
}]
|
||||||
response = json.loads(context.upload_device_response.data)
|
response = json.loads(context.upload_device_response.data)
|
||||||
skill_id = response['uuid']
|
skill_id = response['uuid']
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
|
SkillSettingRepository(db).update_device_skill_settings(skill_id, update_settings)
|
||||||
|
|
||||||
|
|
||||||
@when('the skill settings is fetched')
|
@when('the skill settings is fetched')
|
||||||
|
|
|
@ -31,6 +31,7 @@ def validate_response(context):
|
||||||
response = context.get_device_response
|
response = context.get_device_response
|
||||||
assert_that(response.status_code, equal_to(HTTPStatus.OK))
|
assert_that(response.status_code, equal_to(HTTPStatus.OK))
|
||||||
device = json.loads(response.data)
|
device = json.loads(response.data)
|
||||||
|
assert_that(device, has_key('uuid'))
|
||||||
assert_that(device, has_key('name'))
|
assert_that(device, has_key('name'))
|
||||||
assert_that(device, has_key('description'))
|
assert_that(device, has_key('description'))
|
||||||
assert_that(device, has_key('coreVersion'))
|
assert_that(device, has_key('coreVersion'))
|
||||||
|
|
|
@ -7,7 +7,7 @@ from behave import when, then
|
||||||
from hamcrest import assert_that, has_entry, equal_to
|
from hamcrest import assert_that, has_entry, equal_to
|
||||||
|
|
||||||
from selene.data.account import AccountRepository, AccountMembership
|
from selene.data.account import AccountRepository, AccountMembership
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
@when('the subscription endpoint is called')
|
@when('the subscription endpoint is called')
|
||||||
|
@ -42,9 +42,9 @@ def get_device_subscription(context):
|
||||||
login = context.device_login
|
login = context.device_login
|
||||||
device_id = login['uuid']
|
device_id = login['uuid']
|
||||||
access_token = login['accessToken']
|
access_token = login['accessToken']
|
||||||
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
|
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
AccountRepository(db).add_membership(context.account.id, membership)
|
AccountRepository(db).add_membership(context.account.id, membership)
|
||||||
context.subscription_response = context.client.get(
|
context.subscription_response = context.client.get(
|
||||||
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
|
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
|
||||||
headers=headers
|
headers=headers
|
||||||
|
|
|
@ -13,7 +13,7 @@ from selene.data.account import (
|
||||||
AgreementRepository,
|
AgreementRepository,
|
||||||
PRIVACY_POLICY
|
PRIVACY_POLICY
|
||||||
)
|
)
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
|
@ -32,9 +32,9 @@ def before_feature(context, _):
|
||||||
|
|
||||||
|
|
||||||
def before_scenario(context, _):
|
def before_scenario(context, _):
|
||||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
_add_agreement(context, db)
|
_add_agreement(context, db)
|
||||||
_add_account(context, db)
|
_add_account(context, db)
|
||||||
|
|
||||||
|
|
||||||
def _add_agreement(context, db):
|
def _add_agreement(context, db):
|
||||||
|
@ -72,8 +72,8 @@ def _add_account(context, db):
|
||||||
|
|
||||||
|
|
||||||
def after_scenario(context, _):
|
def after_scenario(context, _):
|
||||||
with get_db_connection(context.db_pool) as db:
|
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
acct_repository.remove(context.account)
|
acct_repository.remove(context.account)
|
||||||
agreement_repository = AgreementRepository(db)
|
agreement_repository = AgreementRepository(db)
|
||||||
agreement_repository.remove(context.agreement, testing=True)
|
agreement_repository.remove(context.agreement, testing=True)
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
CREATE TABLE metrics.job (
|
||||||
|
id uuid PRIMARY KEY
|
||||||
|
DEFAULT gen_random_uuid(),
|
||||||
|
job_name text NOT NULL,
|
||||||
|
batch_date date NOT NULL,
|
||||||
|
start_ts TIMESTAMP NOT NULL,
|
||||||
|
end_ts TIMESTAMP NOT NULL,
|
||||||
|
command text NOT NULL,
|
||||||
|
success BOOLEAN NOT NULL,
|
||||||
|
UNIQUE (job_name, start_ts)
|
||||||
|
)
|
|
@ -1,6 +1,7 @@
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from os import path
|
from os import environ, path, remove
|
||||||
|
|
||||||
|
from markdown import markdown
|
||||||
from psycopg2 import connect
|
from psycopg2 import connect
|
||||||
|
|
||||||
MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft')
|
MYCROFT_DB_DIR = path.join(path.abspath('..'), 'mycroft')
|
||||||
|
@ -8,10 +9,8 @@ SCHEMAS = ('account', 'skill', 'device', 'geography', 'metrics')
|
||||||
DB_DESTROY_FILES = (
|
DB_DESTROY_FILES = (
|
||||||
'drop_mycroft_db.sql',
|
'drop_mycroft_db.sql',
|
||||||
'drop_template_db.sql',
|
'drop_template_db.sql',
|
||||||
# 'drop_roles.sql'
|
|
||||||
)
|
)
|
||||||
DB_CREATE_FILES = (
|
DB_CREATE_FILES = (
|
||||||
# 'create_roles.sql',
|
|
||||||
'create_template_db.sql',
|
'create_template_db.sql',
|
||||||
)
|
)
|
||||||
ACCOUNT_TABLE_ORDER = (
|
ACCOUNT_TABLE_ORDER = (
|
||||||
|
@ -48,6 +47,7 @@ GEOGRAPHY_TABLE_ORDER = (
|
||||||
|
|
||||||
METRICS_TABLE_ORDER = (
|
METRICS_TABLE_ORDER = (
|
||||||
'api',
|
'api',
|
||||||
|
'job'
|
||||||
)
|
)
|
||||||
|
|
||||||
schema_directory = '{}_schema'
|
schema_directory = '{}_schema'
|
||||||
|
@ -61,32 +61,40 @@ def get_sql_from_file(file_path: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class PostgresDB(object):
|
class PostgresDB(object):
|
||||||
def __init__(self, dbname, user, password=None):
|
def __init__(self, db_name, user=None):
|
||||||
self.db = connect(dbname=dbname, user=user, host='127.0.0.1')
|
db_host = environ['DB_HOST']
|
||||||
# self.db = connect(
|
db_port = environ['DB_PORT']
|
||||||
# dbname=dbname,
|
db_ssl_mode = environ.get('DB_SSL_MODE')
|
||||||
# user=user,
|
if db_name in ('postgres', 'defaultdb'):
|
||||||
# password=password,
|
db_user = environ['POSTGRES_DB_USER']
|
||||||
# host='selene-test-db-do-user-1412453-0.db.ondigitalocean.com',
|
db_password = environ.get('POSTGRES_DB_PASSWORD')
|
||||||
# port=25060,
|
else:
|
||||||
# sslmode='require'
|
db_user = environ['MYCROFT_DB_USER']
|
||||||
# )
|
db_password = environ['MYCROFT_DB_PASSWORD']
|
||||||
|
|
||||||
|
if user is not None:
|
||||||
|
db_user = user
|
||||||
|
|
||||||
|
self.db = connect(
|
||||||
|
dbname=db_name,
|
||||||
|
user=db_user,
|
||||||
|
password=db_password,
|
||||||
|
host=db_host,
|
||||||
|
port=db_port,
|
||||||
|
sslmode=db_ssl_mode
|
||||||
|
)
|
||||||
self.db.autocommit = True
|
self.db.autocommit = True
|
||||||
|
|
||||||
def close_db(self):
|
def close_db(self):
|
||||||
self.db.close()
|
self.db.close()
|
||||||
|
|
||||||
def execute_sql(self, sql: str):
|
def execute_sql(self, sql: str, args=None):
|
||||||
cursor = self.db.cursor()
|
cursor = self.db.cursor()
|
||||||
cursor.execute(sql)
|
cursor.execute(sql, args)
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
postgres_db = PostgresDB(dbname='postgres', user='postgres')
|
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
|
||||||
# postgres_db = PostgresDB(
|
|
||||||
# dbname='defaultdb',
|
|
||||||
# user='doadmin',
|
|
||||||
# password='l06tn0qi2bjhgcki'
|
|
||||||
# )
|
|
||||||
|
|
||||||
print('Destroying any objects we will be creating later.')
|
print('Destroying any objects we will be creating later.')
|
||||||
for db_destroy_file in DB_DESTROY_FILES:
|
for db_destroy_file in DB_DESTROY_FILES:
|
||||||
|
@ -94,7 +102,7 @@ for db_destroy_file in DB_DESTROY_FILES:
|
||||||
get_sql_from_file(db_destroy_file)
|
get_sql_from_file(db_destroy_file)
|
||||||
)
|
)
|
||||||
|
|
||||||
print('Creating the extensions, mycroft database, and selene roles')
|
print('Creating the mycroft database')
|
||||||
for db_setup_file in DB_CREATE_FILES:
|
for db_setup_file in DB_CREATE_FILES:
|
||||||
postgres_db.execute_sql(
|
postgres_db.execute_sql(
|
||||||
get_sql_from_file(db_setup_file)
|
get_sql_from_file(db_setup_file)
|
||||||
|
@ -102,13 +110,10 @@ for db_setup_file in DB_CREATE_FILES:
|
||||||
|
|
||||||
postgres_db.close_db()
|
postgres_db.close_db()
|
||||||
|
|
||||||
template_db = PostgresDB(dbname='mycroft_template', user='mycroft')
|
|
||||||
# template_db = PostgresDB(
|
|
||||||
# dbname='mycroft_template',
|
|
||||||
# user='selene',
|
|
||||||
# password='ubhemhx1dikmqc5f'
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
template_db = PostgresDB(db_name='mycroft_template')
|
||||||
|
|
||||||
|
print('Creating the extensions')
|
||||||
template_db.execute_sql(
|
template_db.execute_sql(
|
||||||
get_sql_from_file(path.join('create_extensions.sql'))
|
get_sql_from_file(path.join('create_extensions.sql'))
|
||||||
)
|
)
|
||||||
|
@ -193,22 +198,14 @@ for schema in SCHEMAS:
|
||||||
|
|
||||||
template_db.close_db()
|
template_db.close_db()
|
||||||
|
|
||||||
|
|
||||||
print('Copying template to new database.')
|
print('Copying template to new database.')
|
||||||
postgres_db = PostgresDB(dbname='postgres', user='mycroft')
|
postgres_db = PostgresDB(db_name=environ['POSTGRES_DB_NAME'])
|
||||||
# postgres_db = PostgresDB(
|
|
||||||
# dbname='defaultdb',
|
|
||||||
# user='doadmin',
|
|
||||||
# password='l06tn0qi2bjhgcki'
|
|
||||||
# )
|
|
||||||
postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql'))
|
postgres_db.execute_sql(get_sql_from_file('create_mycroft_db.sql'))
|
||||||
postgres_db.close_db()
|
postgres_db.close_db()
|
||||||
|
|
||||||
mycroft_db = PostgresDB(dbname='mycroft', user='mycroft')
|
|
||||||
# mycroft_db = PostgresDB(
|
mycroft_db = PostgresDB(db_name=environ['MYCROFT_DB_NAME'])
|
||||||
# dbname='mycroft_template',
|
|
||||||
# user='selene',
|
|
||||||
# password='ubhemhx1dikmqc5f'
|
|
||||||
# )
|
|
||||||
insert_files = [
|
insert_files = [
|
||||||
dict(schema_dir='account_schema', file_name='membership.sql'),
|
dict(schema_dir='account_schema', file_name='membership.sql'),
|
||||||
dict(schema_dir='device_schema', file_name='text_to_speech.sql'),
|
dict(schema_dir='device_schema', file_name='text_to_speech.sql'),
|
||||||
|
@ -226,3 +223,162 @@ for insert_file in insert_files:
|
||||||
)
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
print('Building account.agreement table')
|
||||||
|
mycroft_db.db.autocommit = False
|
||||||
|
insert_sql = (
|
||||||
|
"insert into account.agreement VALUES (default, '{}', '1', '[today,]', {})"
|
||||||
|
)
|
||||||
|
doc_dir = '/Users/chrisveilleux/Mycroft/github/documentation/_pages/'
|
||||||
|
docs = {
|
||||||
|
'Privacy Policy': doc_dir + 'embed-privacy-policy.md',
|
||||||
|
'Terms of Use': doc_dir + 'embed-terms-of-use.md'
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
for agrmt_type, doc_path in docs.items():
|
||||||
|
lobj = mycroft_db.db.lobject(0, 'b')
|
||||||
|
with open(doc_path) as doc:
|
||||||
|
header_delimiter_count = 0
|
||||||
|
while True:
|
||||||
|
rec = doc.readline()
|
||||||
|
if rec == '---\n':
|
||||||
|
header_delimiter_count += 1
|
||||||
|
if header_delimiter_count == 2:
|
||||||
|
break
|
||||||
|
doc_html = markdown(
|
||||||
|
doc.read(),
|
||||||
|
output_format='html5'
|
||||||
|
)
|
||||||
|
lobj.write(doc_html)
|
||||||
|
mycroft_db.execute_sql(
|
||||||
|
insert_sql.format(agrmt_type, lobj.oid)
|
||||||
|
)
|
||||||
|
mycroft_db.execute_sql(
|
||||||
|
"grant select on large object {} to selene".format(lobj.oid)
|
||||||
|
)
|
||||||
|
mycroft_db.execute_sql(
|
||||||
|
insert_sql.format('Open Dataset', 'null')
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
mycroft_db.db.rollback()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
mycroft_db.db.commit()
|
||||||
|
|
||||||
|
mycroft_db.db.autocommit = True
|
||||||
|
|
||||||
|
reference_file_dir = '/Users/chrisveilleux/Mycroft'
|
||||||
|
|
||||||
|
print('Building geography.country table')
|
||||||
|
country_file = 'country.txt'
|
||||||
|
country_insert = """
|
||||||
|
INSERT INTO
|
||||||
|
geography.country (iso_code, name)
|
||||||
|
VALUES
|
||||||
|
('{iso_code}', '{country_name}')
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(path.join(reference_file_dir, country_file)) as countries:
|
||||||
|
while True:
|
||||||
|
rec = countries.readline()
|
||||||
|
if rec.startswith('#ISO'):
|
||||||
|
break
|
||||||
|
|
||||||
|
for country in countries.readlines():
|
||||||
|
country_fields = country.split('\t')
|
||||||
|
insert_args = dict(
|
||||||
|
iso_code=country_fields[0],
|
||||||
|
country_name=country_fields[4]
|
||||||
|
)
|
||||||
|
mycroft_db.execute_sql(country_insert.format(**insert_args))
|
||||||
|
|
||||||
|
print('Building geography.region table')
|
||||||
|
region_file = 'regions.txt'
|
||||||
|
region_insert = """
|
||||||
|
INSERT INTO
|
||||||
|
geography.region (country_id, region_code, name)
|
||||||
|
VALUES
|
||||||
|
(
|
||||||
|
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
|
||||||
|
%(region_code)s,
|
||||||
|
%(region_name)s)
|
||||||
|
"""
|
||||||
|
with open(path.join(reference_file_dir, region_file)) as regions:
|
||||||
|
for region in regions.readlines():
|
||||||
|
region_fields = region.split('\t')
|
||||||
|
country_iso_code = region_fields[0][:2]
|
||||||
|
insert_args = dict(
|
||||||
|
iso_code=country_iso_code,
|
||||||
|
region_code=region_fields[0],
|
||||||
|
region_name=region_fields[1]
|
||||||
|
)
|
||||||
|
mycroft_db.execute_sql(region_insert, insert_args)
|
||||||
|
|
||||||
|
print('Building geography.timezone table')
|
||||||
|
timezone_file = 'timezones.txt'
|
||||||
|
timezone_insert = """
|
||||||
|
INSERT INTO
|
||||||
|
geography.timezone (country_id, name, gmt_offset, dst_offset)
|
||||||
|
VALUES
|
||||||
|
(
|
||||||
|
(SELECT id FROM geography.country WHERE iso_code = %(iso_code)s),
|
||||||
|
%(timezone_name)s,
|
||||||
|
%(gmt_offset)s,
|
||||||
|
%(dst_offset)s
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
with open(path.join(reference_file_dir, timezone_file)) as timezones:
|
||||||
|
timezones.readline()
|
||||||
|
for timezone in timezones.readlines():
|
||||||
|
timezone_fields = timezone.split('\t')
|
||||||
|
insert_args = dict(
|
||||||
|
iso_code=timezone_fields[0],
|
||||||
|
timezone_name=timezone_fields[1],
|
||||||
|
gmt_offset=timezone_fields[2],
|
||||||
|
dst_offset=timezone_fields[3]
|
||||||
|
)
|
||||||
|
mycroft_db.execute_sql(timezone_insert, insert_args)
|
||||||
|
|
||||||
|
print('Building geography.city table')
|
||||||
|
cities_file = 'cities500.txt'
|
||||||
|
region_query = "SELECT id, region_code FROM geography.region"
|
||||||
|
query_result = mycroft_db.execute_sql(region_query)
|
||||||
|
region_lookup = dict()
|
||||||
|
for row in query_result.fetchall():
|
||||||
|
region_lookup[row[1]] = row[0]
|
||||||
|
|
||||||
|
timezone_query = "SELECT id, name FROM geography.timezone"
|
||||||
|
query_result = mycroft_db.execute_sql(timezone_query)
|
||||||
|
timezone_lookup = dict()
|
||||||
|
for row in query_result.fetchall():
|
||||||
|
timezone_lookup[row[1]] = row[0]
|
||||||
|
# city_insert = """
|
||||||
|
# INSERT INTO
|
||||||
|
# geography.city (region_id, timezone_id, name, latitude, longitude)
|
||||||
|
# VALUES
|
||||||
|
# (%(region_id)s, %(timezone_id)s, %(city_name)s, %(latitude)s, %(longitude)s)
|
||||||
|
# """
|
||||||
|
with open(path.join(reference_file_dir, cities_file)) as cities:
|
||||||
|
with open(path.join(reference_file_dir, 'city.dump'), 'w') as dump_file:
|
||||||
|
for city in cities.readlines():
|
||||||
|
city_fields = city.split('\t')
|
||||||
|
city_region = city_fields[8] + '.' + city_fields[10]
|
||||||
|
region_id = region_lookup.get(city_region)
|
||||||
|
timezone_id = timezone_lookup[city_fields[17]]
|
||||||
|
if region_id is not None:
|
||||||
|
dump_file.write('\t'.join([
|
||||||
|
region_id,
|
||||||
|
timezone_id,
|
||||||
|
city_fields[1],
|
||||||
|
city_fields[4],
|
||||||
|
city_fields[5]
|
||||||
|
]) + '\n')
|
||||||
|
# mycroft_db.execute_sql(city_insert, insert_args)
|
||||||
|
with open(path.join(reference_file_dir, 'city.dump')) as dump_file:
|
||||||
|
cursor = mycroft_db.db.cursor()
|
||||||
|
cursor.copy_from(dump_file, 'geography.city', columns=(
|
||||||
|
'region_id', 'timezone_id', 'name', 'latitude', 'longitude')
|
||||||
|
)
|
||||||
|
remove(path.join(reference_file_dir, 'city.dump'))
|
||||||
|
|
||||||
|
mycroft_db.close_db()
|
||||||
|
|
|
@ -22,15 +22,6 @@ import os
|
||||||
|
|
||||||
from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig
|
from selene.util.db import allocate_db_connection_pool, DatabaseConnectionConfig
|
||||||
|
|
||||||
db_connection_config = DatabaseConnectionConfig(
|
|
||||||
host=os.environ['DB_HOST'],
|
|
||||||
db_name=os.environ['DB_NAME'],
|
|
||||||
password=os.environ['DB_PASSWORD'],
|
|
||||||
port=os.environ.get('DB_PORT', 5432),
|
|
||||||
user=os.environ['DB_USER'],
|
|
||||||
sslmode=os.environ.get('DB_SSLMODE')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class APIConfigError(Exception):
|
class APIConfigError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -43,6 +34,14 @@ class BaseConfig(object):
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
ENV = os.environ['SELENE_ENVIRONMENT']
|
ENV = os.environ['SELENE_ENVIRONMENT']
|
||||||
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
|
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
|
||||||
|
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
|
||||||
|
host=os.environ['DB_HOST'],
|
||||||
|
db_name=os.environ['DB_NAME'],
|
||||||
|
password=os.environ['DB_PASSWORD'],
|
||||||
|
port=os.environ.get('DB_PORT', 5432),
|
||||||
|
user=os.environ['DB_USER'],
|
||||||
|
sslmode=os.environ.get('DB_SSLMODE')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DevelopmentConfig(BaseConfig):
|
class DevelopmentConfig(BaseConfig):
|
||||||
|
@ -80,10 +79,4 @@ def get_base_config():
|
||||||
error_msg = 'no configuration defined for the "{}" environment'
|
error_msg = 'no configuration defined for the "{}" environment'
|
||||||
raise APIConfigError(error_msg.format(environment_name))
|
raise APIConfigError(error_msg.format(environment_name))
|
||||||
|
|
||||||
max_db_connections = os.environ.get('MAX_DB_CONNECTIONS', 20)
|
|
||||||
app_config.DB_CONNECTION_POOL = allocate_db_connection_pool(
|
|
||||||
db_connection_config,
|
|
||||||
max_db_connections
|
|
||||||
)
|
|
||||||
|
|
||||||
return app_config
|
return app_config
|
||||||
|
|
|
@ -6,7 +6,7 @@ from flask.views import MethodView
|
||||||
|
|
||||||
from selene.data.account import Account, AccountRepository
|
from selene.data.account import Account, AccountRepository
|
||||||
from selene.util.auth import AuthenticationError, AuthenticationToken
|
from selene.util.auth import AuthenticationError, AuthenticationToken
|
||||||
from selene.util.db import get_db_connection_from_pool
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
||||||
FIFTEEN_MINUTES = 900
|
FIFTEEN_MINUTES = 900
|
||||||
|
@ -42,8 +42,8 @@ class SeleneEndpoint(MethodView):
|
||||||
@property
|
@property
|
||||||
def db(self):
|
def db(self):
|
||||||
if 'db' not in global_context:
|
if 'db' not in global_context:
|
||||||
global_context.db = get_db_connection_from_pool(
|
global_context.db = connect_to_db(
|
||||||
current_app.config['DB_CONNECTION_POOL']
|
current_app.config['DB_CONNECTION_CONFIG']
|
||||||
)
|
)
|
||||||
|
|
||||||
return global_context.db
|
return global_context.db
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from datetime import datetime
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from flask import current_app, Blueprint, g as global_context
|
from flask import current_app, Blueprint, g as global_context
|
||||||
|
@ -7,10 +7,7 @@ from schematics.exceptions import DataError
|
||||||
|
|
||||||
from selene.data.metrics import ApiMetric, ApiMetricsRepository
|
from selene.data.metrics import ApiMetric, ApiMetricsRepository
|
||||||
from selene.util.auth import AuthenticationError
|
from selene.util.auth import AuthenticationError
|
||||||
from selene.util.db import (
|
from selene.util.db import connect_to_db
|
||||||
get_db_connection_from_pool,
|
|
||||||
return_db_connection_to_pool
|
|
||||||
)
|
|
||||||
from selene.util.not_modified import NotModifiedError
|
from selene.util.not_modified import NotModifiedError
|
||||||
|
|
||||||
selene_api = Blueprint('selene_api', __name__)
|
selene_api = Blueprint('selene_api', __name__)
|
||||||
|
@ -39,7 +36,6 @@ def setup_request():
|
||||||
@selene_api.after_app_request
|
@selene_api.after_app_request
|
||||||
def teardown_request(response):
|
def teardown_request(response):
|
||||||
add_api_metric(response.status_code)
|
add_api_metric(response.status_code)
|
||||||
release_db_connection()
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -54,8 +50,8 @@ def add_api_metric(http_status):
|
||||||
|
|
||||||
if api is not None and int(http_status) != 304:
|
if api is not None and int(http_status) != 304:
|
||||||
if 'db' not in global_context:
|
if 'db' not in global_context:
|
||||||
global_context.db = get_db_connection_from_pool(
|
global_context.db = connect_to_db(
|
||||||
current_app.config['DB_CONNECTION_POOL']
|
current_app.config['DB_CONNECTION_CONFIG']
|
||||||
)
|
)
|
||||||
if 'account_id' in global_context:
|
if 'account_id' in global_context:
|
||||||
account_id = global_context.account_id
|
account_id = global_context.account_id
|
||||||
|
@ -78,12 +74,3 @@ def add_api_metric(http_status):
|
||||||
)
|
)
|
||||||
metric_repository = ApiMetricsRepository(global_context.db)
|
metric_repository = ApiMetricsRepository(global_context.db)
|
||||||
metric_repository.add(api_metric)
|
metric_repository.add(api_metric)
|
||||||
|
|
||||||
|
|
||||||
def release_db_connection():
|
|
||||||
db = global_context.pop('db', None)
|
|
||||||
if db is not None:
|
|
||||||
return_db_connection_to_pool(
|
|
||||||
current_app.config['DB_CONNECTION_POOL'],
|
|
||||||
db
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import asdict
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from selene.data.account import AgreementRepository
|
from selene.data.account import AgreementRepository
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
from ..base_endpoint import SeleneEndpoint
|
from ..base_endpoint import SeleneEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,14 +16,14 @@ class AgreementsEndpoint(SeleneEndpoint):
|
||||||
|
|
||||||
def get(self, agreement_type):
|
def get(self, agreement_type):
|
||||||
"""Process HTTP GET request for an agreement."""
|
"""Process HTTP GET request for an agreement."""
|
||||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
|
||||||
agreement_repository = AgreementRepository(db)
|
agreement_repository = AgreementRepository(db)
|
||||||
agreement = agreement_repository.get_active_for_type(
|
agreement = agreement_repository.get_active_for_type(
|
||||||
self.agreement_types[agreement_type]
|
self.agreement_types[agreement_type]
|
||||||
)
|
)
|
||||||
if agreement is not None:
|
if agreement is not None:
|
||||||
agreement = asdict(agreement)
|
agreement = asdict(agreement)
|
||||||
del(agreement['effective_date'])
|
del(agreement['effective_date'])
|
||||||
self.response = agreement, HTTPStatus.OK
|
self.response = agreement, HTTPStatus.OK
|
||||||
|
|
||||||
return self.response
|
return self.response
|
||||||
|
|
|
@ -3,7 +3,7 @@ import string
|
||||||
|
|
||||||
from selene.data.device import DeviceRepository
|
from selene.data.device import DeviceRepository
|
||||||
from selene.util.cache import SeleneCache
|
from selene.util.cache import SeleneCache
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
|
|
||||||
def device_etag_key(device_id: str):
|
def device_etag_key(device_id: str):
|
||||||
|
@ -29,7 +29,7 @@ class ETagManager(object):
|
||||||
|
|
||||||
def __init__(self, cache: SeleneCache, config: dict):
|
def __init__(self, cache: SeleneCache, config: dict):
|
||||||
self.cache: SeleneCache = cache
|
self.cache: SeleneCache = cache
|
||||||
self.db_connection_pool = config['DB_CONNECTION_POOL']
|
self.db_connection_config = config['DB_CONNECTION_CONFIG']
|
||||||
|
|
||||||
def get(self, key: str) -> str:
|
def get(self, key: str) -> str:
|
||||||
"""Generate a etag with 32 random chars and store it into a given key
|
"""Generate a etag with 32 random chars and store it into a given key
|
||||||
|
@ -60,10 +60,10 @@ class ETagManager(object):
|
||||||
def expire_device_setting_etag_by_account_id(self, account_id: str):
|
def expire_device_setting_etag_by_account_id(self, account_id: str):
|
||||||
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated
|
"""Expire the settings' etags for all devices from a given account. Used when the settings are updated
|
||||||
at account level"""
|
at account level"""
|
||||||
with get_db_connection(self.db_connection_pool) as db:
|
db = connect_to_db(self.db_connection_config)
|
||||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||||
for device in devices:
|
for device in devices:
|
||||||
self.expire_device_setting_etag_by_device_id(device.id)
|
self.expire_device_setting_etag_by_device_id(device.id)
|
||||||
|
|
||||||
def expire_device_location_etag_by_device_id(self, device_id: str):
|
def expire_device_location_etag_by_device_id(self, device_id: str):
|
||||||
"""Expire the etag associate with the device's location entity
|
"""Expire the etag associate with the device's location entity
|
||||||
|
@ -73,10 +73,10 @@ class ETagManager(object):
|
||||||
def expire_device_location_etag_by_account_id(self, account_id: str):
|
def expire_device_location_etag_by_account_id(self, account_id: str):
|
||||||
"""Expire the locations' etag fpr açç device for a given acccount
|
"""Expire the locations' etag fpr açç device for a given acccount
|
||||||
:param account_id: account uuid"""
|
:param account_id: account uuid"""
|
||||||
with get_db_connection(self.db_connection_pool) as db:
|
db = connect_to_db(self.db_connection_config)
|
||||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||||
for device in devices:
|
for device in devices:
|
||||||
self.expire_device_location_etag_by_device_id(device.id)
|
self.expire_device_location_etag_by_device_id(device.id)
|
||||||
|
|
||||||
def expire_skill_etag_by_device_id(self, device_id):
|
def expire_skill_etag_by_device_id(self, device_id):
|
||||||
"""Expire the locations' etag for a given device
|
"""Expire the locations' etag for a given device
|
||||||
|
@ -84,7 +84,7 @@ class ETagManager(object):
|
||||||
self._expire(device_skill_etag_key(device_id))
|
self._expire(device_skill_etag_key(device_id))
|
||||||
|
|
||||||
def expire_skill_etag_by_account_id(self, account_id):
|
def expire_skill_etag_by_account_id(self, account_id):
|
||||||
with get_db_connection(self.db_connection_pool) as db:
|
db = connect_to_db(self.db_connection_config)
|
||||||
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
devices = DeviceRepository(db).get_devices_by_account_id(account_id)
|
||||||
for device in devices:
|
for device in devices:
|
||||||
self.expire_skill_etag_by_device_id(device.id)
|
self.expire_skill_etag_by_device_id(device.id)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from flask.views import MethodView
|
||||||
|
|
||||||
from selene.api.etag import ETagManager
|
from selene.api.etag import ETagManager
|
||||||
from selene.util.auth import AuthenticationError
|
from selene.util.auth import AuthenticationError
|
||||||
from selene.util.db import get_db_connection_from_pool
|
from selene.util.db import connect_to_db
|
||||||
from selene.util.not_modified import NotModifiedError
|
from selene.util.not_modified import NotModifiedError
|
||||||
from ..util.cache import SeleneCache
|
from ..util.cache import SeleneCache
|
||||||
|
|
||||||
|
@ -91,8 +91,8 @@ class PublicEndpoint(MethodView):
|
||||||
@property
|
@property
|
||||||
def db(self):
|
def db(self):
|
||||||
if 'db' not in global_context:
|
if 'db' not in global_context:
|
||||||
global_context.db = get_db_connection_from_pool(
|
global_context.db = connect_to_db(
|
||||||
current_app.config['DB_CONNECTION_POOL']
|
current_app.config['DB_CONNECTION_CONFIG']
|
||||||
)
|
)
|
||||||
|
|
||||||
return global_context.db
|
return global_context.db
|
||||||
|
|
|
@ -2,7 +2,7 @@ from hamcrest import assert_that, equal_to, has_item
|
||||||
|
|
||||||
from selene.data.account import Account, AccountRepository
|
from selene.data.account import Account, AccountRepository
|
||||||
from selene.util.auth import AuthenticationToken
|
from selene.util.auth import AuthenticationToken
|
||||||
from selene.util.db import get_db_connection
|
from selene.util.db import connect_to_db
|
||||||
|
|
||||||
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
|
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
|
||||||
ONE_MINUTE = 60
|
ONE_MINUTE = 60
|
||||||
|
@ -77,8 +77,8 @@ def _parse_cookie(cookie: str) -> dict:
|
||||||
|
|
||||||
|
|
||||||
def get_account(context) -> Account:
|
def get_account(context) -> Account:
|
||||||
with get_db_connection(context.db_pool) as db:
|
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
|
||||||
acct_repository = AccountRepository(db)
|
acct_repository = AccountRepository(db)
|
||||||
account = acct_repository.get_account_by_id(context.account.id)
|
account = acct_repository.get_account_by_id(context.account.id)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
|
@ -12,7 +12,7 @@ class SkillDisplayRepository(RepositoryBase):
|
||||||
sql_file_name='get_display_data_for_skills.sql'
|
sql_file_name='get_display_data_for_skills.sql'
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_display_data_for_skill(self, skill_display_id):
|
def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
|
||||||
return self._select_one_into_dataclass(
|
return self._select_one_into_dataclass(
|
||||||
dataclass=SkillDisplay,
|
dataclass=SkillDisplay,
|
||||||
sql_file_name='get_display_data_for_skill.sql',
|
sql_file_name='get_display_data_for_skill.sql',
|
||||||
|
|
|
@ -34,12 +34,12 @@ class SkillSettingRepository(RepositoryBase):
|
||||||
|
|
||||||
return skill_settings
|
return skill_settings
|
||||||
|
|
||||||
def get_installer_settings(self, account_id: str):
|
def get_installer_settings(self, account_id: str) -> List[AccountSkillSetting]:
|
||||||
skill_repo = SkillRepository(self.db)
|
skill_repo = SkillRepository(self.db)
|
||||||
skills = skill_repo.get_skills_for_account(account_id)
|
skills = skill_repo.get_skills_for_account(account_id)
|
||||||
installer_skill_id = None
|
installer_skill_id = None
|
||||||
for skill in skills:
|
for skill in skills:
|
||||||
if skill.name == 'mycroft_installer':
|
if skill.display_name == 'Installer':
|
||||||
installer_skill_id = skill.id
|
installer_skill_id = skill.id
|
||||||
|
|
||||||
skill_settings = None
|
skill_settings = None
|
||||||
|
|
|
@ -6,12 +6,12 @@ Example Usage:
|
||||||
query_result = mycroft_db_ro.execute_sql(sql)
|
query_result = mycroft_db_ro.execute_sql(sql)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from dataclasses import dataclass, field, InitVar
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from psycopg2 import connect
|
from psycopg2 import connect
|
||||||
from psycopg2.extras import RealDictCursor
|
from psycopg2.extras import RealDictCursor, NamedTupleCursor
|
||||||
|
from psycopg2.extensions import cursor
|
||||||
|
|
||||||
_log = getLogger(__package__)
|
_log = getLogger(__package__)
|
||||||
|
|
||||||
|
@ -29,10 +29,16 @@ class DatabaseConnectionConfig(object):
|
||||||
password: str
|
password: str
|
||||||
port: int = field(default=5432)
|
port: int = field(default=5432)
|
||||||
sslmode: str = None
|
sslmode: str = None
|
||||||
|
autocommit: str = True
|
||||||
|
cursor_factory = RealDictCursor
|
||||||
|
use_namedtuple_cursor: InitVar[bool] = False
|
||||||
|
|
||||||
|
def __post_init__(self, use_namedtuple_cursor: bool):
|
||||||
|
if use_namedtuple_cursor:
|
||||||
|
self.cursor_factory = NamedTupleCursor
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
def connect_to_db(connection_config: DatabaseConnectionConfig):
|
||||||
def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
|
|
||||||
"""
|
"""
|
||||||
Return a connection to the mycroft database for the specified user.
|
Return a connection to the mycroft database for the specified user.
|
||||||
|
|
||||||
|
@ -41,33 +47,19 @@ def connect_to_db(connection_config: DatabaseConnectionConfig, autocommit=True):
|
||||||
python notebook)
|
python notebook)
|
||||||
|
|
||||||
:param connection_config: data needed to establish a connection
|
:param connection_config: data needed to establish a connection
|
||||||
:param autocommit: indicated if transactions should commit automatically
|
|
||||||
:return: database connection
|
:return: database connection
|
||||||
"""
|
"""
|
||||||
db = None
|
|
||||||
log_msg = 'establishing connection to the {db_name} database'
|
log_msg = 'establishing connection to the {db_name} database'
|
||||||
_log.info(log_msg.format(db_name=connection_config.db_name))
|
_log.info(log_msg.format(db_name=connection_config.db_name))
|
||||||
try:
|
db = connect(
|
||||||
if connection_config.sslmode is None:
|
host=connection_config.host,
|
||||||
db = connect(
|
dbname=connection_config.db_name,
|
||||||
host=connection_config.host,
|
user=connection_config.user,
|
||||||
dbname=connection_config.db_name,
|
password=connection_config.password,
|
||||||
user=connection_config.user,
|
port=connection_config.port,
|
||||||
port=connection_config.port,
|
cursor_factory=connection_config.cursor_factory,
|
||||||
cursor_factory=RealDictCursor,
|
sslmode=connection_config.sslmode
|
||||||
)
|
)
|
||||||
else:
|
db.autocommit = connection_config.autocommit
|
||||||
db = connect(
|
|
||||||
host=connection_config.host,
|
return db
|
||||||
dbname=connection_config.db_name,
|
|
||||||
user=connection_config.user,
|
|
||||||
password=connection_config.password,
|
|
||||||
port=connection_config.port,
|
|
||||||
cursor_factory=RealDictCursor,
|
|
||||||
sslmode=connection_config.sslmode
|
|
||||||
)
|
|
||||||
db.autocommit = autocommit
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
if db is not None:
|
|
||||||
db.close()
|
|
||||||
|
|
Loading…
Reference in New Issue