applied the "Black" formatter to all files and added pre-commit hook to check

pull/293/head
Chris Veilleux 2022-03-11 13:22:33 -06:00
parent bbad8e2f3b
commit 26ed641b48
109 changed files with 1457 additions and 1543 deletions

11
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,11 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
- id: black

View File

@ -1,41 +1,43 @@
[![License](https://img.shields.io/badge/License-GNU_AGPL%203.0-blue.svg)](LICENSE)
[![CLA](https://img.shields.io/badge/CLA%3F-Required-blue.svg)](https://mycroft.ai/cla)
[![Team](https://img.shields.io/badge/Team-Mycroft_Backend-violetblue.svg)](https://github.com/MycroftAI/contributors/blob/master/team/Mycroft%20Backend.md)
[![License](https://img.shields.io/badge/License-GNU_AGPL%203.0-blue.svg)](LICENSE)
[![CLA](https://img.shields.io/badge/CLA%3F-Required-blue.svg)](https://mycroft.ai/cla)
[![Team](https://img.shields.io/badge/Team-Mycroft_Backend-violetblue.svg)](https://github.com/MycroftAI/contributors/blob/master/team/Mycroft%20Backend.md)
![Status](https://img.shields.io/badge/-Production_ready-green.svg)
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](http://makeapullrequest.com)
[![Join chat](https://img.shields.io/badge/Mattermost-join_chat-brightgreen.svg)](https://chat.mycroft.ai)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
Selene -- Mycroft's Server Backend
==========
Selene provides the services used by [Mycroft Core](https://github.com/mycroftai/mycroft-core) to manage devices, skills
and settings. It consists of two repositories. This one contains Python and SQL representing the database definition,
data access layer, APIs and scripts. The second repository, [Selene UI](https://github.com/mycroftai/selene-ui),
and settings. It consists of two repositories. This one contains Python and SQL representing the database definition,
data access layer, APIs and scripts. The second repository, [Selene UI](https://github.com/mycroftai/selene-ui),
contains Angular web applications that use the APIs defined in this repository.
There are four APIs defined in this repository, account management, single sign on, skill marketplace and device.
The first three support account.mycroft.ai (aka home.mycroft.ai), sso.mycroft.ai, and market.mycroft.ai, respectively.
The first three support account.mycroft.ai (aka home.mycroft.ai), sso.mycroft.ai, and market.mycroft.ai, respectively.
The device API is how devices running Mycroft Core communicate with the server. Also included in this repository is
a package containing batch scripts for maintenance and the definition of the database schema.
Each API is designed to run independently of the others. Code common to each of the APIs, such as the Data Access Layer,
can be found in the "shared" directory. The shared code is an independent Python package required by each of the APIs.
Each API has its own Pipfile so that it can be run in its own virtual environment.
Each API is designed to run independently of the others. Code common to each of the APIs, such as the Data Access Layer,
can be found in the "shared" directory. The shared code is an independent Python package required by each of the APIs.
Each API has its own Pipfile so that it can be run in its own virtual environment.
## Installation
The Python code utilizes features introduced in Python 3.7, such as data classes.
The Python code utilizes features introduced in Python 3.7, such as data classes.
[Pipenv](https://pipenv.readthedocs.io/en/latest/) is used for virtual environment and package management.
If you prefer to use pip and pyenv (or virtualenv), you can find the required libraries in the files named "Pipfile".
These instructions will use pipenv commands.
If the Selene applications will be servicing a large number of devices (enterprise usage, for example), it is
If the Selene applications will be servicing a large number of devices (enterprise usage, for example), it is
recommended that each of the applications run on their own server or virtual machine. This configuration makes it
easier to scale and monitor each application independently. However, all applications can be run on a single server.
This configuration could be more practical for a household running a handful of devices.
easier to scale and monitor each application independently. However, all applications can be run on a single server.
This configuration could be more practical for a household running a handful of devices.
These instructions will assume a multi-server setup for several thousand devices. To run on a single server servicing a
These instructions will assume a multi-server setup for several thousand devices. To run on a single server servicing a
small number of devices, the recommended system requirements are 4 CPU, 8GB RAM and 100GB of disk. There are a lot of
manual steps in this section that will eventually be replaced with an installation script.
@ -45,14 +47,14 @@ It is recommended to create an application specific user. In these instructions
### Postgres DB
* Recommended server configuration: [Ubuntu 18.04 LTS (server install)](https://releases.ubuntu.com/bionic/), 2 CPU, 4GB RAM, 50GB disk.
* Use the package management system to install Python 3.7, Python 3 pip and PostgreSQL 10
* Use the package management system to install Python 3.7, Python 3 pip and PostgreSQL 10
```
sudo apt-get install postgresql python3.7 python python3-pip
```
* Set Postgres to start on boot
```
* Set Postgres to start on boot
```
sudo systemctl enable postgresql
```
```
* Clone the selene-backend and documentation repositories
```
sudo mkdir -p /opt/selene
@ -97,7 +99,7 @@ pipenv run python bootstrap_mycroft_db.py
# IPv4 local connections:
host all all 127.0.0.1/32 trust
```
* By default, Postgres only listens on localhost. This will not do for a multi-server setup. Change the
* By default, Postgres only listens on localhost. This will not do for a multi-server setup. Change the
`listen_addresses` value in the `posgresql.conf` file to the private IP of the database server. This file is owned by
the `postgres` user so use the following command to edit it (substituting vi for your favorite editor)
```
@ -109,7 +111,7 @@ the `postgres` user so use the following command to edit it (substituting vi for
```
sudo -u postgres vi /etc/postgres/10/main/pg_hba.conf
```
* Instructions on how to update the `pg_hba.conf` file can be found in
* Instructions on how to update the `pg_hba.conf` file can be found in
[Postgres' documentation](https://www.postgresql.org/docs/10/auth-pg-hba-conf.html). Below is an example for reference.
```
# IPv4 Selene connections
@ -123,7 +125,7 @@ sudo systemctl restart postgresql
### Redis DB
* Recommended server configuration: Ubuntu 18.04 LTS, 1 CPU, 1GB RAM, 5GB disk.
So as to not reinvent the wheel, here are some easy-to-follow instructions for
So as to not reinvent the wheel, here are some easy-to-follow instructions for
[installing Redis on Ubuntu 18.04](https://www.digitalocean.com/community/tutorials/how-to-install-and-secure-redis-on-ubuntu-18-04).
* By default, Redis only listens on local host. For multi-server setups, one additional step is to change the "bind" variable in `/etc/redis/redis.conf` to be the private IP of the Redis host.
@ -133,9 +135,9 @@ The majority of the setup for each API is the same. This section defines the st
to each API will be defined in their respective sections.
* Add an application user to the VM. Either give this user sudo privileges or execute the sudo commands below as a user
with sudo privileges. These instructions will assume a user name of "mycroft"
* Use the package management system to install Python 3.7, Python 3 pip and Python 3.7 Developer Tools
* Use the package management system to install Python 3.7, Python 3 pip and Python 3.7 Developer Tools
```
sudo apt install python3.7 python3-pip python3.7-dev
sudo apt install python3.7 python3-pip python3.7-dev
sudo python3.7 -m pip install pipenv
```
* Setup the Backend Application Directory
@ -196,7 +198,7 @@ pipenv install
```
### Running the APIs
Each API is configured to run on port 5000. This is not a problem if each is running in its own VM but will be an
issue if all APIs are running on the same server, or if port 5000 is already in use. To address these scenarios,
issue if all APIs are running on the same server, or if port 5000 is already in use. To address these scenarios,
change the port numbering in the uwsgi.ini file for each API.
#### Single Sign On API
@ -206,7 +208,7 @@ reset key, which is used in a password reset scenario. Generate a secret key fo
* Any data that can identify a user is encrypted. Generate a salt that will be used with the encryption algorithm.
* Access to the Github API is required to support logging in with your Github account. Details can be found
[here](https://developer.github.com/v3/guides/basics-of-authentication/).
* The password reset functionality sends an email to the user with a link to reset their password. Selene uses
* The password reset functionality sends an email to the user with a link to reset their password. Selene uses
SendGrid to send these emails so a SendGrid account and API key are required.
* Define a systemd service to run the API. The service defines environment variables that use the secret and API keys
generated in previous steps.
@ -250,7 +252,7 @@ sudo systemctl enable sso_api.service
```
#### Account API
* The account API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
* The account API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
JWT_REFRESH_SECRET and SALT environment variables must be the same values as those on the single sign on API.
* This application uses the Redis database so the service needs to know where it resides.
* Define a systemd service to run the API. The service defines environment variables that use the secret and API keys
@ -293,7 +295,7 @@ sudo systemctl enable account_api.service
```
#### Marketplace API
* The marketplace API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
* The marketplace API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
JWT_REFRESH_SECRET and SALT environment variables must be the same values as those on the single sign on API.
* This application uses the Redis database so the service needs to know where it resides.
* Define a systemd service to run the API. The service defines environment variables that use the secret and API keys
@ -345,7 +347,7 @@ pipenv run python load_skill_display_data.py --core-version <specify core versio
```
#### Device API
* The device API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
* The device API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET,
JWT_REFRESH_SECRET and SALT environment variables must be the same values as those on the single sign on API.
* This application uses the Redis database so the service needs to know where it resides.
* The weather skill requires a key to the Open Weather Map API
@ -433,7 +435,7 @@ Before we continue, let's make sure that your endpoints are operational - for th
## Other Considerations
### DNS
There are multiple ways to setup DNS. This document will not dictate how to do so for Selene. However, here is an
There are multiple ways to setup DNS. This document will not dictate how to do so for Selene. However, here is an
example, based on how DNS is setup at Mycroft AI...
Each application runs on its own sub-domain. Assuming a top level domain of "mycroft.ai" the subdomains are:
@ -442,25 +444,25 @@ Each application runs on its own sub-domain. Assuming a top level domain of "my
* market.mycroft.ai
* sso.mycroft.ai
The APIs that support the web applications are directories within the sub-domain (e.g. account.mycroft.ai/api). Since
The APIs that support the web applications are directories within the sub-domain (e.g. account.mycroft.ai/api). Since
the device API is externally facing, it is versioned. It's subdirectory must be "v1".
### Reverse Proxy
There are multiple tools available for setting up a reverse proxy that will point your DNS entries to your APIs. As such, the decision on how to set this up will be left to the user.
### SSL
It is recommended that Selene applications be run using HTTPS. To do this an SSL certificate is necessary.
It is recommended that Selene applications be run using HTTPS. To do this an SSL certificate is necessary.
[Let's Encrypt](https://letsencrypt.org) is a great way to easily set up SSL certificates for free.
## What About the GUI???
Once the database and API setup is complete, the next step is to setup the GUI, The README file for the
[Selene UI](https://github.com/mycroftai/selene-ui) repository contains the instructions for setting up the web
Once the database and API setup is complete, the next step is to setup the GUI, The README file for the
[Selene UI](https://github.com/mycroftai/selene-ui) repository contains the instructions for setting up the web
applications.
## Getting Involved
This is an open source project and we would love your help. We have prepared a [contributing](.github/CONTRIBUTING.md)
This is an open source project and we would love your help. We have prepared a [contributing](.github/CONTRIBUTING.md)
guide to help you get started.
If this is your first PR or you're not sure where to get started,

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -31,8 +31,6 @@ class DeviceCountEndpoint(SeleneEndpoint):
def _get_devices(self):
device_repository = DeviceRepository(self.db)
device_count = device_repository.get_account_device_count(
self.account.id
)
device_count = device_repository.get_account_device_count(self.account.id)
return device_count

View File

@ -25,7 +25,7 @@ from selene.api import SeleneEndpoint
class PairingCodeEndpoint(SeleneEndpoint):
def __init__(self):
super(PairingCodeEndpoint, self).__init__()
self.cache = self.config['SELENE_CACHE']
self.cache = self.config["SELENE_CACHE"]
def get(self, pairing_code):
self._authenticate()
@ -36,7 +36,7 @@ class PairingCodeEndpoint(SeleneEndpoint):
def _get_pairing_data(self, pairing_code: str) -> bool:
"""Checking if there's one pairing session for the pairing code."""
pairing_code_is_valid = False
cache_key = 'pairing.code:' + pairing_code
cache_key = "pairing.code:" + pairing_code
pairing_cache = self.cache.get(cache_key)
if pairing_cache is not None:
pairing_code_is_valid = True

View File

@ -29,29 +29,23 @@ from selene.data.device import AccountPreferences, PreferenceRepository
class PreferencesRequest(Model):
date_format = StringType(
required=True,
choices=['DD/MM/YYYY', 'MM/DD/YYYY']
)
measurement_system = StringType(
required=True,
choices=['Imperial', 'Metric']
)
time_format = StringType(required=True, choices=['12 Hour', '24 Hour'])
date_format = StringType(required=True, choices=["DD/MM/YYYY", "MM/DD/YYYY"])
measurement_system = StringType(required=True, choices=["Imperial", "Metric"])
time_format = StringType(required=True, choices=["12 Hour", "24 Hour"])
class PreferencesEndpoint(SeleneEndpoint):
def __init__(self):
super(PreferencesEndpoint, self).__init__()
self.preferences = None
self.cache = self.config['SELENE_CACHE']
self.cache = self.config["SELENE_CACHE"]
self.etag_manager: ETagManager = ETagManager(self.cache, self.config)
def get(self):
self._authenticate()
self._get_preferences()
if self.preferences is None:
response_data = ''
response_data = ""
response_code = HTTPStatus.NO_CONTENT
else:
response_data = asdict(self.preferences)
@ -68,22 +62,20 @@ class PreferencesEndpoint(SeleneEndpoint):
self._validate_request()
self._upsert_preferences()
self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def patch(self):
self._authenticate()
self._validate_request()
self._upsert_preferences()
self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def _validate_request(self):
self.preferences = PreferencesRequest()
self.preferences.date_format = self.request.json['dateFormat']
self.preferences.measurement_system = (
self.request.json['measurementSystem']
)
self.preferences.time_format = self.request.json['timeFormat']
self.preferences.date_format = self.request.json["dateFormat"]
self.preferences.measurement_system = self.request.json["measurementSystem"]
self.preferences.time_format = self.request.json["timeFormat"]
self.preferences.validate()
def _upsert_preferences(self):

View File

@ -25,7 +25,7 @@ from selene.data.geography import RegionRepository
class RegionEndpoint(SeleneEndpoint):
def get(self):
country_id = self.request.args['country']
country_id = self.request.args["country"]
region_repository = RegionRepository(self.db)
regions = region_repository.get_regions_by_country(country_id)

View File

@ -27,17 +27,15 @@ from selene.api import SeleneEndpoint
class SkillOauthEndpoint(SeleneEndpoint):
def __init__(self):
super(SkillOauthEndpoint, self).__init__()
self.oauth_base_url = os.environ['OAUTH_BASE_URL']
self.oauth_base_url = os.environ["OAUTH_BASE_URL"]
def get(self, oauth_id):
self._authenticate()
return self._get_oauth_url(oauth_id)
def _get_oauth_url(self, oauth_id):
url = '{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}'.format(
base_url=self.oauth_base_url,
oauth_id=oauth_id,
account_id=self.account.id
url = "{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}".format(
base_url=self.oauth_base_url, oauth_id=oauth_id, account_id=self.account.id
)
response = requests.get(url)
return response.text, response.status_code

View File

@ -35,8 +35,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
self.account_skills = None
self.family_settings = None
self.etag_manager: ETagManager = ETagManager(
self.config['SELENE_CACHE'],
self.config
self.config["SELENE_CACHE"], self.config
)
@property
@ -51,8 +50,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
"""Process an HTTP GET request"""
self._authenticate()
self.family_settings = self.setting_repository.get_family_settings(
self.account.id,
skill_family_name
self.account.id, skill_family_name
)
self._parse_selection_options()
response_data = self._build_response_data()
@ -62,7 +60,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
return Response(
response=json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
def _parse_selection_options(self):
@ -75,19 +73,16 @@ class SkillSettingsEndpoint(SeleneEndpoint):
"""
for skill_settings in self.family_settings:
if skill_settings.settings_definition is not None:
for section in skill_settings.settings_definition['sections']:
for field in section['fields']:
if field['type'] == 'select':
for section in skill_settings.settings_definition["sections"]:
for field in section["fields"]:
if field["type"] == "select":
parsed_options = []
for option in field['options'].split(';'):
option_display, option_value = option.split('|')
for option in field["options"].split(";"):
option_display, option_value = option.split("|")
parsed_options.append(
dict(
display=option_display,
value=option_value
)
dict(display=option_display, value=option_value)
)
field['options'] = parsed_options
field["options"] = parsed_options
def _build_response_data(self):
"""Build the object to return to the UI."""
@ -100,7 +95,7 @@ class SkillSettingsEndpoint(SeleneEndpoint):
response_skill = dict(
settingsDisplay=skill_settings.settings_definition,
settingsValues=skill_settings.settings_values,
deviceNames=skill_settings.device_names
deviceNames=skill_settings.device_names,
)
response_data.append(response_skill)
@ -111,19 +106,17 @@ class SkillSettingsEndpoint(SeleneEndpoint):
self._authenticate()
self._update_settings_values()
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _update_settings_values(self):
"""Update the value of the settings column on the device_skill table,"""
for new_skill_settings in self.request.json['skillSettings']:
for new_skill_settings in self.request.json["skillSettings"]:
account_skill_settings = AccountSkillSetting(
settings_definition=new_skill_settings['settingsDisplay'],
settings_values=new_skill_settings['settingsValues'],
device_names=new_skill_settings['deviceNames']
settings_definition=new_skill_settings["settingsDisplay"],
settings_values=new_skill_settings["settingsValues"],
device_names=new_skill_settings["deviceNames"],
)
self.setting_repository.update_skill_settings(
self.account.id,
account_skill_settings,
self.request.json['skillIds']
self.account.id, account_skill_settings, self.request.json["skillIds"]
)
self.etag_manager.expire_skill_etag_by_account_id(self.account.id)

View File

@ -44,11 +44,11 @@ class SkillsEndpoint(SeleneEndpoint):
market_id=skill.market_id,
name=skill.display_name or skill.family_name,
has_settings=skill.has_settings,
skill_ids=skill.skill_ids
skill_ids=skill.skill_ids,
)
else:
response_skill['skill_ids'].extend(skill.skill_ids)
if response_skill['market_id'] is None:
response_skill['market_id'] = skill.market_id
response_skill["skill_ids"].extend(skill.skill_ids)
if response_skill["market_id"] is None:
response_skill["market_id"] = skill.market_id
return sorted(response_data.values(), key=lambda x: x['name'])
return sorted(response_data.values(), key=lambda x: x["name"])

View File

@ -25,7 +25,7 @@ from selene.data.geography import TimezoneRepository
class TimezoneEndpoint(SeleneEndpoint):
def get(self):
country_id = self.request.args['country']
country_id = self.request.args["country"]
timezone_repository = TimezoneRepository(self.db)
timezones = timezone_repository.get_timezones_by_country(country_id)

View File

@ -26,19 +26,19 @@ from hamcrest import assert_that, equal_to
from selene.data.account import PRIVACY_POLICY, TERMS_OF_USE
@when('API request for {agreement} is made')
@when("API request for {agreement} is made")
def call_agreement_endpoint(context, agreement):
if agreement == PRIVACY_POLICY:
url = '/api/agreement/privacy-policy'
url = "/api/agreement/privacy-policy"
elif agreement == TERMS_OF_USE:
url = '/api/agreement/terms-of-use'
url = "/api/agreement/terms-of-use"
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
context.response = context.client.get(url)
@then('{agreement} version {version} is returned')
@then("{agreement} version {version} is returned")
def validate_response(context, agreement, version):
response_data = json.loads(context.response.data)
if agreement == PRIVACY_POLICY:
@ -46,7 +46,7 @@ def validate_response(context, agreement, version):
elif agreement == TERMS_OF_USE:
expected_response = asdict(context.terms_of_use)
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
del(expected_response['effective_date'])
del expected_response["effective_date"]
assert_that(response_data, equal_to(expected_response))

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -27,7 +27,8 @@ from selene.data.skill import SkillDisplay, SkillDisplayRepository
class SkillDetailEndpoint(SeleneEndpoint):
""""Supply the data that will populate the skill detail page."""
""" "Supply the data that will populate the skill detail page."""
authentication_required = False
def __init__(self):
@ -57,39 +58,37 @@ class SkillDetailEndpoint(SeleneEndpoint):
def _build_response_data(self, skill_display: SkillDisplay):
"""Make some modifications to the response skill for the marketplace"""
self.response_skill = dict(
categories=skill_display.display_data.get('categories'),
credits=skill_display.display_data.get('credits'),
categories=skill_display.display_data.get("categories"),
credits=skill_display.display_data.get("credits"),
description=markdown(
skill_display.display_data.get('description'),
output_format='html5'
skill_display.display_data.get("description"), output_format="html5"
),
display_name=skill_display.display_data['display_name'],
icon=skill_display.display_data.get('icon'),
iconImage=skill_display.display_data.get('icon_img'),
display_name=skill_display.display_data["display_name"],
icon=skill_display.display_data.get("icon"),
iconImage=skill_display.display_data.get("icon_img"),
isSystemSkill=False,
worksOnMarkOne=(
'all' in skill_display.display_data['platforms'] or
'platform_mark1' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_mark1" in skill_display.display_data["platforms"]
),
worksOnMarkTwo=(
'all' in skill_display.display_data['platforms'] or
'platform_mark2' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_mark2" in skill_display.display_data["platforms"]
),
worksOnPicroft=(
'all' in skill_display.display_data['platforms'] or
'platform_picroft' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_picroft" in skill_display.display_data["platforms"]
),
worksOnKDE=(
'all' in skill_display.display_data['platforms'] or
'platform_plasmoid' in skill_display.display_data['platforms']
"all" in skill_display.display_data["platforms"]
or "platform_plasmoid" in skill_display.display_data["platforms"]
),
repositoryUrl=skill_display.display_data.get('repo'),
repositoryUrl=skill_display.display_data.get("repo"),
summary=markdown(
skill_display.display_data['short_desc'],
output_format='html5'
skill_display.display_data["short_desc"], output_format="html5"
),
triggers=skill_display.display_data['examples']
triggers=skill_display.display_data["examples"],
)
if skill_display.display_data['tags'] is not None:
if 'system' in skill_display.display_data['tags']:
self.response_skill['isSystemSkill'] = True
if skill_display.display_data["tags"] is not None:
if "system" in skill_display.display_data["tags"]:
self.response_skill["isSystemSkill"] = True

View File

@ -26,12 +26,7 @@ from selene.api import SeleneEndpoint
from selene.data.device import DeviceSkillRepository, ManifestSkill
from selene.util.auth import AuthenticationError
VALID_STATUS_VALUES = (
'failed',
'installed',
'installing',
'uninstalling'
)
VALID_STATUS_VALUES = ("failed", "installed", "installing", "uninstalling")
class SkillInstallStatusEndpoint(SeleneEndpoint):
@ -45,7 +40,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
try:
self._authenticate()
except AuthenticationError:
self.response = ('', HTTPStatus.NO_CONTENT)
self.response = ("", HTTPStatus.NO_CONTENT)
else:
self._get_installed_skills()
response_data = self._build_response_data()
@ -55,9 +50,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
def _get_installed_skills(self):
skill_repo = DeviceSkillRepository(self.db)
installed_skills = skill_repo.get_skill_manifest_for_account(
self.account.id
)
installed_skills = skill_repo.get_skill_manifest_for_account(self.account.id)
for skill in installed_skills:
self.installed_skills[skill.skill_id].append(skill)
@ -67,18 +60,13 @@ class SkillInstallStatusEndpoint(SeleneEndpoint):
for skill_id, skills in self.installed_skills.items():
skill_aggregator = SkillManifestAggregator(skills)
skill_aggregator.aggregate_skill_status()
if skill_aggregator.aggregate_skill.install_status == 'failed':
failure_reasons[skill_id] = (
skill_aggregator.aggregate_skill.install_failure_reason
)
install_statuses[skill_id] = (
skill_aggregator.aggregate_skill.install_status
)
if skill_aggregator.aggregate_skill.install_status == "failed":
failure_reasons[
skill_id
] = skill_aggregator.aggregate_skill.install_failure_reason
install_statuses[skill_id] = skill_aggregator.aggregate_skill.install_status
return dict(
installStatuses=install_statuses,
failureReasons=failure_reasons
)
return dict(installStatuses=install_statuses, failureReasons=failure_reasons)
class SkillManifestAggregator(object):
@ -96,7 +84,7 @@ class SkillManifestAggregator(object):
"""
self._validate_install_status()
self._determine_install_status()
if self.aggregate_skill.install_status == 'failed':
if self.aggregate_skill.install_status == "failed":
self._determine_failure_reason()
def _validate_install_status(self):
@ -104,7 +92,7 @@ class SkillManifestAggregator(object):
if skill.install_status not in VALID_STATUS_VALUES:
raise ValueError(
'"{install_status}" is not a supported value of the '
'installation field in the skill manifest'.format(
"installation field in the skill manifest".format(
install_status=skill.install_status
)
)
@ -120,33 +108,24 @@ class SkillManifestAggregator(object):
If the install fails on any device, the install will be flagged as a
failed install in the Marketplace.
"""
failed = [
skill.install_status == 'failed' for skill in self.installed_skills
]
installing = [
s.install_status == 'installing' for s in self.installed_skills
]
failed = [skill.install_status == "failed" for skill in self.installed_skills]
installing = [s.install_status == "installing" for s in self.installed_skills]
uninstalling = [
skill.install_status == 'uninstalling' for skill in
self.installed_skills
]
installed = [
s.install_status == 'installed' for s in self.installed_skills
skill.install_status == "uninstalling" for skill in self.installed_skills
]
installed = [s.install_status == "installed" for s in self.installed_skills]
if any(failed):
self.aggregate_skill.install_status = 'failed'
self.aggregate_skill.install_status = "failed"
elif any(installing):
self.aggregate_skill.install_status = 'installing'
self.aggregate_skill.install_status = "installing"
elif any(uninstalling):
self.aggregate_skill.install_status = 'uninstalling'
self.aggregate_skill.install_status = "uninstalling"
elif all(installed):
self.aggregate_skill.install_status = 'installed'
self.aggregate_skill.install_status = "installed"
def _determine_failure_reason(self):
"""When a skill fails to install, determine the reason"""
for skill in self.installed_skills:
if skill.install_status == 'failed':
self.aggregate_skill.failure_reason = (
skill.install_failure_reason
)
if skill.install_status == "failed":
self.aggregate_skill.failure_reason = skill.install_failure_reason
break

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -29,14 +29,15 @@ from selene.data.device import DeviceRepository
class UpdateDevice(Model):
coreVersion = StringType(default='unknown')
platform = StringType(default='unknown')
coreVersion = StringType(default="unknown")
platform = StringType(default="unknown")
platform_build = StringType()
enclosureVersion = StringType(default='unknown')
enclosureVersion = StringType(default="unknown")
class DeviceEndpoint(PublicEndpoint):
"""Return the device entity using the device_id"""
def __init__(self):
super(DeviceEndpoint, self).__init__()
@ -53,13 +54,13 @@ class DeviceEndpoint(PublicEndpoint):
coreVersion=device.core_version,
enclosureVersion=device.enclosure_version,
platform=device.platform,
user=dict(uuid=device.account_id)
user=dict(uuid=device.account_id),
)
response = response_data, HTTPStatus.OK
self._add_etag(device_etag_key(device_id))
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response
@ -69,10 +70,10 @@ class DeviceEndpoint(PublicEndpoint):
update_device = UpdateDevice(payload)
update_device.validate()
updates = dict(
platform=payload.get('platform') or 'unknown',
enclosure_version=payload.get('enclosureVersion') or 'unknown',
core_version=payload.get('coreVersion') or 'unknown'
platform=payload.get("platform") or "unknown",
enclosure_version=payload.get("enclosureVersion") or "unknown",
core_version=payload.get("coreVersion") or "unknown",
)
DeviceRepository(self.db).update_device_from_core(device_id, updates)
return '', HTTPStatus.OK
return "", HTTPStatus.OK

View File

@ -25,17 +25,18 @@ from selene.data.device import GeographyRepository
class DeviceLocationEndpoint(PublicEndpoint):
def __init__(self):
super(DeviceLocationEndpoint, self).__init__()
def get(self, device_id):
self._authenticate(device_id)
self._validate_etag(device_location_etag_key(device_id))
location = GeographyRepository(self.db, None).get_location_by_device_id(device_id)
location = GeographyRepository(self.db, None).get_location_by_device_id(
device_id
)
if location:
response = (location, HTTPStatus.OK)
self._add_etag(device_location_etag_key(device_id))
else:
response = ('', HTTPStatus.NOT_FOUND)
response = ("", HTTPStatus.NOT_FOUND)
return response

View File

@ -26,18 +26,15 @@ from selene.data.account import AccountRepository
class OauthServiceEndpoint(PublicEndpoint):
def __init__(self):
super(OauthServiceEndpoint, self).__init__()
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
self.oauth_service_host = os.environ["OAUTH_BASE_URL"]
def get(self, device_id, credentials, oauth_path):
account = AccountRepository(self.db).get_account_by_device_id(device_id)
uuid = account.id
url = '{host}/auth/{credentials}/{oauth_path}'.format(
host=self.oauth_service_host,
credentials=credentials,
oauth_path=oauth_path
url = "{host}/auth/{credentials}/{oauth_path}".format(
host=self.oauth_service_host, credentials=credentials, oauth_path=oauth_path
)
params = dict(uuid=uuid)
response = requests.get(url, params=params)

View File

@ -35,44 +35,44 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint):
def get(self):
headers = self.request.headers
if 'Authorization' not in headers:
raise AuthenticationError('Oauth token not found')
token_header = self.request.headers['Authorization']
if token_header.startswith('Bearer '):
refresh = token_header[len('Bearer '):]
if "Authorization" not in headers:
raise AuthenticationError("Oauth token not found")
token_header = self.request.headers["Authorization"]
if token_header.startswith("Bearer "):
refresh = token_header[len("Bearer ") :]
session = self._refresh_session_token(refresh)
# Trying to fetch a session using the refresh token
if session:
response = session, HTTPStatus.OK
else:
device = self.request.headers.get('Device')
device = self.request.headers.get("Device")
if device:
# trying to fetch a session using the device uuid
session = self._refresh_session_token_device(device)
if session:
response = session, HTTPStatus.OK
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
else:
response = '', HTTPStatus.UNAUTHORIZED
response = "", HTTPStatus.UNAUTHORIZED
return response
def _refresh_session_token(self, refresh: str):
refresh_key = 'device.token.refresh:{}'.format(refresh)
refresh_key = "device.token.refresh:{}".format(refresh)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
device_id = old_login["uuid"]
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)
def _refresh_session_token_device(self, device: str):
refresh_key = 'device.session:{}'.format(device)
refresh_key = "device.session:{}".format(device)
session = self.cache.get(refresh_key)
if session:
old_login = json.loads(session)
device_id = old_login['uuid']
device_id = old_login["uuid"]
self.cache.delete(refresh_key)
return generate_device_login(device_id, self.cache)

View File

@ -25,6 +25,7 @@ from selene.data.device import SettingRepository
class DeviceSettingEndpoint(PublicEndpoint):
"""Return the device's settings for the API v1 model"""
def __init__(self):
super(DeviceSettingEndpoint, self).__init__()
@ -36,5 +37,5 @@ class DeviceSettingEndpoint(PublicEndpoint):
response = (setting, HTTPStatus.OK)
self._add_etag(device_setting_etag_key(device_id))
else:
response = ('', HTTPStatus.NO_CONTENT)
response = ("", HTTPStatus.NO_CONTENT)
return response

View File

@ -28,7 +28,7 @@ from schematics.types import (
ListType,
IntType,
BooleanType,
TimestampType
TimestampType,
)
from selene.api import PublicEndpoint
@ -43,9 +43,7 @@ class SkillManifestReconciler(object):
self.skill_repo = SkillRepository(self.db)
self.device_manifest = {sm.skill_gid: sm for sm in device_manifest}
self.db_manifest = {ds.skill_gid: ds for ds in db_manifest}
self.device_manifest_global_ids = {
gid for gid in self.device_manifest.keys()
}
self.device_manifest_global_ids = {gid for gid in self.device_manifest.keys()}
self.db_manifest_global_ids = {gid for gid in self.db_manifest}
def reconcile(self):
@ -82,16 +80,14 @@ class SkillManifestReconciler(object):
for gid in skills_to_add:
skill_id = self.skill_repo.ensure_skill_exists(gid)
self.device_manifest[gid].skill_id = skill_id
self.skill_manifest_repo.add_manifest_skill(
self.device_manifest[gid]
)
self.skill_manifest_repo.add_manifest_skill(self.device_manifest[gid])
class RequestManifestSkill(Model):
name = StringType(required=True)
origin = StringType(required=True)
installation = StringType(required=True)
failure_message = StringType(default='')
failure_message = StringType(default="")
status = StringType(required=True)
beta = BooleanType(required=True)
installed = TimestampType(required=True)
@ -126,7 +122,7 @@ class DeviceSkillManifestEndpoint(PublicEndpoint):
self._validate_put_request()
self._update_skill_manifest(device_id)
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _validate_put_request(self):
request_data = SkillManifestRequest(self.request.json)
@ -137,29 +133,27 @@ class DeviceSkillManifestEndpoint(PublicEndpoint):
device_id
)
device_skill_manifest = []
for manifest_skill in self.request.json['skills']:
for manifest_skill in self.request.json["skills"]:
self._convert_manifest_timestamps(manifest_skill)
device_skill_manifest.append(
ManifestSkill(
device_id=device_id,
install_method=manifest_skill['origin'],
install_status=manifest_skill['installation'],
install_failure_reason=manifest_skill.get('failure_message'),
install_ts=manifest_skill['installed'],
skill_gid=manifest_skill['skill_gid'],
update_ts=manifest_skill['updated']
install_method=manifest_skill["origin"],
install_status=manifest_skill["installation"],
install_failure_reason=manifest_skill.get("failure_message"),
install_ts=manifest_skill["installed"],
skill_gid=manifest_skill["skill_gid"],
update_ts=manifest_skill["updated"],
)
)
reconciler = SkillManifestReconciler(
self.db,
device_skill_manifest,
db_skill_manifest
self.db, device_skill_manifest, db_skill_manifest
)
reconciler.reconcile()
@staticmethod
def _convert_manifest_timestamps(manifest_skill):
for key in ('installed', 'updated'):
for key in ("installed", "updated"):
value = manifest_skill[key]
if value:
manifest_skill[key] = datetime.fromtimestamp(value)

View File

@ -33,39 +33,37 @@ from selene.data.skill import (
SettingsDisplayRepository,
Skill,
SkillRepository,
SkillSettingRepository
SkillSettingRepository,
)
from selene.util.cache import DEVICE_SKILL_ETAG_KEY
# matches <submodule_name>|<branch>
GLOBAL_ID_PATTERN = '^([^\|@]+)\|([^\|]+$)'
GLOBAL_ID_PATTERN = "^([^\|@]+)\|([^\|]+$)"
# matches @<device_id>|<submodule_name>|<branch>
GLOBAL_ID_DIRTY_PATTERN = '^@(.*)\|(.*)\|(.*)$'
GLOBAL_ID_DIRTY_PATTERN = "^@(.*)\|(.*)\|(.*)$"
# matches @<device_id>|<folder_name>
GLOBAL_ID_NON_MSM_PATTERN = '^@([^\|]+)\|([^\|]+$)'
GLOBAL_ID_ANY_PATTERN = '(?:{})|(?:{})|(?:{})'.format(
GLOBAL_ID_PATTERN,
GLOBAL_ID_DIRTY_PATTERN,
GLOBAL_ID_NON_MSM_PATTERN
GLOBAL_ID_NON_MSM_PATTERN = "^@([^\|]+)\|([^\|]+$)"
GLOBAL_ID_ANY_PATTERN = "(?:{})|(?:{})|(?:{})".format(
GLOBAL_ID_PATTERN, GLOBAL_ID_DIRTY_PATTERN, GLOBAL_ID_NON_MSM_PATTERN
)
def _normalize_field_value(field):
"""The field values in skillMetadata are all strings, convert to native."""
normalized_value = field.get('value')
if field['type'].lower() == 'checkbox':
if field['value'] in ('false', 'False', '0'):
normalized_value = field.get("value")
if field["type"].lower() == "checkbox":
if field["value"] in ("false", "False", "0"):
normalized_value = False
elif field['value'] in ('true', 'True', '1'):
elif field["value"] in ("true", "True", "1"):
normalized_value = True
elif field['type'].lower() == 'number' and isinstance(field['value'], str):
if field['value']:
normalized_value = float(field['value'])
elif field["type"].lower() == "number" and isinstance(field["value"], str):
if field["value"]:
normalized_value = float(field["value"])
if not normalized_value % 1:
normalized_value = int(field['value'])
normalized_value = int(field["value"])
else:
normalized_value = 0
elif field['value'] == "[]":
elif field["value"] == "[]":
normalized_value = []
return normalized_value
@ -78,6 +76,7 @@ class SkillSettingUpdater(object):
request specifies a single device to update, all devices with
the same skill must be updated as well.
"""
_device_skill_repo = None
_settings_display_repo = None
@ -115,28 +114,27 @@ class SkillSettingUpdater(object):
settings_meta.json file before sending the result to this API. The
settings values are stored separately from the metadata in the database.
"""
settings_definition = self.display_data.get('skillMetadata')
settings_definition = self.display_data.get("skillMetadata")
if settings_definition is not None:
self.settings_values = dict()
sections_without_values = []
for section in settings_definition['sections']:
for section in settings_definition["sections"]:
section_without_values = dict(**section)
for field in section_without_values['fields']:
field_name = field.get('name')
field_value = field.get('value')
for field in section_without_values["fields"]:
field_name = field.get("name")
field_value = field.get("value")
if field_name is not None:
if field_value is not None:
field_value = _normalize_field_value(field)
del(field['value'])
del field["value"]
self.settings_values[field_name] = field_value
sections_without_values.append(section_without_values)
settings_definition['sections'] = sections_without_values
settings_definition["sections"] = sections_without_values
def _get_skill_id(self):
"""Get the id of the skill in the request"""
skill_global_id = (
self.display_data.get('skill_gid') or
self.display_data.get('identifier')
skill_global_id = self.display_data.get("skill_gid") or self.display_data.get(
"identifier"
)
skill_repo = SkillRepository(self.db)
skill_id = skill_repo.ensure_skill_exists(skill_global_id)
@ -145,14 +143,9 @@ class SkillSettingUpdater(object):
def _ensure_settings_display_exists(self) -> bool:
"""If the settings display changed, a new row needs to be added."""
new_settings_display = False
self.settings_display = SettingsDisplay(
self.skill.id,
self.display_data
)
self.settings_display.id = (
self.settings_display_repo.get_settings_display_id(
self.settings_display
)
self.settings_display = SettingsDisplay(self.skill.id, self.display_data)
self.settings_display.id = self.settings_display_repo.get_settings_display_id(
self.settings_display
)
if self.settings_display.id is None:
self.settings_display.id = self.settings_display_repo.add(
@ -173,11 +166,8 @@ class SkillSettingUpdater(object):
"""Get all the permutations of settings for a skill"""
account_repo = AccountRepository(self.db)
account = account_repo.get_account_by_device_id(self.device_id)
skill_settings = (
self.device_skill_repo.get_skill_settings_for_account(
account.id,
self.skill.id
)
skill_settings = self.device_skill_repo.get_skill_settings_for_account(
account.id, self.skill.id
)
return skill_settings
@ -187,14 +177,14 @@ class SkillSettingUpdater(object):
for skill_setting in skill_settings:
if self.device_id in skill_setting.device_ids:
device_skill_found = True
if skill_setting.install_method in ('voice', 'cli'):
if skill_setting.install_method in ("voice", "cli"):
devices_to_update = [self.device_id]
else:
devices_to_update = skill_setting.device_ids
self.device_skill_repo.upsert_device_skill_settings(
devices_to_update,
self.settings_display,
self._merge_settings_values(skill_setting.settings_values)
self._merge_settings_values(skill_setting.settings_values),
)
break
@ -225,9 +215,7 @@ class SkillSettingUpdater(object):
manifest endpoint in some cases.
"""
self.device_skill_repo.upsert_device_skill_settings(
[self.device_id],
self.settings_display,
self._merge_settings_values()
[self.device_id], self.settings_display, self._merge_settings_values()
)
@ -267,15 +255,16 @@ class RequestSkill(Model):
identifier = StringType()
def validate_skill_gid(self, data, value):
if data['skill_gid'] is None and data['identifier'] is None:
if data["skill_gid"] is None and data["identifier"] is None:
raise ValidationError(
'skill should have either skill_gid or identifier defined'
"skill should have either skill_gid or identifier defined"
)
return value
class DeviceSkillSettingsEndpoint(PublicEndpoint):
"""Fetch all skills associated with a device using the API v1 format"""
_device_skill_repo = None
_skill_repo = None
_skill_setting_repo = None
@ -317,23 +306,19 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
"""
self._authenticate(device_id)
self._validate_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
device_skills = self.skill_setting_repo.get_skill_settings_for_device(
device_id
)
device_skills = self.skill_setting_repo.get_skill_settings_for_device(device_id)
if device_skills:
response_data = self._build_response_data(device_skills)
response = Response(
json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
else:
response = Response(
'',
status=HTTPStatus.NO_CONTENT,
content_type='application/json'
"", status=HTTPStatus.NO_CONTENT, content_type="application/json"
)
return response
@ -341,7 +326,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
response_data = []
for skill in device_skills:
response_skill = dict(uuid=skill.skill_id)
settings_definition = skill.settings_display.get('skillMetadata')
settings_definition = skill.settings_display.get("skillMetadata")
if settings_definition:
settings_sections = self._apply_settings_values(
settings_definition, skill.settings_values
@ -350,10 +335,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
response_skill.update(
skillMetadata=dict(sections=settings_sections)
)
skill_gid = skill.settings_display.get('skill_gid')
skill_gid = skill.settings_display.get("skill_gid")
if skill_gid is not None:
response_skill.update(skill_gid=skill_gid)
identifier = skill.settings_display.get('identifier')
identifier = skill.settings_display.get("identifier")
if identifier is None:
response_skill.update(identifier=skill_gid)
else:
@ -366,10 +351,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
def _apply_settings_values(settings_definition, settings_values):
"""Build a copy of the settings sections populated with values."""
sections_with_values = []
for section in settings_definition['sections']:
for section in settings_definition["sections"]:
section_with_values = dict(**section)
for field in section_with_values['fields']:
field_name = field.get('name')
for field in section_with_values["fields"]:
field_name = field.get("name")
if field_name is not None and field_name in settings_values:
field.update(value=str(settings_values[field_name]))
sections_with_values.append(section_with_values)
@ -380,9 +365,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
self._authenticate(device_id)
self._validate_put_request()
skill_id = self._update_skill_settings(device_id)
self.etag_manager.expire(
DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)
)
self.etag_manager.expire(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))
return dict(uuid=skill_id), HTTPStatus.OK
@ -392,9 +375,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint):
def _update_skill_settings(self, device_id):
skill_setting_updater = SkillSettingUpdater(
self.db,
device_id,
self.request.json
self.db, device_id, self.request.json
)
skill_setting_updater.update()
self._delete_orphaned_settings_display(
@ -418,6 +399,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
with pre 19.08 versions of mycroft-core. Once those versions are no
longer supported, the older class can be deprecated.
"""
def get(self, device_id):
"""
Retrieve skills installed on device from the database.
@ -433,9 +415,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
def _build_response_data(self, device_id):
device_skill_repo = DeviceSkillRepository(self.db)
device_skills = device_skill_repo.get_skill_settings_for_device(
device_id
)
device_skills = device_skill_repo.get_skill_settings_for_device(device_id)
if device_skills is not None:
response_data = {}
for skill in device_skills:
@ -446,15 +426,13 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint):
def _build_response(self, device_id, response_data):
if response_data is None:
response = Response(
'',
status=HTTPStatus.NO_CONTENT,
content_type='application/json'
"", status=HTTPStatus.NO_CONTENT, content_type="application/json"
)
else:
response = Response(
json.dumps(response_data),
status=HTTPStatus.OK,
content_type='application/json'
content_type="application/json",
)
self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id))

View File

@ -33,10 +33,10 @@ class DeviceSubscriptionEndpoint(PublicEndpoint):
if account:
membership = account.membership
response = (
{'@type': membership.type if membership is not None else 'free'},
HTTPStatus.OK
{"@type": membership.type if membership is not None else "free"},
HTTPStatus.OK,
)
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response

View File

@ -25,13 +25,12 @@ from selene.api import PublicEndpoint
class OauthCallbackEndpoint(PublicEndpoint):
def __init__(self):
super(OauthCallbackEndpoint, self).__init__()
self.oauth_service_host = os.environ['OAUTH_BASE_URL']
self.oauth_service_host = os.environ["OAUTH_BASE_URL"]
def get(self):
params = dict(self.request.args)
url = self.oauth_service_host + '/auth/callback'
url = self.oauth_service_host + "/auth/callback"
response = requests.get(url, params=params)
return response.text, response.status_code

View File

@ -30,20 +30,20 @@ class PremiumVoiceEndpoint(PublicEndpoint):
def get(self, device_id):
self._authenticate(device_id)
arch = self.request.args.get('arch')
arch = self.request.args.get("arch")
account = AccountRepository(self.db).get_account_by_device_id(device_id)
if account and account.membership:
link = self._get_premium_voice_link(arch)
response = {'link': link}, HTTPStatus.OK
response = {"link": link}, HTTPStatus.OK
else:
response = '', HTTPStatus.NO_CONTENT
response = "", HTTPStatus.NO_CONTENT
return response
def _get_premium_voice_link(self, arch):
if arch == 'arm':
response = os.environ['URL_VOICE_ARM']
elif arch == 'x86_64':
response = os.environ['URL_VOICE_X86_64']
if arch == "arm":
response = os.environ["URL_VOICE_ARM"]
elif arch == "x86_64":
response = os.environ["URL_VOICE_X86_64"]
else:
response = ''
response = ""
return response

View File

@ -25,14 +25,13 @@ from selene.data.account import AccountRepository
class StripeWebHookEndpoint(PublicEndpoint):
def __init__(self):
super(StripeWebHookEndpoint, self).__init__()
def post(self):
event = json.loads(self.request.data)
type = event.get('type')
if type == 'customer.subscription.deleted':
customer = event['data']['object']['customer']
type = event.get("type")
if type == "customer.subscription.deleted":
customer = event["data"]["object"]["customer"]
AccountRepository(self.db).end_active_membership(customer)
return '', HTTPStatus.OK
return "", HTTPStatus.OK

View File

@ -34,14 +34,14 @@ class WolframAlphaSimpleEndpoint(PublicEndpoint):
def __init__(self):
super(WolframAlphaSimpleEndpoint, self).__init__()
self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY']
self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL']
self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"]
self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"]
def get(self):
self._authenticate()
params = dict(self.request.args)
params['appid'] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + '/v1/simple', params=params)
params["appid"] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + "/v1/simple", params=params)
code = response.status_code
response = (response.text, code) if code == HTTPStatus.OK else ('', code)
response = (response.text, code) if code == HTTPStatus.OK else ("", code)
return response

View File

@ -30,14 +30,14 @@ class WolframAlphaSpokenEndpoint(PublicEndpoint):
def __init__(self):
super(WolframAlphaSpokenEndpoint, self).__init__()
self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY']
self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL']
self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"]
self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"]
def get(self):
self._authenticate()
params = dict(self.request.args)
params['appid'] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + '/v1/spoken', params=params)
params["appid"] = self.wolfram_alpha_key
response = requests.get(self.wolfram_alpha_url + "/v1/spoken", params=params)
code = response.status_code
response = (response.text, code) if code == HTTPStatus.OK else ('', code)
response = (response.text, code) if code == HTTPStatus.OK else ("", code)
return response

View File

@ -26,105 +26,104 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not
from selene.api.etag import ETagManager, device_location_etag_key
@when('a api call to get the location is done')
@when("a api call to get the location is done")
def get_device_location(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('the location should be retrieved')
@then("the location should be retrieved")
def validate_location(context):
response = context.get_location_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
location = json.loads(response.data)
assert_that(location, has_key('coordinate'))
assert_that(location, has_key('timezone'))
assert_that(location, has_key('city'))
assert_that(location, has_key("coordinate"))
assert_that(location, has_key("timezone"))
assert_that(location, has_key("city"))
coordinate = location['coordinate']
assert_that(coordinate, has_key('latitude'))
assert_that(coordinate, has_key('longitude'))
coordinate = location["coordinate"]
assert_that(coordinate, has_key("latitude"))
assert_that(coordinate, has_key("longitude"))
timezone = location['timezone']
assert_that(timezone, has_key('name'))
assert_that(timezone, has_key('code'))
assert_that(timezone, has_key('offset'))
assert_that(timezone, has_key('dstOffset'))
timezone = location["timezone"]
assert_that(timezone, has_key("name"))
assert_that(timezone, has_key("code"))
assert_that(timezone, has_key("offset"))
assert_that(timezone, has_key("dstOffset"))
city = location['city']
assert_that(city, has_key('name'))
assert_that(city, has_key('state'))
city = location["city"]
assert_that(city, has_key("name"))
assert_that(city, has_key("state"))
state = city['state']
assert_that(state, has_key('name'))
assert_that(state, has_key('country'))
assert_that(state, has_key('code'))
state = city["state"]
assert_that(state, has_key("name"))
assert_that(state, has_key("country"))
assert_that(state, has_key("code"))
country = state['country']
assert_that(country, has_key('name'))
assert_that(country, has_key('code'))
country = state["country"]
assert_that(country, has_key("name"))
assert_that(country, has_key("code"))
@given('an expired etag from a location entity')
@given("an expired etag from a location entity")
def expire_location_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
context.expired_location_etag = etag_manager.get(device_location_etag_key(device_id))
device_id = context.device_login["uuid"]
context.expired_location_etag = etag_manager.get(
device_location_etag_key(device_id)
)
etag_manager.expire_device_location_etag_by_device_id(device_id)
@when('try to get the location using the expired etag')
@when("try to get the location using the expired etag")
def get_using_expired_etag(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
device_id = login["uuid"]
access_token = login["accessToken"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': context.expired_location_etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": context.expired_location_etag,
}
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('an etag associated with the location should be created')
@then("an etag associated with the location should be created")
def validate_etag(context):
response = context.get_location_response
new_location_etag = response.headers.get('ETag')
new_location_etag = response.headers.get("ETag")
assert_that(new_location_etag, not_none())
assert_that(new_location_etag, is_not(context.expired_location_etag))
@given('a valid etag from a location entity')
@given("a valid etag from a location entity")
def valid_etag(context):
etag_manager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.valid_location_etag = etag_manager.get(device_location_etag_key(device_id))
@when('try to get the location using a valid etag')
@when("try to get the location using a valid etag")
def get_using_valid_etag(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
device_id = login["uuid"]
access_token = login["accessToken"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': context.valid_location_etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": context.valid_location_etag,
}
context.get_location_response = context.client.get(
'/v1/device/{uuid}/location'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/location".format(uuid=device_id), headers=headers
)
@then('the location endpoint should return 304')
@then("the location endpoint should return 304")
def validate_response_valid_etag(context):
response = context.get_location_response
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))

View File

@ -24,42 +24,42 @@ from behave import when, then
from hamcrest import assert_that, equal_to, has_key, is_not
@when('the session token is refreshed')
@when("the session token is refreshed")
def refresh_token(context):
login = json.loads(context.activate_device_response.data)
refresh = login['refreshToken']
refresh = login["refreshToken"]
context.refresh_token_response = context.client.get(
'/v1/auth/token',
headers={'Authorization': 'Bearer {token}'.format(token=refresh)}
"/v1/auth/token",
headers={"Authorization": "Bearer {token}".format(token=refresh)},
)
@then('a valid new session entity should be returned')
@then("a valid new session entity should be returned")
def validate_refresh_token(context):
response = context.refresh_token_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
new_login = json.loads(response.data)
assert_that(new_login, has_key(equal_to('uuid')))
assert_that(new_login, has_key(equal_to('accessToken')))
assert_that(new_login, has_key(equal_to('refreshToken')))
assert_that(new_login, has_key(equal_to('expiration')))
assert_that(new_login, has_key(equal_to("uuid")))
assert_that(new_login, has_key(equal_to("accessToken")))
assert_that(new_login, has_key(equal_to("refreshToken")))
assert_that(new_login, has_key(equal_to("expiration")))
old_login = json.loads(context.activate_device_response.data)
assert_that(new_login['uuid']), equal_to(old_login['uuid'])
assert_that(new_login['accessToken'], is_not(equal_to(old_login['accessToken'])))
assert_that(new_login['refreshToken'], is_not(equal_to(old_login['refreshToken'])))
assert_that(new_login["uuid"]), equal_to(old_login["uuid"])
assert_that(new_login["accessToken"], is_not(equal_to(old_login["accessToken"])))
assert_that(new_login["refreshToken"], is_not(equal_to(old_login["refreshToken"])))
@when('try to refresh an invalid refresh token')
@when("try to refresh an invalid refresh token")
def refresh_invalid_token(context):
context.refresh_invalid_token_response = context.client.get(
'/v1/auth/token',
headers={'Authorization': 'Bearer {token}'.format(token='123')}
"/v1/auth/token",
headers={"Authorization": "Bearer {token}".format(token="123")},
)
@then('401 status code should be returned')
@then("401 status code should be returned")
def validate_refresh_invalid_token(context):
response = context.refresh_invalid_token_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))

View File

@ -29,193 +29,177 @@ from selene.testing.skill import add_skill, build_label_field, build_text_field
from selene.util.cache import DEVICE_SKILL_ETAG_KEY
@given('skill settings with a new value')
@given("skill settings with a new value")
def change_skill_setting_value(context):
_, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
field_with_value = section['fields'][1]
field_with_value['value'] = 'New device text value'
_, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
field_with_value = section["fields"][1]
field_with_value["value"] = "New device text value"
@given('skill settings with a deleted field')
@given("skill settings with a deleted field")
def delete_field_from_settings(context):
_, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
context.removed_field = section['fields'].pop(1)
context.remaining_field = section['fields'][1]
_, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
context.removed_field = section["fields"].pop(1)
context.remaining_field = section["fields"][1]
@given('a valid device skill E-tag')
@given("a valid device skill E-tag")
def set_skill_setting_etag(context):
context.device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
@given('an expired device skill E-tag')
@given("an expired device skill E-tag")
def expire_skill_setting_etag(context):
valid_device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
context.device_skill_etag = context.etag_manager.expire(
valid_device_skill_etag
)
context.device_skill_etag = context.etag_manager.expire(valid_device_skill_etag)
@given('settings for a skill not assigned to the device')
@given("settings for a skill not assigned to the device")
def add_skill_not_assigned_to_device(context):
foobar_skill, foobar_settings_display = add_skill(
context.db,
skill_global_id='foobar-skill|19.02',
settings_fields=[build_label_field(), build_text_field()]
skill_global_id="foobar-skill|19.02",
settings_fields=[build_label_field(), build_text_field()],
)
section = foobar_settings_display.display_data['skillMetadata']['sections'][0]
field_with_value = section['fields'][1]
field_with_value['value'] = 'New skill text value'
section = foobar_settings_display.display_data["skillMetadata"]["sections"][0]
field_with_value = section["fields"][1]
field_with_value["value"] = "New skill text value"
context.skills.update(foobar=(foobar_skill, foobar_settings_display))
@when('a device requests the settings for its skills')
@when("a device requests the settings for its skills")
def get_device_skill_settings(context):
if hasattr(context, 'device_skill_etag'):
context.request_header[ETAG_REQUEST_HEADER_KEY] = (
context.device_skill_etag
)
if hasattr(context, "device_skill_etag"):
context.request_header[ETAG_REQUEST_HEADER_KEY] = context.device_skill_etag
context.response = context.client.get(
'/v1/device/{device_id}/skill'.format(device_id=context.device_id),
content_type='application/json',
headers=context.request_header
"/v1/device/{device_id}/skill".format(device_id=context.device_id),
content_type="application/json",
headers=context.request_header,
)
@when('the device sends a request to update the {skill} skill settings')
@when("the device sends a request to update the {skill} skill settings")
def update_skill_settings(context, skill):
_, settings_display = context.skills[skill]
context.response = context.client.put(
'/v1/device/{device_id}/skill'.format(device_id=context.device_id),
"/v1/device/{device_id}/skill".format(device_id=context.device_id),
data=json.dumps(settings_display.display_data),
content_type='application/json',
headers=context.request_header
content_type="application/json",
headers=context.request_header,
)
@when('the device requests a skill to be deleted')
@when("the device requests a skill to be deleted")
def delete_skill(context):
foo_skill, _ = context.skills['foo']
foo_skill, _ = context.skills["foo"]
context.response = context.client.delete(
'/v1/device/{device_id}/skill/{skill_gid}'.format(
device_id=context.device_id,
skill_gid=foo_skill.skill_gid
"/v1/device/{device_id}/skill/{skill_gid}".format(
device_id=context.device_id, skill_gid=foo_skill.skill_gid
),
headers=context.request_header
headers=context.request_header,
)
@then('the settings are returned')
@then("the settings are returned")
def validate_response(context):
response = context.response.json
assert_that(len(response), equal_to(2))
foo_skill, foo_settings_display = context.skills['foo']
foo_skill, foo_settings_display = context.skills["foo"]
foo_skill_expected_result = dict(
uuid=foo_skill.id,
skill_gid=foo_skill.skill_gid,
identifier=foo_settings_display.display_data['identifier']
identifier=foo_settings_display.display_data["identifier"],
)
assert_that(foo_skill_expected_result, is_in(response))
bar_skill, bar_settings_display = context.skills['bar']
section = bar_settings_display.display_data['skillMetadata']['sections'][0]
text_field = section['fields'][1]
text_field['value'] = 'Device text value'
checkbox_field = section['fields'][2]
checkbox_field['value'] = 'false'
bar_skill, bar_settings_display = context.skills["bar"]
section = bar_settings_display.display_data["skillMetadata"]["sections"][0]
text_field = section["fields"][1]
text_field["value"] = "Device text value"
checkbox_field = section["fields"][2]
checkbox_field["value"] = "false"
bar_skill_expected_result = dict(
uuid=bar_skill.id,
skill_gid=bar_skill.skill_gid,
identifier=bar_settings_display.display_data['identifier'],
skillMetadata=bar_settings_display.display_data['skillMetadata']
identifier=bar_settings_display.display_data["identifier"],
skillMetadata=bar_settings_display.display_data["skillMetadata"],
)
assert_that(bar_skill_expected_result, is_in(response))
@then('the device skill E-tag is expired')
@then("the device skill E-tag is expired")
def check_for_expired_etag(context):
"""An E-tag is expired by changing its value."""
expired_device_skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
assert_that(
expired_device_skill_etag.decode(),
is_not(equal_to(context.device_skill_etag))
expired_device_skill_etag.decode(), is_not(equal_to(context.device_skill_etag))
)
def _get_device_skill_settings(context):
"""Minimize DB hits and code duplication by getting these values once."""
if not hasattr(context, 'device_skill_settings'):
if not hasattr(context, "device_skill_settings"):
settings_repo = SkillSettingRepository(context.db)
context.device_skill_settings = (
settings_repo.get_skill_settings_for_device(context.device_id)
context.device_skill_settings = settings_repo.get_skill_settings_for_device(
context.device_id
)
context.device_settings_values = [
dss.settings_values for dss in context.device_skill_settings
]
@then('the skill settings are updated with the new value')
@then("the skill settings are updated with the new value")
def validate_updated_skill_setting_value(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(2))
expected_settings_values = dict(
textfield='New device text value',
checkboxfield='false'
)
assert_that(
expected_settings_values,
is_in(context.device_settings_values)
textfield="New device text value", checkboxfield="false"
)
assert_that(expected_settings_values, is_in(context.device_settings_values))
@then('the skill is assigned to the device with the settings populated')
@then("the skill is assigned to the device with the settings populated")
def validate_updated_skill_setting_value(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(3))
expected_settings_values = dict(textfield='New skill text value')
assert_that(
expected_settings_values,
is_in(context.device_settings_values)
)
expected_settings_values = dict(textfield="New skill text value")
assert_that(expected_settings_values, is_in(context.device_settings_values))
@then('an E-tag is generated for these settings')
@then("an E-tag is generated for these settings")
def get_skills_etag(context):
response_headers = context.response.headers
response_etag = response_headers['ETag']
response_etag = response_headers["ETag"]
skill_etag = context.etag_manager.get(
DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id)
)
assert_that(skill_etag.decode(), equal_to(response_etag))
@then('the field is no longer in the skill settings')
@then("the field is no longer in the skill settings")
def validate_skill_setting_field_removed(context):
_get_device_skill_settings(context)
assert_that(len(context.device_skill_settings), equal_to(2))
# The removed field should no longer be in the settings values but the
# value of the field that was not deleted should remain
assert_that(
dict(checkboxfield='false'),
is_in(context.device_settings_values)
)
assert_that(dict(checkboxfield="false"), is_in(context.device_settings_values))
new_section = dict(fields=None)
for device_skill_setting in context.device_skill_settings:
skill_gid = device_skill_setting.settings_display['skill_gid']
if skill_gid.startswith('bar'):
skill_gid = device_skill_setting.settings_display["skill_gid"]
if skill_gid.startswith("bar"):
new_settings_display = device_skill_setting.settings_display
new_skill_definition = new_settings_display['skillMetadata']
new_section = new_skill_definition['sections'][0]
new_skill_definition = new_settings_display["skillMetadata"]
new_section = new_skill_definition["sections"][0]
# The removed field should no longer be in the settings values but the
# value of the field that was not deleted should remain
assert_that(context.removed_field, not is_in(new_section['fields']))
assert_that(context.remaining_field, is_in(new_section['fields']))
assert_that(context.removed_field, not is_in(new_section["fields"]))
assert_that(context.remaining_field, is_in(new_section["fields"]))

View File

@ -27,77 +27,75 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not
from selene.api.etag import ETagManager, device_etag_key
new_fields = dict(
platform='mycroft_mark_1',
coreVersion='19.2.0',
enclosureVersion='1.4.0'
platform="mycroft_mark_1", coreVersion="19.2.0", enclosureVersion="1.4.0"
)
@when('device is retrieved')
@when("device is retrieved")
def get_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
device_id = context.device_login["uuid"]
context.get_device_response = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}".format(uuid=device_id), headers=headers
)
context.device_etag = context.get_device_response.headers.get('ETag')
context.device_etag = context.get_device_response.headers.get("ETag")
@then('a valid device should be returned')
@then("a valid device should be returned")
def validate_response(context):
response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data)
assert_that(device, has_key('uuid'))
assert_that(device, has_key('name'))
assert_that(device, has_key('description'))
assert_that(device, has_key('coreVersion'))
assert_that(device, has_key('enclosureVersion'))
assert_that(device, has_key('platform'))
assert_that(device, has_key('user'))
assert_that(device['user'], has_key('uuid'))
assert_that(device['user']['uuid'], equal_to(context.account.id))
assert_that(device, has_key("uuid"))
assert_that(device, has_key("name"))
assert_that(device, has_key("description"))
assert_that(device, has_key("coreVersion"))
assert_that(device, has_key("enclosureVersion"))
assert_that(device, has_key("platform"))
assert_that(device, has_key("user"))
assert_that(device["user"], has_key("uuid"))
assert_that(device["user"]["uuid"], equal_to(context.account.id))
@when('try to fetch a device without the authorization header')
@when("try to fetch a device without the authorization header")
def get_invalid_device(context):
context.get_invalid_device_response = context.client.get('/v1/device/{uuid}'.format(uuid=str(uuid.uuid4())))
@when('try to fetch a not allowed device')
def get_not_allowed_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
context.get_invalid_device_response = context.client.get(
'/v1/device/{uuid}'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}".format(uuid=str(uuid.uuid4()))
)
@then('a 401 status code should be returned')
@when("try to fetch a not allowed device")
def get_not_allowed_device(context):
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_invalid_device_response = context.client.get(
"/v1/device/{uuid}".format(uuid=str(uuid.uuid4())), headers=headers
)
@then("a 401 status code should be returned")
def validate_invalid_response(context):
response = context.get_invalid_device_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
@when('the device is updated')
@when("the device is updated")
def update_device(context):
login = context.device_login
access_token = login['accessToken']
device_id = login['uuid']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = login["accessToken"]
device_id = login["uuid"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.update_device_response = context.client.patch(
'/v1/device/{uuid}'.format(uuid=device_id),
"/v1/device/{uuid}".format(uuid=device_id),
data=json.dumps(new_fields),
content_type='application_json',
headers=headers
content_type="application_json",
headers=headers,
)
@then('the information should be updated')
@then("the information should be updated")
def validate_update(context):
response = context.update_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
@ -105,74 +103,72 @@ def validate_update(context):
response = context.get_device_response
assert_that(response.status_code, equal_to(HTTPStatus.OK))
device = json.loads(response.data)
assert_that(device, has_key('name'))
assert_that(device['coreVersion'], equal_to(new_fields['coreVersion']))
assert_that(device['enclosureVersion'], equal_to(new_fields['enclosureVersion']))
assert_that(device['platform'], equal_to(new_fields['platform']))
assert_that(device, has_key("name"))
assert_that(device["coreVersion"], equal_to(new_fields["coreVersion"]))
assert_that(device["enclosureVersion"], equal_to(new_fields["enclosureVersion"]))
assert_that(device["platform"], equal_to(new_fields["platform"]))
@given('a device with a valid etag')
@given("a device with a valid etag")
def get_device_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.device_etag = etag_manager.get(device_etag_key(device_id))
@when('try to fetch a device using a valid etag')
@when("try to fetch a device using a valid etag")
def get_device_using_etag(context):
etag = context.device_etag
assert_that(etag, not_none())
access_token = context.device_login['accessToken']
device_uuid = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_uuid = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.response_using_etag = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_uuid),
headers=headers
"/v1/device/{uuid}".format(uuid=device_uuid), headers=headers
)
@then('304 status code should be returned by the device endpoint')
@then("304 status code should be returned by the device endpoint")
def validate_etag(context):
response = context.response_using_etag
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))
@given('a device\'s etag expired by the web ui')
@given("a device's etag expired by the web ui")
def expire_etag(context):
etag_manager: ETagManager = context.etag_manager
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
context.device_etag = etag_manager.get(device_etag_key(device_id))
etag_manager.expire_device_etag_by_device_id(device_id)
@when('try to fetch a device using an expired etag')
@when("try to fetch a device using an expired etag")
def fetch_device_expired_etag(context):
etag = context.device_etag
assert_that(etag, not_none())
access_token = context.device_login['accessToken']
device_uuid = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_uuid = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.response_using_invalid_etag = context.client.get(
'/v1/device/{uuid}'.format(uuid=device_uuid),
headers=headers
"/v1/device/{uuid}".format(uuid=device_uuid), headers=headers
)
@then('should return status 200')
@then("should return status 200")
def validate_status_code(context):
response = context.response_using_invalid_etag
assert_that(response.status_code, equal_to(HTTPStatus.OK))
@then('a new etag')
@then("a new etag")
def validate_new_etag(context):
etag = context.device_etag
response = context.response_using_invalid_etag
etag_from_response = response.headers.get('ETag')
etag_from_response = response.headers.get("ETag")
assert_that(etag, is_not(etag_from_response))

View File

@ -27,115 +27,113 @@ from hamcrest import assert_that, equal_to, has_key, is_not
from selene.api.etag import ETagManager, device_setting_etag_key
@when('try to fetch device\'s setting')
@when("try to fetch device's setting")
def get_device_settings(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers=dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.response_setting = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=device_id), headers=headers
)
@then('a valid setting should be returned')
@then("a valid setting should be returned")
def validate_response_setting(context):
response = context.response_setting
assert_that(response.status_code, equal_to(HTTPStatus.OK))
setting = json.loads(response.data)
assert_that(response.status_code, equal_to(HTTPStatus.OK))
assert_that(setting, has_key('uuid'))
assert_that(setting, has_key('systemUnit'))
assert_that(setting['systemUnit'], equal_to('imperial'))
assert_that(setting, has_key('timeFormat'))
assert_that(setting, has_key('dateFormat'))
assert_that(setting, has_key('optIn'))
assert_that(setting['optIn'], equal_to(True))
assert_that(setting, has_key('ttsSettings'))
tts = setting['ttsSettings']
assert_that(tts, has_key('module'))
assert_that(setting, has_key("uuid"))
assert_that(setting, has_key("systemUnit"))
assert_that(setting["systemUnit"], equal_to("imperial"))
assert_that(setting, has_key("timeFormat"))
assert_that(setting, has_key("dateFormat"))
assert_that(setting, has_key("optIn"))
assert_that(setting["optIn"], equal_to(True))
assert_that(setting, has_key("ttsSettings"))
tts = setting["ttsSettings"]
assert_that(tts, has_key("module"))
@when('the settings endpoint is a called to a not allowed device')
@when("the settings endpoint is a called to a not allowed device")
def get_device_settings(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.get_invalid_setting_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(uuid.uuid4())), headers=headers
)
@then('a 401 status code should be returned for the setting')
@then("a 401 status code should be returned for the setting")
def validate_response(context):
response = context.get_invalid_setting_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
@given('a device\'s setting with a valid etag')
@given("a device's setting with a valid etag")
def get_device_setting_etag(context):
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
@when('try to fetch the device\'s settings using a valid etag')
@when("try to fetch the device's settings using a valid etag")
def get_device_settings_using_etag(context):
etag = context.device_etag
access_token = context.device_login['accessToken']
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_id = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.get_setting_etag_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(device_id)),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers
)
@then('304 status code should be returned by the device\'s settings endpoint')
@then("304 status code should be returned by the device's settings endpoint")
def validate_etag_response(context):
response = context.get_setting_etag_response
assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED))
@given('a device\'s setting etag expired by the web ui at device level')
@given("a device's setting etag expired by the web ui at device level")
def expire_etag_device_level(context):
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
etag_manager.expire_device_setting_etag_by_device_id(device_id)
@given('a device\'s setting etag expired by the web ui at account level')
@given("a device's setting etag expired by the web ui at account level")
def expire_etag_account_level(context):
account_id = context.account.id
device_id = context.device_login['uuid']
device_id = context.device_login["uuid"]
etag_manager: ETagManager = context.etag_manager
context.device_etag = etag_manager.get(device_setting_etag_key(device_id))
etag_manager.expire_device_setting_etag_by_account_id(account_id)
@when('try to fetch the device\'s settings using an expired etag')
@when("try to fetch the device's settings using an expired etag")
def get_device_settings_using_etag(context):
etag = context.device_etag
access_token = context.device_login['accessToken']
device_id = context.device_login['uuid']
access_token = context.device_login["accessToken"]
device_id = context.device_login["uuid"]
headers = {
'Authorization': 'Bearer {token}'.format(token=access_token),
'If-None-Match': etag
"Authorization": "Bearer {token}".format(token=access_token),
"If-None-Match": etag,
}
context.get_setting_invalid_etag_response = context.client.get(
'/v1/device/{uuid}/setting'.format(uuid=str(device_id)),
headers=headers
"/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers
)
@then('200 status code should be returned by the device\'s setting endpoint and a new etag')
@then(
"200 status code should be returned by the device's setting endpoint and a new etag"
)
def validate_new_etag(context):
etag = context.device_etag
response = context.get_setting_invalid_etag_response
etag_from_response = response.headers.get('ETag')
etag_from_response = response.headers.get("ETag")
assert_that(etag, is_not(etag_from_response))

View File

@ -29,66 +29,63 @@ from selene.data.account import AccountRepository, AccountMembership
from selene.util.db import connect_to_db
@when('the subscription endpoint is called')
@when("the subscription endpoint is called")
def get_device_subscription(context):
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers
)
@then('free type should be returned')
@then("free type should be returned")
def validate_response(context):
response = context.subscription_response
assert_that(response.status_code, HTTPStatus.OK)
subscription = json.loads(response.data)
assert_that(subscription, has_entry('@type', 'free'))
assert_that(subscription, has_entry("@type", "free"))
@when('the subscription endpoint is called for a monthly account')
@when("the subscription endpoint is called for a monthly account")
def get_device_subscription(context):
membership = AccountMembership(
start_date=date.today(),
type='Monthly Membership',
payment_method='Stripe',
payment_account_id='test_monthly',
payment_id='stripe_id'
type="Monthly Membership",
payment_method="Stripe",
payment_account_id="test_monthly",
payment_id="stripe_id",
)
login = context.device_login
device_id = login['uuid']
access_token = login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG'])
device_id = login["uuid"]
access_token = login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
db = connect_to_db(context.client_config["DB_CONNECTION_CONFIG"])
AccountRepository(db).add_membership(context.account.id, membership)
context.subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=device_id),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers
)
@then('monthly type should be returned')
@then("monthly type should be returned")
def validate_response_monthly(context):
response = context.subscription_response
assert_that(response.status_code, HTTPStatus.OK)
subscription = json.loads(response.data)
assert_that(subscription, has_entry('@type', 'Monthly Membership'))
assert_that(subscription, has_entry("@type", "Monthly Membership"))
@when('try to get the subscription for a nonexistent device')
@when("try to get the subscription for a nonexistent device")
def get_subscription_nonexistent_device(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
access_token = context.device_login["accessToken"]
headers = dict(Authorization="Bearer {token}".format(token=access_token))
context.invalid_subscription_response = context.client.get(
'/v1/device/{uuid}/subscription'.format(uuid=str(uuid.uuid4())),
headers=headers
"/v1/device/{uuid}/subscription".format(uuid=str(uuid.uuid4())), headers=headers
)
@then('401 status code should be returned for the subscription endpoint')
@then("401 status code should be returned for the subscription endpoint")
def validate_nonexistent_device(context):
response = context.invalid_subscription_response
assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -34,6 +34,7 @@ from selene.util.auth import AuthenticationError
class AuthenticateInternalEndpoint(SeleneEndpoint):
"""Sign in a user with an email address and password."""
def __init__(self):
super(AuthenticateInternalEndpoint, self).__init__()
self.account: Account = None
@ -44,22 +45,21 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
self._generate_tokens()
self._set_token_cookies()
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT
def _authenticate_credentials(self):
"""Compare credentials in request to credentials in database.
:raises AuthenticationError when no match found on database
"""
basic_credentials = self.request.headers['authorization']
basic_credentials = self.request.headers["authorization"]
binary_credentials = a2b_base64(basic_credentials[6:])
email_address, password = binary_credentials.decode().split(':||:')
email_address, password = binary_credentials.decode().split(":||:")
acct_repository = AccountRepository(self.db)
self.account = acct_repository.get_account_from_credentials(
email_address,
password
email_address, password
)
if self.account is None:
raise AuthenticationError('provided credentials not found')
raise AuthenticationError("provided credentials not found")
self.access_token.account_id = self.account.id
self.refresh_token.account_id = self.account.id

View File

@ -26,8 +26,7 @@ from selene.util.auth import get_github_authentication_token
class GithubTokenEndpoint(SeleneEndpoint):
def get(self):
token = get_github_authentication_token(
self.request.args['code'],
self.request.args['state']
self.request.args["code"], self.request.args["state"]
)
return dict(token=token), HTTPStatus.OK

View File

@ -26,11 +26,11 @@ from selene.data.account import AccountRepository
class PasswordChangeEndpoint(SeleneEndpoint):
def put(self):
account_id = self.request.json['accountId']
coded_password = self.request.json['password']
account_id = self.request.json["accountId"]
coded_password = self.request.json["password"]
binary_password = a2b_base64(coded_password)
password = binary_password.decode()
acct_repository = AccountRepository(self.db)
acct_repository.change_password(account_id, password)
return '', HTTPStatus.NO_CONTENT
return "", HTTPStatus.NO_CONTENT

View File

@ -37,44 +37,40 @@ class PasswordResetEndpoint(SeleneEndpoint):
reset_token = self._generate_reset_token()
self._send_reset_email(reset_token)
return '', HTTPStatus.OK
return "", HTTPStatus.OK
def _get_account_from_email(self):
acct_repository = AccountRepository(self.db)
self.account = acct_repository.get_account_by_email(
self.request.json['emailAddress']
self.request.json["emailAddress"]
)
def _generate_reset_token(self):
reset_token = AuthenticationToken(
self.config['RESET_SECRET'],
ONE_HOUR
)
reset_token = AuthenticationToken(self.config["RESET_SECRET"], ONE_HOUR)
reset_token.generate(self.account.id)
return reset_token.jwt
def _send_reset_email(self, reset_token):
url = '{base_url}/change-password?token={reset_token}'.format(
base_url=os.environ['SSO_BASE_URL'],
reset_token=reset_token
url = "{base_url}/change-password?token={reset_token}".format(
base_url=os.environ["SSO_BASE_URL"], reset_token=reset_token
)
email = EmailMessage(
recipient=self.request.json['emailAddress'],
sender='Mycroft AI<no-reply@mycroft.ai>',
subject='Password Reset Request',
template_file_name='reset_password.html',
template_variables=dict(reset_password_url=url)
recipient=self.request.json["emailAddress"],
sender="Mycroft AI<no-reply@mycroft.ai>",
subject="Password Reset Request",
template_file_name="reset_password.html",
template_variables=dict(reset_password_url=url),
)
mailer = SeleneMailer(email)
mailer.send()
def _send_account_not_found_email(self):
email = EmailMessage(
recipient=self.request.json['emailAddress'],
sender='Mycroft AI<no-reply@mycroft.ai>',
subject='Password Reset Request',
template_file_name='account_not_found.html'
recipient=self.request.json["emailAddress"],
sender="Mycroft AI<no-reply@mycroft.ai>",
subject="Password Reset Request",
template_file_name="account_not_found.html",
)
mailer = SeleneMailer(email)
mailer.send()

View File

@ -25,16 +25,16 @@ from selene.data.account import AccountRepository
from selene.util.auth import (
get_facebook_account_email,
get_github_account_email,
get_google_account_email
get_google_account_email,
)
class ValidateEmailEndpoint(SeleneEndpoint):
def get(self):
return_data = dict(accountExists=False, noFederatedEmail=False)
if self.request.args['token']:
if self.request.args["token"]:
email_address = self._get_email_address()
if self.request.args['platform'] != 'Internal' and not email_address:
if self.request.args["platform"] != "Internal" and not email_address:
return_data.update(noFederatedEmail=True)
account_repository = AccountRepository(self.db)
account = account_repository.get_account_by_email(email_address)
@ -44,20 +44,14 @@ class ValidateEmailEndpoint(SeleneEndpoint):
return return_data, HTTPStatus.OK
def _get_email_address(self):
if self.request.args['platform'] == 'Google':
email_address = get_google_account_email(
self.request.args['token']
)
elif self.request.args['platform'] == 'Facebook':
email_address = get_facebook_account_email(
self.request.args['token']
)
elif self.request.args['platform'] == 'GitHub':
email_address = get_github_account_email(
self.request.args['token']
)
if self.request.args["platform"] == "Google":
email_address = get_google_account_email(self.request.args["token"])
elif self.request.args["platform"] == "Facebook":
email_address = get_facebook_account_email(self.request.args["token"])
elif self.request.args["platform"] == "GitHub":
email_address = get_github_account_email(self.request.args["token"])
else:
coded_email = self.request.args['token']
coded_email = self.request.args["token"]
email_address = a2b_base64(coded_email).decode()
return email_address

View File

@ -28,15 +28,12 @@ class ValidateTokenEndpoint(SeleneEndpoint):
return response_data, HTTPStatus.OK
def _validate_token(self):
auth_token = AuthenticationToken(
self.config['RESET_SECRET'],
duration=0
)
auth_token.jwt = self.request.json['token']
auth_token = AuthenticationToken(self.config["RESET_SECRET"], duration=0)
auth_token.jwt = self.request.json["token"]
auth_token.validate()
return dict(
account_id=auth_token.account_id,
token_expired=auth_token.is_expired,
token_invalid=not auth_token.is_valid
token_invalid=not auth_token.is_valid,
)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -30,13 +30,14 @@ from datetime import date, timedelta
import schedule
from selene.util.log import configure_logger
from selene.util.log import configure_selene_logger
_log = configure_logger('selene_job_scheduler')
_log = configure_selene_logger("job_scheduler")
class JobRunner(object):
"""Build the command to run a batch job and run it via subprocess."""
def __init__(self, script_name: str):
self.script_name = script_name
self.job_args: str = None
@ -55,17 +56,14 @@ class JobRunner(object):
date argument only needs to be specified when it is not current date.
"""
if self.job_args is None:
self.job_args = ''
date_arg = ' --date ' + str(self.job_date)
self.job_args = ""
date_arg = " --date " + str(self.job_date)
self.job_args += date_arg
def _build_command(self):
"""Build the command to run the script."""
command = ['pipenv', 'run', 'python']
script_path = os.path.join(
os.environ['SELENE_SCRIPT_DIR'],
self.script_name
)
command = ["pipenv", "run", "python"]
script_path = os.path.join(os.environ["SELENE_SCRIPT_DIR"], self.script_name)
command.append(script_path)
if self.job_args is not None:
command.extend(self.job_args.split())
@ -78,31 +76,31 @@ class JobRunner(object):
result = subprocess.run(command, capture_output=True)
if result.returncode:
_log.error(
'Job {job_name} failed\n'
'\tSTDOUT - {stdout}'
'\tSTDERR - {stderr}'.format(
"Job {job_name} failed\n"
"\tSTDOUT - {stdout}"
"\tSTDERR - {stderr}".format(
job_name=self.script_name[:-3],
stdout=result.stdout.decode(),
stderr=result.stderr.decode()
stderr=result.stderr.decode(),
)
)
else:
log_msg = 'Job {job_name} completed successfully'
log_msg = "Job {job_name} completed successfully"
_log.info(log_msg.format(job_name=self.script_name[:-3]))
def test_scheduler():
"""Run in non-production environments to test scheduler functionality."""
job_runner = JobRunner('test_scheduler.py')
job_runner = JobRunner("test_scheduler.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.job_args = '--arg-with-value test --arg-no-value'
job_runner.job_args = "--arg-with-value test --arg-no-value"
job_runner.run_job()
def load_skills(version):
"""Load the json file from the mycroft-skills-data repository to the DB"""
job_runner = JobRunner('load_skill_display_data.py')
job_runner.job_args = '--core-version {}'.format(version)
job_runner = JobRunner("load_skill_display_data.py")
job_runner.job_args = "--core-version {}".format(version)
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -112,7 +110,7 @@ def parse_core_metrics():
Build a de-normalized table that will make latency research easier.
"""
job_runner = JobRunner('parse_core_metrics.py')
job_runner = JobRunner("parse_core_metrics.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -123,7 +121,7 @@ def partition_api_metrics():
Build a partition on the metric.api_history table for yesterday's date.
Copy yesterday's metric.api table rows to the partition.
"""
job_runner = JobRunner('partition_api_metrics.py')
job_runner = JobRunner("partition_api_metrics.py")
job_runner.job_date = date.today() - timedelta(days=1)
job_runner.run_job()
@ -135,22 +133,22 @@ def update_device_last_contact():
to associate the time of the call with the device. Dump the contents of
the Redis data to the device.device table on the Postgres database.
"""
job_runner = JobRunner('update_device_last_contact.py')
job_runner = JobRunner("update_device_last_contact.py")
job_runner.run_job()
# Define the schedule
if os.environ['SELENE_ENVIRONMENT'] != 'prod':
if os.environ["SELENE_ENVIRONMENT"] != "prod":
schedule.every(5).minutes.do(test_scheduler)
schedule.every().day.at('00:00').do(partition_api_metrics)
schedule.every().day.at('00:05').do(update_device_last_contact)
schedule.every().day.at('00:10').do(parse_core_metrics)
schedule.every().day.at('00:15').do(load_skills, version='19.02')
schedule.every().day.at('00:20').do(load_skills, version='19.08')
schedule.every().day.at('00:25').do(load_skills, version='20.02')
schedule.every().day.at('00:25').do(load_skills, version='20.08')
schedule.every().day.at('00:25').do(load_skills, version='21.02')
schedule.every().day.at("00:00").do(partition_api_metrics)
schedule.every().day.at("00:05").do(update_device_last_contact)
schedule.every().day.at("00:10").do(parse_core_metrics)
schedule.every().day.at("00:15").do(load_skills, version="19.02")
schedule.every().day.at("00:20").do(load_skills, version="19.08")
schedule.every().day.at("00:25").do(load_skills, version="20.02")
schedule.every().day.at("00:25").do(load_skills, version="20.08")
schedule.every().day.at("00:25").do(load_skills, version="21.02")
# Run the schedule
while True:

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -30,12 +30,12 @@ from selene.util.db import DatabaseConnectionConfig
from selene.util.email import EmailMessage, SeleneMailer
mycroft_db = DatabaseConnectionConfig(
host=environ['DB_HOST'],
db_name=environ['DB_NAME'],
user=environ['DB_USER'],
password=environ['DB_PASSWORD'],
port=environ['DB_PORT'],
sslmode=environ['DB_SSL_MODE']
host=environ["DB_HOST"],
db_name=environ["DB_NAME"],
user=environ["DB_USER"],
password=environ["DB_PASSWORD"],
port=int(environ["DB_PORT"]),
sslmode=environ["DB_SSL_MODE"],
)
@ -43,16 +43,16 @@ class DailyReport(SeleneScript):
def __init__(self):
super(DailyReport, self).__init__(__file__)
self._arg_parser.add_argument(
'--run-mode',
help='If the script should run as a job or just once',
choices=['job', 'once'],
"--run-mode",
help="If the script should run as a job or just once",
choices=["job", "once"],
type=str,
default='job'
default="job",
)
def _run(self):
if self.args.run_mode == 'job':
schedule.every().day.at('00:00').do(self._build_report)
if self.args.run_mode == "job":
schedule.every().day.at("00:00").do(self._build_report)
while True:
schedule.run_pending()
time.sleep(1)
@ -65,11 +65,11 @@ class DailyReport(SeleneScript):
user_metrics = AccountRepository(self.db).daily_report(date)
email = EmailMessage(
sender='reports@mycroft.ai',
recipient=os.environ['REPORT_RECIPIENT'],
subject='Mycroft Daily Report - {}'.format(date.strftime('%Y-%m-%d')),
template_file_name='metrics.html',
template_variables=dict(user_metrics=user_metrics)
sender="reports@mycroft.ai",
recipient=os.environ["REPORT_RECIPIENT"],
subject="Mycroft Daily Report - {}".format(date.strftime("%Y-%m-%d")),
template_file_name="metrics.html",
template_variables=dict(user_metrics=user_metrics),
)
mailer = SeleneMailer(email)

View File

@ -30,17 +30,13 @@ import json
from os import environ
from selene.batch import SeleneScript
from selene.data.skill import (
SkillDisplay,
SkillDisplayRepository,
SkillRepository
)
from selene.data.skill import SkillDisplay, SkillDisplayRepository, SkillRepository
from selene.util.github import download_repository_file, log_into_github
GITHUB_USER = environ['GITHUB_USER']
GITHUB_PASSWORD = environ['GITHUB_PASSWORD']
SKILL_DATA_GITHUB_REPO = 'mycroft-skills-data'
SKILL_DATA_FILE_NAME = 'skill-metadata.json'
GITHUB_USER = environ["GITHUB_USER"]
GITHUB_PASSWORD = environ["GITHUB_PASSWORD"]
SKILL_DATA_GITHUB_REPO = "mycroft-skills-data"
SKILL_DATA_FILE_NAME = "skill-metadata.json"
class SkillDisplayUpdater(SeleneScript):
@ -52,16 +48,15 @@ class SkillDisplayUpdater(SeleneScript):
super(SkillDisplayUpdater, self)._define_args()
self._arg_parser.add_argument(
"--core-version",
help='Version of Mycroft Core related to skill display data',
help="Version of Mycroft Core related to skill display data",
required=True,
type=str
type=str,
)
def _run(self):
"""Make it so."""
self.log.info(
"Updating skill display data for core version " +
self.args.core_version
"Updating skill display data for core version " + self.args.core_version
)
self._get_skill_display_data()
self._update_skill_display_table()
@ -73,7 +68,7 @@ class SkillDisplayUpdater(SeleneScript):
github_api,
SKILL_DATA_GITHUB_REPO,
self.args.core_version,
SKILL_DATA_FILE_NAME
SKILL_DATA_FILE_NAME,
)
self.skill_display_data = json.loads(file_contents)
@ -83,20 +78,18 @@ class SkillDisplayUpdater(SeleneScript):
display_repository = SkillDisplayRepository(self.db)
for skill_name, skill_metadata in self.skill_display_data.items():
skill_count += 1
skill_id = skill_repository.ensure_skill_exists(
skill_metadata['skill_gid']
)
skill_id = skill_repository.ensure_skill_exists(skill_metadata["skill_gid"])
# add the skill display row
display_data = SkillDisplay(
skill_id=skill_id,
core_version=self.args.core_version,
display_data=json.dumps(skill_metadata)
display_data=json.dumps(skill_metadata),
)
display_repository.upsert(display_data)
self.log.info("updated {} skills".format(skill_count))
if __name__ == '__main__':
if __name__ == "__main__":
SkillDisplayUpdater().run()

View File

@ -32,7 +32,7 @@ from decimal import Decimal
from selene.batch import SeleneScript
from selene.data.metric import CoreMetricRepository, CoreInteraction
SKILL_HANDLERS_TO_SKIP = ('reset', 'notify', 'prime', 'stop_laugh')
SKILL_HANDLERS_TO_SKIP = ("reset", "notify", "prime", "stop_laugh")
class CoreMetricsParser(SeleneScript):
@ -47,7 +47,7 @@ class CoreMetricsParser(SeleneScript):
def _run(self):
last_interaction_id = None
for metric in self.core_metric_repo.get_metrics_by_date(self.args.date):
if metric.metric_value['id'] != last_interaction_id:
if metric.metric_value["id"] != last_interaction_id:
self._add_interaction_to_db()
self._start_new_interaction(metric)
last_interaction_id = self.interaction.core_id
@ -57,48 +57,44 @@ class CoreMetricsParser(SeleneScript):
def _start_new_interaction(self, metric):
"""Initialize the interaction object"""
self.interaction = CoreInteraction(
core_id=metric.metric_value['id'],
core_id=metric.metric_value["id"],
device_id=metric.device_id,
start_ts=datetime.utcfromtimestamp(
metric.metric_value['start_time']
),
start_ts=datetime.utcfromtimestamp(metric.metric_value["start_time"]),
)
self.stt_start_ts = None
self.playback_start_ts = None
def _add_metric_to_interaction(self, metric_value):
"""Combine all the steps of an interaction into a single record"""
duration = Decimal(str(metric_value['time']))
duration = duration.quantize(Decimal('0.000001'))
if metric_value['system'] == 'stt':
self.interaction.stt_engine = metric_value['stt']
self.interaction.stt_transcription = metric_value['transcription']
duration = Decimal(str(metric_value["time"]))
duration = duration.quantize(Decimal("0.000001"))
if metric_value["system"] == "stt":
self.interaction.stt_engine = metric_value["stt"]
self.interaction.stt_transcription = metric_value["transcription"]
self.interaction.stt_duration = duration
self.stt_start_ts = metric_value['start_time']
elif metric_value['system'] == 'intent_service':
self.interaction.intent_type = metric_value['intent_type']
self.stt_start_ts = metric_value["start_time"]
elif metric_value["system"] == "intent_service":
self.interaction.intent_type = metric_value["intent_type"]
self.interaction.intent_duration = duration
elif metric_value['system'] == 'fallback_handler':
elif metric_value["system"] == "fallback_handler":
self.interaction.fallback_handler_duration = duration
elif metric_value['system'] == 'skill_handler':
if metric_value['handler'] not in SKILL_HANDLERS_TO_SKIP:
self.interaction.skill_handler = metric_value['handler']
elif metric_value["system"] == "skill_handler":
if metric_value["handler"] not in SKILL_HANDLERS_TO_SKIP:
self.interaction.skill_handler = metric_value["handler"]
self.interaction.skill_duration = duration
elif metric_value['system'] == 'speech':
self.interaction.tts_engine = metric_value['tts']
self.interaction.tts_utterance = metric_value['utterance']
elif metric_value["system"] == "speech":
self.interaction.tts_engine = metric_value["tts"]
self.interaction.tts_utterance = metric_value["utterance"]
self.interaction.tts_duration = duration
elif metric_value['system'] == 'speech_playback':
elif metric_value["system"] == "speech_playback":
self.interaction.speech_playback_duration = duration
self.playback_start_ts = metric_value['start_time']
self.playback_start_ts = metric_value["start_time"]
# The user-experienced latency is the time between when the user
# finishes speaking their intent and when the device provides a voice
# response.
if self.stt_start_ts is not None and self.playback_start_ts is not None:
self.interaction.user_latency = (
self.playback_start_ts - self.stt_start_ts
)
self.interaction.user_latency = self.playback_start_ts - self.stt_start_ts
def _add_interaction_to_db(self):
if self.interaction is not None:
@ -107,5 +103,5 @@ class CoreMetricsParser(SeleneScript):
self.core_metric_repo.add_interaction(self.interaction)
if __name__ == '__main__':
if __name__ == "__main__":
CoreMetricsParser().run()

View File

@ -40,5 +40,5 @@ class PartitionApiMetrics(SeleneScript):
api_metrics_repo.remove_by_date(self.args.date)
if __name__ == '__main__':
if __name__ == "__main__":
PartitionApiMetrics().run()

View File

@ -42,24 +42,24 @@ class TestScheduler(SeleneScript):
super(TestScheduler, self)._define_args()
self._arg_parser.add_argument(
"--arg-with-value",
help='Argument to test passing a value with an argument',
help="Argument to test passing a value with an argument",
required=True,
type=str
type=str,
)
self._arg_parser.add_argument(
"--arg-no-value",
help='Argument to test passing a value with an argument',
action="store_true"
help="Argument to test passing a value with an argument",
action="store_true",
)
def _run(self):
self.log.info('Running the scheduler test job')
self.log.info("Running the scheduler test job")
assert self.args.arg_no_value
assert self.args.arg_with_value == 'test'
assert self.args.arg_with_value == "test"
# Tests the logic that overrides the default date in the scheduler.
assert self.args.date == date.today() - timedelta(days=1)
if __name__ == '__main__':
if __name__ == "__main__":
TestScheduler().run()

View File

@ -47,21 +47,18 @@ class UpdateDeviceLastContact(SeleneScript):
devices_updated += 1
device_repo.update_last_contact_ts(device.id, last_contact_ts)
self.log.info(str(devices_updated) + ' devices were active today')
self.log.info(str(devices_updated) + " devices were active today")
def _get_ts_from_cache(self, device_id):
last_contact_ts = None
cache_key = DEVICE_LAST_CONTACT_KEY.format(device_id=device_id)
value = self.cache.get(cache_key)
if value is not None:
last_contact_ts = datetime.strptime(
value.decode(),
'%Y-%m-%d %H:%M:%S.%f'
)
last_contact_ts = datetime.strptime(value.decode(), "%Y-%m-%d %H:%M:%S.%f")
self.cache.delete(cache_key)
return last_contact_ts
if __name__ == '__main__':
if __name__ == "__main__":
UpdateDeviceLastContact().run()

View File

@ -108,7 +108,7 @@ class PostgresDB(object):
sslmode=db_ssl_mode,
)
self.db.autocommit = True
self.db.set_client_encoding('UTF8')
self.db.set_client_encoding("UTF8")
def close_db(self):
self.db.close()

File diff suppressed because it is too large Load Diff

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -48,32 +48,33 @@ class APIConfigError(Exception):
class BaseConfig(object):
"""Base configuration."""
ACCESS_SECRET = os.environ['JWT_ACCESS_SECRET']
ACCESS_SECRET = os.environ["JWT_ACCESS_SECRET"]
DB_CONNECTION_POOL = None
DEBUG = False
ENV = os.environ['SELENE_ENVIRONMENT']
REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET']
ENV = os.environ["SELENE_ENVIRONMENT"]
REFRESH_SECRET = os.environ["JWT_REFRESH_SECRET"]
DB_CONNECTION_CONFIG = DatabaseConnectionConfig(
host=os.environ['DB_HOST'],
db_name=os.environ['DB_NAME'],
password=os.environ['DB_PASSWORD'],
port=os.environ.get('DB_PORT', 5432),
user=os.environ['DB_USER'],
sslmode=os.environ.get('DB_SSLMODE')
host=os.environ["DB_HOST"],
db_name=os.environ["DB_NAME"],
password=os.environ["DB_PASSWORD"],
port=os.environ.get("DB_PORT", 5432),
user=os.environ["DB_USER"],
sslmode=os.environ.get("DB_SSLMODE"),
)
class DevelopmentConfig(BaseConfig):
DEBUG = True
DOMAIN = '.mycroft.test'
DOMAIN = ".mycroft.test"
class TestConfig(BaseConfig):
DOMAIN = '.mycroft-test.net'
DOMAIN = ".mycroft-test.net"
class ProdConfig(BaseConfig):
DOMAIN = '.mycroft.ai'
DOMAIN = ".mycroft.ai"
def get_base_config():
@ -81,16 +82,12 @@ def get_base_config():
:return: an object containing the configs for the API.
"""
environment_configs = dict(
dev=DevelopmentConfig,
test=TestConfig,
prod=ProdConfig
)
environment_configs = dict(dev=DevelopmentConfig, test=TestConfig, prod=ProdConfig)
try:
environment_name = os.environ['SELENE_ENVIRONMENT']
environment_name = os.environ["SELENE_ENVIRONMENT"]
except KeyError:
raise APIConfigError('the SELENE_ENVIRONMENT variable is not set')
raise APIConfigError("the SELENE_ENVIRONMENT variable is not set")
try:
app_config = environment_configs[environment_name]

View File

@ -31,7 +31,7 @@ from selene.util.cache import DEVICE_LAST_CONTACT_KEY
from selene.util.db import connect_to_db
from selene.util.exceptions import NotModifiedException
selene_api = Blueprint('selene_api', __name__)
selene_api = Blueprint("selene_api", __name__)
@selene_api.app_errorhandler(DataError)
@ -46,7 +46,7 @@ def handle_data_error(error):
@selene_api.app_errorhandler(NotModifiedException)
def handle_not_modified(_):
return '', HTTPStatus.NOT_MODIFIED
return "", HTTPStatus.NOT_MODIFIED
@selene_api.before_app_request
@ -67,26 +67,26 @@ def add_api_metric(http_status):
api = None
# We are not logging metric for the public API until after the socket
# implementation to avoid putting millions of rows a day on the table
for api_name in ('account', 'sso', 'market', 'public'):
for api_name in ("account", "sso", "market", "public"):
if api_name in current_app.name:
api = api_name
if api is not None and int(http_status) != 304:
if 'db' not in global_context:
if "db" not in global_context:
global_context.db = connect_to_db(
current_app.config['DB_CONNECTION_CONFIG']
current_app.config["DB_CONNECTION_CONFIG"]
)
if 'account_id' in global_context:
if "account_id" in global_context:
account_id = global_context.account_id
else:
account_id = None
if 'device_id' in global_context:
if "device_id" in global_context:
device_id = global_context.device_id
else:
device_id = None
duration = (datetime.utcnow() - global_context.start_ts)
duration = datetime.utcnow() - global_context.start_ts
api_metric = ApiMetric(
access_ts=datetime.utcnow(),
account_id=account_id,
@ -95,7 +95,7 @@ def add_api_metric(http_status):
duration=Decimal(str(duration.total_seconds())),
http_method=request.method,
http_status=int(http_status),
url=global_context.url
url=global_context.url,
)
metric_repository = ApiMetricsRepository(global_context.db)
metric_repository.add(api_metric)
@ -107,6 +107,6 @@ def update_device_last_contact():
This should only be done on public API calls because we are tracking
device activity only.
"""
if 'public' in current_app.name and 'device_id' in global_context:
if "public" in current_app.name and "device_id" in global_context:
key = DEVICE_LAST_CONTACT_KEY.format(device_id=global_context.device_id)
global_context.cache.set(key, str(datetime.utcnow()))

View File

@ -29,20 +29,20 @@ from ..base_endpoint import SeleneEndpoint
class AgreementsEndpoint(SeleneEndpoint):
authentication_required = False
agreement_types = {
'terms-of-use': 'Terms of Use',
'privacy-policy': 'Privacy Policy'
"terms-of-use": "Terms of Use",
"privacy-policy": "Privacy Policy",
}
def get(self, agreement_type):
"""Process HTTP GET request for an agreement."""
db = connect_to_db(self.config['DB_CONNECTION_CONFIG'])
db = connect_to_db(self.config["DB_CONNECTION_CONFIG"])
agreement_repository = AgreementRepository(db)
agreement = agreement_repository.get_active_for_type(
self.agreement_types[agreement_type]
)
if agreement is not None:
agreement = asdict(agreement)
del(agreement['effective_date'])
del agreement["effective_date"]
self.response = agreement, HTTPStatus.OK
return self.response

View File

@ -24,19 +24,19 @@ from selene.data.device import DeviceRepository
from selene.util.cache import SeleneCache, DEVICE_SKILL_ETAG_KEY
from selene.util.db import connect_to_db
ETAG_REQUEST_HEADER_KEY = 'If-None-Match'
ETAG_REQUEST_HEADER_KEY = "If-None-Match"
def device_etag_key(device_id: str):
return 'device.etag:{uuid}'.format(uuid=device_id)
return "device.etag:{uuid}".format(uuid=device_id)
def device_setting_etag_key(device_id: str):
return 'device.setting.etag:{uuid}'.format(uuid=device_id)
return "device.setting.etag:{uuid}".format(uuid=device_id)
def device_location_etag_key(device_id: str):
return 'device.location.etag:{uuid}'.format(uuid=device_id)
return "device.location.etag:{uuid}".format(uuid=device_id)
class ETagManager(object):
@ -46,7 +46,7 @@ class ETagManager(object):
def __init__(self, cache: SeleneCache, config: dict):
self.cache: SeleneCache = cache
self.db_connection_config = config['DB_CONNECTION_CONFIG']
self.db_connection_config = config["DB_CONNECTION_CONFIG"]
def get(self, key: str) -> str:
"""Generate a etag with 32 random chars and store it into a given key
@ -54,14 +54,14 @@ class ETagManager(object):
:return etag"""
etag = self.cache.get(key)
if etag is None:
etag = ''.join(random.choice(self.etag_chars) for _ in range(32))
etag = "".join(random.choice(self.etag_chars) for _ in range(32))
self.cache.set(key, etag)
return etag
def expire(self, key):
"""Expires an existent etag
:param key: key where the etag is stored"""
etag = ''.join(random.choice(self.etag_chars) for _ in range(32))
etag = "".join(random.choice(self.etag_chars) for _ in range(32))
self.cache.set(key, etag)
def expire_device_etag_by_device_id(self, device_id: str):

View File

@ -22,7 +22,7 @@ import re
from flask import jsonify, Response
snake_pattern = re.compile(r'_([a-z])')
snake_pattern = re.compile(r"_([a-z])")
def snake_to_camel(name):

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -18,12 +18,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .entity.account import Account, AccountAgreement, AccountMembership
from .entity.agreement import (
Agreement,
PRIVACY_POLICY,
TERMS_OF_USE,
OPEN_DATASET
)
from .entity.agreement import Agreement, PRIVACY_POLICY, TERMS_OF_USE, OPEN_DATASET
from .entity.membership import Membership
from .entity.skill import AccountSkill
from .repository.account import AccountRepository
@ -31,6 +26,6 @@ from .repository.agreement import AgreementRepository
from .repository.membership import (
MembershipRepository,
MONTHLY_MEMBERSHIP,
YEARLY_MEMBERSHIP
YEARLY_MEMBERSHIP,
)
from .repository.skill import AccountSkillRepository

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -20,9 +20,9 @@
from dataclasses import dataclass
from datetime import date
TERMS_OF_USE = 'Terms of Use'
PRIVACY_POLICY = 'Privacy Policy'
OPEN_DATASET = 'Open Dataset'
TERMS_OF_USE = "Terms of Use"
PRIVACY_POLICY = "Privacy Policy"
OPEN_DATASET = "Open Dataset"
@dataclass

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -21,8 +21,8 @@ from selene.data.account import AccountMembership
from ..entity.membership import Membership
from ...repository_base import RepositoryBase
MONTHLY_MEMBERSHIP = 'Monthly Membership'
YEARLY_MEMBERSHIP = 'Yearly Membership'
MONTHLY_MEMBERSHIP = "Monthly Membership"
YEARLY_MEMBERSHIP = "Yearly Membership"
class MembershipRepository(RepositoryBase):
@ -30,37 +30,34 @@ class MembershipRepository(RepositoryBase):
super(MembershipRepository, self).__init__(db, __file__)
def get_membership_types(self):
db_request = self._build_db_request(
sql_file_name='get_membership_types.sql'
)
db_request = self._build_db_request(sql_file_name="get_membership_types.sql")
db_result = self.cursor.select_all(db_request)
return [Membership(**row) for row in db_result]
def get_membership_by_type(self, membership_type: str):
db_request = self._build_db_request(
sql_file_name='get_membership_by_type.sql',
args=dict(type=membership_type)
sql_file_name="get_membership_by_type.sql", args=dict(type=membership_type)
)
db_result = self.cursor.select_one(db_request)
return Membership(**db_result)
def add(self, membership: Membership):
db_request = self._build_db_request(
'add_membership.sql',
"add_membership.sql",
args=dict(
membership_type=membership.type,
rate=membership.rate,
rate_period=membership.rate_period
)
rate_period=membership.rate_period,
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
def remove(self, membership: Membership):
db_request = self._build_db_request(
sql_file_name='delete_membership.sql',
args=dict(membership_id=membership.id)
sql_file_name="delete_membership.sql",
args=dict(membership_id=membership.id),
)
self.cursor.delete(db_request)

View File

@ -30,8 +30,8 @@ class AccountSkillRepository(RepositoryBase):
def get_skills_for_account(self) -> List[AccountSkill]:
db_request = self._build_db_request(
sql_file_name='get_account_skills.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_skills.sql",
args=dict(account_id=self.account_id),
)
db_result = self.cursor.select_all(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -17,4 +17,4 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .device import Device
from .device import Device

View File

@ -26,7 +26,7 @@ from selene.data.skill import SettingsDisplay
from ..entity.device_skill import (
AccountSkillSettings,
DeviceSkillSettings,
ManifestSkill
ManifestSkill,
)
from ...repository_base import RepositoryBase
@ -36,19 +36,19 @@ class DeviceSkillRepository(RepositoryBase):
super(DeviceSkillRepository, self).__init__(db, __file__)
def get_skill_settings_for_account(
self, account_id: str, skill_id: str
self, account_id: str, skill_id: str
) -> List[AccountSkillSettings]:
return self._select_all_into_dataclass(
AccountSkillSettings,
sql_file_name='get_skill_settings_for_account.sql',
args=dict(account_id=account_id, skill_id=skill_id)
sql_file_name="get_skill_settings_for_account.sql",
args=dict(account_id=account_id, skill_id=skill_id),
)
def get_skill_settings_for_device(self, device_id, skill_id=None):
device_skills = self._select_all_into_dataclass(
DeviceSkillSettings,
sql_file_name='get_skill_settings_for_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_skill_settings_for_device.sql",
args=dict(device_id=device_id),
)
if skill_id is None:
skill_settings = device_skills
@ -62,23 +62,21 @@ class DeviceSkillRepository(RepositoryBase):
return skill_settings
def update_skill_settings(
self, account_id: str, device_names: tuple, skill_name: str
self, account_id: str, device_names: tuple, skill_name: str
):
db_request = self._build_db_request(
sql_file_name='update_skill_settings.sql',
sql_file_name="update_skill_settings.sql",
args=dict(
account_id=account_id,
device_names=device_names,
skill_name=skill_name
)
account_id=account_id, device_names=device_names, skill_name=skill_name
),
)
self.cursor.update(db_request)
def upsert_device_skill_settings(
self,
device_ids: List[str],
settings_display: SettingsDisplay,
settings_values: dict,
self,
device_ids: List[str],
settings_display: SettingsDisplay,
settings_values: dict,
):
for device_id in device_ids:
if settings_values is None:
@ -86,13 +84,13 @@ class DeviceSkillRepository(RepositoryBase):
else:
db_settings_values = json.dumps(settings_values)
db_request = self._build_db_request(
sql_file_name='upsert_device_skill_settings.sql',
sql_file_name="upsert_device_skill_settings.sql",
args=dict(
device_id=device_id,
skill_id=settings_display.skill_id,
settings_values=db_settings_values,
settings_display_id=settings_display.id
)
settings_display_id=settings_display.id,
),
)
self.cursor.insert(db_request)
@ -103,76 +101,66 @@ class DeviceSkillRepository(RepositoryBase):
else:
db_settings_values = json.dumps(device_skill.settings_values)
db_request = self._build_db_request(
sql_file_name='update_device_skill_settings.sql',
sql_file_name="update_device_skill_settings.sql",
args=dict(
device_id=device_id,
skill_id=device_skill.skill_id,
settings_display_id=device_skill.settings_display_id,
settings_values=db_settings_values
)
settings_values=db_settings_values,
),
)
self.cursor.update(db_request)
def get_skill_manifest_for_device(
self, device_id: str
) -> List[ManifestSkill]:
def get_skill_manifest_for_device(self, device_id: str) -> List[ManifestSkill]:
return self._select_all_into_dataclass(
dataclass=ManifestSkill,
sql_file_name='get_device_skill_manifest.sql',
args=dict(device_id=device_id)
sql_file_name="get_device_skill_manifest.sql",
args=dict(device_id=device_id),
)
def get_skill_manifest_for_account(
self, account_id: str
) -> List[ManifestSkill]:
def get_skill_manifest_for_account(self, account_id: str) -> List[ManifestSkill]:
return self._select_all_into_dataclass(
dataclass=ManifestSkill,
sql_file_name='get_skill_manifest_for_account.sql',
args=dict(account_id=account_id)
sql_file_name="get_skill_manifest_for_account.sql",
args=dict(account_id=account_id),
)
def update_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='update_skill_manifest.sql',
args=asdict(manifest_skill)
sql_file_name="update_skill_manifest.sql", args=asdict(manifest_skill)
)
self.cursor.update(db_request)
def add_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='add_manifest_skill.sql',
args=asdict(manifest_skill)
sql_file_name="add_manifest_skill.sql", args=asdict(manifest_skill)
)
db_result = self.cursor.insert_returning(db_request)
return db_result['id']
return db_result["id"]
def remove_manifest_skill(self, manifest_skill: ManifestSkill):
db_request = self._build_db_request(
sql_file_name='remove_manifest_skill.sql',
sql_file_name="remove_manifest_skill.sql",
args=dict(
device_id=manifest_skill.device_id,
skill_gid=manifest_skill.skill_gid
)
device_id=manifest_skill.device_id, skill_gid=manifest_skill.skill_gid
),
)
self.cursor.delete(db_request)
def get_settings_display_usage(self, settings_display_id: str) -> int:
db_request = self._build_db_request(
sql_file_name='get_settings_display_usage.sql',
args=dict(settings_display_id=settings_display_id)
sql_file_name="get_settings_display_usage.sql",
args=dict(settings_display_id=settings_display_id),
)
db_result = self.cursor.select_one(db_request)
return db_result['usage']
return db_result["usage"]
def remove(self, device_id, skill_id):
db_request = self._build_db_request(
sql_file_name='delete_device_skill.sql',
args=dict(
device_id=device_id,
skill_id=skill_id
)
sql_file_name="delete_device_skill.sql",
args=dict(device_id=device_id, skill_id=skill_id),
)
self.cursor.delete(db_request)

View File

@ -28,8 +28,8 @@ class GeographyRepository(RepositoryBase):
def get_account_geographies(self):
db_request = self._build_db_request(
sql_file_name='get_account_geographies.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_geographies.sql",
args=dict(account_id=self.account_id),
)
db_response = self.cursor.select_all(db_request)
@ -40,10 +40,10 @@ class GeographyRepository(RepositoryBase):
acct_geographies = self.get_account_geographies()
for acct_geography in acct_geographies:
match = (
acct_geography.city == geography.city and
acct_geography.country == geography.country and
acct_geography.region == geography.region and
acct_geography.time_zone == geography.time_zone
acct_geography.city == geography.city
and acct_geography.country == geography.country
and acct_geography.region == geography.region
and acct_geography.time_zone == geography.time_zone
)
if match:
geography_id = acct_geography.id
@ -53,22 +53,22 @@ class GeographyRepository(RepositoryBase):
def add(self, geography: Geography):
db_request = self._build_db_request(
sql_file_name='add_geography.sql',
sql_file_name="add_geography.sql",
args=dict(
account_id=self.account_id,
city=geography.city,
country=geography.country,
region=geography.region,
timezone=geography.time_zone
)
timezone=geography.time_zone,
),
)
db_result = self.cursor.insert_returning(db_request)
return db_result['id']
return db_result["id"]
def get_location_by_device_id(self, device_id):
db_request = self._build_db_request(
sql_file_name='get_location_by_device_id.sql',
args=dict(device_id=device_id)
sql_file_name="get_location_by_device_id.sql",
args=dict(device_id=device_id),
)
return self.cursor.select_one(db_request)

View File

@ -17,7 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from dataclasses import asdict
from dataclasses import asdict
from ..entity.preference import AccountPreferences
from ...repository_base import RepositoryBase
@ -30,8 +30,8 @@ class PreferenceRepository(RepositoryBase):
def get_account_preferences(self) -> AccountPreferences:
db_request = self._build_db_request(
sql_file_name='get_account_preferences.sql',
args=dict(account_id=self.account_id)
sql_file_name="get_account_preferences.sql",
args=dict(account_id=self.account_id),
)
db_result = self.cursor.select_one(db_request)
@ -46,7 +46,6 @@ class PreferenceRepository(RepositoryBase):
db_request_args = dict(account_id=self.account_id)
db_request_args.update(asdict(preferences))
db_request = self._build_db_request(
sql_file_name='upsert_preferences.sql',
args=db_request_args
sql_file_name="upsert_preferences.sql", args=db_request_args
)
self.cursor.insert(db_request)

View File

@ -21,7 +21,7 @@ from os import path
from selene.util.db import get_sql_from_file, Cursor, DatabaseRequest
SQL_DIR = path.join(path.dirname(__file__), 'sql')
SQL_DIR = path.join(path.dirname(__file__), "sql")
class SettingRepository(object):
@ -30,36 +30,38 @@ class SettingRepository(object):
def get_device_settings_by_device_id(self, device_id):
query = DatabaseRequest(
sql=get_sql_from_file(path.join(SQL_DIR, 'get_device_settings_by_device_id.sql')),
args=dict(device_id=device_id)
sql=get_sql_from_file(
path.join(SQL_DIR, "get_device_settings_by_device_id.sql")
),
args=dict(device_id=device_id),
)
return self.cursor.select_one(query)
def convert_text_to_speech_setting(self, setting_name, engine) -> (str, str):
"""Convert the selene representation of TTS into the tartarus representation, for backward compatibility
with the API v1"""
if engine == 'mimic':
if setting_name == 'trinity':
return 'mimic', 'trinity'
elif setting_name == 'kusal':
return 'mimic2', 'kusal'
if engine == "mimic":
if setting_name == "trinity":
return "mimic", "trinity"
elif setting_name == "kusal":
return "mimic2", "kusal"
else:
return 'mimic', 'ap'
return "mimic", "ap"
else:
return 'google', ''
return "google", ""
def _format_date_v1(self, date: str):
if date == 'DD/MM/YYYY':
result = 'DMY'
if date == "DD/MM/YYYY":
result = "DMY"
else:
result = 'MDY'
result = "MDY"
return result
def _format_time_v1(self, time: str):
if time == '24 Hour':
result = 'full'
if time == "24 Hour":
result = "full"
else:
result = 'half'
result = "half"
return result
def get_device_settings(self, device_id):
@ -69,22 +71,29 @@ class SettingRepository(object):
:return setting entity using the legacy format from the API v1"""
response = self.get_device_settings_by_device_id(device_id)
if response:
if response['listener_setting']['uuid'] is None:
del response['listener_setting']
tts_setting = response['tts_settings']
tts_setting = self.convert_text_to_speech_setting(tts_setting['setting_name'], tts_setting['engine'])
tts_setting = {'module': tts_setting[0], tts_setting[0]: {'voice': tts_setting[1]}}
response['tts_settings'] = tts_setting
response['date_format'] = self._format_date_v1(response['date_format'])
response['time_format'] = self._format_time_v1(response['time_format'])
response['system_unit'] = response['system_unit'].lower()
if response["listener_setting"]["uuid"] is None:
del response["listener_setting"]
tts_setting = response["tts_settings"]
tts_setting = self.convert_text_to_speech_setting(
tts_setting["setting_name"], tts_setting["engine"]
)
tts_setting = {
"module": tts_setting[0],
tts_setting[0]: {"voice": tts_setting[1]},
}
response["tts_settings"] = tts_setting
response["date_format"] = self._format_date_v1(response["date_format"])
response["time_format"] = self._format_time_v1(response["time_format"])
response["system_unit"] = response["system_unit"].lower()
open_dataset = self._get_open_dataset_agreement_by_device_id(device_id)
response['optIn'] = open_dataset is not None
response["optIn"] = open_dataset is not None
return response
def _get_open_dataset_agreement_by_device_id(self, device_id: str):
query = DatabaseRequest(
sql=get_sql_from_file(path.join(SQL_DIR, 'get_open_dataset_agreement_by_device_id.sql')),
args=dict(device_id=device_id)
sql=get_sql_from_file(
path.join(SQL_DIR, "get_open_dataset_agreement_by_device_id.sql")
),
args=dict(device_id=device_id),
)
return self.cursor.select_one(query)

View File

@ -26,7 +26,7 @@ class TextToSpeechRepository(RepositoryBase):
super(TextToSpeechRepository, self).__init__(db, __file__)
def get_voices(self):
db_request = self._build_db_request(sql_file_name='get_voices.sql')
db_request = self._build_db_request(sql_file_name="get_voices.sql")
db_result = self.cursor.select_all(db_request)
return [TextToSpeech(**row) for row in db_result]
@ -37,16 +37,16 @@ class TextToSpeechRepository(RepositoryBase):
:return wake word id
"""
db_request = self._build_db_request(
sql_file_name='add_text_to_speech.sql',
sql_file_name="add_text_to_speech.sql",
args=dict(
wake_word=text_to_speech.setting_name,
account_id=text_to_speech.display_name,
engine=text_to_speech.engine
)
engine=text_to_speech.engine,
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
# def remove(self, wake_word: WakeWord):
# """Delete a wake word from the wake_word table."""

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -27,8 +27,7 @@ class CityRepository(RepositoryBase):
def get_cities_by_region(self, region_id):
db_request = self._build_db_request(
sql_file_name='get_cities_by_region.sql',
args=dict(region_id=region_id)
sql_file_name="get_cities_by_region.sql", args=dict(region_id=region_id)
)
db_result = self.cursor.select_all(db_request)
@ -39,23 +38,22 @@ class CityRepository(RepositoryBase):
city_names = [nm.lower() for nm in possible_city_names]
return self._select_all_into_dataclass(
GeographicLocation,
sql_file_name='get_geographic_location_by_city.sql',
args=dict(possible_city_names=tuple(city_names))
sql_file_name="get_geographic_location_by_city.sql",
args=dict(possible_city_names=tuple(city_names)),
)
def get_biggest_city_in_region(self, region_name):
"""Return the geolocation of the most populous city in a region."""
return self._select_one_into_dataclass(
GeographicLocation,
sql_file_name='get_biggest_city_in_region.sql',
args=dict(region=region_name.lower())
sql_file_name="get_biggest_city_in_region.sql",
args=dict(region=region_name.lower()),
)
def get_biggest_city_in_country(self, country_name):
"""Return the geolocation of the most populous city in a country."""
return self._select_one_into_dataclass(
GeographicLocation,
sql_file_name='get_biggest_city_in_country.sql',
args=dict(country=country_name.lower())
sql_file_name="get_biggest_city_in_country.sql",
args=dict(country=country_name.lower()),
)

View File

@ -26,7 +26,7 @@ class CountryRepository(RepositoryBase):
super(CountryRepository, self).__init__(db, __file__)
def get_countries(self):
db_request = self._build_db_request(sql_file_name='get_countries.sql')
db_request = self._build_db_request(sql_file_name="get_countries.sql")
db_result = self.cursor.select_all(db_request)
return [Country(**row) for row in db_result]

View File

@ -27,8 +27,7 @@ class RegionRepository(RepositoryBase):
def get_regions_by_country(self, country_id):
db_request = self._build_db_request(
sql_file_name='get_regions_by_country.sql',
args=dict(country_id=country_id)
sql_file_name="get_regions_by_country.sql", args=dict(country_id=country_id)
)
db_result = self.cursor.select_all(db_request)

View File

@ -27,8 +27,8 @@ class TimezoneRepository(RepositoryBase):
def get_timezones_by_country(self, country_id):
db_request = self._build_db_request(
sql_file_name='get_timezones_by_country.sql',
args=dict(country_id=country_id)
sql_file_name="get_timezones_by_country.sql",
args=dict(country_id=country_id),
)
db_result = self.cursor.select_all(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -34,7 +34,7 @@ from datetime import date, datetime, time
from ..entity.api import ApiMetric
from ...repository_base import RepositoryBase
DUMP_FILE_DIR = '/opt/selene/dump'
DUMP_FILE_DIR = "/opt/selene/dump"
class ApiMetricsRepository(RepositoryBase):
@ -43,8 +43,7 @@ class ApiMetricsRepository(RepositoryBase):
def add(self, metric: ApiMetric):
db_request = self._build_db_request(
sql_file_name='add_api_metric.sql',
args=asdict(metric)
sql_file_name="add_api_metric.sql", args=asdict(metric)
)
self.cursor.insert(db_request)
@ -53,27 +52,27 @@ class ApiMetricsRepository(RepositoryBase):
start_ts = datetime.combine(partition_date, time.min)
end_ts = datetime.combine(partition_date, time.max)
db_request = self._build_db_request(
sql_file_name='create_api_metric_partition.sql',
sql_file_name="create_api_metric_partition.sql",
args=dict(start_ts=str(start_ts), end_ts=str(end_ts)),
sql_vars=dict(partition=partition_date.strftime('%Y%m%d'))
sql_vars=dict(partition=partition_date.strftime("%Y%m%d")),
)
self.cursor.execute(db_request)
db_request = self._build_db_request(
sql_file_name='create_api_metric_partition_index.sql',
sql_vars=dict(partition=partition_date.strftime('%Y%m%d'))
sql_file_name="create_api_metric_partition_index.sql",
sql_vars=dict(partition=partition_date.strftime("%Y%m%d")),
)
self.cursor.execute(db_request)
def copy_to_partition(self, partition_date: date):
"""Copy rows from metric.api table to metric.api_history."""
dump_file_name = 'api_metrics_' + str(partition_date)
dump_file_name = "api_metrics_" + str(partition_date)
dump_file_path = os.path.join(DUMP_FILE_DIR, dump_file_name)
db_request = self._build_db_request(
sql_file_name='get_api_metrics_for_date.sql',
args=dict(metrics_date=partition_date)
sql_file_name="get_api_metrics_for_date.sql",
args=dict(metrics_date=partition_date),
)
table_name = 'metric.api_history_' + partition_date.strftime('%Y%m%d')
table_name = "metric.api_history_" + partition_date.strftime("%Y%m%d")
self.cursor.dump_query_result_to_file(db_request, dump_file_path)
self.cursor.load_dump_file_to_table(table_name, dump_file_path)
os.remove(dump_file_path)
@ -81,7 +80,7 @@ class ApiMetricsRepository(RepositoryBase):
def remove_by_date(self, partition_date: date):
"""Delete from metric.api table after copying to metric.api_history"""
db_request = self._build_db_request(
sql_file_name='delete_api_metrics_by_date.sql',
args=dict(delete_date=partition_date)
sql_file_name="delete_api_metrics_by_date.sql",
args=dict(delete_date=partition_date),
)
self.cursor.delete(db_request)

View File

@ -31,33 +31,29 @@ class CoreMetricRepository(RepositoryBase):
def add(self, metric: CoreMetric):
db_request_args = asdict(metric)
db_request_args['metric_value'] = json.dumps(
db_request_args['metric_value']
)
db_request_args["metric_value"] = json.dumps(db_request_args["metric_value"])
db_request = self._build_db_request(
sql_file_name='add_core_metric.sql',
args=db_request_args
sql_file_name="add_core_metric.sql", args=db_request_args
)
self.cursor.insert(db_request)
def get_metrics_by_device(self, device_id):
return self._select_all_into_dataclass(
CoreMetric,
sql_file_name='get_core_metric_by_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_core_metric_by_device.sql",
args=dict(device_id=device_id),
)
def get_metrics_by_date(self, metric_date: date) -> List[CoreMetric]:
return self._select_all_into_dataclass(
CoreMetric,
sql_file_name='get_core_timing_metrics_by_date.sql',
args=dict(metric_date=metric_date)
sql_file_name="get_core_timing_metrics_by_date.sql",
args=dict(metric_date=metric_date),
)
def add_interaction(self, interaction: CoreInteraction) -> str:
db_request = self._build_db_request(
sql_file_name='add_core_interaction.sql',
args=asdict(interaction)
sql_file_name="add_core_interaction.sql", args=asdict(interaction)
)
db_result = self.cursor.insert_returning(db_request)

View File

@ -29,7 +29,7 @@ from selene.util.db import (
Cursor,
DatabaseRequest,
DatabaseBatchRequest,
get_sql_from_file
get_sql_from_file,
)
@ -52,10 +52,10 @@ class RepositoryBase(object):
def __init__(self, db, repository_path):
self.db = db
self.cursor = Cursor(db)
self.sql_dir = path.join(path.dirname(repository_path), 'sql')
self.sql_dir = path.join(path.dirname(repository_path), "sql")
def _build_db_request(
self, sql_file_name: str, args: dict = None, sql_vars: dict = None
self, sql_file_name: str, args: dict = None, sql_vars: dict = None
):
"""Build a DatabaseRequest object containing a query and args"""
sql = get_sql_from_file(path.join(self.sql_dir, sql_file_name))
@ -67,8 +67,7 @@ class RepositoryBase(object):
def _build_db_batch_request(self, sql_file_name: str, args: List[dict]):
"""Build a DatabaseBatchRequest object containing a query and args"""
return DatabaseBatchRequest(
sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)),
args=args
sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)), args=args
)
def _select_one_into_dataclass(self, dataclass, sql_file_name, args=None):

