applied the "Black" formatter to all files and added pre-commit hook to check

pull/293/head
Chris Veilleux 2022-03-11 13:22:33 -06:00
parent bbad8e2f3b
commit 26ed641b48
109 changed files with 1457 additions and 1543 deletions

11
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,11 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
- id: black

View File

@ -5,6 +5,8 @@
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](http://makeapullrequest.com)
[![Join chat](https://img.shields.io/badge/Mattermost-join_chat-brightgreen.svg)](https://chat.mycroft.ai)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
Selene -- Mycroft's Server Backend

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -31,8 +31,6 @@ class DeviceCountEndpoint(SeleneEndpoint):
def _get_devices(self):
device_repository = DeviceRepository(self.db)
device_count = device_repository.get_account_device_count(
self.account.id
)
device_count = device_repository.get_account_device_count(self.account.id)
return device_count

View File

@ -25,7 +25,7 @@ from selene.api import SeleneEndpoint
class PairingCodeEndpoint(SeleneEndpoint):
def __init__(self):
super(PairingCodeEndpoint, self).__init__()
self.cache = self.config['SELENE_CACHE']
self.cache = self.config["SELENE_CACHE"]
def get(self, pairing_code):
self._authenticate()
@ -36,7 +36,7 @@ class PairingCodeEndpoint(SeleneEndpoint):
def _get_pairing_data(self, pairing_code: str) -> bool:
"""Checking if there's one pairing session for the pairing code."""
pairing_code_is_valid = False
cache_key = 'pairing.code:' + pairing_code
cache_key = "pairing.code:" + pairing_code
pairing_cache = self.cache.get(cache_key)
if pairing_cache is not None:
pairing_code_is_valid = True

View File

@ -29,29 +29,23 @@ from selene.data.device import AccountPreferences, PreferenceRepository
class PreferencesRequest(Model):
date_format = StringType(
required=True,
choices=['DD/MM/YYYY', 'MM/DD/YYYY']
)
measurement_system = StringType(
required=True,
choices=['Imperial', 'Metric']
)
time_format = StringType(required=True, choices=['12 Hour', '24 Hour'])
date_format = StringType(required=True, choices=["DD/MM/YYYY", "MM/DD/YYYY"])
measurement_system = StringType(required=True, choices=["Imperial", "Metric"])
time_format = StringType(required=True, choices=["12 Hour", "24 Hour"])
class PreferencesEndpoint(SeleneEndpoint):
def __init__(self):
super(PreferencesEndpoint, self).__init__()
self.preferences = None
self.cache = self.config['SELENE_CACHE']
self.cache = self.config["SELENE_CACHE"]
self.etag_manager: ETagManager = ETagManager(self.cache, self.config)
def get(self):
self._authenticate()
self._get_preferences()
if self.preferences is None:
response_data = ''
response_data = ""
response_code = HTTPStatus.NO_CONTENT
else:
response_data = asdict(self.preferences)
@ -68,22 +62,20 @@ class PreferencesEndpoint(SeleneEndpoint):
self._validate_request()
self._upsert_preferences()
self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def patch(self):
self._authenticate()
self._validate_request()
self._upsert_preferences()
self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def _validate_request(self):
self.preferences = PreferencesRequest()
self.preferences.date_format = self.request.json['dateFormat']
self.preferences.measurement_system = (
self.request.json['measurementSystem']
)
self.preferences.time_format = self.request.json['timeFormat']
self.preferences.date_format = self.request.json["dateFormat"]
self.preferences.measurement_system = self.request.json["measurementSystem"]
self.preferences.time_format = self.request.json["timeFormat"]
self.preferences.validate()
def _upsert_preferences(self):

View File

@ -25,7 +25,7 @@ from selene.data.geography import RegionRepository
class RegionEndpoint(SeleneEndpoint):
def get(self):
country_id = self.request.args['country']
country_id = self.request.args["country"]
region_repository = RegionRepository(self.db)
regions = region_repository.get_regions_by_country(country_id)

View File

@ -27,17 +27,15 @@ from selene.api import SeleneEndpoint
class SkillOauthEndpoint(SeleneEndpoint):
def __init__(self):
super(SkillOauthEndpoint, self).__init__()
self.oauth_base_url = os.environ['OAUTH_BASE_URL']
self.oauth_base_url = os.environ["OAUTH_BASE_URL"]
def get(self, oauth_id):
self._authenticate()
return self._get_oauth_url(oauth_id)
def _get_oauth_url(self, oauth_id):
url = '{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}'.format(
base_url=self.oauth_base_url,
oauth_id=oauth_id,
account_id=self.account.id
url = "{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}".format(
base_url=self.oauth_base_url, oauth_id=oauth_id, account_id=self.account.id
)
response = requests.get(url)
return response.text, response.status_code

View File

@ -35,8 +35,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
self.account_skills = None
self.family_settings = None
self.etag_manager: ETagManager = ETagManager(
self.config['SELENE_CACHE'],
self.config
self.config["SELENE_CACHE"], self.config
)
@property
@ -51,8 +50,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
"""Process an HTTP GET request"""
self._authenticate()
self.family_settings = self.setting_repository.get_family_settings(
self.account.id,
skill_family_name
self.account.id, skill_family_name
)
self._parse_selection_options()
response_data = self._build_response_data()
@ -62,7 +60,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
return Response(
response=json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
def _parse_selection_options(self):
@ -75,19 +73,16 @@ class SkillSettingsEndpoint(SeleneEndpoint):
"""
for skill_settings in self.family_settings:
if skill_settings.settings_definition is not None:
for section in skill_settings.settings_definition['sections']:
for field in section['fields']:
if field['type'] == 'select':
for section in skill_settings.settings_definition["sections"]:
for field in section["fields"]:
if field["type"] == "select":
parsed_options = []
for option in field['options'].split(';'):
option_display, option_value = option.split('|')
for option in field["options"].split(";"):
option_display, option_value = option.split("|")
parsed_options.append(
dict(
display=option_display,
value=option_value
dict(display=option_display, value=option_value)
)
)
field['options'] = parsed_options
field["options"] = parsed_options
def _build_response_data(self):
"""Build the object to return to the UI."""
@ -100,7 +95,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
response_skill = dict(
settingsDisplay=skill_settings.settings_definition,
settingsValues=skill_settings.settings_values,
deviceNames=skill_settings.device_names
deviceNames=skill_settings.device_names,
)
response_data.append(response_skill)
@ -111,19 +106,17 @@ class SkillSettingsEndpoint(SeleneEndpoint):
self._authenticate()
self._update_settings_values()
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _update_settings_values(self):
"""Update the value of the settings column on the device_skill table,"""
for new_skill_settings in self.request.json['skillSettings']:
for new_skill_settings in self.request.json["skillSettings"]:
account_skill_settings = AccountSkillSetting(
settings_definition=new_skill_settings['settingsDisplay'],
settings_values=new_skill_settings['settingsValues'],
device_names=new_skill_settings['deviceNames']
settings_definition=new_skill_settings["settingsDisplay"],
settings_values=new_skill_settings["settingsValues"],
device_names=new_skill_settings["deviceNames"],
)
self.setting_repository.update_skill_settings(
self.account.id,
account_skill_settings,
self.request.json['skillIds']
self.account.id, account_skill_settings, self.request.json["skillIds"]
)
self.etag_manager.expire_skill_etag_by_account_id(self.account.id)

View File

@ -44,11 +44,11 @@ class SkillsEndpoint(SeleneEndpoint):
market_id=skill.market_id,
name=skill.display_name or skill.family_name,
has_settings=skill.has_settings,
skill_ids=skill.skill_ids
skill_ids=skill.skill_ids,
)
else:
response_skill['skill_ids'].extend(skill.skill_ids)
if response_skill['market_id'] is None:
response_skill['market_id'] = skill.market_id
response_skill["skill_ids"].extend(skill.skill_ids)
if response_skill["market_id"] is None:
response_skill["market_id"] = skill.market_id
return sorted(response_data.values(), key=lambda x: x['name'])
return sorted(response_data.values(), key=lambda x: x["name"])

View File

@ -25,7 +25,7 @@ from selene.data.geography import TimezoneRepository
class TimezoneEndpoint(SeleneEndpoint):
def get(self):
country_id = self.request.args['country']
country_id = self.request.args["country"]
timezone_repository = TimezoneRepository(self.db)
timezones = timezone_repository.get_timezones_by_country(country_id)

View File

@ -26,19 +26,19 @@ from hamcrest import assert_that, equal_to
from selene.data.account import PRIVACY_POLICY, TERMS_OF_USE
@when('API request for {agreement} is made')
@when("API request for {agreement} is made")
def call_agreement_endpoint(context, agreement):
if agreement == PRIVACY_POLICY:
url = '/api/agreement/privacy-policy'
url = "/api/agreement/privacy-policy"
elif agreement == TERMS_OF_USE:
url = '/api/agreement/terms-of-use'
url = "/api/agreement/terms-of-use"
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
context.response = context.client.get(url)
@then('{agreement} version {version} is returned')
@then("{agreement} version {version} is returned")
def validate_response(context, agreement, version):
response_data = json.loads(context.response.data)
if agreement == PRIVACY_POLICY:
@ -46,7 +46,7 @@ def validate_response(context, agreement, version):
elif agreement == TERMS_OF_USE:
expected_response = asdict(context.terms_of_use)
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
del(expected_response['effective_date'])
del expected_response["effective_date"]
assert_that(response_data, equal_to(expected_response))

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -27,7 +27,8 @@ from selene.data.skill import SkillDisplay, SkillDisplayRepository
class SkillDetailEndpoint(SeleneEndpoint):
""""Supply the data that will populate the skill detail page."""
""" "Supply the data that will populate the skill detail page."""
authentication_required = False
def __init__(self):
@ -57,39 +58,37 @@ class SkillDetailEndpoint(SeleneEndpoint):
def _build_response_data(self, skill_display: SkillDisplay):
"""Make some modifications to the response skill for the marketplace"""
self.response_skill = dict(
categories=skill_display.display_data.get('categories'),
credits=skill_display.display_data.get('credits'),
categories=skill_display.display_data.get("categories"),
credits=skill_display.display_data.get("credits"),
description=markdown(
skill_display.display_data.get('description'),
output_format='html5'
skill_display.display_data.get("description"), output_format="html5"
),
display_name=skill_display.display_data['display_name'],
icon=skill_display.display_data.get('icon'),
iconImage=skill_display.display_data.get('icon_img'),
display_name=skill_display.display_data["display_name"],
icon=skill_display.display_data.get("icon"),
iconImage=skill_display.display_data.get("icon_img"),
isSystemSkill=False,
worksOnMarkOne=(
'all' in skill_display.display_data['platforms'] or
'platform_mark1' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_mark1" in skill_display.display_data["platforms"]
),
worksOnMarkTwo=(
'all' in skill_display.display_data['platforms'] or
'platform_mark2' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_mark2" in skill_display.display_data["platforms"]
),
worksOnPicroft=(
'all' in skill_display.display_data['platforms'] or
'platform_picroft' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_picroft" in skill_display.display_data["platforms"]
),
worksOnKDE=(
'all' in skill_display.display_data['platforms'] or
'platform_plasmoid' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_plasmoid" in skill_display.display_data["platforms"]
),
repositoryUrl=skill_display.display_data.get('repo'),
repositoryUrl=skill_display.display_data.get("repo"),
summary=markdown(
skill_display.display_data['short_desc'],
output_format='html5'
skill_display.display_data["short_desc"], output_format="html5"
),
triggers=skill_display.display_data['examples']
triggers=skill_display.display_data["examples"],
)
if skill_display.display_data['tags'] is not None:
if 'system' in skill_display.display_data['tags']:
self.response_skill['isSystemSkill'] = True
if skill_display.display_data["tags"] is not None:
if "system" in skill_display.display_data["tags"]:
self.response_skill["isSystemSkill"] = True

View File

@ -26,12 +26,7 @@ from selene.api import SeleneEndpoint
from selene.data.device import DeviceSkillRepository, ManifestSkill
from selene.util.auth import AuthenticationError
VALID_STATUS_VALUES = (
'failed',
'installed',
'installing',
'uninstalling'
)
VALID_STATUS_VALUES = ("failed", "installed", "installing", "uninstalling")
class SkillInstallStatusEndpoint(SeleneEndpoint):
@ -45,7 +40,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
try:
self._authenticate()
except AuthenticationError:
self.response = ('', HTTPStatus.NO_CONTENT)
self.response = ("", HTTPStatus.NO_CONTENT)
else:
self._get_installed_skills()
response_data = self._build_response_data()
@ -55,9 +50,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
def _get_installed_skills(self):
skill_repo = DeviceSkillRepository(self.db)
installed_skills = skill_repo.get_skill_manifest_for_account(
self.account.id
)
installed_skills = skill_repo.get_skill_manifest_for_account(self.account.id)
for skill in installed_skills:
self.installed_skills[skill.skill_id].append(skill)
@ -67,18 +60,13 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
for skill_id, skills in self.installed_skills.items():
skill_aggregator = SkillManifestAggregator(skills)
skill_aggregator.aggregate_skill_status()
if skill_aggregator.aggregate_skill.install_status == 'failed':
failure_reasons[skill_id] = (
skill_aggregator.aggregate_skill.install_failure_reason
)
install_statuses[skill_id] = (
skill_aggregator.aggregate_skill.install_status
)
if skill_aggregator.aggregate_skill.install_status == "failed":
failure_reasons[
skill_id
] = skill_aggregator.aggregate_skill.install_failure_reason
install_statuses[skill_id] = skill_aggregator.aggregate_skill.install_status
return dict(
installStatuses=install_statuses,
failureReasons=failure_reasons
)
return dict(installStatuses=install_statuses, failureReasons=failure_reasons)
class SkillManifestAggregator(object):
@ -96,7 +84,7 @@ class SkillManifestAggregator(object):
"""
self._validate_install_status()
self._determine_install_status()
if self.aggregate_skill.install_status == 'failed':
if self.aggregate_skill.install_status == "failed":
self._determine_failure_reason()
def _validate_install_status(self):
@ -104,7 +92,7 @@ class SkillManifestAggregator(object):
if skill.install_status not in VALID_STATUS_VALUES:
raise ValueError(
'"{install_status}" is not a supported value of the '
'installation field in the skill manifest'.format(
"installation field in the skill manifest".format(
install_status=skill.install_status
)
)
@ -120,33 +108,24 @@ class SkillManifestAggregator(object):
If the install fails on any device, the install will be flagged as a
failed install in the Marketplace.
"""
failed = [
skill.install_status == 'failed' for skill in self.installed_skills
]
installing = [
s.install_status == 'installing' for s in self.installed_skills
]
failed = [skill.install_status == "failed" for skill in self.installed_skills]
installing = [s.install_status == "installing" for s in self.installed_skills]
uninstalling = [
skill.install_status == 'uninstalling' for skill in
self.installed_skills
]
installed = [
s.install_status == 'installed' for s in self.installed_skills
skill.install_status == "uninstalling" for skill in self.installed_skills
]
installed = [s.install_status == "installed" for s in self.installed_skills]
if any(failed):
self.aggregate_skill.install_status = 'failed'
self.aggregate_skill.install_status = "failed"
elif any(installing):
self.aggregate_skill.install_status = 'installing'
self.aggregate_skill.install_status = "installing"
elif any(uninstalling):
self.aggregate_skill.install_status = 'uninstalling'
self.aggregate_skill.install_status = "uninstalling"
elif all(installed):
self.aggregate_skill.install_status = 'installed'
self.aggregate_skill.install_status = "installed"
def _determine_failure_reason(self):
"""When a skill fails to install, determine the reason"""
for skill in self.installed_skills:
if skill.install_status == 'failed':
self.aggregate_skill.failure_reason = (
skill.install_failure_reason
)
if skill.install_status == "failed":
self.aggregate_skill.failure_reason = skill.install_failure_reason
break

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -29,14 +29,15 @@ from selene.data.device import DeviceRepository
class UpdateDevice(Model):
coreVersion = StringType(default='unknown')
platform = StringType(default='unknown')
coreVersion = StringType(default="unknown")
platform = StringType(default="unknown")
platform_build = StringType()
enclosureVersion = StringType(default='unknown')
enclosureVersion = StringType(default="unknown")
class DeviceEndpoint(PublicEndpoint):
"""Return the device entity using the device_id"""
def __init__(self):
super(DeviceEndpoint, self).__init__()
@ -53,13 +54,13 @@ class DeviceEndpoint(PublicEndpoint):
coreVersion=device.core_version,
enclosureVersion=device.enclosure_version,
platform=device.platform,
user=dict(uuid=device.account_id)
user=dict(uuid=device.account_id),
)
response = response_data, HTTPStatus.OK
self._add_etag(device_etag_key(device_id))
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response
@ -69,10 +70,10 @@ class DeviceEndpoint(PublicEndpoint):
update_device = UpdateDevice(payload)
update_device.validate()
updates = dict(
platform=payload.get('platform') or 'unknown',
enclosure_version=payload.get('enclosureVersion') or 'unknown',
core_version=payload.get('coreVersion') or 'unknown'
platform=payload.get("platform") or "unknown",
enclosure_version=payload.get("enclosureVersion") or "unknown",
core_version=payload.get("coreVersion") or "unknown",
)
DeviceRepository(self.db).update_device_from_core(device_id, updates)
return '', HTTPStatus.OK
return "", HTTPStatus.OK

View File

@ -25,17 +25,18 @@ from selene.data.device import GeographyRepository
class DeviceLocationEndpoint(PublicEndpoint):
def __init__(self):
super(DeviceLocationEndpoint, self).__init__()
def get(self, device_id):
self._authenticate(device_id)
self._validate_etag(device_location_etag_key(device_id))
location = GeographyRepository(self.db, None).get_location_by_device_id(device_id)
location = GeographyRepository(self.db, None).get_location_by_device_id(
device_id
)
if location:
response = (location, HTTPStatus.OK)
self._add_etag(device_location_etag_key(device_id))
else:
response = ('', HTTPStatus.NOT_FOUND)
response = ("", HTTPStatus.NOT_FOUND)
return response

View File

@ -26,18 +26,15 @@ from selene.data.account import AccountRepository
class OauthServiceEndpoint(PublicEndpoint):
def __init__(self):
super(OauthServiceEndpoint, self).__init__()
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
self.oauth_service_host = os.environ["OAUTH_BASE_URL"]
def get(self, device_id, credentials, oauth_path):
account = AccountRepository(self.db).get_account_by_device_id(device_id)
uuid = account.id
url = '{host}/auth/{credentials}/{oauth_path}'.format(
host=self.oauth_service_host,
credentials=credentials,
oauth_path=oauth_path
url = "{host}/auth/{credentials}/{oauth_path}".format(
host=self.oauth_service_host, credentials=credentials, oauth_path=oauth_path
)
params = dict(uuid=uuid)
response = requests.get(url, params=params)

View File

@ -35,44 +35,44 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
def get(self):
headers = self.request.headers
if 'Authorization' not in headers:
raise AuthenticationError('Oauth token not found')
token_header = self.request.headers['Authorization']
if token_header.startswith('Bearer '):
refresh = token_header[len('Bearer '):]
if "Authorization" not in headers:
raise AuthenticationError("Oauth token not found")
token_header = self.request.headers["Authorization"]
if token_header.startswith("Bearer "):
refresh = token_header[len("Bearer ") :]
session = self._refresh_session_token(refresh)
# Trying to fetch a session using the refresh token
if session:
response = session, HTTPStatus.OK
else:
device = self.request.headers.get('Device')
device = self.request.headers.get("Device")
if device:
# trying to fetch a session using the device uuid
session = self._refresh_session_token_device(device)
if session:
response = session, HTTPStatus.OK
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
return response
def _refresh_session_token(self, refresh: str):
refresh_key = 'device.token.refresh:{}'.format(refresh)
refresh_key = "device.token.refresh:{}".format(refresh)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
device_id = old_login["uuid"]
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)
def _refresh_session_token_device(self, device: str):
refresh_key = 'device.session:{}'.format(device)
refresh_key = "device.session:{}".format(device)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
device_id = old_login["uuid"]
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)

View File

@ -25,6 +25,7 @@ from selene.data.device import SettingRepository
class DeviceSettingEndpoint(PublicEndpoint):
"""Return the device's settings for the API v1 model"""
def __init__(self):
super(DeviceSettingEndpoint, self).__init__()
@ -36,5 +37,5 @@ class DeviceSettingEndpoint(PublicEndpoint):
response = (setting, HTTPStatus.OK)
self._add_etag(device_setting_etag_key(device_id))
else:
response = ('', HTTPStatus.NO_CONTENT)
response = ("", HTTPStatus.NO_CONTENT)
return response

View File

@ -28,7 +28,7 @@ from schematics.types import (
ListType,
IntType,
BooleanType,
TimestampType
TimestampType,
)
from selene.api import PublicEndpoint
@ -43,9 +43,7 @@ class SkillManifestReconciler(object):
self.skill_repo = SkillRepository(self.db)
self.device_manifest = {sm.skill_gid: sm for sm in device_manifest}
self.db_manifest = {ds.skill_gid: ds for ds in db_manifest}
self.device_manifest_global_ids = {
gid for gid in self.device_manifest.keys()
}
self.device_manifest_global_ids = {gid for gid in self.device_manifest.keys()}
self.db_manifest_global_ids = {gid for gid in self.db_manifest}
def reconcile(self):
@ -82,16 +80,14 @@ class SkillManifestReconciler(object):
for gid in skills_to_add:
skill_id = self.skill_repo.ensure_skill_exists(gid)
self.device_manifest[gid].skill_id = skill_id
self.skill_manifest_repo.add_manifest_skill(
self.device_manifest[gid]
)
self.skill_manifest_repo.add_manifest_skill(self.device_manifest[gid])
class RequestManifestSkill(Model):
name = StringType(required=True)
origin = StringType(required=True)
installation = StringType(required=True)
failure_message = StringType(default='')
failure_message = StringType(default="")
status = StringType(required=True)
beta = BooleanType(required=True)
installed = TimestampType(required=True)
@ -126,7 +122,7 @@ class DeviceSkillManifestEndpoint(PublicEndpoint):
self._validate_put_request()
self._update_skill_manifest(device_id)
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _validate_put_request(self):
request_data = SkillManifestRequest(self.request.json)
@ -137,29 +133,27 @@ class DeviceSkillManifestEndpoint(PublicEndpoint):
device_id
)
device_skill_manifest = []
for manifest_skill in self.request.json['skills']:
for manifest_skill in self.request.json["skills"]:
self._convert_manifest_timestamps(manifest_skill)
device_skill_manifest.append(
ManifestSkill(
device_id=device_id,
install_method=manifest_skill['origin'],
install_status=manifest_skill['installation'],
install_failure_reason=manifest_skill.get('failure_message'),
install_ts=manifest_skill['installed'],
skill_gid=manifest_skill['skill_gid'],
update_ts=manifest_skill['updated']
install_method=manifest_skill["origin"],
install_status=manifest_skill["installation"],
install_failure_reason=manifest_skill.get("failure_message"),
install_ts=manifest_skill["installed"],
skill_gid=manifest_skill["skill_gid"],
update_ts=manifest_skill["updated"],
)
)
reconciler = SkillManifestReconciler(
self.db,
device_skill_manifest,
db_skill_manifest
self.db, device_skill_manifest, db_skill_manifest
)
reconciler.reconcile()
@staticmethod
def _convert_manifest_timestamps(manifest_skill):
for key in ('installed', 'updated'):
for key in ("installed", "updated"):
value = manifest_skill[key]
if value:
manifest_skill[key] = datetime.fromtimestamp(value)

View File

@ -33,39 +33,37 @@ from selene.data.skill import (
SettingsDisplayRepository,
Skill,
SkillRepository,
SkillSettingRepository
SkillSettingRepository,
)
from selene.util.cache import DEVICE_SKILL_ETAG_KEY
# matches <submodule_name>|<branch>
GLOBAL_ID_PATTERN = '^([^\|@]+)\|([^\|]+$)'
GLOBAL_ID_PATTERN = "^([^\|@]+)\|([^\|]+$)"
# matches @<device_id>|<submodule_name>|<branch>
GLOBAL_ID_DIRTY_PATTERN = '^@(.*)\|(.*)\|(.*)$'
GLOBAL_ID_DIRTY_PATTERN = "^@(.*)\|(.*)\|(.*)$"
# matches @<device_id>|<folder_name>
GLOBAL_ID_NON_MSM_PATTERN = '^@([^\|]+)\|([^\|]+$)'
GLOBAL_ID_ANY_PATTERN = '(?:{})|(?:{})|(?:{})'.format(
GLOBAL_ID_PATTERN,
GLOBAL_ID_DIRTY_PATTERN,
GLOBAL_ID_NON_MSM_PATTERN
GLOBAL_ID_NON_MSM_PATTERN = "^@([^\|]+)\|([^\|]+$)"
GLOBAL_ID_ANY_PATTERN = "(?:{})|(?:{})|(?:{})".format(
GLOBAL_ID_PATTERN, GLOBAL_ID_DIRTY_PATTERN, GLOBAL_ID_NON_MSM_PATTERN
)
def _normalize_field_value(field):
"""The field values in skillMetadata are all strings, convert to native."""
normalized_value = field.get('value')
if field['type'].lower() == 'checkbox':
if field['value'] in ('false', 'False', '0'):
normalized_value = field.get("value")
if field["type"].lower() == "checkbox":
if field["value"] in ("false", "False", "0"):
normalized_value = False
elif field['value'] in ('true', 'True', '1'):
elif field["value"] in ("true", "True", "1"):
normalized_value = True
elif field['type'].lower() == 'number' and isinstance(field['value'], str):
if field['value']:
normalized_value = float(field['value'])
elif field["type"].lower() == "number" and isinstance(field["value"], str):
if field["value"]:
normalized_value = float(field["value"])
if not normalized_value % 1:
normalized_value = int(field['value'])
normalized_value = int(field["value"])
else:
normalized_value = 0
elif field['value'] == "[]":
elif field["value"] == "[]":
normalized_value = []
return normalized_value
@ -78,6 +76,7 @@ class SkillSettingUpdater(object):
request specifies a single device to update, all devices with
the same skill must be updated as well.
"""
_device_skill_repo = None
_settings_display_repo = None
@ -115,28 +114,27 @@ class SkillSettingUpdater(object):
settings_meta.json file before sending the result to this API. The
settings values are stored separately from the metadata in the database.
"""
settings_definition = self.display_data.get('skillMetadata')
settings_definition = self.display_data.get("skillMetadata")
if settings_definition is not None:
self.settings_values = dict()
sections_without_values = []
for section in settings_definition['sections']:
for section in settings_definition["sections"]:
section_without_values = dict(**section)
for field in section_without_values['fields']:
field_name = field.get('name')
field_value = field.get('value')
for field in section_without_values["fields"]:
field_name = field.get("name")
field_value = field.get("value")
if field_name is not None:
if field_value is not None:
field_value = _normalize_field_value(field)
del(field['value'])
del field["value"]
self.settings_values[field_name] = field_value
sections_without_values.append(section_without_values)
settings_definition['sections'] = sections_without_values
settings_definition["sections"] = sections_without_values
def _get_skill_id(self):
"""Get the id of the skill in the request"""
skill_global_id = (
self.display_data.get('skill_gid') or
self.display_data.get('identifier')
skill_global_id = self.display_data.get("skill_gid") or self.display_data.get(
"identifier"
)
skill_repo = SkillRepository(self.db)
skill_id = skill_repo.ensure_skill_exists(skill_global_id)
@ -145,15 +143,10 @@ class SkillSettingUpdater(object):
def _ensure_settings_display_exists(self) -> bool:
"""If the settings display changed, a new row needs to be added."""
new_settings_display = False
self.settings_display = SettingsDisplay(
self.skill.id,
self.display_data
)
self.settings_display.id = (
self.settings_display_repo.get_settings_display_id(
self.settings_display = SettingsDisplay(self.skill.id, self.display_data)
self.settings_display.id = self.settings_display_repo.get_settings_display_id(
self.settings_display
)
)
if self.settings_display.id is None:
self.settings_display.id = self.settings_display_repo.add(
self.settings_display
@ -173,11 +166,8 @@ class SkillSettingUpdater(object):
"""Get all the permutations of settings for a skill"""
account_repo = AccountRepository(self.db)
account = account_repo.get_account_by_device_id(self.device_id)
skill_settings = (
self.device_skill_repo.get_skill_settings_for_account(
account.id,
self.skill.id
)
skill_settings = self.device_skill_repo.get_skill_settings_for_account(
account.id, self.skill.id
)
return skill_settings
@ -187,14 +177,14 @@ class SkillSettingUpdater(object):
for skill_setting in skill_settings:
if self.device_id in skill_setting.device_ids:
device_skill_found = True
if skill_setting.install_method in ('voice', 'cli'):
if skill_setting.install_method in ("voice", "cli"):
devices_to_update = [self.device_id]
else:
devices_to_update = skill_setting.device_ids
self.device_skill_repo.upsert_device_skill_settings(
devices_to_update,
self.settings_display,
self._merge_settings_values(skill_setting.settings_values)
self._merge_settings_values(skill_setting.settings_values),
)
break
@ -225,9 +215,7 @@ class SkillSettingUpdater(object):
manifest endpoint in some cases.
"""
self.device_skill_repo.upsert_device_skill_settings(
[self.device_id],
self.settings_display,
self._merge_settings_values()
[self.device_id], self.settings_display, self._merge_settings_values()
)
@ -267,15 +255,16 @@ class RequestSkill(Model):
identifier = StringType()
def validate_skill_gid(self, data, value):
if data['skill_gid'] is None and data['identifier'] is None:
if data["skill_gid"] is None and data["identifier"] is None:
raise ValidationError(
'skill should have either skill_gid or identifier defined'
"skill should have either skill_gid or identifier defined"
)
return value
class DeviceSkillSettingsEndpoint(PublicEndpoint):
"""Fetch all skills associated with a device using the API v1 format"""
_device_skill_repo = None
_skill_repo = None
_skill_setting_repo = None
@ -317,23 +306,19 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
"""
self._authenticate(device_id)
self._validate_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
device_skills = self.skill_setting_repo.get_skill_settings_for_device(
device_id
)
device_skills = self.skill_setting_repo.get_skill_settings_for_device(device_id)
if device_skills:
response_data = self._build_response_data(device_skills)
response = Response(
json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
else:
response = Response(
'',
status=HTTPStatus.NO_CONTENT,
content_type='application/json'
"", status=HTTPStatus.NO_CONTENT, content_type="application/json"
)
return response
@ -341,7 +326,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
response_data = []
for skill in device_skills:
response_skill = dict(uuid=skill.skill_id)
settings_definition = skill.settings_display.get('skillMetadata')
settings_definition = skill.settings_display.get("skillMetadata")
if settings_definition:
settings_sections = self._apply_settings_values(
settings_definition, skill.settings_values
@ -350,10 +335,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
response_skill.update(
skillMetadata=dict(sections=settings_sections)
)
skill_gid = skill.settings_display.get('skill_gid')
skill_gid = skill.settings_display.get("skill_gid")
if skill_gid is not None:
response_skill.update(skill_gid=skill_gid)
identifier = skill.settings_display.get('identifier')
identifier = skill.settings_display.get("identifier")
if identifier is None:
response_skill.update(identifier=skill_gid)
else:
@ -366,10 +351,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
def _apply_settings_values(settings_definition, settings_values):
"""Build a copy of the settings sections populated with values."""
sections_with_values = []
for section in settings_definition['sections']:
for section in settings_definition["sections"]:
section_with_values = dict(**section)
for field in section_with_values['fields']:
field_name = field.get('name')
for field in section_with_values["fields"]:
field_name = field.get("name")
if field_name is not None and field_name in settings_values:
field.update(value=str(settings_values[field_name]))
sections_with_values.append(section_with_values)
@ -380,9 +365,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
self._authenticate(device_id)
self._validate_put_request()
skill_id = self._update_skill_settings(device_id)
self.etag_manager.expire(
DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)
)
self.etag_manager.expire(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
return dict(uuid=skill_id), HTTPStatus.OK
@ -392,9 +375,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
def _update_skill_settings(self, device_id):
skill_setting_updater = SkillSettingUpdater(
self.db,
device_id,
self.request.json
self.db, device_id, self.request.json
)
skill_setting_updater.update()
self._delete_orphaned_settings_display(
@ -418,6 +399,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
with pre 19.08 versions of mycroft-core. Once those versions are no
longer supported, the older class can be deprecated.
"""
def get(self, device_id):
"""
Retrieve skills installed on device from the database.
@ -433,9 +415,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
def _build_response_data(self, device_id):
device_skill_repo = DeviceSkillRepository(self.db)
device_skills = device_skill_repo.get_skill_settings_for_device(
device_id
)
device_skills = device_skill_repo.get_skill_settings_for_device(device_id)
if device_skills is not None:
response_data = {}
for skill in device_skills:
@ -446,15 +426,13 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
def _build_response(self, device_id, response_data):
if response_data is None:
response = Response(
'',
status=HTTPStatus.NO_CONTENT,
content_type='application/json'
"", status=HTTPStatus.NO_CONTENT, content_type="application/json"
)
else:
response = Response(
json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))

View File

@ -33,10 +33,10 @@ class DeviceSubscriptionEndpoint(PublicEndpoint):
if account:
membership = account.membership
response = (
{'@type': membership.type if membership is not None else 'free'},
HTTPStatus.OK
{"@type": membership.type if membership is not None else "free"},
HTTPStatus.OK,
)
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response

View File

@ -25,13 +25,12 @@ from selene.api import PublicEndpoint
class OauthCallbackEndpoint(PublicEndpoint):
def __init__(self):
super(OauthCallbackEndpoint, self).__init__()
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
self.oauth_service_host = os.environ["OAUTH_BASE_URL"]
def get(self):
params = dict(self.request.args)
url = self.oauth_service_host + '/auth/callback'
url = self.oauth_service_host + "/auth/callback"
response = requests.get(url, params=params)
return response.text, response.status_code

View File

@ -30,20 +30,20 @@ class PremiumVoiceEndpoint(PublicEndpoint):
def get(self, device_id):
self._authenticate(device_id)
arch = self.request.args.get('arch')
arch = self.request.args.get("arch")
account = AccountRepository(self.db).get_account_by_device_id(device_id)
if account and account.membership:
link = self._get_premium_voice_link(arch)
response = {'link': link}, HTTPStatus.OK
response = {"link": link}, HTTPStatus.OK
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response
def _get_premium_voice_link(self, arch):
if arch == 'arm':
response = os.environ['URL_VOICE_ARM']
elif arch == 'x86_64':
response = os.environ['URL_VOICE_X86_64']
if arch == "arm":
response = os.environ["URL_VOICE_ARM"]
elif arch == "x86_64":
response = os.environ["URL_VOICE_X86_64"]
else:
response = ''
response = ""
return response

View File

@ -25,14 +25,13 @@ from selene.data.account import AccountRepository
class StripeWebHookEndpoint(PublicEndpoint):
def __init__(self):
super(StripeWebHookEndpoint, self).__init__()
def post(self):
event = json.loads(self.request.data)
type = event.get('type')
if type == 'customer.subscription.deleted':
customer = event['data']['object']['customer']
type = event.get("type")
if type == "customer.subscription.deleted":
customer = event["data"]["object"]["customer"]
AccountRepository(self.db).end_active_membership(customer)
return '', HTTPStatus.OK
return "", HTTPStatus.OK

View File

@ -34,14 +34,14 @@ class WolframAlphaSimpleEndpoint(PublicEndpoint):
def __init__(self):
super(WolframAlphaSimpleEndpoint, self).__init__()
self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY']
self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL']
self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"]
self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"]
def get(self):
self._authenticate()
params = dict(self.request.args)
params['appid'] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + '/v1/simple', params=params)
params["appid"] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + "/v1/simple", params=params)
code = response.status_code
response = (response.text, code) if code == HTTPStatus.OK else ('', code)
response = (response.text, code) if code == HTTPStatus.OK else ("", code)
return response

View File

@ -30,14 +30,14 @@ class WolframAlphaSpokenEndpoint(PublicEndpoint):
def __init__(self):
super(WolframAlphaSpokenEndpoint, self).__init__()
self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY']
self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL']
self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"]
self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"]
def get(self):
self._authenticate()
params = dict(self.request.args)
params['appid'] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + '/v1/spoken', params=params)
params["appid"] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + "/v1/spoken", params=params)
code = response.status_code
response = (response.text, code) if code == HTTPStatus.OK else ('', code)
response = (response.text, code) if code == HTTPStatus.OK else ("", code)
return response

View File

@ -26,105 +26,104 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not
from selene.api.etag import ETagManager, device_location_etag_key
@when('a api call to get the location is done')
@when("a api call to get the location is done")
def get_device_location(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('the location should be retrieved')
@then("the location should be retrieved")
def validate_location(context):
response = context.get_location_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
location = json.loads(response.data)
assert_that(location, has_key('coordinate'))
assert_that(location, has_key('timezone'))
assert_that(location, has_key('city'))
assert_that(location, has_key("coordinate"))
assert_that(location, has_key("timezone"))
assert_that(location, has_key("city"))
coordinate = location['coordinate']
assert_that(coordinate, has_key('latitude'))
assert_that(coordinate, has_key('longitude'))
coordinate = location["coordinate"]
assert_that(coordinate, has_key("latitude"))
assert_that(coordinate, has_key("longitude"))
timezone = location['timezone']
assert_that(timezone, has_key('name'))
assert_that(timezone, has_key('code'))
assert_that(timezone, has_key('offset'))
assert_that(timezone, has_key('dstOffset'))
timezone = location["timezone"]
assert_that(timezone, has_key("name"))
assert_that(timezone, has_key("code"))
assert_that(timezone, has_key("offset"))
assert_that(timezone, has_key("dstOffset"))
city = location['city']
assert_that(city, has_key('name'))
assert_that(city, has_key('state'))
city = location["city"]
assert_that(city, has_key("name"))
assert_that(city, has_key("state"))
state = city['state']
assert_that(state, has_key('name'))
assert_that(state, has_key('country'))
assert_that(state, has_key('code'))
state = city["state"]
assert_that(state, has_key("name"))
assert_that(state, has_key("country"))
assert_that(state, has_key("code"))
country = state['country']
assert_that(country, has_key('name'))
assert_that(country, has_key('code'))
country = state["country"]
assert_that(country, has_key("name"))
assert_that(country, has_key("code"))
@given('an expired etag from a location entity')
@given("an expired etag from a location entity")
def expire_location_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
context.expired_location_etag = etag_manager.get(device_location_etag_key(device_id))
device_id = context.device_login["uuid"]
context.expired_location_etag = etag_manager.get(
device_location_etag_key(device_id)
)
etag_manager.expire_device_location_etag_by_device_id(device_id)
@when('try to get the location using the expired etag')
@when("try to get the location using the expired etag")
def get_using_expired_etag(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
device_id = login["uuid"]
access_token = login["accessToken"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': context.expired_location_etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": context.expired_location_etag,
}
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('an etag associated with the location should be created')
@then("an etag associated with the location should be created")
def validate_etag(context):
response = context.get_location_response
new_location_etag = response.headers.get('ETag')
new_location_etag = response.headers.get("ETag")
assert_that(new_location_etag, not_none())
assert_that(new_location_etag, is_not(context.expired_location_etag))
@given('a valid etag from a location entity')
@given("a valid etag from a location entity")
def valid_etag(context):
etag_manager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.valid_location_etag = etag_manager.get(device_location_etag_key(device_id))
@when('try to get the location using a valid etag')
@when("try to get the location using a valid etag")
def get_using_valid_etag(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
device_id = login["uuid"]
access_token = login["accessToken"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': context.valid_location_etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": context.valid_location_etag,
}
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('the location endpoint should return 304')
@then("the location endpoint should return 304")
def validate_response_valid_etag(context):
response = context.get_location_response
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))

View File

@ -24,42 +24,42 @@ from behave import when, then
from hamcrest import assert_that, equal_to, has_key, is_not
@when('the session token is refreshed')
@when("the session token is refreshed")
def refresh_token(context):
login = json.loads(context.activate_device_response.data)
refresh = login['refreshToken']
refresh = login["refreshToken"]
context.refresh_token_response = context.client.get(
'/v1/auth/token',
headers={'Authorization': 'Bearer {token}'.format(token=refresh)}
"/v1/auth/token",
headers={"Authorization": "Bearer {token}".format(token=refresh)},
)
@then('a valid new session entity should be returned')
@then("a valid new session entity should be returned")
def validate_refresh_token(context):
response = context.refresh_token_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
new_login = json.loads(response.data)
assert_that(new_login, has_key(equal_to('uuid')))
assert_that(new_login, has_key(equal_to('accessToken')))
assert_that(new_login, has_key(equal_to('refreshToken')))
assert_that(new_login, has_key(equal_to('expiration')))
assert_that(new_login, has_key(equal_to("uuid")))
assert_that(new_login, has_key(equal_to("accessToken")))
assert_that(new_login, has_key(equal_to("refreshToken")))
assert_that(new_login, has_key(equal_to("expiration")))
old_login = json.loads(context.activate_device_response.data)
assert_that(new_login['uuid']), equal_to(old_login['uuid'])
assert_that(new_login['accessToken'], is_not(equal_to(old_login['accessToken'])))
assert_that(new_login['refreshToken'], is_not(equal_to(old_login['refreshToken'])))
assert_that(new_login["uuid"]), equal_to(old_login["uuid"])
assert_that(new_login["accessToken"], is_not(equal_to(old_login["accessToken"])))
assert_that(new_login["refreshToken"], is_not(equal_to(old_login["refreshToken"])))
@when('try to refresh an invalid refresh token')
@when("try to refresh an invalid refresh token")
def refresh_invalid_token(context):
context.refresh_invalid_token_response = context.client.get(
'/v1/auth/token',
headers={'Authorization': 'Bearer {token}'.format(token='123')}
"/v1/auth/token",
headers={"Authorization": "Bearer {token}".format(token="123")},
)
@then('401 status code should be returned')
@then("401 status code should be returned")
def validate_refresh_invalid_token(context):
response = context.refresh_invalid_token_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))

View File

@ -29,193 +29,177 @@ from selene.testing.skill import add_skill, build_label_field, build_text_field
from selene.util.cache import DEVICE_SKILL_ETAG_KEY
@given('skill settings with a new value')
@given("skill settings with a new value")
def change_skill_setting_value(context):
_, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
field_with_value = section['fields'][1]
field_with_value['value'] = 'New device text value'
_, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
field_with_value = section["fields"][1]
field_with_value["value"] = "New device text value"
@given('skill settings with a deleted field')
@given("skill settings with a deleted field")
def delete_field_from_settings(context):
_, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
context.removed_field = section['fields'].pop(1)
context.remaining_field = section['fields'][1]
_, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
context.removed_field = section["fields"].pop(1)
context.remaining_field = section["fields"][1]
@given('a valid device skill E-tag')
@given("a valid device skill E-tag")
def set_skill_setting_etag(context):
context.device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
@given('an expired device skill E-tag')
@given("an expired device skill E-tag")
def expire_skill_setting_etag(context):
valid_device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
context.device_skill_etag = context.etag_manager.expire(
valid_device_skill_etag
)
context.device_skill_etag = context.etag_manager.expire(valid_device_skill_etag)
@given('settings for a skill not assigned to the device')
@given("settings for a skill not assigned to the device")
def add_skill_not_assigned_to_device(context):
foobar_skill, foobar_settings_display = add_skill(
context.db,
skill_global_id='foobar-skill|19.02',
settings_fields=[build_label_field(), build_text_field()]
skill_global_id="foobar-skill|19.02",
settings_fields=[build_label_field(), build_text_field()],
)
section = foobar_settings_display.display_data['skillMetadata']['sections'][0]
field_with_value = section['fields'][1]
field_with_value['value'] = 'New skill text value'
section = foobar_settings_display.display_data["skillMetadata"]["sections"][0]
field_with_value = section["fields"][1]
field_with_value["value"] = "New skill text value"
context.skills.update(foobar=(foobar_skill, foobar_settings_display))
@when('a device requests the settings for its skills')
@when("a device requests the settings for its skills")
def get_device_skill_settings(context):
if hasattr(context, 'device_skill_etag'):
context.request_header[ETAG_REQUEST_HEADER_KEY] = (
context.device_skill_etag
)
if hasattr(context, "device_skill_etag"):
context.request_header[ETAG_REQUEST_HEADER_KEY] = context.device_skill_etag
context.response = context.client.get(
'/v1/device/{device_id}/skill'.format(device_id=context.device_id),
content_type='application/json',
headers=context.request_header
"/v1/device/{device_id}/skill".format(device_id=context.device_id),
content_type="application/json",
headers=context.request_header,
)
@when('the device sends a request to update the {skill} skill settings')
@when("the device sends a request to update the {skill} skill settings")
def update_skill_settings(context, skill):
_, settings_display = context.skills[skill]
context.response = context.client.put(
'/v1/device/{device_id}/skill'.format(device_id=context.device_id),
"/v1/device/{device_id}/skill".format(device_id=context.device_id),
data=json.dumps(settings_display.display_data),
content_type='application/json',
headers=context.request_header
content_type="application/json",
headers=context.request_header,
)
@when('the device requests a skill to be deleted')
@when("the device requests a skill to be deleted")
def delete_skill(context):
foo_skill, _ = context.skills['foo']
foo_skill, _ = context.skills["foo"]
context.response = context.client.delete(
'/v1/device/{device_id}/skill/{skill_gid}'.format(
device_id=context.device_id,
skill_gid=foo_skill.skill_gid
"/v1/device/{device_id}/skill/{skill_gid}".format(
device_id=context.device_id, skill_gid=foo_skill.skill_gid
),
headers=context.request_header
headers=context.request_header,
)
@then('the settings are returned')
@then("the settings are returned")
def validate_response(context):
response = context.response.json
assert_that(len(response), equal_to(2))
foo_skill, foo_settings_display = context.skills['foo']
foo_skill, foo_settings_display = context.skills["foo"]
foo_skill_expected_result = dict(
uuid=foo_skill.id,
skill_gid=foo_skill.skill_gid,
identifier=foo_settings_display.display_data['identifier']
identifier=foo_settings_display.display_data["identifier"],
)
assert_that(foo_skill_expected_result, is_in(response))
bar_skill, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
text_field = section['fields'][1]
text_field['value'] = 'Device text value'
checkbox_field = section['fields'][2]
checkbox_field['value'] = 'false'
bar_skill, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
text_field = section["fields"][1]
text_field["value"] = "Device text value"
checkbox_field = section["fields"][2]
checkbox_field["value"] = "false"
bar_skill_expected_result = dict(
uuid=bar_skill.id,
skill_gid=bar_skill.skill_gid,
identifier=bar_settings_display.display_data['identifier'],
skillMetadata=bar_settings_display.display_data['skillMetadata']
identifier=bar_settings_display.display_data["identifier"],
skillMetadata=bar_settings_display.display_data["skillMetadata"],
)
assert_that(bar_skill_expected_result, is_in(response))
@then('the device skill E-tag is expired')
@then("the device skill E-tag is expired")
def check_for_expired_etag(context):
"""An E-tag is expired by changing its value."""
expired_device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
assert_that(
expired_device_skill_etag.decode(),
is_not(equal_to(context.device_skill_etag))
expired_device_skill_etag.decode(), is_not(equal_to(context.device_skill_etag))
)
def _get_device_skill_settings(context):
"""Minimize DB hits and code duplication by getting these values once."""
if not hasattr(context, 'device_skill_settings'):
if not hasattr(context, "device_skill_settings"):
settings_repo = SkillSettingRepository(context.db)
context.device_skill_settings = (
settings_repo.get_skill_settings_for_device(context.device_id)
context.device_skill_settings = settings_repo.get_skill_settings_for_device(
context.device_id
)
context.device_settings_values = [
dss.settings_values for dss in context.device_skill_settings
]
@then('the skill settings are updated with the new value')
@then("the skill settings are updated with the new value")
def validate_updated_skill_setting_value(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(2))
expected_settings_values = dict(
textfield='New device text value',
checkboxfield='false'
)
assert_that(
expected_settings_values,
is_in(context.device_settings_values)
textfield="New device text value", checkboxfield="false"
)
assert_that(expected_settings_values, is_in(context.device_settings_values))
@then('the skill is assigned to the device with the settings populated')
@then("the skill is assigned to the device with the settings populated")
def validate_updated_skill_setting_value(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(3))
expected_settings_values = dict(textfield='New skill text value')
assert_that(
expected_settings_values,
is_in(context.device_settings_values)
)
expected_settings_values = dict(textfield="New skill text value")
assert_that(expected_settings_values, is_in(context.device_settings_values))
@then('an E-tag is generated for these settings')
@then("an E-tag is generated for these settings")
def get_skills_etag(context):
response_headers = context.response.headers
response_etag = response_headers['ETag']
response_etag = response_headers["ETag"]
skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
assert_that(skill_etag.decode(), equal_to(response_etag))
@then('the field is no longer in the skill settings')
@then("the field is no longer in the skill settings")
def validate_skill_setting_field_removed(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(2))
# The removed field should no longer be in the settings values but the
# value of the field that was not deleted should remain
assert_that(
dict(checkboxfield='false'),
is_in(context.device_settings_values)
)
assert_that(dict(checkboxfield="false"), is_in(context.device_settings_values))
new_section = dict(fields=None)
for device_skill_setting in context.device_skill_settings:
skill_gid = device_skill_setting.settings_display['skill_gid']
if skill_gid.startswith('bar'):
skill_gid = device_skill_setting.settings_display["skill_gid"]
if skill_gid.startswith("bar"):
new_settings_display = device_skill_setting.settings_display
new_skill_definition = new_settings_display['skillMetadata']
new_section = new_skill_definition['sections'][0]
new_skill_definition = new_settings_display["skillMetadata"]
new_section = new_skill_definition["sections"][0]
# The removed field should no longer be in the settings values but the
# value of the field that was not deleted should remain
assert_that(context.removed_field, not is_in(new_section['fields']))
assert_that(context.remaining_field, is_in(new_section['fields']))
assert_that(context.removed_field, not is_in(new_section["fields"]))
assert_that(context.remaining_field, is_in(new_section["fields"]))

View File

@ -27,77 +27,75 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not
from selene.api.etag import ETagManager, device_etag_key
new_fields = dict(
platform='mycroft_mark_1',
coreVersion='19.2.0',
enclosureVersion='1.4.0'
platform="mycroft_mark_1", coreVersion="19.2.0", enclosureVersion="1.4.0"
)
@when('device is retrieved')
@when("device is retrieved")
def get_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
device_id = context.device_login["uuid"]
context.get_device_response = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}".format(uuid=device_id), headers=headers
)
context.device_etag = context.get_device_response.headers.get('ETag')
context.device_etag = context.get_device_response.headers.get("ETag")
@then('a valid device should be returned')
@then("a valid device should be returned")
def validate_response(context):
response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data)
assert_that(device, has_key('uuid'))
assert_that(device, has_key('name'))
assert_that(device, has_key('description'))
assert_that(device, has_key('coreVersion'))
assert_that(device, has_key('enclosureVersion'))
assert_that(device, has_key('platform'))
assert_that(device, has_key('user'))
assert_that(device['user'], has_key('uuid'))
assert_that(device['user']['uuid'], equal_to(context.account.id))
assert_that(device, has_key("uuid"))
assert_that(device, has_key("name"))
assert_that(device, has_key("description"))
assert_that(device, has_key("coreVersion"))
assert_that(device, has_key("enclosureVersion"))
assert_that(device, has_key("platform"))
assert_that(device, has_key("user"))
assert_that(device["user"], has_key("uuid"))
assert_that(device["user"]["uuid"], equal_to(context.account.id))
@when('try to fetch a device without the authorization header')
@when("try to fetch a device without the authorization header")
def get_invalid_device(context):
context.get_invalid_device_response = context.client.get('/v1/device/{uuid}'.format(uuid=str(uuid.uuid4())))
@when('try to fetch a not allowed device')
def get_not_allowed_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
context.get_invalid_device_response = context.client.get(
'/v1/device/{uuid}'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}".format(uuid=str(uuid.uuid4()))
)
@then('a 401 status code should be returned')
@when("try to fetch a not allowed device")
def get_not_allowed_device(context):
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_invalid_device_response = context.client.get(
"/v1/device/{uuid}".format(uuid=str(uuid.uuid4())), headers=headers
)
@then("a 401 status code should be returned")
def validate_invalid_response(context):
response = context.get_invalid_device_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
@when('the device is updated')
@when("the device is updated")
def update_device(context):
login = context.device_login
access_token = login['accessToken']
device_id = login['uuid']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = login["accessToken"]
device_id = login["uuid"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.update_device_response = context.client.patch(
'/v1/device/{uuid}'.format(uuid=device_id),
"/v1/device/{uuid}".format(uuid=device_id),
data=json.dumps(new_fields),
content_type='application_json',
headers=headers
content_type="application_json",
headers=headers,
)
@then('the information should be updated')
@then("the information should be updated")
def validate_update(context):
response = context.update_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
@ -105,74 +103,72 @@ def validate_update(context):
response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data)
assert_that(device, has_key('name'))
assert_that(device['coreVersion'], equal_to(new_fields['coreVersion']))
assert_that(device['enclosureVersion'], equal_to(new_fields['enclosureVersion']))
assert_that(device['platform'], equal_to(new_fields['platform']))
assert_that(device, has_key("name"))
assert_that(device["coreVersion"], equal_to(new_fields["coreVersion"]))
assert_that(device["enclosureVersion"], equal_to(new_fields["enclosureVersion"]))
assert_that(device["platform"], equal_to(new_fields["platform"]))
@given('a device with a valid etag')
@given("a device with a valid etag")
def get_device_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.device_etag = etag_manager.get(device_etag_key(device_id))
@when('try to fetch a device using a valid etag')
@when("try to fetch a device using a valid etag")
def get_device_using_etag(context):
etag = context.device_etag
assert_that(etag, not_none())
access_token = context.device_login['accessToken']
device_uuid = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_uuid = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.response_using_etag = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_uuid),
headers=headers
"/v1/device/{uuid}".format(uuid=device_uuid), headers=headers
)
@then('304 status code should be returned by the device endpoint')
@then("304 status code should be returned by the device endpoint")
def validate_etag(context):
response = context.response_using_etag
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))
@given('a device\'s etag expired by the web ui')
@given("a device's etag expired by the web ui")
def expire_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.device_etag = etag_manager.get(device_etag_key(device_id))
etag_manager.expire_device_etag_by_device_id(device_id)
@when('try to fetch a device using an expired etag')
@when("try to fetch a device using an expired etag")
def fetch_device_expired_etag(context):
etag = context.device_etag
assert_that(etag, not_none())
access_token = context.device_login['accessToken']
device_uuid = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_uuid = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.response_using_invalid_etag = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_uuid),
headers=headers
"/v1/device/{uuid}".format(uuid=device_uuid), headers=headers
)
@then('should return status 200')
@then("should return status 200")
def validate_status_code(context):
response = context.response_using_invalid_etag
assert_that(response.status_code, equal_to(HTTPStatus.OK))
@then('a new etag')
@then("a new etag")
def validate_new_etag(context):
etag = context.device_etag
response = context.response_using_invalid_etag
etag_from_response = response.headers.get('ETag')
etag_from_response = response.headers.get("ETag")
assert_that(etag, is_not(etag_from_response))