View File

@ -22,7 +22,7 @@ from .entity.skill import Skill
from .entity.skill_setting import (
AccountSkillSetting,
DeviceSkillSetting,
SettingsDisplay
SettingsDisplay,
)
from .repository.display import SkillDisplayRepository
from .repository.setting import SkillSettingRepository

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -26,30 +26,30 @@ class SkillDisplayRepository(RepositoryBase):
super(SkillDisplayRepository, self).__init__(db, __file__)
# TODO: Change this to a value that can be passed in
self.core_version = '21.02'
self.core_version = "21.02"
def get_display_data_for_skills(self):
return self._select_all_into_dataclass(
dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skills.sql',
args=dict(core_version=self.core_version)
sql_file_name="get_display_data_for_skills.sql",
args=dict(core_version=self.core_version),
)
def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay:
return self._select_one_into_dataclass(
dataclass=SkillDisplay,
sql_file_name='get_display_data_for_skill.sql',
args=dict(skill_display_id=skill_display_id)
sql_file_name="get_display_data_for_skill.sql",
args=dict(skill_display_id=skill_display_id),
)
def upsert(self, skill_display: SkillDisplay):
db_request = self._build_db_request(
sql_file_name='upsert_skill_display_data.sql',
sql_file_name="upsert_skill_display_data.sql",
args=dict(
skill_id=skill_display.skill_id,
core_version=skill_display.core_version,
display_data=skill_display.display_data,
)
),
)
self.cursor.insert(db_request)