View File

@ -27,115 +27,113 @@ from hamcrest import assert_that, equal_to, has_key, is_not
from selene.api.etag import ETagManager, device_setting_etag_key
@when('try to fetch device\'s setting')
@when("try to fetch device's setting")
def get_device_settings(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.response_setting = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=device_id), headers=headers
)
@then('a valid setting should be returned')
@then("a valid setting should be returned")
def validate_response_setting(context):
response = context.response_setting
assert_that(response.status_code, equal_to(HTTPStatus.OK))
setting = json.loads(response.data)
assert_that(response.status_code, equal_to(HTTPStatus.OK))
assert_that(setting, has_key('uuid'))
assert_that(setting, has_key('systemUnit'))
assert_that(setting['systemUnit'], equal_to('imperial'))
assert_that(setting, has_key('timeFormat'))
assert_that(setting, has_key('dateFormat'))
assert_that(setting, has_key('optIn'))
assert_that(setting['optIn'], equal_to(True))
assert_that(setting, has_key('ttsSettings'))
tts = setting['ttsSettings']
assert_that(tts, has_key('module'))
assert_that(setting, has_key("uuid"))
assert_that(setting, has_key("systemUnit"))
assert_that(setting["systemUnit"], equal_to("imperial"))
assert_that(setting, has_key("timeFormat"))
assert_that(setting, has_key("dateFormat"))
assert_that(setting, has_key("optIn"))
assert_that(setting["optIn"], equal_to(True))
assert_that(setting, has_key("ttsSettings"))
tts = setting["ttsSettings"]
assert_that(tts, has_key("module"))
@when('the settings endpoint is a called to a not allowed device')
@when("the settings endpoint is a called to a not allowed device")
def get_device_settings(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_invalid_setting_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(uuid.uuid4())), headers=headers
)
@then('a 401 status code should be returned for the setting')
@then("a 401 status code should be returned for the setting")
def validate_response(context):
response = context.get_invalid_setting_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
@given('a device\'s setting with a valid etag')
@given("a device's setting with a valid etag")
def get_device_setting_etag(context):
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
@when('try to fetch the device\'s settings using a valid etag')
@when("try to fetch the device's settings using a valid etag")
def get_device_settings_using_etag(context):
etag = context.device_etag
access_token = context.device_login['accessToken']
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_id = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.get_setting_etag_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(device_id)),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers
)
@then('304 status code should be returned by the device\'s settings endpoint')
@then("304 status code should be returned by the device's settings endpoint")
def validate_etag_response(context):
response = context.get_setting_etag_response
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))
@given('a device\'s setting etag expired by the web ui at device level')
@given("a device's setting etag expired by the web ui at device level")
def expire_etag_device_level(context):
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
etag_manager.expire_device_setting_etag_by_device_id(device_id)
@given('a device\'s setting etag expired by the web ui at account level')
@given("a device's setting etag expired by the web ui at account level")
def expire_etag_account_level(context):
account_id = context.account.id
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
etag_manager.expire_device_setting_etag_by_account_id(account_id)
@when('try to fetch the device\'s settings using an expired etag')
@when("try to fetch the device's settings using an expired etag")
def get_device_settings_using_etag(context):
etag = context.device_etag
access_token = context.device_login['accessToken']
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_id = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.get_setting_invalid_etag_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(device_id)),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers
)
@then('200 status code should be returned by the device\'s setting endpoint and a new etag')
@then(
"200 status code should be returned by the device's setting endpoint and a new etag"
)
def validate_new_etag(context):
etag = context.device_etag
response = context.get_setting_invalid_etag_response
etag_from_response = response.headers.get('ETag')
etag_from_response = response.headers.get("ETag")
assert_that(etag, is_not(etag_from_response))

View File

@ -29,66 +29,63 @@ from selene.data.account import AccountRepository, AccountMembership
from selene.util.db import connect_to_db
@when('the subscription endpoint is called')
@when("the subscription endpoint is called")
def get_device_subscription(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers
)
@then('free type should be returned')
@then("free type should be returned")
def validate_response(context):
response = context.subscription_response
assert_that(response.status_code, HTTPStatus.OK)
subscription = json.loads(response.data)
assert_that(subscription, has_entry('@type', 'free'))
assert_that(subscription, has_entry("@type", "free"))
@when('the subscription endpoint is called for a monthly account')
@when("the subscription endpoint is called for a monthly account")
def get_device_subscription(context):
membership = AccountMembership(
start_date=date.today(),
type='Monthly Membership',
payment_method='Stripe',
payment_account_id='test_monthly',
payment_id='stripe_id'
type="Monthly Membership",
payment_method="Stripe",
payment_account_id="test_monthly",
payment_id="stripe_id",
)
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
db = connect_to_db(context.client_config["DB_CONNECTION_CONFIG"])
AccountRepository(db).add_membership(context.account.id, membership)
context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers
)
@then('monthly type should be returned')
@then("monthly type should be returned")
def validate_response_monthly(context):
response = context.subscription_response
assert_that(response.status_code, HTTPStatus.OK)
subscription = json.loads(response.data)
assert_that(subscription, has_entry('@type', 'Monthly Membership'))
assert_that(subscription, has_entry("@type", "Monthly Membership"))
@when('try to get the subscription for a nonexistent device')
@when("try to get the subscription for a nonexistent device")
def get_subscription_nonexistent_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.invalid_subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=str(uuid.uuid4())), headers=headers
)
@then('401 status code should be returned for the subscription endpoint')
@then("401 status code should be returned for the subscription endpoint")
def validate_nonexistent_device(context):
response = context.invalid_subscription_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -34,6 +34,7 @@ from selene.util.auth import AuthenticationError
class AuthenticateInternalEndpoint(SeleneEndpoint):
"""Sign in a user with an email address and password."""
def __init__(self):
super(AuthenticateInternalEndpoint, self).__init__()
self.account: Account = None
@ -44,22 +45,21 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
self._generate_tokens()
self._set_token_cookies()
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def _authenticate_credentials(self):
"""Compare credentials in request to credentials in database.
:raises AuthenticationError when no match found on database
"""
basic_credentials = self.request.headers['authorization']
basic_credentials = self.request.headers["authorization"]
binary_credentials = a2b_base64(basic_credentials[6:])
email_address, password = binary_credentials.decode().split(':||:')
email_address, password = binary_credentials.decode().split(":||:")
acct_repository = AccountRepository(self.db)
self.account = acct_repository.get_account_from_credentials(
email_address,
password
email_address, password
)
if self.account is None:
raise AuthenticationError('provided credentials not found')
raise AuthenticationError("provided credentials not found")
self.access_token.account_id = self.account.id
self.refresh_token.account_id = self.account.id