View File

@ -32,14 +32,12 @@ class SkillSettingRepository(RepositoryBase):
self.db = db
def get_family_settings(
self,
account_id: str,
family_name: str
self, account_id: str, family_name: str
) -> List[AccountSkillSetting]:
return self._select_all_into_dataclass(
AccountSkillSetting,
sql_file_name='get_settings_for_skill_family.sql',
args=dict(family_name=family_name, account_id=account_id)
sql_file_name="get_settings_for_skill_family.sql",
args=dict(family_name=family_name, account_id=account_id),
)
def get_installer_settings(self, account_id) -> List[AccountSkillSetting]:
@ -47,39 +45,31 @@ class SkillSettingRepository(RepositoryBase):
skills = skill_repo.get_skills_for_account(account_id)
installer_skill_id = None
for skill in skills:
if skill.display_name == 'Installer':
if skill.display_name == "Installer":
installer_skill_id = skill.id
skill_settings = None
if installer_skill_id is not None:
skill_settings = self.get_family_settings(
account_id,
installer_skill_id
)
skill_settings = self.get_family_settings(account_id, installer_skill_id)
return skill_settings
@use_transaction
def update_skill_settings(
self,
account_id,
new_skill_settings: AccountSkillSetting,
skill_ids: List[str]
self, account_id, new_skill_settings: AccountSkillSetting, skill_ids: List[str]
):
if new_skill_settings.settings_values is None:
serialized_settings_values = None
else:
serialized_settings_values = json.dumps(
new_skill_settings.settings_values
)
serialized_settings_values = json.dumps(new_skill_settings.settings_values)
db_request = self._build_db_request(
'update_device_skill_settings.sql',
"update_device_skill_settings.sql",
args=dict(
account_id=account_id,
settings_values=serialized_settings_values,
skill_id=tuple(skill_ids),
device_names=tuple(new_skill_settings.device_names)
)
device_names=tuple(new_skill_settings.device_names),
),
)
self.cursor.update(db_request)
@ -87,6 +77,6 @@ class SkillSettingRepository(RepositoryBase):
"""Return all skills and their settings for a given device id"""
return self._select_all_into_dataclass(
DeviceSkillSetting,
sql_file_name='get_skill_setting_by_device.sql',
args=dict(device_id=device_id)
sql_file_name="get_skill_setting_by_device.sql",
args=dict(device_id=device_id),
)