View File

@ -26,8 +26,7 @@ from selene.util.auth import get_github_authentication_token
class GithubTokenEndpoint(SeleneEndpoint):
def get(self):
token = get_github_authentication_token(
self.request.args['code'],
self.request.args['state']
self.request.args["code"], self.request.args["state"]
)
return dict(token=token), HTTPStatus.OK

View File

@ -26,11 +26,11 @@ from selene.data.account import AccountRepository
class PasswordChangeEndpoint(SeleneEndpoint):
def put(self):
account_id = self.request.json['accountId']
coded_password = self.request.json['password']
account_id = self.request.json["accountId"]
coded_password = self.request.json["password"]
binary_password = a2b_base64(coded_password)
password = binary_password.decode()
acct_repository = AccountRepository(self.db)
acct_repository.change_password(account_id, password)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT

View File

@ -37,44 +37,40 @@ class PasswordResetEndpoint(SeleneEndpoint):
reset_token = self._generate_reset_token()
self._send_reset_email(reset_token)
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _get_account_from_email(self):
acct_repository = AccountRepository(self.db)
self.account = acct_repository.get_account_by_email(
self.request.json['emailAddress']
self.request.json["emailAddress"]
)
def _generate_reset_token(self):
reset_token = AuthenticationToken(
self.config['RESET_SECRET'],
ONE_HOUR
)
reset_token = AuthenticationToken(self.config["RESET_SECRET"], ONE_HOUR)
reset_token.generate(self.account.id)
return reset_token.jwt
def _send_reset_email(self, reset_token):
url = '{base_url}/change-password?token={reset_token}'.format(
base_url=os.environ['SSO_BASE_URL'],
reset_token=reset_token
url = "{base_url}/change-password?token={reset_token}".format(
base_url=os.environ["SSO_BASE_URL"], reset_token=reset_token
)
email = EmailMessage(
recipient=self.request.json['emailAddress'],
sender='Mycroft AI<no-reply@mycroft.ai>',
subject='Password Reset Request',
template_file_name='reset_password.html',
template_variables=dict(reset_password_url=url)
recipient=self.request.json["emailAddress"],
sender="Mycroft AI<no-reply@mycroft.ai>",
subject="Password Reset Request",
template_file_name="reset_password.html",
template_variables=dict(reset_password_url=url),
)
mailer = SeleneMailer(email)
mailer.send()
def _send_account_not_found_email(self):
email = EmailMessage(
recipient=self.request.json['emailAddress'],
sender='Mycroft AI<no-reply@mycroft.ai>',
subject='Password Reset Request',
template_file_name='account_not_found.html'
recipient=self.request.json["emailAddress"],
sender="Mycroft AI<no-reply@mycroft.ai>",
subject="Password Reset Request",
template_file_name="account_not_found.html",
)
mailer = SeleneMailer(email)
mailer.send()

View File

@ -25,16 +25,16 @@ from selene.data.account import AccountRepository
from selene.util.auth import (
get_facebook_account_email,
get_github_account_email,
get_google_account_email
get_google_account_email,
)
class ValidateEmailEndpoint(SeleneEndpoint):
def get(self):
return_data = dict(accountExists=False, noFederatedEmail=False)
if self.request.args['token']:
if self.request.args["token"]:
email_address = self._get_email_address()
if self.request.args['platform'] != 'Internal' and not email_address:
if self.request.args["platform"] != "Internal" and not email_address:
return_data.update(noFederatedEmail=True)
account_repository = AccountRepository(self.db)
account = account_repository.get_account_by_email(email_address)
@ -44,20 +44,14 @@ class ValidateEmailEndpoint(SeleneEndpoint):
return return_data, HTTPStatus.OK
def _get_email_address(self):
if self.request.args['platform'] == 'Google':
email_address = get_google_account_email(
self.request.args['token']
)
elif self.request.args['platform'] == 'Facebook':
email_address = get_facebook_account_email(
self.request.args['token']
)
elif self.request.args['platform'] == 'GitHub':
email_address = get_github_account_email(
self.request.args['token']
)
if self.request.args["platform"] == "Google":
email_address = get_google_account_email(self.request.args["token"])
elif self.request.args["platform"] == "Facebook":
email_address = get_facebook_account_email(self.request.args["token"])
elif self.request.args["platform"] == "GitHub":
email_address = get_github_account_email(self.request.args["token"])
else:
coded_email = self.request.args['token']
coded_email = self.request.args["token"]
email_address = a2b_base64(coded_email).decode()
return email_address

View File

@ -28,15 +28,12 @@ class ValidateTokenEndpoint(SeleneEndpoint):
return response_data, HTTPStatus.OK
def _validate_token(self):
auth_token = AuthenticationToken(
self.config['RESET_SECRET'],
duration=0
)
auth_token.jwt = self.request.json['token']
auth_token = AuthenticationToken(self.config["RESET_SECRET"], duration=0)
auth_token.jwt = self.request.json["token"]
auth_token.validate()
return dict(
account_id=auth_token.account_id,
token_expired=auth_token.is_expired,
token_invalid=not auth_token.is_valid
token_invalid=not auth_token.is_valid,
)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -30,13 +30,14 @@ from datetime import date, timedelta
import schedule
from selene.util.log import configure_logger
from selene.util.log import configure_selene_logger
_log = configure_logger('selene_job_scheduler')
_log = configure_selene_logger("job_scheduler")
class JobRunner(object):
"""Build the command to run a batch job and run it via subprocess."""
def __init__(self, script_name: str):
self.script_name = script_name
self.job_args: str = None
@ -55,17 +56,14 @@ class JobRunner(object):
date argument only needs to be specified when it is not current date.
"""
if self.job_args is None:
self.job_args = ''
date_arg = ' --date ' + str(self.job_date)
self.job_args = ""
date_arg = " --date " + str(self.job_date)
self.job_args += date_arg
def _build_command(self):
"""Build the command to run the script."""
command = ['pipenv', 'run', 'python']
script_path = os.path.join(
os.environ['SELENE_SCRIPT_DIR'],
self.script_name
)
command = ["pipenv", "run", "python"]
script_path = os.path.join(os.environ["SELENE_SCRIPT_DIR"], self.script_name)
command.append(script_path)
if self.job_args is not None:
command.extend(self.job_args.split())
@ -78,31 +76,31 @@ class JobRunner(object):
result = subprocess.run(command, capture_output=True)
if result.returncode:
_log.error(
'Job {job_name} failed\n'
'\tSTDOUT - {stdout}'
'\tSTDERR - {stderr}'.format(
"Job {job_name} failed\n"
"\tSTDOUT - {stdout}"
"\tSTDERR - {stderr}".format(
job_name=self.script_name[:-3],
stdout=result.stdout.decode(),
stderr=result.stderr.decode()
stderr=result.stderr.decode(),
)
)
else:
log_msg = 'Job {job_name} completed successfully'
log_msg = "Job {job_name} completed successfully"
_log.info(log_msg.format(job_name=self.script_name[:-3]))
def test_scheduler():
"""Run in non-production environments to test scheduler functionality."""
job_runner = JobRunner('test_scheduler.py')
job_runner = JobRunner("test_scheduler.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.job_args = '--arg-with-value test --arg-no-value'
job_runner.job_args = "--arg-with-value test --arg-no-value"
job_runner.run_job()
def load_skills(version):
"""Load the json file from the mycroft-skills-data repository to the DB"""
job_runner = JobRunner('load_skill_display_data.py')
job_runner.job_args = '--core-version {}'.format(version)
job_runner = JobRunner("load_skill_display_data.py")
job_runner.job_args = "--core-version {}".format(version)
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -112,7 +110,7 @@ def parse_core_metrics():
Build a de-normalized table that will make latency research easier.
"""
job_runner = JobRunner('parse_core_metrics.py')
job_runner = JobRunner("parse_core_metrics.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -123,7 +121,7 @@ def partition_api_metrics():
Build a partition on the metric.api_history table for yesterday's date.
Copy yesterday's metric.api table rows to the partition.
"""
job_runner = JobRunner('partition_api_metrics.py')
job_runner = JobRunner("partition_api_metrics.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -135,22 +133,22 @@ def update_device_last_contact():
to associate the time of the call with the device. Dump the contents of
the Redis data to the device.device table on the Postgres database.
"""
job_runner = JobRunner('update_device_last_contact.py')
job_runner = JobRunner("update_device_last_contact.py")
job_runner.run_job()
# Define the schedule
if os.environ['SELENE_ENVIRONMENT'] != 'prod':
if os.environ["SELENE_ENVIRONMENT"] != "prod":
schedule.every(5).minutes.do(test_scheduler)
schedule.every().day.at('00:00').do(partition_api_metrics)
schedule.every().day.at('00:05').do(update_device_last_contact)
schedule.every().day.at('00:10').do(parse_core_metrics)
schedule.every().day.at('00:15').do(load_skills, version='19.02')
schedule.every().day.at('00:20').do(load_skills, version='19.08')
schedule.every().day.at('00:25').do(load_skills, version='20.02')
schedule.every().day.at('00:25').do(load_skills, version='20.08')
schedule.every().day.at('00:25').do(load_skills, version='21.02')
schedule.every().day.at("00:00").do(partition_api_metrics)
schedule.every().day.at("00:05").do(update_device_last_contact)
schedule.every().day.at("00:10").do(parse_core_metrics)
schedule.every().day.at("00:15").do(load_skills, version="19.02")
schedule.every().day.at("00:20").do(load_skills, version="19.08")
schedule.every().day.at("00:25").do(load_skills, version="20.02")
schedule.every().day.at("00:25").do(load_skills, version="20.08")
schedule.every().day.at("00:25").do(load_skills, version="21.02")
# Run the schedule
while True:

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -30,12 +30,12 @@ from selene.util.db import DatabaseConnectionConfig
from selene.util.email import EmailMessage, SeleneMailer
mycroft_db = DatabaseConnectionConfig(
host=environ['DB_HOST'],
db_name=environ['DB_NAME'],
user=environ['DB_USER'],
password=environ['DB_PASSWORD'],
port=environ['DB_PORT'],
sslmode=environ['DB_SSL_MODE']
host=environ["DB_HOST"],
db_name=environ["DB_NAME"],
user=environ["DB_USER"],
password=environ["DB_PASSWORD"],
port=int(environ["DB_PORT"]),
sslmode=environ["DB_SSL_MODE"],
)
@ -43,16 +43,16 @@ class DailyReport(SeleneScript):
def __init__(self):
super(DailyReport, self).__init__(__file__)
self._arg_parser.add_argument(
'--run-mode',
help='If the script should run as a job or just once',
choices=['job', 'once'],
"--run-mode",
help="If the script should run as a job or just once",
choices=["job", "once"],
type=str,
default='job'
default="job",
)
def _run(self):
if self.args.run_mode == 'job':
schedule.every().day.at('00:00').do(self._build_report)
if self.args.run_mode == "job":
schedule.every().day.at("00:00").do(self._build_report)
while True:
schedule.run_pending()
time.sleep(1)
@ -65,11 +65,11 @@ class DailyReport(SeleneScript):
user_metrics = AccountRepository(self.db).daily_report(date)
email = EmailMessage(
sender='reports@mycroft.ai',
recipient=os.environ['REPORT_RECIPIENT'],
subject='Mycroft Daily Report - {}'.format(date.strftime('%Y-%m-%d')),
template_file_name='metrics.html',
template_variables=dict(user_metrics=user_metrics)
sender="reports@mycroft.ai",
recipient=os.environ["REPORT_RECIPIENT"],
subject="Mycroft Daily Report - {}".format(date.strftime("%Y-%m-%d")),
template_file_name="metrics.html",
template_variables=dict(user_metrics=user_metrics),
)
mailer = SeleneMailer(email)

View File

@ -30,17 +30,13 @@ import json
from os import environ
from selene.batch import SeleneScript
from selene.data.skill import (
SkillDisplay,
SkillDisplayRepository,
SkillRepository
)
from selene.data.skill import SkillDisplay, SkillDisplayRepository, SkillRepository
from selene.util.github import download_repository_file, log_into_github
GITHUB_USER = environ['GITHUB_USER']
GITHUB_PASSWORD = environ['GITHUB_PASSWORD']
SKILL_DATA_GITHUB_REPO = 'mycroft-skills-data'
SKILL_DATA_FILE_NAME = 'skill-metadata.json'
GITHUB_USER = environ["GITHUB_USER"]
GITHUB_PASSWORD = environ["GITHUB_PASSWORD"]
SKILL_DATA_GITHUB_REPO = "mycroft-skills-data"
SKILL_DATA_FILE_NAME = "skill-metadata.json"
class SkillDisplayUpdater(SeleneScript):
@ -52,16 +48,15 @@ class SkillDisplayUpdater(SeleneScript):
super(SkillDisplayUpdater, self)._define_args()
self._arg_parser.add_argument(
"--core-version",
help='Version of Mycroft Core related to skill display data',
help="Version of Mycroft Core related to skill display data",
required=True,
type=str
type=str,
)
def _run(self):
"""Make it so."""
self.log.info(
"Updating skill display data for core version " +
self.args.core_version
"Updating skill display data for core version " + self.args.core_version
)
self._get_skill_display_data()
self._update_skill_display_table()
@ -73,7 +68,7 @@ class SkillDisplayUpdater(SeleneScript):
github_api,
SKILL_DATA_GITHUB_REPO,
self.args.core_version,
SKILL_DATA_FILE_NAME
SKILL_DATA_FILE_NAME,
)
self.skill_display_data = json.loads(file_contents)
@ -83,20 +78,18 @@ class SkillDisplayUpdater(SeleneScript):
display_repository = SkillDisplayRepository(self.db)
for skill_name, skill_metadata in self.skill_display_data.items():
skill_count += 1
skill_id = skill_repository.ensure_skill_exists(
skill_metadata['skill_gid']
)
skill_id = skill_repository.ensure_skill_exists(skill_metadata["skill_gid"])
# add the skill display row
display_data = SkillDisplay(
skill_id=skill_id,
core_version=self.args.core_version,
display_data=json.dumps(skill_metadata)
display_data=json.dumps(skill_metadata),
)
display_repository.upsert(display_data)
self.log.info("updated {} skills".format(skill_count))
if __name__ == '__main__':
if __name__ == "__main__":
SkillDisplayUpdater().run()

View File

@ -32,7 +32,7 @@ from decimal import Decimal
from selene.batch import SeleneScript
from selene.data.metric import CoreMetricRepository, CoreInteraction
SKILL_HANDLERS_TO_SKIP = ('reset', 'notify', 'prime', 'stop_laugh')
SKILL_HANDLERS_TO_SKIP = ("reset", "notify", "prime", "stop_laugh")
class CoreMetricsParser(SeleneScript):
@ -47,7 +47,7 @@ class CoreMetricsParser(SeleneScript):
def _run(self):
last_interaction_id = None
for metric in self.core_metric_repo.get_metrics_by_date(self.args.date):
if metric.metric_value['id'] != last_interaction_id:
if metric.metric_value["id"] != last_interaction_id:
self._add_interaction_to_db()
self._start_new_interaction(metric)
last_interaction_id = self.interaction.core_id
@ -57,48 +57,44 @@ class CoreMetricsParser(SeleneScript):
def _start_new_interaction(self, metric):
"""Initialize the interaction object"""
self.interaction = CoreInteraction(
core_id=metric.metric_value['id'],
core_id=metric.metric_value["id"],
device_id=metric.device_id,
start_ts=datetime.utcfromtimestamp(
metric.metric_value['start_time']
),
start_ts=datetime.utcfromtimestamp(metric.metric_value["start_time"]),
)
self.stt_start_ts = None
self.playback_start_ts = None
def _add_metric_to_interaction(self, metric_value):
"""Combine all the steps of an interaction into a single record"""
duration = Decimal(str(metric_value['time']))
duration = duration.quantize(Decimal('0.000001'))
if metric_value['system'] == 'stt':
self.interaction.stt_engine = metric_value['stt']
self.interaction.stt_transcription = metric_value['transcription']
duration = Decimal(str(metric_value["time"]))
duration = duration.quantize(Decimal("0.000001"))
if metric_value["system"] == "stt":
self.interaction.stt_engine = metric_value["stt"]
self.interaction.stt_transcription = metric_value["transcription"]
self.interaction.stt_duration = duration
self.stt_start_ts = metric_value['start_time']
elif metric_value['system'] == 'intent_service':
self.interaction.intent_type = metric_value['intent_type']
self.stt_start_ts = metric_value["start_time"]
elif metric_value["system"] == "intent_service":
self.interaction.intent_type = metric_value["intent_type"]
self.interaction.intent_duration = duration
elif metric_value['system'] == 'fallback_handler':
elif metric_value["system"] == "fallback_handler":
self.interaction.fallback_handler_duration = duration
elif metric_value['system'] == 'skill_handler':
if metric_value['handler'] not in SKILL_HANDLERS_TO_SKIP:
self.interaction.skill_handler = metric_value['handler']
elif metric_value["system"] == "skill_handler":
if metric_value["handler"] not in SKILL_HANDLERS_TO_SKIP:
self.interaction.skill_handler = metric_value["handler"]
self.interaction.skill_duration = duration
elif metric_value['system'] == 'speech':
self.interaction.tts_engine = metric_value['tts']
self.interaction.tts_utterance = metric_value['utterance']
elif metric_value["system"] == "speech":
self.interaction.tts_engine = metric_value["tts"]
self.interaction.tts_utterance = metric_value["utterance"]
self.interaction.tts_duration = duration
elif metric_value['system'] == 'speech_playback':
elif metric_value["system"] == "speech_playback":
self.interaction.speech_playback_duration = duration
self.playback_start_ts = metric_value['start_time']
self.playback_start_ts = metric_value["start_time"]
# The user-experienced latency is the time between when the user
# finishes speaking their intent and when the device provides a voice
# response.
if self.stt_start_ts is not None and self.playback_start_ts is not None:
self.interaction.user_latency = (
self.playback_start_ts - self.stt_start_ts
)
self.interaction.user_latency = self.playback_start_ts - self.stt_start_ts
def _add_interaction_to_db(self):
if self.interaction is not None:
@ -107,5 +103,5 @@ class CoreMetricsParser(SeleneScript):
self.core_metric_repo.add_interaction(self.interaction)
if __name__ == '__main__':
if __name__ == "__main__":
CoreMetricsParser().run()

View File

@ -40,5 +40,5 @@ class PartitionApiMetrics(SeleneScript):
api_metrics_repo.remove_by_date(self.args.date)
if __name__ == '__main__':
if __name__ == "__main__":
PartitionApiMetrics().run()

View File

@ -42,24 +42,24 @@ class TestScheduler(SeleneScript):
super(TestScheduler, self)._define_args()
self._arg_parser.add_argument(
"--arg-with-value",
help='Argument to test passing a value with an argument',
help="Argument to test passing a value with an argument",
required=True,
type=str
type=str,
)
self._arg_parser.add_argument(
"--arg-no-value",
help='Argument to test passing a value with an argument',
action="store_true"
help="Argument to test passing a value with an argument",
action="store_true",
)
def _run(self):
self.log.info('Running the scheduler test job')
self.log.info("Running the scheduler test job")
assert self.args.arg_no_value
assert self.args.arg_with_value == 'test'
assert self.args.arg_with_value == "test"
# Tests the logic that overrides the default date in the scheduler.
assert self.args.date == date.today() - timedelta(days=1)
if __name__ == '__main__':
if __name__ == "__main__":
TestScheduler().run()

View File

@ -47,21 +47,18 @@ class UpdateDeviceLastContact(SeleneScript):
devices_updated += 1
device_repo.update_last_contact_ts(device.id, last_contact_ts)
self.log.info(str(devices_updated) + ' devices were active today')
self.log.info(str(devices_updated) + " devices were active today")
def _get_ts_from_cache(self, device_id):
last_contact_ts = None
cache_key = DEVICE_LAST_CONTACT_KEY.format(device_id=device_id)
value = self.cache.get(cache_key)
if value is not None:
last_contact_ts = datetime.strptime(
value.decode(),
'%Y-%m-%d %H:%M:%S.%f'
)
last_contact_ts = datetime.strptime(value.decode(), "%Y-%m-%d %H:%M:%S.%f")
self.cache.delete(cache_key)
return last_contact_ts
if __name__ == '__main__':
if __name__ == "__main__":
UpdateDeviceLastContact().run()

View File

@ -108,7 +108,7 @@ class PostgresDB(object):
sslmode=db_ssl_mode,
)
self.db.autocommit = True
self.db.set_client_encoding('UTF8')
self.db.set_client_encoding("UTF8")
def close_db(self):
self.db.close()

File diff suppressed because it is too large Load Diff

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -48,32 +48,33 @@ class APIConfigError(Exception):
class BaseConfig(object):
"""Base configuration."""
ACCESS_SECRET = os.environ['JWT_ACCESS_SECRET']
ACCESS_SECRET = os.environ["JWT_ACCESS_SECRET"]
DB_CONNECTION_POOL = None
DEBUG = False
ENV = os.environ['SELENE_ENVIRONMENT']
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
ENV = os.environ["SELENE_ENVIRONMENT"]
REFRESH_SECRET = os.environ["JWT_REFRESH_SECRET"]
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
host=os.environ["DB_HOST"],
db_name=os.environ["DB_NAME"],
password=os.environ["DB_PASSWORD"],
port=os.environ.get("DB_PORT", 5432),
user=os.environ["DB_USER"],
sslmode=os.environ.get("DB_SSLMODE"),
)
class DevelopmentConfig(BaseConfig):
DEBUG = True
DOMAIN = '.mycroft.test'
DOMAIN = ".mycroft.test"
class TestConfig(BaseConfig):
DOMAIN = '.mycroft-test.net'
DOMAIN = ".mycroft-test.net"
class ProdConfig(BaseConfig):
DOMAIN = '.mycroft.ai'
DOMAIN = ".mycroft.ai"
def get_base_config():
@ -81,16 +82,12 @@ def get_base_config():
:return: an object containing the configs for the API.
"""
environment_configs = dict(
dev=DevelopmentConfig,
test=TestConfig,
prod=ProdConfig
)
environment_configs = dict(dev=DevelopmentConfig, test=TestConfig, prod=ProdConfig)
try:
environment_name = os.environ['SELENE_ENVIRONMENT']
environment_name = os.environ["SELENE_ENVIRONMENT"]
except KeyError:
raise APIConfigError('the SELENE_ENVIRONMENT variable is not set')
raise APIConfigError("the SELENE_ENVIRONMENT variable is not set")
try:
app_config = environment_configs[environment_name]

View File

@ -31,7 +31,7 @@ from selene.util.cache import DEVICE_LAST_CONTACT_KEY
from selene.util.db import connect_to_db
from selene.util.exceptions import NotModifiedException
selene_api = Blueprint('selene_api', __name__)
selene_api = Blueprint("selene_api", __name__)
@selene_api.app_errorhandler(DataError)
@ -46,7 +46,7 @@ def handle_data_error(error):
@selene_api.app_errorhandler(NotModifiedException)
def handle_not_modified(_):
return '', HTTPStatus.NOT_MODIFIED
return "", HTTPStatus.NOT_MODIFIED
@selene_api.before_app_request
@ -67,26 +67,26 @@ def add_api_metric(http_status):
api = None
# We are not logging metric for the public API until after the socket
# implementation to avoid putting millions of rows a day on the table
for api_name in ('account', 'sso', 'market', 'public'):
for api_name in ("account", "sso", "market", "public"):
if api_name in current_app.name:
api = api_name
if api is not None and int(http_status) != 304:
if 'db' not in global_context:
if "db" not in global_context:
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
current_app.config["DB_CONNECTION_CONFIG"]
)
if 'account_id' in global_context:
if "account_id" in global_context:
account_id = global_context.account_id
else:
account_id = None
if 'device_id' in global_context:
if "device_id" in global_context:
device_id = global_context.device_id
else:
device_id = None
duration = (datetime.utcnow() - global_context.start_ts)
duration = datetime.utcnow() - global_context.start_ts
api_metric = ApiMetric(
access_ts=datetime.utcnow(),
account_id=account_id,
@ -95,7 +95,7 @@ def add_api_metric(http_status):
duration=Decimal(str(duration.total_seconds())),
http_method=request.method,
http_status=int(http_status),
url=global_context.url
url=global_context.url,
)
metric_repository = ApiMetricsRepository(global_context.db)
metric_repository.add(api_metric)
@ -107,6 +107,6 @@ def update_device_last_contact():
This should only be done on public API calls because we are tracking
device activity only.
"""
if 'public' in current_app.name and 'device_id' in global_context:
if "public" in current_app.name and "device_id" in global_context:
key = DEVICE_LAST_CONTACT_KEY.format(device_id=global_context.device_id)
global_context.cache.set(key, str(datetime.utcnow()))