View File

@ -24,35 +24,34 @@ from ..entity.skill_setting import SettingsDisplay
class SettingsDisplayRepository(RepositoryBase):
def __init__(self, db):
super(SettingsDisplayRepository, self).__init__(db, __file__)
def add(self, settings_display: SettingsDisplay) -> str:
"""Add a new row to the skill.settings_display table."""
db_request = self._build_db_request(
sql_file_name='add_settings_display.sql',
sql_file_name="add_settings_display.sql",
args=dict(
skill_id=settings_display.skill_id,
display_data=json.dumps(settings_display.display_data)
)
display_data=json.dumps(settings_display.display_data),
),
)
result = self.cursor.insert_returning(db_request)
return result['id']
return result["id"]
def get_settings_display_id(self, settings_display: SettingsDisplay):
"""Get the ID of a skill's settings definition."""
db_request = self._build_db_request(
sql_file_name='get_settings_display_id.sql',
sql_file_name="get_settings_display_id.sql",
args=dict(
skill_id=settings_display.skill_id,
display_data=json.dumps(settings_display.display_data)
)
display_data=json.dumps(settings_display.display_data),
),
)
result = self.cursor.select_one(db_request)
return None if result is None else result['id']
return None if result is None else result["id"]
def get_settings_definitions_by_gid(self, global_id):
"""Get all matching settings definitions for a global skill ID.
@ -63,14 +62,14 @@ class SettingsDisplayRepository(RepositoryBase):
"""
return self._select_all_into_dataclass(
SettingsDisplay,
sql_file_name='get_settings_definition_by_gid.sql',
args=dict(global_id=global_id)
sql_file_name="get_settings_definition_by_gid.sql",
args=dict(global_id=global_id),
)
def remove(self, settings_display_id: str):
"""Delete a settings definition that is no longer used by any device"""
db_request = self._build_db_request(
sql_file_name='delete_settings_display.sql',
args=dict(settings_display_id=settings_display_id)
sql_file_name="delete_settings_display.sql",
args=dict(settings_display_id=settings_display_id),
)
self.cursor.delete(db_request)