View File

@ -29,20 +29,20 @@ from ..base_endpoint import SeleneEndpoint
class AgreementsEndpoint(SeleneEndpoint):
authentication_required = False
agreement_types = {
'terms-of-use': 'Terms of Use',
'privacy-policy': 'Privacy Policy'
"terms-of-use": "Terms of Use",
"privacy-policy": "Privacy Policy",
}
def get(self, agreement_type):
"""Process HTTP GET request for an agreement."""
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
db = connect_to_db(self.config["DB_CONNECTION_CONFIG"])
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
del agreement["effective_date"]
self.response = agreement, HTTPStatus.OK
return self.response

View File

@ -24,19 +24,19 @@ from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache, DEVICE_SKILL_ETAG_KEY
from selene.util.db import connect_to_db
ETAG_REQUEST_HEADER_KEY = 'If-None-Match'
ETAG_REQUEST_HEADER_KEY = "If-None-Match"
def device_etag_key(device_id: str):
return 'device.etag:{uuid}'.format(uuid=device_id)
return "device.etag:{uuid}".format(uuid=device_id)
def device_setting_etag_key(device_id: str):
return 'device.setting.etag:{uuid}'.format(uuid=device_id)
return "device.setting.etag:{uuid}".format(uuid=device_id)
def device_location_etag_key(device_id: str):
return 'device.location.etag:{uuid}'.format(uuid=device_id)
return "device.location.etag:{uuid}".format(uuid=device_id)
class ETagManager(object):
@ -46,7 +46,7 @@ class ETagManager(object):
def __init__(self, cache: SeleneCache, config: dict):
self.cache: SeleneCache = cache
self.db_connection_config = config['DB_CONNECTION_CONFIG']
self.db_connection_config = config["DB_CONNECTION_CONFIG"]
def get(self, key: str) -> str:
"""Generate a etag with 32 random chars and store it into a given key
@ -54,14 +54,14 @@ class ETagManager(object):
:return etag"""
etag = self.cache.get(key)
if etag is None:
etag = ''.join(random.choice(self.etag_chars) for _ in range(32))
etag = "".join(random.choice(self.etag_chars) for _ in range(32))
self.cache.set(key, etag)
return etag
def expire(self, key):
"""Expires an existent etag
:param key: key where the etag is stored"""
etag = ''.join(random.choice(self.etag_chars) for _ in range(32))
etag = "".join(random.choice(self.etag_chars) for _ in range(32))
self.cache.set(key, etag)
def expire_device_etag_by_device_id(self, device_id: str):

View File

@ -22,7 +22,7 @@ import re
from flask import jsonify, Response
snake_pattern = re.compile(r'_([a-z])')
snake_pattern = re.compile(r"_([a-z])")
def snake_to_camel(name):

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -18,12 +18,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .entity.account import Account, AccountAgreement, AccountMembership
from .entity.agreement import (
Agreement,
PRIVACY_POLICY,
TERMS_OF_USE,
OPEN_DATASET
)
from .entity.agreement import Agreement, PRIVACY_POLICY, TERMS_OF_USE, OPEN_DATASET
from .entity.membership import Membership
from .entity.skill import AccountSkill
from .repository.account import AccountRepository
@ -31,6 +26,6 @@ from .repository.agreement import AgreementRepository
from .repository.membership import (
MembershipRepository,
MONTHLY_MEMBERSHIP,
YEARLY_MEMBERSHIP
YEARLY_MEMBERSHIP,
)
from .repository.skill import AccountSkillRepository

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -20,9 +20,9 @@
from dataclasses import dataclass
from datetime import date
TERMS_OF_USE = 'Terms of Use'
PRIVACY_POLICY = 'Privacy Policy'
OPEN_DATASET = 'Open Dataset'
TERMS_OF_USE = "Terms of Use"
PRIVACY_POLICY = "Privacy Policy"
OPEN_DATASET = "Open Dataset"
@dataclass

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -21,8 +21,8 @@ from selene.data.account import AccountMembership
from ..entity.membership import Membership
from ...repository_base import RepositoryBase
MONTHLY_MEMBERSHIP = 'Monthly Membership'
YEARLY_MEMBERSHIP = 'Yearly Membership'
MONTHLY_MEMBERSHIP = "Monthly Membership"
YEARLY_MEMBERSHIP = "Yearly Membership"
class MembershipRepository(RepositoryBase):
@ -30,37 +30,34 @@ class MembershipRepository(RepositoryBase):
super(MembershipRepository, self).__init__(db, __file__)
def get_membership_types(self):
db_request = self._build_db_request(
sql_file_name='get_membership_types.sql'
)
db_request = self._build_db_request(sql_file_name="get_membership_types.sql")
db_result = self.cursor.select_all(db_request)
return [Membership(**row) for row in db_result]
def get_membership_by_type(self, membership_type: str):
db_request = self._build_db_request(
sql_file_name='get_membership_by_type.sql',
args=dict(type=membership_type)
sql_file_name="get_membership_by_type.sql", args=dict(type=membership_type)
)
db_result = self.cursor.select_one(db_request)
return Membership(**db_result)
def add(self, membership: Membership):
db_request = self._build_db_request(
'add_membership.sql',
"add_membership.sql",
args=dict(
membership_type=membership.type,
rate=membership.rate,
rate_period=membership.rate_period
)
rate_period=membership.rate_period,
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
def remove(self, membership: Membership):
db_request = self._build_db_request(
sql_file_name='delete_membership.sql',
args=dict(membership_id=membership.id)
sql_file_name="delete_membership.sql",
args=dict(membership_id=membership.id),
)
self.cursor.delete(db_request)

View File

@ -30,8 +30,8 @@ class AccountSkillRepository(RepositoryBase):
def get_skills_for_account(self) -> List[AccountSkill]:
db_request = self._build_db_request(
sql_file_name='get_account_skills.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_skills.sql",
args=dict(account_id=self.account_id),
)
db_result = self.cursor.select_all(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -26,7 +26,7 @@ from selene.data.skill import SettingsDisplay
from ..entity.device_skill import (
AccountSkillSettings,
DeviceSkillSettings,
ManifestSkill
ManifestSkill,
)
from ...repository_base import RepositoryBase
@ -40,15 +40,15 @@ class DeviceSkillRepository(RepositoryBase):
) -> List[AccountSkillSettings]:
return self._select_all_into_dataclass(
AccountSkillSettings,
sql_file_name='get_skill_settings_for_account.sql',
args=dict(account_id=account_id, skill_id=skill_id)
sql_file_name="get_skill_settings_for_account.sql",
args=dict(account_id=account_id, skill_id=skill_id),
)
def get_skill_settings_for_device(self, device_id, skill_id=None):
device_skills = self._select_all_into_dataclass(
DeviceSkillSettings,
sql_file_name='get_skill_settings_for_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_skill_settings_for_device.sql",
args=dict(device_id=device_id),
)
if skill_id is None:
skill_settings = device_skills
@ -65,12 +65,10 @@ class DeviceSkillRepository(RepositoryBase):
self, account_id: str, device_names: tuple, skill_name: str
):
db_request = self._build_db_request(
sql_file_name='update_skill_settings.sql',
sql_file_name="update_skill_settings.sql",
args=dict(
account_id=account_id,
device_names=device_names,
skill_name=skill_name
)
account_id=account_id, device_names=device_names, skill_name=skill_name
),
)
self.cursor.update(db_request)
@ -86,13 +84,13 @@ class DeviceSkillRepository(RepositoryBase):
else:
db_settings_values = json.dumps(settings_values)
db_request = self._build_db_request(
sql_file_name='upsert_device_skill_settings.sql',
sql_file_name="upsert_device_skill_settings.sql",
args=dict(
device_id=device_id,
skill_id=settings_display.skill_id,
settings_values=db_settings_values,
settings_display_id=settings_display.id
)
settings_display_id=settings_display.id,
),
)
self.cursor.insert(db_request)
@ -103,76 +101,66 @@ class DeviceSkillRepository(RepositoryBase):
else:
db_settings_values = json.dumps(device_skill.settings_values)
db_request = self._build_db_request(
sql_file_name='update_device_skill_settings.sql',
sql_file_name="update_device_skill_settings.sql",
args=dict(
device_id=device_id,
skill_id=device_skill.skill_id,
settings_display_id=device_skill.settings_display_id,
settings_values=db_settings_values
)
settings_values=db_settings_values,
),
)
self.cursor.update(db_request)
def get_skill_manifest_for_device(
self, device_id: str
) -> List[ManifestSkill]:
def get_skill_manifest_for_device(self, device_id: str) -> List[ManifestSkill]:
return self._select_all_into_dataclass(
dataclass=ManifestSkill,
sql_file_name='get_device_skill_manifest.sql',
args=dict(device_id=device_id)
sql_file_name="get_device_skill_manifest.sql",
args=dict(device_id=device_id),
)
def get_skill_manifest_for_account(
self, account_id: str
) -> List[ManifestSkill]:
def get_skill_manifest_for_account(self, account_id: str) -> List[ManifestSkill]:
return self._select_all_into_dataclass(
dataclass=ManifestSkill,
sql_file_name='get_skill_manifest_for_account.sql',
args=dict(account_id=account_id)
sql_file_name="get_skill_manifest_for_account.sql",
args=dict(account_id=account_id),
)
def update_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='update_skill_manifest.sql',
args=asdict(manifest_skill)
sql_file_name="update_skill_manifest.sql", args=asdict(manifest_skill)
)
self.cursor.update(db_request)
def add_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='add_manifest_skill.sql',
args=asdict(manifest_skill)
sql_file_name="add_manifest_skill.sql", args=asdict(manifest_skill)
)
db_result = self.cursor.insert_returning(db_request)
return db_result['id']
return db_result["id"]
def remove_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='remove_manifest_skill.sql',
sql_file_name="remove_manifest_skill.sql",
args=dict(
device_id=manifest_skill.device_id,
skill_gid=manifest_skill.skill_gid
)
device_id=manifest_skill.device_id, skill_gid=manifest_skill.skill_gid
),
)
self.cursor.delete(db_request)
def get_settings_display_usage(self, settings_display_id: str) -> int:
db_request = self._build_db_request(
sql_file_name='get_settings_display_usage.sql',
args=dict(settings_display_id=settings_display_id)
sql_file_name="get_settings_display_usage.sql",
args=dict(settings_display_id=settings_display_id),
)
db_result = self.cursor.select_one(db_request)
return db_result['usage']
return db_result["usage"]
def remove(self, device_id, skill_id):
db_request = self._build_db_request(
sql_file_name='delete_device_skill.sql',
args=dict(
device_id=device_id,
skill_id=skill_id
)
sql_file_name="delete_device_skill.sql",
args=dict(device_id=device_id, skill_id=skill_id),
)
self.cursor.delete(db_request)

View File

@ -28,8 +28,8 @@ class GeographyRepository(RepositoryBase):
def get_account_geographies(self):
db_request = self._build_db_request(
sql_file_name='get_account_geographies.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_geographies.sql",
args=dict(account_id=self.account_id),
)
db_response = self.cursor.select_all(db_request)
@ -40,10 +40,10 @@ class GeographyRepository(RepositoryBase):
acct_geographies = self.get_account_geographies()
for acct_geography in acct_geographies:
match = (
acct_geography.city == geography.city and
acct_geography.country == geography.country and
acct_geography.region == geography.region and
acct_geography.time_zone == geography.time_zone
acct_geography.city == geography.city
and acct_geography.country == geography.country
and acct_geography.region == geography.region
and acct_geography.time_zone == geography.time_zone
)
if match:
geography_id = acct_geography.id
@ -53,22 +53,22 @@ class GeographyRepository(RepositoryBase):
def add(self, geography: Geography):
db_request = self._build_db_request(
sql_file_name='add_geography.sql',
sql_file_name="add_geography.sql",
args=dict(
account_id=self.account_id,
city=geography.city,
country=geography.country,
region=geography.region,
timezone=geography.time_zone
)
timezone=geography.time_zone,
),
)
db_result = self.cursor.insert_returning(db_request)
return db_result['id']
return db_result["id"]
def get_location_by_device_id(self, device_id):
db_request = self._build_db_request(
sql_file_name='get_location_by_device_id.sql',
args=dict(device_id=device_id)
sql_file_name="get_location_by_device_id.sql",
args=dict(device_id=device_id),
)
return self.cursor.select_one(db_request)

View File

@ -30,8 +30,8 @@ class PreferenceRepository(RepositoryBase):
def get_account_preferences(self) -> AccountPreferences:
db_request = self._build_db_request(
sql_file_name='get_account_preferences.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_preferences.sql",
args=dict(account_id=self.account_id),
)
db_result = self.cursor.select_one(db_request)
@ -46,7 +46,6 @@ class PreferenceRepository(RepositoryBase):
db_request_args = dict(account_id=self.account_id)
db_request_args.update(asdict(preferences))
db_request = self._build_db_request(
sql_file_name='upsert_preferences.sql',
args=db_request_args
sql_file_name="upsert_preferences.sql", args=db_request_args
)
self.cursor.insert(db_request)

View File