View File

@ -24,8 +24,8 @@ from ...repository_base import RepositoryBase
def extract_family_from_global_id(skill_gid):
id_parts = skill_gid.split('|')
if id_parts[0].startswith('@'):
id_parts = skill_gid.split("|")
if id_parts[0].startswith("@"):
family_name = id_parts[1]
else:
family_name = id_parts[0]
@ -41,8 +41,7 @@ class SkillRepository(RepositoryBase):
def get_skills_for_account(self, account_id) -> List[SkillFamily]:
skills = []
db_request = self._build_db_request(
'get_skills_for_account.sql',
args=dict(account_id=account_id)
"get_skills_for_account.sql", args=dict(account_id=account_id)
)
db_result = self.cursor.select_all(db_request)
if db_result is not None:
@ -54,23 +53,23 @@ class SkillRepository(RepositoryBase):
def get_skill_by_global_id(self, skill_global_id) -> Skill:
return self._select_one_into_dataclass(
dataclass=Skill,
sql_file_name='get_skill_by_global_id.sql',
args=dict(skill_global_id=skill_global_id)
sql_file_name="get_skill_by_global_id.sql",
args=dict(skill_global_id=skill_global_id),
)
@staticmethod
def _extract_settings(skill):
settings = {}
skill_metadata = skill.get('skillMetadata')
skill_metadata = skill.get("skillMetadata")
if skill_metadata:
for section in skill_metadata['sections']:
for field in section['fields']:
if 'name' in field and 'value' in field:
settings[field['name']] = field['value']
field.pop('value', None)
for section in skill_metadata["sections"]:
for field in section["fields"]:
if "name" in field and "value" in field:
settings[field["name"]] = field["value"]
field.pop("value", None)
result = settings, skill
else:
result = '', ''
result = "", ""
return result
def ensure_skill_exists(self, skill_global_id: str) -> str:
@ -85,14 +84,14 @@ class SkillRepository(RepositoryBase):
def _add_skill(self, skill_gid: str, name: str) -> str:
db_request = self._build_db_request(
sql_file_name='add_skill.sql',
args=dict(skill_gid=skill_gid, family_name=name)
sql_file_name="add_skill.sql",
args=dict(skill_gid=skill_gid, family_name=name),
)
db_result = self.cursor.insert_returning(db_request)
# handle both dictionary cursors and namedtuple cursors
try:
skill_id = db_result['id']
skill_id = db_result["id"]
except TypeError:
skill_id = db_result.id
@ -100,7 +99,6 @@ class SkillRepository(RepositoryBase):
def remove_by_gid(self, skill_gid):
db_request = self._build_db_request(
sql_file_name='remove_skill_by_gid.sql',
args=dict(skill_gid=skill_gid)
sql_file_name="remove_skill_by_gid.sql", args=dict(skill_gid=skill_gid)
)
self.cursor.delete(db_request)

View File

@ -16,4 +16,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

View File

@ -25,7 +25,7 @@ from selene.data.account import (
AccountRepository,
OPEN_DATASET,
PRIVACY_POLICY,
TERMS_OF_USE
TERMS_OF_USE,
)
@ -33,19 +33,19 @@ def build_test_account(**overrides):
test_agreements = [
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today()),
AccountAgreement(type=OPEN_DATASET, accept_date=date.today())
AccountAgreement(type=OPEN_DATASET, accept_date=date.today()),
]
return Account(
email_address=overrides.get('email_address') or 'foo@mycroft.ai',
username=overrides.get('username') or 'foobar',
agreements=overrides.get('agreements') or test_agreements
email_address=overrides.get("email_address") or "foo@mycroft.ai",
username=overrides.get("username") or "foobar",
agreements=overrides.get("agreements") or test_agreements,
)
def add_account(db, **overrides):
acct_repository = AccountRepository(db)
account = build_test_account(**overrides)
password = overrides.get('password') or 'test_password'
password = overrides.get("password") or "test_password"
account.id = acct_repository.add(account, password)
if account.membership is not None:
acct_repository.add_membership(account.id, account.membership)
@ -59,13 +59,13 @@ def remove_account(db, account):
def build_test_membership(**overrides):
stripe_acct = 'test_stripe_acct_id'
stripe_acct = "test_stripe_acct_id"
return AccountMembership(
type=overrides.get('type') or 'Monthly Membership',
start_date=overrides.get('start_date') or date.today(),
payment_method=overrides.get('payment_method') or 'Stripe',
payment_account_id=overrides.get('payment_account_id') or stripe_acct,
payment_id=overrides.get('payment_id') or 'test_stripe_payment_id'
type=overrides.get("type") or "Monthly Membership",
start_date=overrides.get("start_date") or date.today(),
payment_method=overrides.get("payment_method") or "Stripe",
payment_account_id=overrides.get("payment_account_id") or stripe_acct,
payment_id=overrides.get("payment_id") or "test_stripe_payment_id",
)