@ -21,7 +21,7 @@ from os import path
from selene.util.db import get_sql_from_file, Cursor, DatabaseRequest
SQL_DIR = path.join(path.dirname(__file__), 'sql')
SQL_DIR = path.join(path.dirname(__file__), "sql")
class SettingRepository(object):
@ -30,36 +30,38 @@ class SettingRepository(object):
def get_device_settings_by_device_id(self, device_id):
query = DatabaseRequest(
sql=get_sql_from_file(path.join(SQL_DIR, 'get_device_settings_by_device_id.sql')),
args=dict(device_id=device_id)
sql=get_sql_from_file(
path.join(SQL_DIR, "get_device_settings_by_device_id.sql")
),
args=dict(device_id=device_id),
)
return self.cursor.select_one(query)
def convert_text_to_speech_setting(self, setting_name, engine) -> (str, str):
"""Convert the selene representation of TTS into the tartarus representation, for backward compatibility
with the API v1"""
if engine == 'mimic':
if setting_name == 'trinity':
return 'mimic', 'trinity'
elif setting_name == 'kusal':
return 'mimic2', 'kusal'
if engine == "mimic":
if setting_name == "trinity":
return "mimic", "trinity"
elif setting_name == "kusal":
return "mimic2", "kusal"
else:
return 'mimic', 'ap'
return "mimic", "ap"
else:
return 'google', ''
return "google", ""
def _format_date_v1(self, date: str):
if date == 'DD/MM/YYYY':
result = 'DMY'
if date == "DD/MM/YYYY":
result = "DMY"
else:
result = 'MDY'
result = "MDY"
return result
def _format_time_v1(self, time: str):
if time == '24 Hour':
result = 'full'
if time == "24 Hour":
result = "full"
else:
result = 'half'
result = "half"
return result
def get_device_settings(self, device_id):
@ -69,22 +71,29 @@ class SettingRepository(object):
:return setting entity using the legacy format from the API v1"""
response = self.get_device_settings_by_device_id(device_id)
if response:
if response['listener_setting']['uuid'] is None:
del response['listener_setting']
tts_setting = response['tts_settings']
tts_setting = self.convert_text_to_speech_setting(tts_setting['setting_name'], tts_setting['engine'])
tts_setting = {'module': tts_setting[0], tts_setting[0]: {'voice': tts_setting[1]}}
response['tts_settings'] = tts_setting
response['date_format'] = self._format_date_v1(response['date_format'])
response['time_format'] = self._format_time_v1(response['time_format'])
response['system_unit'] = response['system_unit'].lower()
if response["listener_setting"]["uuid"] is None:
del response["listener_setting"]
tts_setting = response["tts_settings"]
tts_setting = self.convert_text_to_speech_setting(
tts_setting["setting_name"], tts_setting["engine"]
)
tts_setting = {
"module": tts_setting[0],
tts_setting[0]: {"voice": tts_setting[1]},
}
response["tts_settings"] = tts_setting
response["date_format"] = self._format_date_v1(response["date_format"])
response["time_format"] = self._format_time_v1(response["time_format"])
response["system_unit"] = response["system_unit"].lower()
open_dataset = self._get_open_dataset_agreement_by_device_id(device_id)
response['optIn'] = open_dataset is not None
response["optIn"] = open_dataset is not None
return response
def _get_open_dataset_agreement_by_device_id(self, device_id: str):
query = DatabaseRequest(
sql=get_sql_from_file(path.join(SQL_DIR, 'get_open_dataset_agreement_by_device_id.sql')),
args=dict(device_id=device_id)
sql=get_sql_from_file(
path.join(SQL_DIR, "get_open_dataset_agreement_by_device_id.sql")
),
args=dict(device_id=device_id),
)
return self.cursor.select_one(query)

View File

@ -26,7 +26,7 @@ class TextToSpeechRepository(RepositoryBase):
super(TextToSpeechRepository, self).__init__(db, __file__)
def get_voices(self):
db_request = self._build_db_request(sql_file_name='get_voices.sql')
db_request = self._build_db_request(sql_file_name="get_voices.sql")
db_result = self.cursor.select_all(db_request)
return [TextToSpeech(**row) for row in db_result]
@ -37,16 +37,16 @@ class TextToSpeechRepository(RepositoryBase):
:return wake word id
"""
db_request = self._build_db_request(
sql_file_name='add_text_to_speech.sql',
sql_file_name="add_text_to_speech.sql",
args=dict(
wake_word=text_to_speech.setting_name,
account_id=text_to_speech.display_name,
engine=text_to_speech.engine
)
engine=text_to_speech.engine,
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
# def remove(self, wake_word: WakeWord):
# """Delete a wake word from the wake_word table."""

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -27,8 +27,7 @@ class CityRepository(RepositoryBase):
def get_cities_by_region(self, region_id):
db_request = self._build_db_request(
sql_file_name='get_cities_by_region.sql',
args=dict(region_id=region_id)
sql_file_name="get_cities_by_region.sql", args=dict(region_id=region_id)
)
db_result = self.cursor.select_all(db_request)
@ -39,23 +38,22 @@ class CityRepository(RepositoryBase):
city_names = [nm.lower() for nm in possible_city_names]
return self._select_all_into_dataclass(
GeographicLocation,
sql_file_name='get_geographic_location_by_city.sql',
args=dict(possible_city_names=tuple(city_names))
sql_file_name="get_geographic_location_by_city.sql",
args=dict(possible_city_names=tuple(city_names)),
)
def get_biggest_city_in_region(self, region_name):
"""Return the geolocation of the most populous city in a region."""
return self._select_one_into_dataclass(
GeographicLocation,
sql_file_name='get_biggest_city_in_region.sql',
args=dict(region=region_name.lower())
sql_file_name="get_biggest_city_in_region.sql",
args=dict(region=region_name.lower()),
)
def get_biggest_city_in_country(self, country_name):
"""Return the geolocation of the most populous city in a country."""
return self._select_one_into_dataclass(
GeographicLocation,
sql_file_name='get_biggest_city_in_country.sql',
args=dict(country=country_name.lower())
sql_file_name="get_biggest_city_in_country.sql",
args=dict(country=country_name.lower()),
)

View File

@ -26,7 +26,7 @@ class CountryRepository(RepositoryBase):
super(CountryRepository, self).__init__(db, __file__)
def get_countries(self):
db_request = self._build_db_request(sql_file_name='get_countries.sql')
db_request = self._build_db_request(sql_file_name="get_countries.sql")
db_result = self.cursor.select_all(db_request)
return [Country(**row) for row in db_result]

View File

@ -27,8 +27,7 @@ class RegionRepository(RepositoryBase):
def get_regions_by_country(self, country_id):
db_request = self._build_db_request(
sql_file_name='get_regions_by_country.sql',
args=dict(country_id=country_id)
sql_file_name="get_regions_by_country.sql", args=dict(country_id=country_id)
)
db_result = self.cursor.select_all(db_request)

View File

@ -27,8 +27,8 @@ class TimezoneRepository(RepositoryBase):
def get_timezones_by_country(self, country_id):
db_request = self._build_db_request(
sql_file_name='get_timezones_by_country.sql',
args=dict(country_id=country_id)
sql_file_name="get_timezones_by_country.sql",
args=dict(country_id=country_id),
)
db_result = self.cursor.select_all(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -34,7 +34,7 @@ from datetime import date, datetime, time
from ..entity.api import ApiMetric
from ...repository_base import RepositoryBase
DUMP_FILE_DIR = '/opt/selene/dump'
DUMP_FILE_DIR = "/opt/selene/dump"
class ApiMetricsRepository(RepositoryBase):
@ -43,8 +43,7 @@ class ApiMetricsRepository(RepositoryBase):
def add(self, metric: ApiMetric):
db_request = self._build_db_request(
sql_file_name='add_api_metric.sql',
args=asdict(metric)
sql_file_name="add_api_metric.sql", args=asdict(metric)
)
self.cursor.insert(db_request)
@ -53,27 +52,27 @@ class ApiMetricsRepository(RepositoryBase):
start_ts = datetime.combine(partition_date, time.min)
end_ts = datetime.combine(partition_date, time.max)
db_request = self._build_db_request(
sql_file_name='create_api_metric_partition.sql',
sql_file_name="create_api_metric_partition.sql",
args=dict(start_ts=str(start_ts), end_ts=str(end_ts)),
sql_vars=dict(partition=partition_date.strftime('%Y%m%d'))
sql_vars=dict(partition=partition_date.strftime("%Y%m%d")),
)
self.cursor.execute(db_request)
db_request = self._build_db_request(
sql_file_name='create_api_metric_partition_index.sql',
sql_vars=dict(partition=partition_date.strftime('%Y%m%d'))
sql_file_name="create_api_metric_partition_index.sql",
sql_vars=dict(partition=partition_date.strftime("%Y%m%d")),
)
self.cursor.execute(db_request)
def copy_to_partition(self, partition_date: date):
"""Copy rows from metric.api table to metric.api_history."""
dump_file_name = 'api_metrics_' + str(partition_date)
dump_file_name = "api_metrics_" + str(partition_date)
dump_file_path = os.path.join(DUMP_FILE_DIR, dump_file_name)
db_request = self._build_db_request(
sql_file_name='get_api_metrics_for_date.sql',
args=dict(metrics_date=partition_date)
sql_file_name="get_api_metrics_for_date.sql",
args=dict(metrics_date=partition_date),
)
table_name = 'metric.api_history_' + partition_date.strftime('%Y%m%d')
table_name = "metric.api_history_" + partition_date.strftime("%Y%m%d")
self.cursor.dump_query_result_to_file(db_request, dump_file_path)
self.cursor.load_dump_file_to_table(table_name, dump_file_path)
os.remove(dump_file_path)
@ -81,7 +80,7 @@ class ApiMetricsRepository(RepositoryBase):
def remove_by_date(self, partition_date: date):
"""Delete from metric.api table after copying to metric.api_history"""
db_request = self._build_db_request(
sql_file_name='delete_api_metrics_by_date.sql',
args=dict(delete_date=partition_date)
sql_file_name="delete_api_metrics_by_date.sql",
args=dict(delete_date=partition_date),
)
self.cursor.delete(db_request)

View File

@ -31,33 +31,29 @@ class CoreMetricRepository(RepositoryBase):
def add(self, metric: CoreMetric):
db_request_args = asdict(metric)
db_request_args['metric_value'] = json.dumps(
db_request_args['metric_value']
)
db_request_args["metric_value"] = json.dumps(db_request_args["metric_value"])
db_request = self._build_db_request(
sql_file_name='add_core_metric.sql',
args=db_request_args
sql_file_name="add_core_metric.sql", args=db_request_args
)
self.cursor.insert(db_request)
def get_metrics_by_device(self, device_id):
return self._select_all_into_dataclass(
CoreMetric,
sql_file_name='get_core_metric_by_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_core_metric_by_device.sql",
args=dict(device_id=device_id),
)
def get_metrics_by_date(self, metric_date: date) -> List[CoreMetric]:
return self._select_all_into_dataclass(
CoreMetric,
sql_file_name='get_core_timing_metrics_by_date.sql',
args=dict(metric_date=metric_date)
sql_file_name="get_core_timing_metrics_by_date.sql",
args=dict(metric_date=metric_date),
)
def add_interaction(self, interaction: CoreInteraction) -> str:
db_request = self._build_db_request(
sql_file_name='add_core_interaction.sql',
args=asdict(interaction)
sql_file_name="add_core_interaction.sql", args=asdict(interaction)
)
db_result = self.cursor.insert_returning(db_request)

View File

@ -29,7 +29,7 @@ from selene.util.db import (
Cursor,
DatabaseRequest,
DatabaseBatchRequest,
get_sql_from_file
get_sql_from_file,
)
@ -52,7 +52,7 @@ class RepositoryBase(object):
def __init__(self, db, repository_path):
self.db = db
self.cursor = Cursor(db)
self.sql_dir = path.join(path.dirname(repository_path), 'sql')
self.sql_dir = path.join(path.dirname(repository_path), "sql")
def _build_db_request(
self, sql_file_name: str, args: dict = None, sql_vars: dict = None
@ -67,8 +67,7 @@ class RepositoryBase(object):
def _build_db_batch_request(self, sql_file_name: str, args: List[dict]):
"""Build a DatabaseBatchRequest object containing a query and args"""
return DatabaseBatchRequest(
sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)),
args=args
sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)), args=args
)
def _select_one_into_dataclass(self, dataclass, sql_file_name, args=None):

View File

@ -22,7 +22,7 @@ from .entity.skill import Skill
from .entity.skill_setting import (
AccountSkillSetting,
DeviceSkillSetting,
SettingsDisplay
SettingsDisplay,
)
from .repository.display import SkillDisplayRepository
from .repository.setting import SkillSettingRepository

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -26,30 +26,30 @@ class SkillDisplayRepository(RepositoryBase):
super(SkillDisplayRepository, self).__init__(db, __file__)
# TODO: Change this to a value that can be passed in
self.core_version = '21.02'
self.core_version = "21.02"
def get_display_data_for_skills(self):
return self._select_all_into_dataclass(
dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skills.sql',
args=dict(core_version=self.core_version)
sql_file_name="get_display_data_for_skills.sql",
args=dict(core_version=self.core_version),
)
def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
return self._select_one_into_dataclass(
dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skill.sql',
args=dict(skill_display_id=skill_display_id)
sql_file_name="get_display_data_for_skill.sql",
args=dict(skill_display_id=skill_display_id),
)
def upsert(self, skill_display: SkillDisplay):
db_request = self._build_db_request(
sql_file_name='upsert_skill_display_data.sql',
sql_file_name="upsert_skill_display_data.sql",
args=dict(
skill_id=skill_display.skill_id,
core_version=skill_display.core_version,
display_data=skill_display.display_data,
)
),
)
self.cursor.insert(db_request)

View File

@ -32,14 +32,12 @@ class SkillSettingRepository(RepositoryBase):
self.db = db
def get_family_settings(
self,
account_id: str,
family_name: str
self, account_id: str, family_name: str
) -> List[AccountSkillSetting]:
return self._select_all_into_dataclass(
AccountSkillSetting,
sql_file_name='get_settings_for_skill_family.sql',
args=dict(family_name=family_name, account_id=account_id)
sql_file_name="get_settings_for_skill_family.sql",
args=dict(family_name=family_name, account_id=account_id),
)
def get_installer_settings(self, account_id) -> List[AccountSkillSetting]:
@ -47,39 +45,31 @@ class SkillSettingRepository(RepositoryBase):
skills = skill_repo.get_skills_for_account(account_id)
installer_skill_id = None
for skill in skills:
if skill.display_name == 'Installer':
if skill.display_name == "Installer":
installer_skill_id = skill.id
skill_settings = None
if installer_skill_id is not None:
skill_settings = self.get_family_settings(
account_id,
installer_skill_id
)
skill_settings = self.get_family_settings(account_id, installer_skill_id)
return skill_settings
@use_transaction
def update_skill_settings(
self,
account_id,
new_skill_settings: AccountSkillSetting,
skill_ids: List[str]
self, account_id, new_skill_settings: AccountSkillSetting, skill_ids: List[str]
):
if new_skill_settings.settings_values is None:
serialized_settings_values = None
else:
serialized_settings_values = json.dumps(
new_skill_settings.settings_values
)
serialized_settings_values = json.dumps(new_skill_settings.settings_values)
db_request = self._build_db_request(
'update_device_skill_settings.sql',
"update_device_skill_settings.sql",
args=dict(
account_id=account_id,
settings_values=serialized_settings_values,
skill_id=tuple(skill_ids),
device_names=tuple(new_skill_settings.device_names)
)
device_names=tuple(new_skill_settings.device_names),
),
)
self.cursor.update(db_request)
@ -87,6 +77,6 @@ class SkillSettingRepository(RepositoryBase):
"""Return all skills and their settings for a given device id"""
return self._select_all_into_dataclass(
DeviceSkillSetting,
sql_file_name='get_skill_setting_by_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_skill_setting_by_device.sql",
args=dict(device_id=device_id),
)

View File

@ -24,35 +24,34 @@ from ..entity.skill_setting import SettingsDisplay
class SettingsDisplayRepository(RepositoryBase):
def __init__(self, db):
super(SettingsDisplayRepository, self).__init__(db, __file__)
def add(self, settings_display: SettingsDisplay) -> str:
"""Add a new row to the skill.settings_display table."""
db_request = self._build_db_request(
sql_file_name='add_settings_display.sql',
sql_file_name="add_settings_display.sql",
args=dict(
skill_id=settings_display.skill_id,
display_data=json.dumps(settings_display.display_data)
)
display_data=json.dumps(settings_display.display_data),
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
def get_settings_display_id(self, settings_display: SettingsDisplay):
"""Get the ID of a skill's settings definition."""
db_request = self._build_db_request(
sql_file_name='get_settings_display_id.sql',
sql_file_name="get_settings_display_id.sql",
args=dict(
skill_id=settings_display.skill_id,
display_data=json.dumps(settings_display.display_data)
)
display_data=json.dumps(settings_display.display_data),
),
)
result = self.cursor.select_one(db_request)
return None if result is None else result['id']
return None if result is None else result["id"]
def get_settings_definitions_by_gid(self, global_id):
"""Get all matching settings definitions for a global skill ID.
@ -63,14 +62,14 @@ class SettingsDisplayRepository(RepositoryBase):
"""
return self._select_all_into_dataclass(
SettingsDisplay,
sql_file_name='get_settings_definition_by_gid.sql',
args=dict(global_id=global_id)
sql_file_name="get_settings_definition_by_gid.sql",
args=dict(global_id=global_id),
)
def remove(self, settings_display_id: str):
"""Delete a settings definition that is no longer used by any device"""
db_request = self._build_db_request(
sql_file_name='delete_settings_display.sql',
args=dict(settings_display_id=settings_display_id)
sql_file_name="delete_settings_display.sql",
args=dict(settings_display_id=settings_display_id),
)
self.cursor.delete(db_request)