View File

@ -22,10 +22,10 @@ from selene.data.device import Geography, GeographyRepository
def add_account_geography(db, account, **overrides):
geography = Geography(
country=overrides.get('country') or 'United States',
region=overrides.get('region') or 'Missouri',
city=overrides.get('city') or 'Kansas City',
time_zone=overrides.get('time_zone') or 'America/Chicago'
country=overrides.get("country") or "United States",
region=overrides.get("region") or "Missouri",
city=overrides.get("city") or "Kansas City",
time_zone=overrides.get("time_zone") or "America/Chicago",
)
geo_repository = GeographyRepository(db, account.id)
account_geography_id = geo_repository.add(geography)

View File

@ -22,9 +22,7 @@ from selene.data.device import AccountPreferences, PreferenceRepository
def add_account_preference(db, account_id):
account_preferences = AccountPreferences(
date_format='MM/DD/YYYY',
time_format='12 Hour',
measurement_system='Imperial'
date_format="MM/DD/YYYY", time_format="12 Hour", measurement_system="Imperial"
)
preference_repo = PreferenceRepository(db, account_id)
preference_repo.upsert(account_preferences)

View File

@ -29,44 +29,44 @@ from selene.data.account import (
AgreementRepository,
OPEN_DATASET,
PRIVACY_POLICY,
TERMS_OF_USE
TERMS_OF_USE,
)
def _build_test_terms_of_use():
return Agreement(
type=TERMS_OF_USE,
version='Holy Grail',
content='I agree that all the tests I write for this application will '
'be in the theme of Monty Python and the Holy Grail. If you '
'do not agree with these terms, I will be forced to say "Ni!" '
'until such time as you agree',
effective_date=date.today() - timedelta(days=1)
)
version="Holy Grail",
content="I agree that all the tests I write for this application will "
"be in the theme of Monty Python and the Holy Grail. If you "
'do not agree with these terms, I will be forced to say "Ni!" '
"until such time as you agree",
effective_date=date.today() - timedelta(days=1),
)
def _build_test_privacy_policy():
return Agreement(
type=PRIVACY_POLICY,
version='Holy Grail',
content='First, shalt thou take out the Holy Pin. Then shalt thou '
'count to three. No more. No less. Three shalt be the '
'number thou shalt count and the number of the counting shall '
'be three. Four shalt thou not count, nor either count thou '
'two, excepting that thou then proceed to three. Five is '
'right out. Once the number three, being the third number, '
'be reached, then lobbest thou Holy Hand Grenade of Antioch '
'towards thy foe, who, being naughty in My sight, '
'shall snuff it.',
effective_date=date.today() - timedelta(days=1)
version="Holy Grail",
content="First, shalt thou take out the Holy Pin. Then shalt thou "
"count to three. No more. No less. Three shalt be the "
"number thou shalt count and the number of the counting shall "
"be three. Four shalt thou not count, nor either count thou "
"two, excepting that thou then proceed to three. Five is "
"right out. Once the number three, being the third number, "
"be reached, then lobbest thou Holy Hand Grenade of Antioch "
"towards thy foe, who, being naughty in My sight, "
"shall snuff it.",
effective_date=date.today() - timedelta(days=1),
)
def _build_open_dataset():
return Agreement(
type=OPEN_DATASET,
version='Holy Grail',
effective_date=date.today() - timedelta(days=1)
version="Holy Grail",
effective_date=date.today() - timedelta(days=1),
)
@ -93,11 +93,11 @@ def remove_agreements(db, agreements: List[Agreement]):
def get_agreements_from_api(context, agreement):
"""Abstracted so both account and single sign on APIs use in their tests"""
if agreement == PRIVACY_POLICY:
url = '/api/agreement/privacy-policy'
url = "/api/agreement/privacy-policy"
elif agreement == TERMS_OF_USE:
url = '/api/agreement/terms-of-use'
url = "/api/agreement/terms-of-use"
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
context.response = context.client.get(url)
@ -109,7 +109,7 @@ def validate_agreement_response(context, agreement):
elif agreement == TERMS_OF_USE:
expected_response = asdict(context.terms_of_use)
else:
raise ValueError('invalid agreement type')
raise ValueError("invalid agreement type")
del(expected_response['effective_date'])
del expected_response["effective_date"]
assert_that(response_data, equal_to(expected_response))

View File

@ -24,17 +24,14 @@ from selene.data.account import Account, AccountRepository
from selene.util.auth import AuthenticationToken
from selene.util.db import connect_to_db
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ACCESS_TOKEN_COOKIE_KEY = "seleneAccess"
ONE_MINUTE = 60
TWO_MINUTES = 120
REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh'
REFRESH_TOKEN_COOKIE_KEY = "seleneRefresh"
def generate_access_token(context, duration=ONE_MINUTE):
access_token = AuthenticationToken(
context.client_config['ACCESS_SECRET'],
duration
)
access_token = AuthenticationToken(context.client_config["ACCESS_SECRET"], duration)
account = context.accounts[context.username]
access_token.generate(account.id)
@ -43,17 +40,16 @@ def generate_access_token(context, duration=ONE_MINUTE):
def set_access_token_cookie(context, duration=ONE_MINUTE):
context.client.set_cookie(
context.client_config['DOMAIN'],
context.client_config["DOMAIN"],
ACCESS_TOKEN_COOKIE_KEY,
context.access_token.jwt,
max_age=duration
max_age=duration,
)
def generate_refresh_token(context, duration=TWO_MINUTES):
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],
duration
context.client_config["REFRESH_SECRET"], duration
)
account = context.accounts[context.username]
refresh_token.generate(account.id)
@ -63,38 +59,38 @@ def generate_refresh_token(context, duration=TWO_MINUTES):
def set_refresh_token_cookie(context, duration=TWO_MINUTES):
context.client.set_cookie(
context.client_config['DOMAIN'],
context.client_config["DOMAIN"],
REFRESH_TOKEN_COOKIE_KEY,
context.refresh_token.jwt,
max_age=duration
max_age=duration,
)
def validate_token_cookies(context, expired=False):
for cookie in context.response.headers.getlist('Set-Cookie'):
for cookie in context.response.headers.getlist("Set-Cookie"):
ingredients = _parse_cookie(cookie)
ingredient_names = list(ingredients.keys())
if ACCESS_TOKEN_COOKIE_KEY in ingredient_names:
context.access_token = ingredients[ACCESS_TOKEN_COOKIE_KEY]
elif REFRESH_TOKEN_COOKIE_KEY in ingredient_names:
context.refresh_token = ingredients[REFRESH_TOKEN_COOKIE_KEY]
for ingredient_name in ('Domain', 'Expires', 'Max-Age'):
for ingredient_name in ("Domain", "Expires", "Max-Age"):
assert_that(ingredient_names, has_item(ingredient_name))
if expired:
assert_that(ingredients['Max-Age'], equal_to('0'))
assert_that(ingredients["Max-Age"], equal_to("0"))
assert hasattr(context, 'access_token'), 'no access token in response'
assert hasattr(context, 'refresh_token'), 'no refresh token in response'
assert hasattr(context, "access_token"), "no access token in response"
assert hasattr(context, "refresh_token"), "no refresh token in response"
if expired:
assert_that(context.access_token, equal_to(''))
assert_that(context.refresh_token, equal_to(''))
assert_that(context.access_token, equal_to(""))
assert_that(context.refresh_token, equal_to(""))
def _parse_cookie(cookie: str) -> dict:
ingredients = {}
for ingredient in cookie.split('; '):
if '=' in ingredient:
key, value = ingredient.split('=')
for ingredient in cookie.split("; "):
if "=" in ingredient:
key, value = ingredient.split("=")
ingredients[key] = value
else:
ingredients[ingredient] = None
@ -103,7 +99,7 @@ def _parse_cookie(cookie: str) -> dict:
def get_account(context) -> Account:
db = connect_to_db(context.client['DB_CONNECTION_CONFIG'])
db = connect_to_db(context.client["DB_CONNECTION_CONFIG"])
acct_repository = AccountRepository(db)
account = acct_repository.get_account_by_id(context.account.id)
@ -112,21 +108,14 @@ def get_account(context) -> Account:
def check_http_success(context):
assert_that(
context.response.status_code,
is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT])
context.response.status_code, is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT])
)
def check_http_error(context, error_type):
if error_type == 'a bad request':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.BAD_REQUEST)
)
elif error_type == 'an unauthorized':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.UNAUTHORIZED)
)
if error_type == "a bad request":
assert_that(context.response.status_code, equal_to(HTTPStatus.BAD_REQUEST))
elif error_type == "an unauthorized":
assert_that(context.response.status_code, equal_to(HTTPStatus.UNAUTHORIZED))
else:
raise ValueError('unsupported error_type')
raise ValueError("unsupported error_type")

View File

@ -26,12 +26,12 @@ from selene.data.device import DeviceSkillRepository, ManifestSkill
def add_device_skill(db, device_id, skill):
manifest_skill = ManifestSkill(
device_id=device_id,
install_method='test_install_method',
install_status='test_install_status',
install_method="test_install_method",
install_status="test_install_status",
skill_id=skill.id,
skill_gid=skill.skill_gid,
install_ts=datetime.utcnow(),
update_ts=datetime.utcnow()
update_ts=datetime.utcnow(),
)
device_skill_repo = DeviceSkillRepository(db)
manifest_skill.id = device_skill_repo.add_manifest_skill(manifest_skill)
@ -42,9 +42,7 @@ def add_device_skill(db, device_id, skill):
def add_device_skill_settings(db, device_id, settings_display, settings_values):
device_skill_repo = DeviceSkillRepository(db)
device_skill_repo.upsert_device_skill_settings(
[device_id],
settings_display,
settings_values
[device_id], settings_display, settings_values
)

View File

@ -23,19 +23,15 @@ from selene.data.account import (
Membership,
MembershipRepository,
MONTHLY_MEMBERSHIP,
YEARLY_MEMBERSHIP
YEARLY_MEMBERSHIP,
)
monthly_membership = dict(
type=MONTHLY_MEMBERSHIP,
rate=Decimal('1.99'),
rate_period='monthly'
type=MONTHLY_MEMBERSHIP, rate=Decimal("1.99"), rate_period="monthly"
)
yearly_membership = dict(
type=YEARLY_MEMBERSHIP,
rate=Decimal('19.99'),
rate_period='yearly'
type=YEARLY_MEMBERSHIP, rate=Decimal("19.99"), rate_period="yearly"
)

Some files were not shown because too many files have changed in this diff Show More