View File

@ -24,8 +24,8 @@ from ...repository_base import RepositoryBase
def extract_family_from_global_id(skill_gid):
id_parts = skill_gid.split('|')
if id_parts[0].startswith('@'):
id_parts = skill_gid.split("|")
if id_parts[0].startswith("@"):
family_name = id_parts[1]
else:
family_name = id_parts[0]
@ -41,8 +41,7 @@ class SkillRepository(RepositoryBase):
def get_skills_for_account(self, account_id) -> List[SkillFamily]:
skills = []
db_request = self._build_db_request(
'get_skills_for_account.sql',
args=dict(account_id=account_id)
"get_skills_for_account.sql", args=dict(account_id=account_id)
)
db_result = self.cursor.select_all(db_request)
if db_result is not None:
@ -54,23 +53,23 @@ class SkillRepository(RepositoryBase):
def get_skill_by_global_id(self, skill_global_id) -> Skill:
return self._select_one_into_dataclass(
dataclass=Skill,
sql_file_name='get_skill_by_global_id.sql',
args=dict(skill_global_id=skill_global_id)
sql_file_name="get_skill_by_global_id.sql",
args=dict(skill_global_id=skill_global_id),
)
@staticmethod
def _extract_settings(skill):
settings = {}
skill_metadata = skill.get('skillMetadata')
skill_metadata = skill.get("skillMetadata")
if skill_metadata:
for section in skill_metadata['sections']:
for field in section['fields']:
if 'name' in field and 'value' in field:
settings[field['name']] = field['value']
field.pop('value', None)
for section in skill_metadata["sections"]:
for field in section["fields"]:
if "name" in field and "value" in field:
settings[field["name"]] = field["value"]
field.pop("value", None)
result = settings, skill
else:
result = '', ''
result = "", ""
return result
def ensure_skill_exists(self, skill_global_id: str) -> str:
@ -85,14 +84,14 @@ class SkillRepository(RepositoryBase):
def _add_skill(self, skill_gid: str, name: str) -> str:
db_request = self._build_db_request(
sql_file_name='add_skill.sql',
args=dict(skill_gid=skill_gid, family_name=name)
sql_file_name="add_skill.sql",
args=dict(skill_gid=skill_gid, family_name=name),
)
db_result = self.cursor.insert_returning(db_request)
# handle both dictionary cursors and namedtuple cursors
try:
skill_id = db_result['id']
skill_id = db_result["id"]
except TypeError:
skill_id = db_result.id
@ -100,7 +99,6 @@ class SkillRepository(RepositoryBase):
def remove_by_gid(self, skill_gid):
db_request = self._build_db_request(
sql_file_name='remove_skill_by_gid.sql',
args=dict(skill_gid=skill_gid)
sql_file_name="remove_skill_by_gid.sql", args=dict(skill_gid=skill_gid)
)
self.cursor.delete(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -25,7 +25,7 @@ from selene.data.account import (
AccountRepository,
OPEN_DATASET,
PRIVACY_POLICY,
TERMS_OF_USE
TERMS_OF_USE,
)
@ -33,19 +33,19 @@ def build_test_account(**overrides):
test_agreements = [
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today()),
AccountAgreement(type=OPEN_DATASET, accept_date=date.today())
AccountAgreement(type=OPEN_DATASET, accept_date=date.today()),
]
return Account(
email_address=overrides.get('email_address') or 'foo@mycroft.ai',
username=overrides.get('username') or 'foobar',
agreements=overrides.get('agreements') or test_agreements
email_address=overrides.get("email_address") or "foo@mycroft.ai",
username=overrides.get("username") or "foobar",
agreements=overrides.get("agreements") or test_agreements,
)
def add_account(db, **overrides):
acct_repository = AccountRepository(db)
account = build_test_account(**overrides)
password = overrides.get('password') or 'test_password'
password = overrides.get("password") or "test_password"
account.id = acct_repository.add(account, password)
if account.membership is not None:
acct_repository.add_membership(account.id, account.membership)
@ -59,13 +59,13 @@ def remove_account(db, account):
def build_test_membership(**overrides):
stripe_acct = 'test_stripe_acct_id'
stripe_acct = "test_stripe_acct_id"
return AccountMembership(
type=overrides.get('type') or 'Monthly Membership',
start_date=overrides.get('start_date') or date.today(),
payment_method=overrides.get('payment_method') or 'Stripe',
payment_account_id=overrides.get('payment_account_id') or stripe_acct,
payment_id=overrides.get('payment_id') or 'test_stripe_payment_id'
type=overrides.get("type") or "Monthly Membership",
start_date=overrides.get("start_date") or date.today(),
payment_method=overrides.get("payment_method") or "Stripe",
payment_account_id=overrides.get("payment_account_id") or stripe_acct,
payment_id=overrides.get("payment_id") or "test_stripe_payment_id",
)

View File

@ -22,10 +22,10 @@ from selene.data.device import Geography, GeographyRepository
def add_account_geography(db, account, **overrides):
geography = Geography(
country=overrides.get('country') or 'United States',
region=overrides.get('region') or 'Missouri',
city=overrides.get('city') or 'Kansas City',
time_zone=overrides.get('time_zone') or 'America/Chicago'
country=overrides.get("country") or "United States",
region=overrides.get("region") or "Missouri",
city=overrides.get("city") or "Kansas City",
time_zone=overrides.get("time_zone") or "America/Chicago",
)
geo_repository = GeographyRepository(db, account.id)
account_geography_id = geo_repository.add(geography)

View File

@ -22,9 +22,7 @@ from selene.data.device import AccountPreferences, PreferenceRepository
def add_account_preference(db, account_id):
account_preferences = AccountPreferences(
date_format='MM/DD/YYYY',
time_format='12 Hour',
measurement_system='Imperial'
date_format="MM/DD/YYYY", time_format="12 Hour", measurement_system="Imperial"
)
preference_repo = PreferenceRepository(db, account_id)
preference_repo.upsert(account_preferences)

View File

@ -29,44 +29,44 @@ from selene.data.account import (
AgreementRepository,
OPEN_DATASET,
PRIVACY_POLICY,
TERMS_OF_USE
TERMS_OF_USE,
)
def _build_test_terms_of_use():
return Agreement(
type=TERMS_OF_USE,
version='Holy Grail',
content='I agree that all the tests I write for this application will '
'be in the theme of Monty Python and the Holy Grail. If you '
version="Holy Grail",
content="I agree that all the tests I write for this application will "
"be in the theme of Monty Python and the Holy Grail. If you "
'do not agree with these terms, I will be forced to say "Ni!" '
'until such time as you agree',
effective_date=date.today() - timedelta(days=1)
)
"until such time as you agree",
effective_date=date.today() - timedelta(days=1),
)
def _build_test_privacy_policy():
return Agreement(
type=PRIVACY_POLICY,
version='Holy Grail',
content='First, shalt thou take out the Holy Pin. Then shalt thou '
'count to three. No more. No less. Three shalt be the '
'number thou shalt count and the number of the counting shall '
'be three. Four shalt thou not count, nor either count thou '
'two, excepting that thou then proceed to three. Five is '
'right out. Once the number three, being the third number, '
'be reached, then lobbest thou Holy Hand Grenade of Antioch '
'towards thy foe, who, being naughty in My sight, '
'shall snuff it.',
effective_date=date.today() - timedelta(days=1)
version="Holy Grail",
content="First, shalt thou take out the Holy Pin. Then shalt thou "
"count to three. No more. No less. Three shalt be the "
"number thou shalt count and the number of the counting shall "
"be three. Four shalt thou not count, nor either count thou "
"two, excepting that thou then proceed to three. Five is "
"right out. Once the number three, being the third number, "
"be reached, then lobbest thou Holy Hand Grenade of Antioch "
"towards thy foe, who, being naughty in My sight, "
"shall snuff it.",
effective_date=date.today() - timedelta(days=1),
)
def _build_open_dataset():
return Agreement(
type=OPEN_DATASET,
version='Holy Grail',
effective_date=date.today() - timedelta(days=1)
version="Holy Grail",
effective_date=date.today() - timedelta(days=1),
)
@ -93,11 +93,11 @@ def remove_agreements(db, agreements: List[Agreement]):
def get_agreements_from_api(context, agreement):
"""Abstracted so both account and single sign on APIs use in their tests"""
if agreement == PRIVACY_POLICY:
url = '/api/agreement/privacy-policy'
url = "/api/agreement/privacy-policy"
elif agreement == TERMS_OF_USE:
url = '/api/agreement/terms-of-use'
url = "/api/agreement/terms-of-use"
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
context.response = context.client.get(url)
@ -109,7 +109,7 @@ def validate_agreement_response(context, agreement):
elif agreement == TERMS_OF_USE:
expected_response = asdict(context.terms_of_use)
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
del(expected_response['effective_date'])
del expected_response["effective_date"]
assert_that(response_data, equal_to(expected_response))

View File

@ -24,17 +24,14 @@ from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ACCESS_TOKEN_COOKIE_KEY = "seleneAccess"
ONE_MINUTE = 60
TWO_MINUTES = 120
REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh'
REFRESH_TOKEN_COOKIE_KEY = "seleneRefresh"
def generate_access_token(context, duration=ONE_MINUTE):
access_token = AuthenticationToken(
context.client_config['ACCESS_SECRET'],
duration
)
access_token = AuthenticationToken(context.client_config["ACCESS_SECRET"], duration)
account = context.accounts[context.username]
access_token.generate(account.id)
@ -43,17 +40,16 @@ def generate_access_token(context, duration=ONE_MINUTE):
def set_access_token_cookie(context, duration=ONE_MINUTE):
context.client.set_cookie(
context.client_config['DOMAIN'],
context.client_config["DOMAIN"],
ACCESS_TOKEN_COOKIE_KEY,
context.access_token.jwt,
max_age=duration
max_age=duration,
)
def generate_refresh_token(context, duration=TWO_MINUTES):
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],
duration
context.client_config["REFRESH_SECRET"], duration
)
account = context.accounts[context.username]
refresh_token.generate(account.id)
@ -63,38 +59,38 @@ def generate_refresh_token(context, duration=TWO_MINUTES):
def set_refresh_token_cookie(context, duration=TWO_MINUTES):
context.client.set_cookie(
context.client_config['DOMAIN'],
context.client_config["DOMAIN"],
REFRESH_TOKEN_COOKIE_KEY,
context.refresh_token.jwt,
max_age=duration
max_age=duration,
)
def validate_token_cookies(context, expired=False):
for cookie in context.response.headers.getlist('Set-Cookie'):
for cookie in context.response.headers.getlist("Set-Cookie"):
ingredients = _parse_cookie(cookie)
ingredient_names = list(ingredients.keys())
if ACCESS_TOKEN_COOKIE_KEY in ingredient_names:
context.access_token = ingredients[ACCESS_TOKEN_COOKIE_KEY]
elif REFRESH_TOKEN_COOKIE_KEY in ingredient_names:
context.refresh_token = ingredients[REFRESH_TOKEN_COOKIE_KEY]
for ingredient_name in ('Domain', 'Expires', 'Max-Age'):
for ingredient_name in ("Domain", "Expires", "Max-Age"):
assert_that(ingredient_names, has_item(ingredient_name))
if expired:
assert_that(ingredients['Max-Age'], equal_to('0'))
assert_that(ingredients["Max-Age"], equal_to("0"))
assert hasattr(context, 'access_token'), 'no access token in response'
assert hasattr(context, 'refresh_token'), 'no refresh token in response'
assert hasattr(context, "access_token"), "no access token in response"
assert hasattr(context, "refresh_token"), "no refresh token in response"
if expired:
assert_that(context.access_token, equal_to(''))
assert_that(context.refresh_token, equal_to(''))
assert_that(context.access_token, equal_to(""))
assert_that(context.refresh_token, equal_to(""))
def _parse_cookie(cookie: str) -> dict:
ingredients = {}
for ingredient in cookie.split('; '):
if '=' in ingredient:
key, value = ingredient.split('=')
for ingredient in cookie.split("; "):
if "=" in ingredient:
key, value = ingredient.split("=")
ingredients[key] = value
else:
ingredients[ingredient] = None
@ -103,7 +99,7 @@ def _parse_cookie(cookie: str) -> dict:
def get_account(context) -> Account:
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
db = connect_to_db(context.client["DB_CONNECTION_CONFIG"])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
@ -112,21 +108,14 @@ def get_account(context) -> Account:
def check_http_success(context):
assert_that(
context.response.status_code,
is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT])
context.response.status_code, is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT])
)
def check_http_error(context, error_type):
if error_type == 'a bad request':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.BAD_REQUEST)
)
elif error_type == 'an unauthorized':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.UNAUTHORIZED)
)
if error_type == "a bad request":
assert_that(context.response.status_code, equal_to(HTTPStatus.BAD_REQUEST))
elif error_type == "an unauthorized":
assert_that(context.response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
else:
raise ValueError('unsupported error_type')
raise ValueError("unsupported error_type")

View File

@ -26,12 +26,12 @@ from selene.data.device import DeviceSkillRepository, ManifestSkill
def add_device_skill(db, device_id, skill):
manifest_skill = ManifestSkill(
device_id=device_id,
install_method='test_install_method',
install_status='test_install_status',
install_method="test_install_method",
install_status="test_install_status",
skill_id=skill.id,
skill_gid=skill.skill_gid,
install_ts=datetime.utcnow(),
update_ts=datetime.utcnow()
update_ts=datetime.utcnow(),
)
device_skill_repo = DeviceSkillRepository(db)
manifest_skill.id = device_skill_repo.add_manifest_skill(manifest_skill)
@ -42,9 +42,7 @@ def add_device_skill(db, device_id, skill):
def add_device_skill_settings(db, device_id, settings_display, settings_values):
device_skill_repo = DeviceSkillRepository(db)
device_skill_repo.upsert_device_skill_settings(
[device_id],
settings_display,
settings_values
[device_id], settings_display, settings_values
)

View File

@ -23,19 +23,15 @@ from selene.data.account import (
Membership,
MembershipRepository,
MONTHLY_MEMBERSHIP,
YEARLY_MEMBERSHIP
YEARLY_MEMBERSHIP,
)
monthly_membership = dict(
type=MONTHLY_MEMBERSHIP,
rate=Decimal('1.99'),
rate_period='monthly'
type=MONTHLY_MEMBERSHIP, rate=Decimal("1.99"), rate_period="monthly"
)
yearly_membership = dict(
type=YEARLY_MEMBERSHIP,
rate=Decimal('19.99'),
rate_period='yearly'
type=YEARLY_MEMBERSHIP, rate=Decimal("19.99"), rate_period="yearly"
)

View File

@ -21,52 +21,41 @@ from selene.data.skill import (
SettingsDisplay,
SettingsDisplayRepository,
Skill,
SkillRepository
SkillRepository,
)
def build_text_field():
return dict(
name='textfield',
type='text',
label='Text Field',
placeholder='Text Placeholder'
name="textfield",
type="text",
label="Text Field",
placeholder="Text Placeholder",
)
def build_checkbox_field():
return dict(
name='checkboxfield',
type='checkbox',
label='Checkbox Field'
)
return dict(name="checkboxfield", type="checkbox", label="Checkbox Field")
def build_label_field():
return dict(
type='label',
label='This is a section label.'
)
return dict(type="label", label="This is a section label.")
def _build_display_data(skill_gid, fields):
gid_parts = skill_gid.split('|')
gid_parts = skill_gid.split("|")
if len(gid_parts) == 3:
skill_name = gid_parts[1]
else:
skill_name = gid_parts[0]
skill_identifier = skill_name + '-123456'
skill_identifier = skill_name + "-123456"
settings_display = dict(
skill_gid=skill_gid,
identifier=skill_identifier,
display_name=skill_name,
skill_gid=skill_gid, identifier=skill_identifier, display_name=skill_name,
)
if fields is not None:
settings_display.update(
skillMetadata=dict(
sections=[dict(name='Section Name', fields=fields)]
)
skillMetadata=dict(sections=[dict(name="Section Name", fields=fields)])
)
return settings_display

Some files were not shown because too many files have changed in this diff Show More