From 26ed641b48ec09276561c180bf78fb3498912eab Mon Sep 17 00:00:00 2001 From: Chris Veilleux Date: Fri, 11 Mar 2022 13:22:33 -0600 Subject: [PATCH] applied the "Black" formatter to all files and added pre-commit hook to check --- .pre-commit-config.yaml | 11 + README.md | 70 +- api/account/account_api/__init__.py | 1 - .../account_api/endpoints/device_count.py | 4 +- .../account_api/endpoints/pairing_code.py | 4 +- .../account_api/endpoints/preferences.py | 28 +- api/account/account_api/endpoints/region.py | 2 +- .../account_api/endpoints/skill_oauth.py | 8 +- .../account_api/endpoints/skill_settings.py | 41 +- api/account/account_api/endpoints/skills.py | 10 +- api/account/account_api/endpoints/timezone.py | 2 +- .../tests/features/steps/agreements.py | 14 +- api/market/market_api/__init__.py | 1 - .../market_api/endpoints/skill_detail.py | 45 +- .../endpoints/skill_install_status.py | 63 +- api/public/public_api/__init__.py | 1 - api/public/public_api/endpoints/__init__.py | 1 - api/public/public_api/endpoints/device.py | 19 +- .../public_api/endpoints/device_location.py | 7 +- .../public_api/endpoints/device_oauth.py | 9 +- .../endpoints/device_refresh_token.py | 26 +- .../public_api/endpoints/device_setting.py | 3 +- .../endpoints/device_skill_manifest.py | 34 +- .../endpoints/device_skill_settings.py | 124 ++- .../endpoints/device_subscription.py | 6 +- .../public_api/endpoints/oauth_callback.py | 5 +- .../public_api/endpoints/premium_voice.py | 16 +- .../public_api/endpoints/stripe_webhook.py | 9 +- .../endpoints/wolfram_alpha_simple.py | 10 +- .../endpoints/wolfram_alpha_spoken.py | 10 +- .../tests/features/steps/device_location.py | 99 ++- .../features/steps/device_refresh_token.py | 32 +- .../features/steps/device_skill_settings.py | 148 ++-- api/public/tests/features/steps/get_device.py | 126 ++- .../features/steps/get_device_settings.py | 90 +- .../features/steps/get_device_subscription.py | 51 +- api/sso/sso_api/__init__.py | 1 - .../endpoints/authenticate_internal.py | 12 +- api/sso/sso_api/endpoints/github_token.py | 3 +- api/sso/sso_api/endpoints/password_change.py | 6 +- api/sso/sso_api/endpoints/password_reset.py | 32 +- api/sso/sso_api/endpoints/validate_email.py | 26 +- api/sso/sso_api/endpoints/validate_token.py | 9 +- batch/job_scheduler/__init__.py | 1 - batch/job_scheduler/jobs.py | 58 +- batch/script/__init__.py | 1 - batch/script/daily_report.py | 34 +- batch/script/load_skill_display_data.py | 31 +- batch/script/parse_core_metrics.py | 50 +- batch/script/partition_api_metrics.py | 2 +- batch/script/test_scheduler.py | 14 +- batch/script/update_device_last_contact.py | 9 +- db/scripts/bootstrap_mycroft_db.py | 2 +- db/scripts/neo4j-postgres.py | 805 ++++++++++-------- shared/selene/__init__.py | 1 - shared/selene/api/base_config.py | 35 +- shared/selene/api/blueprint.py | 20 +- shared/selene/api/endpoints/agreements.py | 8 +- shared/selene/api/etag.py | 14 +- shared/selene/api/response.py | 2 +- shared/selene/data/__init__.py | 1 - shared/selene/data/account/__init__.py | 9 +- shared/selene/data/account/entity/__init__.py | 1 - .../selene/data/account/entity/agreement.py | 6 +- .../data/account/repository/__init__.py | 1 - .../data/account/repository/membership.py | 23 +- .../selene/data/account/repository/skill.py | 4 +- shared/selene/data/device/entity/__init__.py | 1 - .../selene/data/device/repository/__init__.py | 2 +- .../data/device/repository/device_skill.py | 86 +- .../data/device/repository/geography.py | 24 +- .../data/device/repository/preference.py | 9 +- .../selene/data/device/repository/setting.py | 65 +- .../data/device/repository/text_to_speech.py | 10 +- .../selene/data/geography/entity/__init__.py | 1 - .../data/geography/repository/__init__.py | 1 - .../selene/data/geography/repository/city.py | 16 +- .../data/geography/repository/country.py | 2 +- .../data/geography/repository/region.py | 3 +- .../data/geography/repository/timezone.py | 4 +- shared/selene/data/metric/entity/__init__.py | 1 - .../selene/data/metric/repository/__init__.py | 1 - shared/selene/data/metric/repository/api.py | 25 +- shared/selene/data/metric/repository/core.py | 18 +- shared/selene/data/repository_base.py | 9 +- shared/selene/data/skill/__init__.py | 2 +- shared/selene/data/skill/entity/__init__.py | 1 - .../selene/data/skill/repository/__init__.py | 1 - .../selene/data/skill/repository/display.py | 14 +- .../selene/data/skill/repository/setting.py | 34 +- .../data/skill/repository/settings_display.py | 25 +- shared/selene/data/skill/repository/skill.py | 34 +- shared/selene/testing/__init__.py | 1 - shared/selene/testing/account.py | 24 +- shared/selene/testing/account_geography.py | 8 +- shared/selene/testing/account_preference.py | 4 +- shared/selene/testing/agreement.py | 52 +- shared/selene/testing/api.py | 61 +- shared/selene/testing/device_skill.py | 10 +- shared/selene/testing/membership.py | 10 +- shared/selene/testing/skill.py | 33 +- shared/selene/testing/test_db.py | 25 +- shared/selene/testing/text_to_speech.py | 6 +- shared/selene/util/__init__.py | 2 - shared/selene/util/db/__init__.py | 9 +- shared/selene/util/db/transaction.py | 1 + shared/selene/util/exceptions.py | 1 + shared/selene/util/payment/__init__.py | 2 +- shared/selene/util/payment/stripe.py | 11 +- 109 files changed, 1457 insertions(+), 1543 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..976e9dce --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 19.10b0 + hooks: + - id: black diff --git a/README.md b/README.md index 43d8e686..af5e611a 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,43 @@ -[![License](https://img.shields.io/badge/License-GNU_AGPL%203.0-blue.svg)](LICENSE) -[![CLA](https://img.shields.io/badge/CLA%3F-Required-blue.svg)](https://mycroft.ai/cla) -[![Team](https://img.shields.io/badge/Team-Mycroft_Backend-violetblue.svg)](https://github.com/MycroftAI/contributors/blob/master/team/Mycroft%20Backend.md) +[![License](https://img.shields.io/badge/License-GNU_AGPL%203.0-blue.svg)](LICENSE) +[![CLA](https://img.shields.io/badge/CLA%3F-Required-blue.svg)](https://mycroft.ai/cla) +[![Team](https://img.shields.io/badge/Team-Mycroft_Backend-violetblue.svg)](https://github.com/MycroftAI/contributors/blob/master/team/Mycroft%20Backend.md) ![Status](https://img.shields.io/badge/-Production_ready-green.svg) [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](http://makeapullrequest.com) [![Join chat](https://img.shields.io/badge/Mattermost-join_chat-brightgreen.svg)](https://chat.mycroft.ai) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + Selene -- Mycroft's Server Backend ========== Selene provides the services used by [Mycroft Core](https://github.com/mycroftai/mycroft-core) to manage devices, skills -and settings. It consists of two repositories. This one contains Python and SQL representing the database definition, -data access layer, APIs and scripts. The second repository, [Selene UI](https://github.com/mycroftai/selene-ui), +and settings. It consists of two repositories. This one contains Python and SQL representing the database definition, +data access layer, APIs and scripts. The second repository, [Selene UI](https://github.com/mycroftai/selene-ui), contains Angular web applications that use the APIs defined in this repository. There are four APIs defined in this repository, account management, single sign on, skill marketplace and device. -The first three support account.mycroft.ai (aka home.mycroft.ai), sso.mycroft.ai, and market.mycroft.ai, respectively. +The first three support account.mycroft.ai (aka home.mycroft.ai), sso.mycroft.ai, and market.mycroft.ai, respectively. The device API is how devices running Mycroft Core communicate with the server. Also included in this repository is a package containing batch scripts for maintenance and the definition of the database schema. -Each API is designed to run independently of the others. Code common to each of the APIs, such as the Data Access Layer, -can be found in the "shared" directory. The shared code is an independent Python package required by each of the APIs. -Each API has its own Pipfile so that it can be run in its own virtual environment. +Each API is designed to run independently of the others. Code common to each of the APIs, such as the Data Access Layer, +can be found in the "shared" directory. The shared code is an independent Python package required by each of the APIs. +Each API has its own Pipfile so that it can be run in its own virtual environment. ## Installation -The Python code utilizes features introduced in Python 3.7, such as data classes. +The Python code utilizes features introduced in Python 3.7, such as data classes. [Pipenv](https://pipenv.readthedocs.io/en/latest/) is used for virtual environment and package management. If you prefer to use pip and pyenv (or virtualenv), you can find the required libraries in the files named "Pipfile". These instructions will use pipenv commands. -If the Selene applications will be servicing a large number of devices (enterprise usage, for example), it is +If the Selene applications will be servicing a large number of devices (enterprise usage, for example), it is recommended that each of the applications run on their own server or virtual machine. This configuration makes it -easier to scale and monitor each application independently. However, all applications can be run on a single server. -This configuration could be more practical for a household running a handful of devices. +easier to scale and monitor each application independently. However, all applications can be run on a single server. +This configuration could be more practical for a household running a handful of devices. -These instructions will assume a multi-server setup for several thousand devices. To run on a single server servicing a +These instructions will assume a multi-server setup for several thousand devices. To run on a single server servicing a small number of devices, the recommended system requirements are 4 CPU, 8GB RAM and 100GB of disk. There are a lot of manual steps in this section that will eventually be replaced with an installation script. @@ -45,14 +47,14 @@ It is recommended to create an application specific user. In these instructions ### Postgres DB * Recommended server configuration: [Ubuntu 18.04 LTS (server install)](https://releases.ubuntu.com/bionic/), 2 CPU, 4GB RAM, 50GB disk. -* Use the package management system to install Python 3.7, Python 3 pip and PostgreSQL 10 +* Use the package management system to install Python 3.7, Python 3 pip and PostgreSQL 10 ``` sudo apt-get install postgresql python3.7 python python3-pip ``` -* Set Postgres to start on boot -``` +* Set Postgres to start on boot +``` sudo systemctl enable postgresql -``` +``` * Clone the selene-backend and documentation repositories ``` sudo mkdir -p /opt/selene @@ -97,7 +99,7 @@ pipenv run python bootstrap_mycroft_db.py # IPv4 local connections: host all all 127.0.0.1/32 trust ``` -* By default, Postgres only listens on localhost. This will not do for a multi-server setup. Change the +* By default, Postgres only listens on localhost. This will not do for a multi-server setup. Change the `listen_addresses` value in the `posgresql.conf` file to the private IP of the database server. This file is owned by the `postgres` user so use the following command to edit it (substituting vi for your favorite editor) ``` @@ -109,7 +111,7 @@ the `postgres` user so use the following command to edit it (substituting vi for ``` sudo -u postgres vi /etc/postgres/10/main/pg_hba.conf ``` -* Instructions on how to update the `pg_hba.conf` file can be found in +* Instructions on how to update the `pg_hba.conf` file can be found in [Postgres' documentation](https://www.postgresql.org/docs/10/auth-pg-hba-conf.html). Below is an example for reference. ``` # IPv4 Selene connections @@ -123,7 +125,7 @@ sudo systemctl restart postgresql ### Redis DB * Recommended server configuration: Ubuntu 18.04 LTS, 1 CPU, 1GB RAM, 5GB disk. -So as to not reinvent the wheel, here are some easy-to-follow instructions for +So as to not reinvent the wheel, here are some easy-to-follow instructions for [installing Redis on Ubuntu 18.04](https://www.digitalocean.com/community/tutorials/how-to-install-and-secure-redis-on-ubuntu-18-04). * By default, Redis only listens on local host. For multi-server setups, one additional step is to change the "bind" variable in `/etc/redis/redis.conf` to be the private IP of the Redis host. @@ -133,9 +135,9 @@ The majority of the setup for each API is the same. This section defines the st to each API will be defined in their respective sections. * Add an application user to the VM. Either give this user sudo privileges or execute the sudo commands below as a user with sudo privileges. These instructions will assume a user name of "mycroft" -* Use the package management system to install Python 3.7, Python 3 pip and Python 3.7 Developer Tools +* Use the package management system to install Python 3.7, Python 3 pip and Python 3.7 Developer Tools ``` -sudo apt install python3.7 python3-pip python3.7-dev +sudo apt install python3.7 python3-pip python3.7-dev sudo python3.7 -m pip install pipenv ``` * Setup the Backend Application Directory @@ -196,7 +198,7 @@ pipenv install ``` ### Running the APIs Each API is configured to run on port 5000. This is not a problem if each is running in its own VM but will be an -issue if all APIs are running on the same server, or if port 5000 is already in use. To address these scenarios, +issue if all APIs are running on the same server, or if port 5000 is already in use. To address these scenarios, change the port numbering in the uwsgi.ini file for each API. #### Single Sign On API @@ -206,7 +208,7 @@ reset key, which is used in a password reset scenario. Generate a secret key fo * Any data that can identify a user is encrypted. Generate a salt that will be used with the encryption algorithm. * Access to the Github API is required to support logging in with your Github account. Details can be found [here](https://developer.github.com/v3/guides/basics-of-authentication/). -* The password reset functionality sends an email to the user with a link to reset their password. Selene uses +* The password reset functionality sends an email to the user with a link to reset their password. Selene uses SendGrid to send these emails so a SendGrid account and API key are required. * Define a systemd service to run the API. The service defines environment variables that use the secret and API keys generated in previous steps. @@ -250,7 +252,7 @@ sudo systemctl enable sso_api.service ``` #### Account API -* The account API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET, +* The account API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET, JWT_REFRESH_SECRET and SALT environment variables must be the same values as those on the single sign on API. * This application uses the Redis database so the service needs to know where it resides. * Define a systemd service to run the API. The service defines environment variables that use the secret and API keys @@ -293,7 +295,7 @@ sudo systemctl enable account_api.service ``` #### Marketplace API -* The marketplace API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET, +* The marketplace API uses the same authentication mechanism as the single sign on API. The JWT_ACCESS_SECRET, JWT_REFRESH_SECRET and SALT environment variables must be the same values as those on the single sign on API. * This application uses the Redis database so the service needs to know where it resides. * Define a systemd service to run the API. The service defines environment variables that use the secret and API keys @@ -345,7 +347,7 @@ pipenv run python load_skill_display_data.py --core-version . - diff --git a/api/account/account_api/endpoints/device_count.py b/api/account/account_api/endpoints/device_count.py index 6cbb0fcc..b72fec40 100644 --- a/api/account/account_api/endpoints/device_count.py +++ b/api/account/account_api/endpoints/device_count.py @@ -31,8 +31,6 @@ class DeviceCountEndpoint(SeleneEndpoint): def _get_devices(self): device_repository = DeviceRepository(self.db) - device_count = device_repository.get_account_device_count( - self.account.id - ) + device_count = device_repository.get_account_device_count(self.account.id) return device_count diff --git a/api/account/account_api/endpoints/pairing_code.py b/api/account/account_api/endpoints/pairing_code.py index db0b2892..e7b9be0e 100644 --- a/api/account/account_api/endpoints/pairing_code.py +++ b/api/account/account_api/endpoints/pairing_code.py @@ -25,7 +25,7 @@ from selene.api import SeleneEndpoint class PairingCodeEndpoint(SeleneEndpoint): def __init__(self): super(PairingCodeEndpoint, self).__init__() - self.cache = self.config['SELENE_CACHE'] + self.cache = self.config["SELENE_CACHE"] def get(self, pairing_code): self._authenticate() @@ -36,7 +36,7 @@ class PairingCodeEndpoint(SeleneEndpoint): def _get_pairing_data(self, pairing_code: str) -> bool: """Checking if there's one pairing session for the pairing code.""" pairing_code_is_valid = False - cache_key = 'pairing.code:' + pairing_code + cache_key = "pairing.code:" + pairing_code pairing_cache = self.cache.get(cache_key) if pairing_cache is not None: pairing_code_is_valid = True diff --git a/api/account/account_api/endpoints/preferences.py b/api/account/account_api/endpoints/preferences.py index cbd06cb3..03ef20d3 100644 --- a/api/account/account_api/endpoints/preferences.py +++ b/api/account/account_api/endpoints/preferences.py @@ -29,29 +29,23 @@ from selene.data.device import AccountPreferences, PreferenceRepository class PreferencesRequest(Model): - date_format = StringType( - required=True, - choices=['DD/MM/YYYY', 'MM/DD/YYYY'] - ) - measurement_system = StringType( - required=True, - choices=['Imperial', 'Metric'] - ) - time_format = StringType(required=True, choices=['12 Hour', '24 Hour']) + date_format = StringType(required=True, choices=["DD/MM/YYYY", "MM/DD/YYYY"]) + measurement_system = StringType(required=True, choices=["Imperial", "Metric"]) + time_format = StringType(required=True, choices=["12 Hour", "24 Hour"]) class PreferencesEndpoint(SeleneEndpoint): def __init__(self): super(PreferencesEndpoint, self).__init__() self.preferences = None - self.cache = self.config['SELENE_CACHE'] + self.cache = self.config["SELENE_CACHE"] self.etag_manager: ETagManager = ETagManager(self.cache, self.config) def get(self): self._authenticate() self._get_preferences() if self.preferences is None: - response_data = '' + response_data = "" response_code = HTTPStatus.NO_CONTENT else: response_data = asdict(self.preferences) @@ -68,22 +62,20 @@ class PreferencesEndpoint(SeleneEndpoint): self._validate_request() self._upsert_preferences() self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id) - return '', HTTPStatus.NO_CONTENT + return "", HTTPStatus.NO_CONTENT def patch(self): self._authenticate() self._validate_request() self._upsert_preferences() self.etag_manager.expire_device_setting_etag_by_account_id(self.account.id) - return '', HTTPStatus.NO_CONTENT + return "", HTTPStatus.NO_CONTENT def _validate_request(self): self.preferences = PreferencesRequest() - self.preferences.date_format = self.request.json['dateFormat'] - self.preferences.measurement_system = ( - self.request.json['measurementSystem'] - ) - self.preferences.time_format = self.request.json['timeFormat'] + self.preferences.date_format = self.request.json["dateFormat"] + self.preferences.measurement_system = self.request.json["measurementSystem"] + self.preferences.time_format = self.request.json["timeFormat"] self.preferences.validate() def _upsert_preferences(self): diff --git a/api/account/account_api/endpoints/region.py b/api/account/account_api/endpoints/region.py index 8cb07cca..12b1f4e4 100644 --- a/api/account/account_api/endpoints/region.py +++ b/api/account/account_api/endpoints/region.py @@ -25,7 +25,7 @@ from selene.data.geography import RegionRepository class RegionEndpoint(SeleneEndpoint): def get(self): - country_id = self.request.args['country'] + country_id = self.request.args["country"] region_repository = RegionRepository(self.db) regions = region_repository.get_regions_by_country(country_id) diff --git a/api/account/account_api/endpoints/skill_oauth.py b/api/account/account_api/endpoints/skill_oauth.py index 7ba1b340..89b3d026 100644 --- a/api/account/account_api/endpoints/skill_oauth.py +++ b/api/account/account_api/endpoints/skill_oauth.py @@ -27,17 +27,15 @@ from selene.api import SeleneEndpoint class SkillOauthEndpoint(SeleneEndpoint): def __init__(self): super(SkillOauthEndpoint, self).__init__() - self.oauth_base_url = os.environ['OAUTH_BASE_URL'] + self.oauth_base_url = os.environ["OAUTH_BASE_URL"] def get(self, oauth_id): self._authenticate() return self._get_oauth_url(oauth_id) def _get_oauth_url(self, oauth_id): - url = '{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}'.format( - base_url=self.oauth_base_url, - oauth_id=oauth_id, - account_id=self.account.id + url = "{base_url}/auth/{oauth_id}/auth_url?uuid={account_id}".format( + base_url=self.oauth_base_url, oauth_id=oauth_id, account_id=self.account.id ) response = requests.get(url) return response.text, response.status_code diff --git a/api/account/account_api/endpoints/skill_settings.py b/api/account/account_api/endpoints/skill_settings.py index d13b055b..0e22b9c3 100644 --- a/api/account/account_api/endpoints/skill_settings.py +++ b/api/account/account_api/endpoints/skill_settings.py @@ -35,8 +35,7 @@ class SkillSettingsEndpoint(SeleneEndpoint): self.account_skills = None self.family_settings = None self.etag_manager: ETagManager = ETagManager( - self.config['SELENE_CACHE'], - self.config + self.config["SELENE_CACHE"], self.config ) @property @@ -51,8 +50,7 @@ class SkillSettingsEndpoint(SeleneEndpoint): """Process an HTTP GET request""" self._authenticate() self.family_settings = self.setting_repository.get_family_settings( - self.account.id, - skill_family_name + self.account.id, skill_family_name ) self._parse_selection_options() response_data = self._build_response_data() @@ -62,7 +60,7 @@ class SkillSettingsEndpoint(SeleneEndpoint): return Response( response=json.dumps(response_data), status=HTTPStatus.OK, - content_type='application/json' + content_type="application/json", ) def _parse_selection_options(self): @@ -75,19 +73,16 @@ class SkillSettingsEndpoint(SeleneEndpoint): """ for skill_settings in self.family_settings: if skill_settings.settings_definition is not None: - for section in skill_settings.settings_definition['sections']: - for field in section['fields']: - if field['type'] == 'select': + for section in skill_settings.settings_definition["sections"]: + for field in section["fields"]: + if field["type"] == "select": parsed_options = [] - for option in field['options'].split(';'): - option_display, option_value = option.split('|') + for option in field["options"].split(";"): + option_display, option_value = option.split("|") parsed_options.append( - dict( - display=option_display, - value=option_value - ) + dict(display=option_display, value=option_value) ) - field['options'] = parsed_options + field["options"] = parsed_options def _build_response_data(self): """Build the object to return to the UI.""" @@ -100,7 +95,7 @@ class SkillSettingsEndpoint(SeleneEndpoint): response_skill = dict( settingsDisplay=skill_settings.settings_definition, settingsValues=skill_settings.settings_values, - deviceNames=skill_settings.device_names + deviceNames=skill_settings.device_names, ) response_data.append(response_skill) @@ -111,19 +106,17 @@ class SkillSettingsEndpoint(SeleneEndpoint): self._authenticate() self._update_settings_values() - return '', HTTPStatus.OK + return "", HTTPStatus.OK def _update_settings_values(self): """Update the value of the settings column on the device_skill table,""" - for new_skill_settings in self.request.json['skillSettings']: + for new_skill_settings in self.request.json["skillSettings"]: account_skill_settings = AccountSkillSetting( - settings_definition=new_skill_settings['settingsDisplay'], - settings_values=new_skill_settings['settingsValues'], - device_names=new_skill_settings['deviceNames'] + settings_definition=new_skill_settings["settingsDisplay"], + settings_values=new_skill_settings["settingsValues"], + device_names=new_skill_settings["deviceNames"], ) self.setting_repository.update_skill_settings( - self.account.id, - account_skill_settings, - self.request.json['skillIds'] + self.account.id, account_skill_settings, self.request.json["skillIds"] ) self.etag_manager.expire_skill_etag_by_account_id(self.account.id) diff --git a/api/account/account_api/endpoints/skills.py b/api/account/account_api/endpoints/skills.py index 280daa14..6b75bccd 100644 --- a/api/account/account_api/endpoints/skills.py +++ b/api/account/account_api/endpoints/skills.py @@ -44,11 +44,11 @@ class SkillsEndpoint(SeleneEndpoint): market_id=skill.market_id, name=skill.display_name or skill.family_name, has_settings=skill.has_settings, - skill_ids=skill.skill_ids + skill_ids=skill.skill_ids, ) else: - response_skill['skill_ids'].extend(skill.skill_ids) - if response_skill['market_id'] is None: - response_skill['market_id'] = skill.market_id + response_skill["skill_ids"].extend(skill.skill_ids) + if response_skill["market_id"] is None: + response_skill["market_id"] = skill.market_id - return sorted(response_data.values(), key=lambda x: x['name']) + return sorted(response_data.values(), key=lambda x: x["name"]) diff --git a/api/account/account_api/endpoints/timezone.py b/api/account/account_api/endpoints/timezone.py index 036ed836..42a67ef3 100644 --- a/api/account/account_api/endpoints/timezone.py +++ b/api/account/account_api/endpoints/timezone.py @@ -25,7 +25,7 @@ from selene.data.geography import TimezoneRepository class TimezoneEndpoint(SeleneEndpoint): def get(self): - country_id = self.request.args['country'] + country_id = self.request.args["country"] timezone_repository = TimezoneRepository(self.db) timezones = timezone_repository.get_timezones_by_country(country_id) diff --git a/api/account/tests/features/steps/agreements.py b/api/account/tests/features/steps/agreements.py index 57dd4a3a..86a3bbcf 100644 --- a/api/account/tests/features/steps/agreements.py +++ b/api/account/tests/features/steps/agreements.py @@ -26,19 +26,19 @@ from hamcrest import assert_that, equal_to from selene.data.account import PRIVACY_POLICY, TERMS_OF_USE -@when('API request for {agreement} is made') +@when("API request for {agreement} is made") def call_agreement_endpoint(context, agreement): if agreement == PRIVACY_POLICY: - url = '/api/agreement/privacy-policy' + url = "/api/agreement/privacy-policy" elif agreement == TERMS_OF_USE: - url = '/api/agreement/terms-of-use' + url = "/api/agreement/terms-of-use" else: - raise ValueError('invalid agreement type') + raise ValueError("invalid agreement type") context.response = context.client.get(url) -@then('{agreement} version {version} is returned') +@then("{agreement} version {version} is returned") def validate_response(context, agreement, version): response_data = json.loads(context.response.data) if agreement == PRIVACY_POLICY: @@ -46,7 +46,7 @@ def validate_response(context, agreement, version): elif agreement == TERMS_OF_USE: expected_response = asdict(context.terms_of_use) else: - raise ValueError('invalid agreement type') + raise ValueError("invalid agreement type") - del(expected_response['effective_date']) + del expected_response["effective_date"] assert_that(response_data, equal_to(expected_response)) diff --git a/api/market/market_api/__init__.py b/api/market/market_api/__init__.py index 7665c10b..eabab81b 100644 --- a/api/market/market_api/__init__.py +++ b/api/market/market_api/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/api/market/market_api/endpoints/skill_detail.py b/api/market/market_api/endpoints/skill_detail.py index d086e424..f73f3c5a 100644 --- a/api/market/market_api/endpoints/skill_detail.py +++ b/api/market/market_api/endpoints/skill_detail.py @@ -27,7 +27,8 @@ from selene.data.skill import SkillDisplay, SkillDisplayRepository class SkillDetailEndpoint(SeleneEndpoint): - """"Supply the data that will populate the skill detail page.""" + """ "Supply the data that will populate the skill detail page.""" + authentication_required = False def __init__(self): @@ -57,39 +58,37 @@ class SkillDetailEndpoint(SeleneEndpoint): def _build_response_data(self, skill_display: SkillDisplay): """Make some modifications to the response skill for the marketplace""" self.response_skill = dict( - categories=skill_display.display_data.get('categories'), - credits=skill_display.display_data.get('credits'), + categories=skill_display.display_data.get("categories"), + credits=skill_display.display_data.get("credits"), description=markdown( - skill_display.display_data.get('description'), - output_format='html5' + skill_display.display_data.get("description"), output_format="html5" ), - display_name=skill_display.display_data['display_name'], - icon=skill_display.display_data.get('icon'), - iconImage=skill_display.display_data.get('icon_img'), + display_name=skill_display.display_data["display_name"], + icon=skill_display.display_data.get("icon"), + iconImage=skill_display.display_data.get("icon_img"), isSystemSkill=False, worksOnMarkOne=( - 'all' in skill_display.display_data['platforms'] or - 'platform_mark1' in skill_display.display_data['platforms'] + "all" in skill_display.display_data["platforms"] + or "platform_mark1" in skill_display.display_data["platforms"] ), worksOnMarkTwo=( - 'all' in skill_display.display_data['platforms'] or - 'platform_mark2' in skill_display.display_data['platforms'] + "all" in skill_display.display_data["platforms"] + or "platform_mark2" in skill_display.display_data["platforms"] ), worksOnPicroft=( - 'all' in skill_display.display_data['platforms'] or - 'platform_picroft' in skill_display.display_data['platforms'] + "all" in skill_display.display_data["platforms"] + or "platform_picroft" in skill_display.display_data["platforms"] ), worksOnKDE=( - 'all' in skill_display.display_data['platforms'] or - 'platform_plasmoid' in skill_display.display_data['platforms'] + "all" in skill_display.display_data["platforms"] + or "platform_plasmoid" in skill_display.display_data["platforms"] ), - repositoryUrl=skill_display.display_data.get('repo'), + repositoryUrl=skill_display.display_data.get("repo"), summary=markdown( - skill_display.display_data['short_desc'], - output_format='html5' + skill_display.display_data["short_desc"], output_format="html5" ), - triggers=skill_display.display_data['examples'] + triggers=skill_display.display_data["examples"], ) - if skill_display.display_data['tags'] is not None: - if 'system' in skill_display.display_data['tags']: - self.response_skill['isSystemSkill'] = True + if skill_display.display_data["tags"] is not None: + if "system" in skill_display.display_data["tags"]: + self.response_skill["isSystemSkill"] = True diff --git a/api/market/market_api/endpoints/skill_install_status.py b/api/market/market_api/endpoints/skill_install_status.py index b26477ac..c04c22f1 100644 --- a/api/market/market_api/endpoints/skill_install_status.py +++ b/api/market/market_api/endpoints/skill_install_status.py @@ -26,12 +26,7 @@ from selene.api import SeleneEndpoint from selene.data.device import DeviceSkillRepository, ManifestSkill from selene.util.auth import AuthenticationError -VALID_STATUS_VALUES = ( - 'failed', - 'installed', - 'installing', - 'uninstalling' -) +VALID_STATUS_VALUES = ("failed", "installed", "installing", "uninstalling") class SkillInstallStatusEndpoint(SeleneEndpoint): @@ -45,7 +40,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint): try: self._authenticate() except AuthenticationError: - self.response = ('', HTTPStatus.NO_CONTENT) + self.response = ("", HTTPStatus.NO_CONTENT) else: self._get_installed_skills() response_data = self._build_response_data() @@ -55,9 +50,7 @@ class SkillInstallStatusEndpoint(SeleneEndpoint): def _get_installed_skills(self): skill_repo = DeviceSkillRepository(self.db) - installed_skills = skill_repo.get_skill_manifest_for_account( - self.account.id - ) + installed_skills = skill_repo.get_skill_manifest_for_account(self.account.id) for skill in installed_skills: self.installed_skills[skill.skill_id].append(skill) @@ -67,18 +60,13 @@ class SkillInstallStatusEndpoint(SeleneEndpoint): for skill_id, skills in self.installed_skills.items(): skill_aggregator = SkillManifestAggregator(skills) skill_aggregator.aggregate_skill_status() - if skill_aggregator.aggregate_skill.install_status == 'failed': - failure_reasons[skill_id] = ( - skill_aggregator.aggregate_skill.install_failure_reason - ) - install_statuses[skill_id] = ( - skill_aggregator.aggregate_skill.install_status - ) + if skill_aggregator.aggregate_skill.install_status == "failed": + failure_reasons[ + skill_id + ] = skill_aggregator.aggregate_skill.install_failure_reason + install_statuses[skill_id] = skill_aggregator.aggregate_skill.install_status - return dict( - installStatuses=install_statuses, - failureReasons=failure_reasons - ) + return dict(installStatuses=install_statuses, failureReasons=failure_reasons) class SkillManifestAggregator(object): @@ -96,7 +84,7 @@ class SkillManifestAggregator(object): """ self._validate_install_status() self._determine_install_status() - if self.aggregate_skill.install_status == 'failed': + if self.aggregate_skill.install_status == "failed": self._determine_failure_reason() def _validate_install_status(self): @@ -104,7 +92,7 @@ class SkillManifestAggregator(object): if skill.install_status not in VALID_STATUS_VALUES: raise ValueError( '"{install_status}" is not a supported value of the ' - 'installation field in the skill manifest'.format( + "installation field in the skill manifest".format( install_status=skill.install_status ) ) @@ -120,33 +108,24 @@ class SkillManifestAggregator(object): If the install fails on any device, the install will be flagged as a failed install in the Marketplace. """ - failed = [ - skill.install_status == 'failed' for skill in self.installed_skills - ] - installing = [ - s.install_status == 'installing' for s in self.installed_skills - ] + failed = [skill.install_status == "failed" for skill in self.installed_skills] + installing = [s.install_status == "installing" for s in self.installed_skills] uninstalling = [ - skill.install_status == 'uninstalling' for skill in - self.installed_skills - ] - installed = [ - s.install_status == 'installed' for s in self.installed_skills + skill.install_status == "uninstalling" for skill in self.installed_skills ] + installed = [s.install_status == "installed" for s in self.installed_skills] if any(failed): - self.aggregate_skill.install_status = 'failed' + self.aggregate_skill.install_status = "failed" elif any(installing): - self.aggregate_skill.install_status = 'installing' + self.aggregate_skill.install_status = "installing" elif any(uninstalling): - self.aggregate_skill.install_status = 'uninstalling' + self.aggregate_skill.install_status = "uninstalling" elif all(installed): - self.aggregate_skill.install_status = 'installed' + self.aggregate_skill.install_status = "installed" def _determine_failure_reason(self): """When a skill fails to install, determine the reason""" for skill in self.installed_skills: - if skill.install_status == 'failed': - self.aggregate_skill.failure_reason = ( - skill.install_failure_reason - ) + if skill.install_status == "failed": + self.aggregate_skill.failure_reason = skill.install_failure_reason break diff --git a/api/public/public_api/__init__.py b/api/public/public_api/__init__.py index 7665c10b..eabab81b 100644 --- a/api/public/public_api/__init__.py +++ b/api/public/public_api/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/api/public/public_api/endpoints/__init__.py b/api/public/public_api/endpoints/__init__.py index 7665c10b..eabab81b 100644 --- a/api/public/public_api/endpoints/__init__.py +++ b/api/public/public_api/endpoints/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/api/public/public_api/endpoints/device.py b/api/public/public_api/endpoints/device.py index 80c6959e..f7213f66 100644 --- a/api/public/public_api/endpoints/device.py +++ b/api/public/public_api/endpoints/device.py @@ -29,14 +29,15 @@ from selene.data.device import DeviceRepository class UpdateDevice(Model): - coreVersion = StringType(default='unknown') - platform = StringType(default='unknown') + coreVersion = StringType(default="unknown") + platform = StringType(default="unknown") platform_build = StringType() - enclosureVersion = StringType(default='unknown') + enclosureVersion = StringType(default="unknown") class DeviceEndpoint(PublicEndpoint): """Return the device entity using the device_id""" + def __init__(self): super(DeviceEndpoint, self).__init__() @@ -53,13 +54,13 @@ class DeviceEndpoint(PublicEndpoint): coreVersion=device.core_version, enclosureVersion=device.enclosure_version, platform=device.platform, - user=dict(uuid=device.account_id) + user=dict(uuid=device.account_id), ) response = response_data, HTTPStatus.OK self._add_etag(device_etag_key(device_id)) else: - response = '', HTTPStatus.NO_CONTENT + response = "", HTTPStatus.NO_CONTENT return response @@ -69,10 +70,10 @@ class DeviceEndpoint(PublicEndpoint): update_device = UpdateDevice(payload) update_device.validate() updates = dict( - platform=payload.get('platform') or 'unknown', - enclosure_version=payload.get('enclosureVersion') or 'unknown', - core_version=payload.get('coreVersion') or 'unknown' + platform=payload.get("platform") or "unknown", + enclosure_version=payload.get("enclosureVersion") or "unknown", + core_version=payload.get("coreVersion") or "unknown", ) DeviceRepository(self.db).update_device_from_core(device_id, updates) - return '', HTTPStatus.OK + return "", HTTPStatus.OK diff --git a/api/public/public_api/endpoints/device_location.py b/api/public/public_api/endpoints/device_location.py index cec36999..0c2e3a47 100644 --- a/api/public/public_api/endpoints/device_location.py +++ b/api/public/public_api/endpoints/device_location.py @@ -25,17 +25,18 @@ from selene.data.device import GeographyRepository class DeviceLocationEndpoint(PublicEndpoint): - def __init__(self): super(DeviceLocationEndpoint, self).__init__() def get(self, device_id): self._authenticate(device_id) self._validate_etag(device_location_etag_key(device_id)) - location = GeographyRepository(self.db, None).get_location_by_device_id(device_id) + location = GeographyRepository(self.db, None).get_location_by_device_id( + device_id + ) if location: response = (location, HTTPStatus.OK) self._add_etag(device_location_etag_key(device_id)) else: - response = ('', HTTPStatus.NOT_FOUND) + response = ("", HTTPStatus.NOT_FOUND) return response diff --git a/api/public/public_api/endpoints/device_oauth.py b/api/public/public_api/endpoints/device_oauth.py index ef6813a3..045963c6 100644 --- a/api/public/public_api/endpoints/device_oauth.py +++ b/api/public/public_api/endpoints/device_oauth.py @@ -26,18 +26,15 @@ from selene.data.account import AccountRepository class OauthServiceEndpoint(PublicEndpoint): - def __init__(self): super(OauthServiceEndpoint, self).__init__() - self.oauth_service_host = os.environ['OAUTH_BASE_URL'] + self.oauth_service_host = os.environ["OAUTH_BASE_URL"] def get(self, device_id, credentials, oauth_path): account = AccountRepository(self.db).get_account_by_device_id(device_id) uuid = account.id - url = '{host}/auth/{credentials}/{oauth_path}'.format( - host=self.oauth_service_host, - credentials=credentials, - oauth_path=oauth_path + url = "{host}/auth/{credentials}/{oauth_path}".format( + host=self.oauth_service_host, credentials=credentials, oauth_path=oauth_path ) params = dict(uuid=uuid) response = requests.get(url, params=params) diff --git a/api/public/public_api/endpoints/device_refresh_token.py b/api/public/public_api/endpoints/device_refresh_token.py index ee1aca7d..f5db854f 100644 --- a/api/public/public_api/endpoints/device_refresh_token.py +++ b/api/public/public_api/endpoints/device_refresh_token.py @@ -35,44 +35,44 @@ class DeviceRefreshTokenEndpoint(PublicEndpoint): def get(self): headers = self.request.headers - if 'Authorization' not in headers: - raise AuthenticationError('Oauth token not found') - token_header = self.request.headers['Authorization'] - if token_header.startswith('Bearer '): - refresh = token_header[len('Bearer '):] + if "Authorization" not in headers: + raise AuthenticationError("Oauth token not found") + token_header = self.request.headers["Authorization"] + if token_header.startswith("Bearer "): + refresh = token_header[len("Bearer ") :] session = self._refresh_session_token(refresh) # Trying to fetch a session using the refresh token if session: response = session, HTTPStatus.OK else: - device = self.request.headers.get('Device') + device = self.request.headers.get("Device") if device: # trying to fetch a session using the device uuid session = self._refresh_session_token_device(device) if session: response = session, HTTPStatus.OK else: - response = '', HTTPStatus.UNAUTHORIZED + response = "", HTTPStatus.UNAUTHORIZED else: - response = '', HTTPStatus.UNAUTHORIZED + response = "", HTTPStatus.UNAUTHORIZED else: - response = '', HTTPStatus.UNAUTHORIZED + response = "", HTTPStatus.UNAUTHORIZED return response def _refresh_session_token(self, refresh: str): - refresh_key = 'device.token.refresh:{}'.format(refresh) + refresh_key = "device.token.refresh:{}".format(refresh) session = self.cache.get(refresh_key) if session: old_login = json.loads(session) - device_id = old_login['uuid'] + device_id = old_login["uuid"] self.cache.delete(refresh_key) return generate_device_login(device_id, self.cache) def _refresh_session_token_device(self, device: str): - refresh_key = 'device.session:{}'.format(device) + refresh_key = "device.session:{}".format(device) session = self.cache.get(refresh_key) if session: old_login = json.loads(session) - device_id = old_login['uuid'] + device_id = old_login["uuid"] self.cache.delete(refresh_key) return generate_device_login(device_id, self.cache) diff --git a/api/public/public_api/endpoints/device_setting.py b/api/public/public_api/endpoints/device_setting.py index 91d470eb..9e67eb07 100644 --- a/api/public/public_api/endpoints/device_setting.py +++ b/api/public/public_api/endpoints/device_setting.py @@ -25,6 +25,7 @@ from selene.data.device import SettingRepository class DeviceSettingEndpoint(PublicEndpoint): """Return the device's settings for the API v1 model""" + def __init__(self): super(DeviceSettingEndpoint, self).__init__() @@ -36,5 +37,5 @@ class DeviceSettingEndpoint(PublicEndpoint): response = (setting, HTTPStatus.OK) self._add_etag(device_setting_etag_key(device_id)) else: - response = ('', HTTPStatus.NO_CONTENT) + response = ("", HTTPStatus.NO_CONTENT) return response diff --git a/api/public/public_api/endpoints/device_skill_manifest.py b/api/public/public_api/endpoints/device_skill_manifest.py index 01b00e16..c156e0f6 100644 --- a/api/public/public_api/endpoints/device_skill_manifest.py +++ b/api/public/public_api/endpoints/device_skill_manifest.py @@ -28,7 +28,7 @@ from schematics.types import ( ListType, IntType, BooleanType, - TimestampType + TimestampType, ) from selene.api import PublicEndpoint @@ -43,9 +43,7 @@ class SkillManifestReconciler(object): self.skill_repo = SkillRepository(self.db) self.device_manifest = {sm.skill_gid: sm for sm in device_manifest} self.db_manifest = {ds.skill_gid: ds for ds in db_manifest} - self.device_manifest_global_ids = { - gid for gid in self.device_manifest.keys() - } + self.device_manifest_global_ids = {gid for gid in self.device_manifest.keys()} self.db_manifest_global_ids = {gid for gid in self.db_manifest} def reconcile(self): @@ -82,16 +80,14 @@ class SkillManifestReconciler(object): for gid in skills_to_add: skill_id = self.skill_repo.ensure_skill_exists(gid) self.device_manifest[gid].skill_id = skill_id - self.skill_manifest_repo.add_manifest_skill( - self.device_manifest[gid] - ) + self.skill_manifest_repo.add_manifest_skill(self.device_manifest[gid]) class RequestManifestSkill(Model): name = StringType(required=True) origin = StringType(required=True) installation = StringType(required=True) - failure_message = StringType(default='') + failure_message = StringType(default="") status = StringType(required=True) beta = BooleanType(required=True) installed = TimestampType(required=True) @@ -126,7 +122,7 @@ class DeviceSkillManifestEndpoint(PublicEndpoint): self._validate_put_request() self._update_skill_manifest(device_id) - return '', HTTPStatus.OK + return "", HTTPStatus.OK def _validate_put_request(self): request_data = SkillManifestRequest(self.request.json) @@ -137,29 +133,27 @@ class DeviceSkillManifestEndpoint(PublicEndpoint): device_id ) device_skill_manifest = [] - for manifest_skill in self.request.json['skills']: + for manifest_skill in self.request.json["skills"]: self._convert_manifest_timestamps(manifest_skill) device_skill_manifest.append( ManifestSkill( device_id=device_id, - install_method=manifest_skill['origin'], - install_status=manifest_skill['installation'], - install_failure_reason=manifest_skill.get('failure_message'), - install_ts=manifest_skill['installed'], - skill_gid=manifest_skill['skill_gid'], - update_ts=manifest_skill['updated'] + install_method=manifest_skill["origin"], + install_status=manifest_skill["installation"], + install_failure_reason=manifest_skill.get("failure_message"), + install_ts=manifest_skill["installed"], + skill_gid=manifest_skill["skill_gid"], + update_ts=manifest_skill["updated"], ) ) reconciler = SkillManifestReconciler( - self.db, - device_skill_manifest, - db_skill_manifest + self.db, device_skill_manifest, db_skill_manifest ) reconciler.reconcile() @staticmethod def _convert_manifest_timestamps(manifest_skill): - for key in ('installed', 'updated'): + for key in ("installed", "updated"): value = manifest_skill[key] if value: manifest_skill[key] = datetime.fromtimestamp(value) diff --git a/api/public/public_api/endpoints/device_skill_settings.py b/api/public/public_api/endpoints/device_skill_settings.py index f07ba099..fc2d24a8 100644 --- a/api/public/public_api/endpoints/device_skill_settings.py +++ b/api/public/public_api/endpoints/device_skill_settings.py @@ -33,39 +33,37 @@ from selene.data.skill import ( SettingsDisplayRepository, Skill, SkillRepository, - SkillSettingRepository + SkillSettingRepository, ) from selene.util.cache import DEVICE_SKILL_ETAG_KEY # matches | -GLOBAL_ID_PATTERN = '^([^\|@]+)\|([^\|]+$)' +GLOBAL_ID_PATTERN = "^([^\|@]+)\|([^\|]+$)" # matches @|| -GLOBAL_ID_DIRTY_PATTERN = '^@(.*)\|(.*)\|(.*)$' +GLOBAL_ID_DIRTY_PATTERN = "^@(.*)\|(.*)\|(.*)$" # matches @| -GLOBAL_ID_NON_MSM_PATTERN = '^@([^\|]+)\|([^\|]+$)' -GLOBAL_ID_ANY_PATTERN = '(?:{})|(?:{})|(?:{})'.format( - GLOBAL_ID_PATTERN, - GLOBAL_ID_DIRTY_PATTERN, - GLOBAL_ID_NON_MSM_PATTERN +GLOBAL_ID_NON_MSM_PATTERN = "^@([^\|]+)\|([^\|]+$)" +GLOBAL_ID_ANY_PATTERN = "(?:{})|(?:{})|(?:{})".format( + GLOBAL_ID_PATTERN, GLOBAL_ID_DIRTY_PATTERN, GLOBAL_ID_NON_MSM_PATTERN ) def _normalize_field_value(field): """The field values in skillMetadata are all strings, convert to native.""" - normalized_value = field.get('value') - if field['type'].lower() == 'checkbox': - if field['value'] in ('false', 'False', '0'): + normalized_value = field.get("value") + if field["type"].lower() == "checkbox": + if field["value"] in ("false", "False", "0"): normalized_value = False - elif field['value'] in ('true', 'True', '1'): + elif field["value"] in ("true", "True", "1"): normalized_value = True - elif field['type'].lower() == 'number' and isinstance(field['value'], str): - if field['value']: - normalized_value = float(field['value']) + elif field["type"].lower() == "number" and isinstance(field["value"], str): + if field["value"]: + normalized_value = float(field["value"]) if not normalized_value % 1: - normalized_value = int(field['value']) + normalized_value = int(field["value"]) else: normalized_value = 0 - elif field['value'] == "[]": + elif field["value"] == "[]": normalized_value = [] return normalized_value @@ -78,6 +76,7 @@ class SkillSettingUpdater(object): request specifies a single device to update, all devices with the same skill must be updated as well. """ + _device_skill_repo = None _settings_display_repo = None @@ -115,28 +114,27 @@ class SkillSettingUpdater(object): settings_meta.json file before sending the result to this API. The settings values are stored separately from the metadata in the database. """ - settings_definition = self.display_data.get('skillMetadata') + settings_definition = self.display_data.get("skillMetadata") if settings_definition is not None: self.settings_values = dict() sections_without_values = [] - for section in settings_definition['sections']: + for section in settings_definition["sections"]: section_without_values = dict(**section) - for field in section_without_values['fields']: - field_name = field.get('name') - field_value = field.get('value') + for field in section_without_values["fields"]: + field_name = field.get("name") + field_value = field.get("value") if field_name is not None: if field_value is not None: field_value = _normalize_field_value(field) - del(field['value']) + del field["value"] self.settings_values[field_name] = field_value sections_without_values.append(section_without_values) - settings_definition['sections'] = sections_without_values + settings_definition["sections"] = sections_without_values def _get_skill_id(self): """Get the id of the skill in the request""" - skill_global_id = ( - self.display_data.get('skill_gid') or - self.display_data.get('identifier') + skill_global_id = self.display_data.get("skill_gid") or self.display_data.get( + "identifier" ) skill_repo = SkillRepository(self.db) skill_id = skill_repo.ensure_skill_exists(skill_global_id) @@ -145,14 +143,9 @@ class SkillSettingUpdater(object): def _ensure_settings_display_exists(self) -> bool: """If the settings display changed, a new row needs to be added.""" new_settings_display = False - self.settings_display = SettingsDisplay( - self.skill.id, - self.display_data - ) - self.settings_display.id = ( - self.settings_display_repo.get_settings_display_id( - self.settings_display - ) + self.settings_display = SettingsDisplay(self.skill.id, self.display_data) + self.settings_display.id = self.settings_display_repo.get_settings_display_id( + self.settings_display ) if self.settings_display.id is None: self.settings_display.id = self.settings_display_repo.add( @@ -173,11 +166,8 @@ class SkillSettingUpdater(object): """Get all the permutations of settings for a skill""" account_repo = AccountRepository(self.db) account = account_repo.get_account_by_device_id(self.device_id) - skill_settings = ( - self.device_skill_repo.get_skill_settings_for_account( - account.id, - self.skill.id - ) + skill_settings = self.device_skill_repo.get_skill_settings_for_account( + account.id, self.skill.id ) return skill_settings @@ -187,14 +177,14 @@ class SkillSettingUpdater(object): for skill_setting in skill_settings: if self.device_id in skill_setting.device_ids: device_skill_found = True - if skill_setting.install_method in ('voice', 'cli'): + if skill_setting.install_method in ("voice", "cli"): devices_to_update = [self.device_id] else: devices_to_update = skill_setting.device_ids self.device_skill_repo.upsert_device_skill_settings( devices_to_update, self.settings_display, - self._merge_settings_values(skill_setting.settings_values) + self._merge_settings_values(skill_setting.settings_values), ) break @@ -225,9 +215,7 @@ class SkillSettingUpdater(object): manifest endpoint in some cases. """ self.device_skill_repo.upsert_device_skill_settings( - [self.device_id], - self.settings_display, - self._merge_settings_values() + [self.device_id], self.settings_display, self._merge_settings_values() ) @@ -267,15 +255,16 @@ class RequestSkill(Model): identifier = StringType() def validate_skill_gid(self, data, value): - if data['skill_gid'] is None and data['identifier'] is None: + if data["skill_gid"] is None and data["identifier"] is None: raise ValidationError( - 'skill should have either skill_gid or identifier defined' + "skill should have either skill_gid or identifier defined" ) return value class DeviceSkillSettingsEndpoint(PublicEndpoint): """Fetch all skills associated with a device using the API v1 format""" + _device_skill_repo = None _skill_repo = None _skill_setting_repo = None @@ -317,23 +306,19 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): """ self._authenticate(device_id) self._validate_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)) - device_skills = self.skill_setting_repo.get_skill_settings_for_device( - device_id - ) + device_skills = self.skill_setting_repo.get_skill_settings_for_device(device_id) if device_skills: response_data = self._build_response_data(device_skills) response = Response( json.dumps(response_data), status=HTTPStatus.OK, - content_type='application/json' + content_type="application/json", ) self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)) else: response = Response( - '', - status=HTTPStatus.NO_CONTENT, - content_type='application/json' + "", status=HTTPStatus.NO_CONTENT, content_type="application/json" ) return response @@ -341,7 +326,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): response_data = [] for skill in device_skills: response_skill = dict(uuid=skill.skill_id) - settings_definition = skill.settings_display.get('skillMetadata') + settings_definition = skill.settings_display.get("skillMetadata") if settings_definition: settings_sections = self._apply_settings_values( settings_definition, skill.settings_values @@ -350,10 +335,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): response_skill.update( skillMetadata=dict(sections=settings_sections) ) - skill_gid = skill.settings_display.get('skill_gid') + skill_gid = skill.settings_display.get("skill_gid") if skill_gid is not None: response_skill.update(skill_gid=skill_gid) - identifier = skill.settings_display.get('identifier') + identifier = skill.settings_display.get("identifier") if identifier is None: response_skill.update(identifier=skill_gid) else: @@ -366,10 +351,10 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): def _apply_settings_values(settings_definition, settings_values): """Build a copy of the settings sections populated with values.""" sections_with_values = [] - for section in settings_definition['sections']: + for section in settings_definition["sections"]: section_with_values = dict(**section) - for field in section_with_values['fields']: - field_name = field.get('name') + for field in section_with_values["fields"]: + field_name = field.get("name") if field_name is not None and field_name in settings_values: field.update(value=str(settings_values[field_name])) sections_with_values.append(section_with_values) @@ -380,9 +365,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): self._authenticate(device_id) self._validate_put_request() skill_id = self._update_skill_settings(device_id) - self.etag_manager.expire( - DEVICE_SKILL_ETAG_KEY.format(device_id=device_id) - ) + self.etag_manager.expire(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)) return dict(uuid=skill_id), HTTPStatus.OK @@ -392,9 +375,7 @@ class DeviceSkillSettingsEndpoint(PublicEndpoint): def _update_skill_settings(self, device_id): skill_setting_updater = SkillSettingUpdater( - self.db, - device_id, - self.request.json + self.db, device_id, self.request.json ) skill_setting_updater.update() self._delete_orphaned_settings_display( @@ -418,6 +399,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint): with pre 19.08 versions of mycroft-core. Once those versions are no longer supported, the older class can be deprecated. """ + def get(self, device_id): """ Retrieve skills installed on device from the database. @@ -433,9 +415,7 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint): def _build_response_data(self, device_id): device_skill_repo = DeviceSkillRepository(self.db) - device_skills = device_skill_repo.get_skill_settings_for_device( - device_id - ) + device_skills = device_skill_repo.get_skill_settings_for_device(device_id) if device_skills is not None: response_data = {} for skill in device_skills: @@ -446,15 +426,13 @@ class DeviceSkillSettingsEndpointV2(PublicEndpoint): def _build_response(self, device_id, response_data): if response_data is None: response = Response( - '', - status=HTTPStatus.NO_CONTENT, - content_type='application/json' + "", status=HTTPStatus.NO_CONTENT, content_type="application/json" ) else: response = Response( json.dumps(response_data), status=HTTPStatus.OK, - content_type='application/json' + content_type="application/json", ) self._add_etag(DEVICE_SKILL_ETAG_KEY.format(device_id=device_id)) diff --git a/api/public/public_api/endpoints/device_subscription.py b/api/public/public_api/endpoints/device_subscription.py index 5518adfd..3fc3a11d 100644 --- a/api/public/public_api/endpoints/device_subscription.py +++ b/api/public/public_api/endpoints/device_subscription.py @@ -33,10 +33,10 @@ class DeviceSubscriptionEndpoint(PublicEndpoint): if account: membership = account.membership response = ( - {'@type': membership.type if membership is not None else 'free'}, - HTTPStatus.OK + {"@type": membership.type if membership is not None else "free"}, + HTTPStatus.OK, ) else: - response = '', HTTPStatus.NO_CONTENT + response = "", HTTPStatus.NO_CONTENT return response diff --git a/api/public/public_api/endpoints/oauth_callback.py b/api/public/public_api/endpoints/oauth_callback.py index e074ef64..7f5a8ce4 100644 --- a/api/public/public_api/endpoints/oauth_callback.py +++ b/api/public/public_api/endpoints/oauth_callback.py @@ -25,13 +25,12 @@ from selene.api import PublicEndpoint class OauthCallbackEndpoint(PublicEndpoint): - def __init__(self): super(OauthCallbackEndpoint, self).__init__() - self.oauth_service_host = os.environ['OAUTH_BASE_URL'] + self.oauth_service_host = os.environ["OAUTH_BASE_URL"] def get(self): params = dict(self.request.args) - url = self.oauth_service_host + '/auth/callback' + url = self.oauth_service_host + "/auth/callback" response = requests.get(url, params=params) return response.text, response.status_code diff --git a/api/public/public_api/endpoints/premium_voice.py b/api/public/public_api/endpoints/premium_voice.py index 07e610e8..cb436bea 100644 --- a/api/public/public_api/endpoints/premium_voice.py +++ b/api/public/public_api/endpoints/premium_voice.py @@ -30,20 +30,20 @@ class PremiumVoiceEndpoint(PublicEndpoint): def get(self, device_id): self._authenticate(device_id) - arch = self.request.args.get('arch') + arch = self.request.args.get("arch") account = AccountRepository(self.db).get_account_by_device_id(device_id) if account and account.membership: link = self._get_premium_voice_link(arch) - response = {'link': link}, HTTPStatus.OK + response = {"link": link}, HTTPStatus.OK else: - response = '', HTTPStatus.NO_CONTENT + response = "", HTTPStatus.NO_CONTENT return response def _get_premium_voice_link(self, arch): - if arch == 'arm': - response = os.environ['URL_VOICE_ARM'] - elif arch == 'x86_64': - response = os.environ['URL_VOICE_X86_64'] + if arch == "arm": + response = os.environ["URL_VOICE_ARM"] + elif arch == "x86_64": + response = os.environ["URL_VOICE_X86_64"] else: - response = '' + response = "" return response diff --git a/api/public/public_api/endpoints/stripe_webhook.py b/api/public/public_api/endpoints/stripe_webhook.py index fece3e31..885ed26e 100644 --- a/api/public/public_api/endpoints/stripe_webhook.py +++ b/api/public/public_api/endpoints/stripe_webhook.py @@ -25,14 +25,13 @@ from selene.data.account import AccountRepository class StripeWebHookEndpoint(PublicEndpoint): - def __init__(self): super(StripeWebHookEndpoint, self).__init__() def post(self): event = json.loads(self.request.data) - type = event.get('type') - if type == 'customer.subscription.deleted': - customer = event['data']['object']['customer'] + type = event.get("type") + if type == "customer.subscription.deleted": + customer = event["data"]["object"]["customer"] AccountRepository(self.db).end_active_membership(customer) - return '', HTTPStatus.OK + return "", HTTPStatus.OK diff --git a/api/public/public_api/endpoints/wolfram_alpha_simple.py b/api/public/public_api/endpoints/wolfram_alpha_simple.py index be30a3e6..42d16ed2 100644 --- a/api/public/public_api/endpoints/wolfram_alpha_simple.py +++ b/api/public/public_api/endpoints/wolfram_alpha_simple.py @@ -34,14 +34,14 @@ class WolframAlphaSimpleEndpoint(PublicEndpoint): def __init__(self): super(WolframAlphaSimpleEndpoint, self).__init__() - self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY'] - self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL'] + self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"] + self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"] def get(self): self._authenticate() params = dict(self.request.args) - params['appid'] = self.wolfram_alpha_key - response = requests.get(self.wolfram_alpha_url + '/v1/simple', params=params) + params["appid"] = self.wolfram_alpha_key + response = requests.get(self.wolfram_alpha_url + "/v1/simple", params=params) code = response.status_code - response = (response.text, code) if code == HTTPStatus.OK else ('', code) + response = (response.text, code) if code == HTTPStatus.OK else ("", code) return response diff --git a/api/public/public_api/endpoints/wolfram_alpha_spoken.py b/api/public/public_api/endpoints/wolfram_alpha_spoken.py index 977bb8fe..f6d42006 100644 --- a/api/public/public_api/endpoints/wolfram_alpha_spoken.py +++ b/api/public/public_api/endpoints/wolfram_alpha_spoken.py @@ -30,14 +30,14 @@ class WolframAlphaSpokenEndpoint(PublicEndpoint): def __init__(self): super(WolframAlphaSpokenEndpoint, self).__init__() - self.wolfram_alpha_key = os.environ['WOLFRAM_ALPHA_KEY'] - self.wolfram_alpha_url = os.environ['WOLFRAM_ALPHA_URL'] + self.wolfram_alpha_key = os.environ["WOLFRAM_ALPHA_KEY"] + self.wolfram_alpha_url = os.environ["WOLFRAM_ALPHA_URL"] def get(self): self._authenticate() params = dict(self.request.args) - params['appid'] = self.wolfram_alpha_key - response = requests.get(self.wolfram_alpha_url + '/v1/spoken', params=params) + params["appid"] = self.wolfram_alpha_key + response = requests.get(self.wolfram_alpha_url + "/v1/spoken", params=params) code = response.status_code - response = (response.text, code) if code == HTTPStatus.OK else ('', code) + response = (response.text, code) if code == HTTPStatus.OK else ("", code) return response diff --git a/api/public/tests/features/steps/device_location.py b/api/public/tests/features/steps/device_location.py index 95a9cc7b..a70d9de2 100644 --- a/api/public/tests/features/steps/device_location.py +++ b/api/public/tests/features/steps/device_location.py @@ -26,105 +26,104 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not from selene.api.etag import ETagManager, device_location_etag_key -@when('a api call to get the location is done') +@when("a api call to get the location is done") def get_device_location(context): login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + device_id = login["uuid"] + access_token = login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.get_location_response = context.client.get( - '/v1/device/{uuid}/location'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/location".format(uuid=device_id), headers=headers ) -@then('the location should be retrieved') +@then("the location should be retrieved") def validate_location(context): response = context.get_location_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) location = json.loads(response.data) - assert_that(location, has_key('coordinate')) - assert_that(location, has_key('timezone')) - assert_that(location, has_key('city')) + assert_that(location, has_key("coordinate")) + assert_that(location, has_key("timezone")) + assert_that(location, has_key("city")) - coordinate = location['coordinate'] - assert_that(coordinate, has_key('latitude')) - assert_that(coordinate, has_key('longitude')) + coordinate = location["coordinate"] + assert_that(coordinate, has_key("latitude")) + assert_that(coordinate, has_key("longitude")) - timezone = location['timezone'] - assert_that(timezone, has_key('name')) - assert_that(timezone, has_key('code')) - assert_that(timezone, has_key('offset')) - assert_that(timezone, has_key('dstOffset')) + timezone = location["timezone"] + assert_that(timezone, has_key("name")) + assert_that(timezone, has_key("code")) + assert_that(timezone, has_key("offset")) + assert_that(timezone, has_key("dstOffset")) - city = location['city'] - assert_that(city, has_key('name')) - assert_that(city, has_key('state')) + city = location["city"] + assert_that(city, has_key("name")) + assert_that(city, has_key("state")) - state = city['state'] - assert_that(state, has_key('name')) - assert_that(state, has_key('country')) - assert_that(state, has_key('code')) + state = city["state"] + assert_that(state, has_key("name")) + assert_that(state, has_key("country")) + assert_that(state, has_key("code")) - country = state['country'] - assert_that(country, has_key('name')) - assert_that(country, has_key('code')) + country = state["country"] + assert_that(country, has_key("name")) + assert_that(country, has_key("code")) -@given('an expired etag from a location entity') +@given("an expired etag from a location entity") def expire_location_etag(context): etag_manager: ETagManager = context.etag_manager - device_id = context.device_login['uuid'] - context.expired_location_etag = etag_manager.get(device_location_etag_key(device_id)) + device_id = context.device_login["uuid"] + context.expired_location_etag = etag_manager.get( + device_location_etag_key(device_id) + ) etag_manager.expire_device_location_etag_by_device_id(device_id) -@when('try to get the location using the expired etag') +@when("try to get the location using the expired etag") def get_using_expired_etag(context): login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] + device_id = login["uuid"] + access_token = login["accessToken"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': context.expired_location_etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": context.expired_location_etag, } context.get_location_response = context.client.get( - '/v1/device/{uuid}/location'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/location".format(uuid=device_id), headers=headers ) -@then('an etag associated with the location should be created') +@then("an etag associated with the location should be created") def validate_etag(context): response = context.get_location_response - new_location_etag = response.headers.get('ETag') + new_location_etag = response.headers.get("ETag") assert_that(new_location_etag, not_none()) assert_that(new_location_etag, is_not(context.expired_location_etag)) -@given('a valid etag from a location entity') +@given("a valid etag from a location entity") def valid_etag(context): etag_manager = context.etag_manager - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] context.valid_location_etag = etag_manager.get(device_location_etag_key(device_id)) -@when('try to get the location using a valid etag') +@when("try to get the location using a valid etag") def get_using_valid_etag(context): login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] + device_id = login["uuid"] + access_token = login["accessToken"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': context.valid_location_etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": context.valid_location_etag, } context.get_location_response = context.client.get( - '/v1/device/{uuid}/location'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/location".format(uuid=device_id), headers=headers ) -@then('the location endpoint should return 304') +@then("the location endpoint should return 304") def validate_response_valid_etag(context): response = context.get_location_response assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED)) diff --git a/api/public/tests/features/steps/device_refresh_token.py b/api/public/tests/features/steps/device_refresh_token.py index 00da0f61..f4ced9ce 100644 --- a/api/public/tests/features/steps/device_refresh_token.py +++ b/api/public/tests/features/steps/device_refresh_token.py @@ -24,42 +24,42 @@ from behave import when, then from hamcrest import assert_that, equal_to, has_key, is_not -@when('the session token is refreshed') +@when("the session token is refreshed") def refresh_token(context): login = json.loads(context.activate_device_response.data) - refresh = login['refreshToken'] + refresh = login["refreshToken"] context.refresh_token_response = context.client.get( - '/v1/auth/token', - headers={'Authorization': 'Bearer {token}'.format(token=refresh)} + "/v1/auth/token", + headers={"Authorization": "Bearer {token}".format(token=refresh)}, ) -@then('a valid new session entity should be returned') +@then("a valid new session entity should be returned") def validate_refresh_token(context): response = context.refresh_token_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) new_login = json.loads(response.data) - assert_that(new_login, has_key(equal_to('uuid'))) - assert_that(new_login, has_key(equal_to('accessToken'))) - assert_that(new_login, has_key(equal_to('refreshToken'))) - assert_that(new_login, has_key(equal_to('expiration'))) + assert_that(new_login, has_key(equal_to("uuid"))) + assert_that(new_login, has_key(equal_to("accessToken"))) + assert_that(new_login, has_key(equal_to("refreshToken"))) + assert_that(new_login, has_key(equal_to("expiration"))) old_login = json.loads(context.activate_device_response.data) - assert_that(new_login['uuid']), equal_to(old_login['uuid']) - assert_that(new_login['accessToken'], is_not(equal_to(old_login['accessToken']))) - assert_that(new_login['refreshToken'], is_not(equal_to(old_login['refreshToken']))) + assert_that(new_login["uuid"]), equal_to(old_login["uuid"]) + assert_that(new_login["accessToken"], is_not(equal_to(old_login["accessToken"]))) + assert_that(new_login["refreshToken"], is_not(equal_to(old_login["refreshToken"]))) -@when('try to refresh an invalid refresh token') +@when("try to refresh an invalid refresh token") def refresh_invalid_token(context): context.refresh_invalid_token_response = context.client.get( - '/v1/auth/token', - headers={'Authorization': 'Bearer {token}'.format(token='123')} + "/v1/auth/token", + headers={"Authorization": "Bearer {token}".format(token="123")}, ) -@then('401 status code should be returned') +@then("401 status code should be returned") def validate_refresh_invalid_token(context): response = context.refresh_invalid_token_response assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED)) diff --git a/api/public/tests/features/steps/device_skill_settings.py b/api/public/tests/features/steps/device_skill_settings.py index ae7f205f..8fbadc2f 100644 --- a/api/public/tests/features/steps/device_skill_settings.py +++ b/api/public/tests/features/steps/device_skill_settings.py @@ -29,193 +29,177 @@ from selene.testing.skill import add_skill, build_label_field, build_text_field from selene.util.cache import DEVICE_SKILL_ETAG_KEY -@given('skill settings with a new value') +@given("skill settings with a new value") def change_skill_setting_value(context): - _, bar_settings_display = context.skills['bar'] - section = bar_settings_display.display_data['skillMetadata']['sections'][0] - field_with_value = section['fields'][1] - field_with_value['value'] = 'New device text value' + _, bar_settings_display = context.skills["bar"] + section = bar_settings_display.display_data["skillMetadata"]["sections"][0] + field_with_value = section["fields"][1] + field_with_value["value"] = "New device text value" -@given('skill settings with a deleted field') +@given("skill settings with a deleted field") def delete_field_from_settings(context): - _, bar_settings_display = context.skills['bar'] - section = bar_settings_display.display_data['skillMetadata']['sections'][0] - context.removed_field = section['fields'].pop(1) - context.remaining_field = section['fields'][1] + _, bar_settings_display = context.skills["bar"] + section = bar_settings_display.display_data["skillMetadata"]["sections"][0] + context.removed_field = section["fields"].pop(1) + context.remaining_field = section["fields"][1] -@given('a valid device skill E-tag') +@given("a valid device skill E-tag") def set_skill_setting_etag(context): context.device_skill_etag = context.etag_manager.get( DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id) ) -@given('an expired device skill E-tag') +@given("an expired device skill E-tag") def expire_skill_setting_etag(context): valid_device_skill_etag = context.etag_manager.get( DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id) ) - context.device_skill_etag = context.etag_manager.expire( - valid_device_skill_etag - ) + context.device_skill_etag = context.etag_manager.expire(valid_device_skill_etag) -@given('settings for a skill not assigned to the device') +@given("settings for a skill not assigned to the device") def add_skill_not_assigned_to_device(context): foobar_skill, foobar_settings_display = add_skill( context.db, - skill_global_id='foobar-skill|19.02', - settings_fields=[build_label_field(), build_text_field()] + skill_global_id="foobar-skill|19.02", + settings_fields=[build_label_field(), build_text_field()], ) - section = foobar_settings_display.display_data['skillMetadata']['sections'][0] - field_with_value = section['fields'][1] - field_with_value['value'] = 'New skill text value' + section = foobar_settings_display.display_data["skillMetadata"]["sections"][0] + field_with_value = section["fields"][1] + field_with_value["value"] = "New skill text value" context.skills.update(foobar=(foobar_skill, foobar_settings_display)) -@when('a device requests the settings for its skills') +@when("a device requests the settings for its skills") def get_device_skill_settings(context): - if hasattr(context, 'device_skill_etag'): - context.request_header[ETAG_REQUEST_HEADER_KEY] = ( - context.device_skill_etag - ) + if hasattr(context, "device_skill_etag"): + context.request_header[ETAG_REQUEST_HEADER_KEY] = context.device_skill_etag context.response = context.client.get( - '/v1/device/{device_id}/skill'.format(device_id=context.device_id), - content_type='application/json', - headers=context.request_header + "/v1/device/{device_id}/skill".format(device_id=context.device_id), + content_type="application/json", + headers=context.request_header, ) -@when('the device sends a request to update the {skill} skill settings') +@when("the device sends a request to update the {skill} skill settings") def update_skill_settings(context, skill): _, settings_display = context.skills[skill] context.response = context.client.put( - '/v1/device/{device_id}/skill'.format(device_id=context.device_id), + "/v1/device/{device_id}/skill".format(device_id=context.device_id), data=json.dumps(settings_display.display_data), - content_type='application/json', - headers=context.request_header + content_type="application/json", + headers=context.request_header, ) -@when('the device requests a skill to be deleted') +@when("the device requests a skill to be deleted") def delete_skill(context): - foo_skill, _ = context.skills['foo'] + foo_skill, _ = context.skills["foo"] context.response = context.client.delete( - '/v1/device/{device_id}/skill/{skill_gid}'.format( - device_id=context.device_id, - skill_gid=foo_skill.skill_gid + "/v1/device/{device_id}/skill/{skill_gid}".format( + device_id=context.device_id, skill_gid=foo_skill.skill_gid ), - headers=context.request_header + headers=context.request_header, ) -@then('the settings are returned') +@then("the settings are returned") def validate_response(context): response = context.response.json assert_that(len(response), equal_to(2)) - foo_skill, foo_settings_display = context.skills['foo'] + foo_skill, foo_settings_display = context.skills["foo"] foo_skill_expected_result = dict( uuid=foo_skill.id, skill_gid=foo_skill.skill_gid, - identifier=foo_settings_display.display_data['identifier'] + identifier=foo_settings_display.display_data["identifier"], ) assert_that(foo_skill_expected_result, is_in(response)) - bar_skill, bar_settings_display = context.skills['bar'] - section = bar_settings_display.display_data['skillMetadata']['sections'][0] - text_field = section['fields'][1] - text_field['value'] = 'Device text value' - checkbox_field = section['fields'][2] - checkbox_field['value'] = 'false' + bar_skill, bar_settings_display = context.skills["bar"] + section = bar_settings_display.display_data["skillMetadata"]["sections"][0] + text_field = section["fields"][1] + text_field["value"] = "Device text value" + checkbox_field = section["fields"][2] + checkbox_field["value"] = "false" bar_skill_expected_result = dict( uuid=bar_skill.id, skill_gid=bar_skill.skill_gid, - identifier=bar_settings_display.display_data['identifier'], - skillMetadata=bar_settings_display.display_data['skillMetadata'] + identifier=bar_settings_display.display_data["identifier"], + skillMetadata=bar_settings_display.display_data["skillMetadata"], ) assert_that(bar_skill_expected_result, is_in(response)) -@then('the device skill E-tag is expired') +@then("the device skill E-tag is expired") def check_for_expired_etag(context): """An E-tag is expired by changing its value.""" expired_device_skill_etag = context.etag_manager.get( DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id) ) assert_that( - expired_device_skill_etag.decode(), - is_not(equal_to(context.device_skill_etag)) + expired_device_skill_etag.decode(), is_not(equal_to(context.device_skill_etag)) ) def _get_device_skill_settings(context): """Minimize DB hits and code duplication by getting these values once.""" - if not hasattr(context, 'device_skill_settings'): + if not hasattr(context, "device_skill_settings"): settings_repo = SkillSettingRepository(context.db) - context.device_skill_settings = ( - settings_repo.get_skill_settings_for_device(context.device_id) + context.device_skill_settings = settings_repo.get_skill_settings_for_device( + context.device_id ) context.device_settings_values = [ dss.settings_values for dss in context.device_skill_settings ] -@then('the skill settings are updated with the new value') +@then("the skill settings are updated with the new value") def validate_updated_skill_setting_value(context): _get_device_skill_settings(context) assert_that(len(context.device_skill_settings), equal_to(2)) expected_settings_values = dict( - textfield='New device text value', - checkboxfield='false' - ) - assert_that( - expected_settings_values, - is_in(context.device_settings_values) + textfield="New device text value", checkboxfield="false" ) + assert_that(expected_settings_values, is_in(context.device_settings_values)) -@then('the skill is assigned to the device with the settings populated') +@then("the skill is assigned to the device with the settings populated") def validate_updated_skill_setting_value(context): _get_device_skill_settings(context) assert_that(len(context.device_skill_settings), equal_to(3)) - expected_settings_values = dict(textfield='New skill text value') - assert_that( - expected_settings_values, - is_in(context.device_settings_values) - ) + expected_settings_values = dict(textfield="New skill text value") + assert_that(expected_settings_values, is_in(context.device_settings_values)) -@then('an E-tag is generated for these settings') +@then("an E-tag is generated for these settings") def get_skills_etag(context): response_headers = context.response.headers - response_etag = response_headers['ETag'] + response_etag = response_headers["ETag"] skill_etag = context.etag_manager.get( DEVICE_SKILL_ETAG_KEY.format(device_id=context.device_id) ) assert_that(skill_etag.decode(), equal_to(response_etag)) -@then('the field is no longer in the skill settings') +@then("the field is no longer in the skill settings") def validate_skill_setting_field_removed(context): _get_device_skill_settings(context) assert_that(len(context.device_skill_settings), equal_to(2)) # The removed field should no longer be in the settings values but the # value of the field that was not deleted should remain - assert_that( - dict(checkboxfield='false'), - is_in(context.device_settings_values) - ) + assert_that(dict(checkboxfield="false"), is_in(context.device_settings_values)) new_section = dict(fields=None) for device_skill_setting in context.device_skill_settings: - skill_gid = device_skill_setting.settings_display['skill_gid'] - if skill_gid.startswith('bar'): + skill_gid = device_skill_setting.settings_display["skill_gid"] + if skill_gid.startswith("bar"): new_settings_display = device_skill_setting.settings_display - new_skill_definition = new_settings_display['skillMetadata'] - new_section = new_skill_definition['sections'][0] + new_skill_definition = new_settings_display["skillMetadata"] + new_section = new_skill_definition["sections"][0] # The removed field should no longer be in the settings values but the # value of the field that was not deleted should remain - assert_that(context.removed_field, not is_in(new_section['fields'])) - assert_that(context.remaining_field, is_in(new_section['fields'])) + assert_that(context.removed_field, not is_in(new_section["fields"])) + assert_that(context.remaining_field, is_in(new_section["fields"])) diff --git a/api/public/tests/features/steps/get_device.py b/api/public/tests/features/steps/get_device.py index a93989e6..86ad9231 100644 --- a/api/public/tests/features/steps/get_device.py +++ b/api/public/tests/features/steps/get_device.py @@ -27,77 +27,75 @@ from hamcrest import assert_that, equal_to, has_key, not_none, is_not from selene.api.etag import ETagManager, device_etag_key new_fields = dict( - platform='mycroft_mark_1', - coreVersion='19.2.0', - enclosureVersion='1.4.0' + platform="mycroft_mark_1", coreVersion="19.2.0", enclosureVersion="1.4.0" ) -@when('device is retrieved') +@when("device is retrieved") def get_device(context): - access_token = context.device_login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) - device_id = context.device_login['uuid'] + access_token = context.device_login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) + device_id = context.device_login["uuid"] context.get_device_response = context.client.get( - '/v1/device/{uuid}'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}".format(uuid=device_id), headers=headers ) - context.device_etag = context.get_device_response.headers.get('ETag') + context.device_etag = context.get_device_response.headers.get("ETag") -@then('a valid device should be returned') +@then("a valid device should be returned") def validate_response(context): response = context.get_device_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) device = json.loads(response.data) - assert_that(device, has_key('uuid')) - assert_that(device, has_key('name')) - assert_that(device, has_key('description')) - assert_that(device, has_key('coreVersion')) - assert_that(device, has_key('enclosureVersion')) - assert_that(device, has_key('platform')) - assert_that(device, has_key('user')) - assert_that(device['user'], has_key('uuid')) - assert_that(device['user']['uuid'], equal_to(context.account.id)) + assert_that(device, has_key("uuid")) + assert_that(device, has_key("name")) + assert_that(device, has_key("description")) + assert_that(device, has_key("coreVersion")) + assert_that(device, has_key("enclosureVersion")) + assert_that(device, has_key("platform")) + assert_that(device, has_key("user")) + assert_that(device["user"], has_key("uuid")) + assert_that(device["user"]["uuid"], equal_to(context.account.id)) -@when('try to fetch a device without the authorization header') +@when("try to fetch a device without the authorization header") def get_invalid_device(context): - context.get_invalid_device_response = context.client.get('/v1/device/{uuid}'.format(uuid=str(uuid.uuid4()))) - - -@when('try to fetch a not allowed device') -def get_not_allowed_device(context): - access_token = context.device_login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) context.get_invalid_device_response = context.client.get( - '/v1/device/{uuid}'.format(uuid=str(uuid.uuid4())), - headers=headers + "/v1/device/{uuid}".format(uuid=str(uuid.uuid4())) ) -@then('a 401 status code should be returned') +@when("try to fetch a not allowed device") +def get_not_allowed_device(context): + access_token = context.device_login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) + context.get_invalid_device_response = context.client.get( + "/v1/device/{uuid}".format(uuid=str(uuid.uuid4())), headers=headers + ) + + +@then("a 401 status code should be returned") def validate_invalid_response(context): response = context.get_invalid_device_response assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED)) -@when('the device is updated') +@when("the device is updated") def update_device(context): login = context.device_login - access_token = login['accessToken'] - device_id = login['uuid'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + access_token = login["accessToken"] + device_id = login["uuid"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.update_device_response = context.client.patch( - '/v1/device/{uuid}'.format(uuid=device_id), + "/v1/device/{uuid}".format(uuid=device_id), data=json.dumps(new_fields), - content_type='application_json', - headers=headers + content_type="application_json", + headers=headers, ) -@then('the information should be updated') +@then("the information should be updated") def validate_update(context): response = context.update_device_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) @@ -105,74 +103,72 @@ def validate_update(context): response = context.get_device_response assert_that(response.status_code, equal_to(HTTPStatus.OK)) device = json.loads(response.data) - assert_that(device, has_key('name')) - assert_that(device['coreVersion'], equal_to(new_fields['coreVersion'])) - assert_that(device['enclosureVersion'], equal_to(new_fields['enclosureVersion'])) - assert_that(device['platform'], equal_to(new_fields['platform'])) + assert_that(device, has_key("name")) + assert_that(device["coreVersion"], equal_to(new_fields["coreVersion"])) + assert_that(device["enclosureVersion"], equal_to(new_fields["enclosureVersion"])) + assert_that(device["platform"], equal_to(new_fields["platform"])) -@given('a device with a valid etag') +@given("a device with a valid etag") def get_device_etag(context): etag_manager: ETagManager = context.etag_manager - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] context.device_etag = etag_manager.get(device_etag_key(device_id)) -@when('try to fetch a device using a valid etag') +@when("try to fetch a device using a valid etag") def get_device_using_etag(context): etag = context.device_etag assert_that(etag, not_none()) - access_token = context.device_login['accessToken'] - device_uuid = context.device_login['uuid'] + access_token = context.device_login["accessToken"] + device_uuid = context.device_login["uuid"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": etag, } context.response_using_etag = context.client.get( - '/v1/device/{uuid}'.format(uuid=device_uuid), - headers=headers + "/v1/device/{uuid}".format(uuid=device_uuid), headers=headers ) -@then('304 status code should be returned by the device endpoint') +@then("304 status code should be returned by the device endpoint") def validate_etag(context): response = context.response_using_etag assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED)) -@given('a device\'s etag expired by the web ui') +@given("a device's etag expired by the web ui") def expire_etag(context): etag_manager: ETagManager = context.etag_manager - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] context.device_etag = etag_manager.get(device_etag_key(device_id)) etag_manager.expire_device_etag_by_device_id(device_id) -@when('try to fetch a device using an expired etag') +@when("try to fetch a device using an expired etag") def fetch_device_expired_etag(context): etag = context.device_etag assert_that(etag, not_none()) - access_token = context.device_login['accessToken'] - device_uuid = context.device_login['uuid'] + access_token = context.device_login["accessToken"] + device_uuid = context.device_login["uuid"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": etag, } context.response_using_invalid_etag = context.client.get( - '/v1/device/{uuid}'.format(uuid=device_uuid), - headers=headers + "/v1/device/{uuid}".format(uuid=device_uuid), headers=headers ) -@then('should return status 200') +@then("should return status 200") def validate_status_code(context): response = context.response_using_invalid_etag assert_that(response.status_code, equal_to(HTTPStatus.OK)) -@then('a new etag') +@then("a new etag") def validate_new_etag(context): etag = context.device_etag response = context.response_using_invalid_etag - etag_from_response = response.headers.get('ETag') + etag_from_response = response.headers.get("ETag") assert_that(etag, is_not(etag_from_response)) diff --git a/api/public/tests/features/steps/get_device_settings.py b/api/public/tests/features/steps/get_device_settings.py index 575314ce..0b3cb90a 100644 --- a/api/public/tests/features/steps/get_device_settings.py +++ b/api/public/tests/features/steps/get_device_settings.py @@ -27,115 +27,113 @@ from hamcrest import assert_that, equal_to, has_key, is_not from selene.api.etag import ETagManager, device_setting_etag_key -@when('try to fetch device\'s setting') +@when("try to fetch device's setting") def get_device_settings(context): login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] - headers=dict(Authorization='Bearer {token}'.format(token=access_token)) + device_id = login["uuid"] + access_token = login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.response_setting = context.client.get( - '/v1/device/{uuid}/setting'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/setting".format(uuid=device_id), headers=headers ) -@then('a valid setting should be returned') +@then("a valid setting should be returned") def validate_response_setting(context): response = context.response_setting assert_that(response.status_code, equal_to(HTTPStatus.OK)) setting = json.loads(response.data) assert_that(response.status_code, equal_to(HTTPStatus.OK)) - assert_that(setting, has_key('uuid')) - assert_that(setting, has_key('systemUnit')) - assert_that(setting['systemUnit'], equal_to('imperial')) - assert_that(setting, has_key('timeFormat')) - assert_that(setting, has_key('dateFormat')) - assert_that(setting, has_key('optIn')) - assert_that(setting['optIn'], equal_to(True)) - assert_that(setting, has_key('ttsSettings')) - tts = setting['ttsSettings'] - assert_that(tts, has_key('module')) + assert_that(setting, has_key("uuid")) + assert_that(setting, has_key("systemUnit")) + assert_that(setting["systemUnit"], equal_to("imperial")) + assert_that(setting, has_key("timeFormat")) + assert_that(setting, has_key("dateFormat")) + assert_that(setting, has_key("optIn")) + assert_that(setting["optIn"], equal_to(True)) + assert_that(setting, has_key("ttsSettings")) + tts = setting["ttsSettings"] + assert_that(tts, has_key("module")) -@when('the settings endpoint is a called to a not allowed device') +@when("the settings endpoint is a called to a not allowed device") def get_device_settings(context): - access_token = context.device_login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + access_token = context.device_login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.get_invalid_setting_response = context.client.get( - '/v1/device/{uuid}/setting'.format(uuid=str(uuid.uuid4())), - headers=headers + "/v1/device/{uuid}/setting".format(uuid=str(uuid.uuid4())), headers=headers ) -@then('a 401 status code should be returned for the setting') +@then("a 401 status code should be returned for the setting") def validate_response(context): response = context.get_invalid_setting_response assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED)) -@given('a device\'s setting with a valid etag') +@given("a device's setting with a valid etag") def get_device_setting_etag(context): - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] etag_manager: ETagManager = context.etag_manager context.device_etag = etag_manager.get(device_setting_etag_key(device_id)) -@when('try to fetch the device\'s settings using a valid etag') +@when("try to fetch the device's settings using a valid etag") def get_device_settings_using_etag(context): etag = context.device_etag - access_token = context.device_login['accessToken'] - device_id = context.device_login['uuid'] + access_token = context.device_login["accessToken"] + device_id = context.device_login["uuid"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": etag, } context.get_setting_etag_response = context.client.get( - '/v1/device/{uuid}/setting'.format(uuid=str(device_id)), - headers=headers + "/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers ) -@then('304 status code should be returned by the device\'s settings endpoint') +@then("304 status code should be returned by the device's settings endpoint") def validate_etag_response(context): response = context.get_setting_etag_response assert_that(response.status_code, equal_to(HTTPStatus.NOT_MODIFIED)) -@given('a device\'s setting etag expired by the web ui at device level') +@given("a device's setting etag expired by the web ui at device level") def expire_etag_device_level(context): - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] etag_manager: ETagManager = context.etag_manager context.device_etag = etag_manager.get(device_setting_etag_key(device_id)) etag_manager.expire_device_setting_etag_by_device_id(device_id) -@given('a device\'s setting etag expired by the web ui at account level') +@given("a device's setting etag expired by the web ui at account level") def expire_etag_account_level(context): account_id = context.account.id - device_id = context.device_login['uuid'] + device_id = context.device_login["uuid"] etag_manager: ETagManager = context.etag_manager context.device_etag = etag_manager.get(device_setting_etag_key(device_id)) etag_manager.expire_device_setting_etag_by_account_id(account_id) -@when('try to fetch the device\'s settings using an expired etag') +@when("try to fetch the device's settings using an expired etag") def get_device_settings_using_etag(context): etag = context.device_etag - access_token = context.device_login['accessToken'] - device_id = context.device_login['uuid'] + access_token = context.device_login["accessToken"] + device_id = context.device_login["uuid"] headers = { - 'Authorization': 'Bearer {token}'.format(token=access_token), - 'If-None-Match': etag + "Authorization": "Bearer {token}".format(token=access_token), + "If-None-Match": etag, } context.get_setting_invalid_etag_response = context.client.get( - '/v1/device/{uuid}/setting'.format(uuid=str(device_id)), - headers=headers + "/v1/device/{uuid}/setting".format(uuid=str(device_id)), headers=headers ) -@then('200 status code should be returned by the device\'s setting endpoint and a new etag') +@then( + "200 status code should be returned by the device's setting endpoint and a new etag" +) def validate_new_etag(context): etag = context.device_etag response = context.get_setting_invalid_etag_response - etag_from_response = response.headers.get('ETag') + etag_from_response = response.headers.get("ETag") assert_that(etag, is_not(etag_from_response)) diff --git a/api/public/tests/features/steps/get_device_subscription.py b/api/public/tests/features/steps/get_device_subscription.py index 76c4fed2..db9c1cc1 100644 --- a/api/public/tests/features/steps/get_device_subscription.py +++ b/api/public/tests/features/steps/get_device_subscription.py @@ -29,66 +29,63 @@ from selene.data.account import AccountRepository, AccountMembership from selene.util.db import connect_to_db -@when('the subscription endpoint is called') +@when("the subscription endpoint is called") def get_device_subscription(context): login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + device_id = login["uuid"] + access_token = login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.subscription_response = context.client.get( - '/v1/device/{uuid}/subscription'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers ) -@then('free type should be returned') +@then("free type should be returned") def validate_response(context): response = context.subscription_response assert_that(response.status_code, HTTPStatus.OK) subscription = json.loads(response.data) - assert_that(subscription, has_entry('@type', 'free')) + assert_that(subscription, has_entry("@type", "free")) -@when('the subscription endpoint is called for a monthly account') +@when("the subscription endpoint is called for a monthly account") def get_device_subscription(context): membership = AccountMembership( start_date=date.today(), - type='Monthly Membership', - payment_method='Stripe', - payment_account_id='test_monthly', - payment_id='stripe_id' + type="Monthly Membership", + payment_method="Stripe", + payment_account_id="test_monthly", + payment_id="stripe_id", ) login = context.device_login - device_id = login['uuid'] - access_token = login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) - db = connect_to_db(context.client_config['DB_CONNECTION_CONFIG']) + device_id = login["uuid"] + access_token = login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) + db = connect_to_db(context.client_config["DB_CONNECTION_CONFIG"]) AccountRepository(db).add_membership(context.account.id, membership) context.subscription_response = context.client.get( - '/v1/device/{uuid}/subscription'.format(uuid=device_id), - headers=headers + "/v1/device/{uuid}/subscription".format(uuid=device_id), headers=headers ) -@then('monthly type should be returned') +@then("monthly type should be returned") def validate_response_monthly(context): response = context.subscription_response assert_that(response.status_code, HTTPStatus.OK) subscription = json.loads(response.data) - assert_that(subscription, has_entry('@type', 'Monthly Membership')) + assert_that(subscription, has_entry("@type", "Monthly Membership")) -@when('try to get the subscription for a nonexistent device') +@when("try to get the subscription for a nonexistent device") def get_subscription_nonexistent_device(context): - access_token = context.device_login['accessToken'] - headers = dict(Authorization='Bearer {token}'.format(token=access_token)) + access_token = context.device_login["accessToken"] + headers = dict(Authorization="Bearer {token}".format(token=access_token)) context.invalid_subscription_response = context.client.get( - '/v1/device/{uuid}/subscription'.format(uuid=str(uuid.uuid4())), - headers=headers + "/v1/device/{uuid}/subscription".format(uuid=str(uuid.uuid4())), headers=headers ) -@then('401 status code should be returned for the subscription endpoint') +@then("401 status code should be returned for the subscription endpoint") def validate_nonexistent_device(context): response = context.invalid_subscription_response assert_that(response.status_code, equal_to(HTTPStatus.UNAUTHORIZED)) diff --git a/api/sso/sso_api/__init__.py b/api/sso/sso_api/__init__.py index 7665c10b..eabab81b 100644 --- a/api/sso/sso_api/__init__.py +++ b/api/sso/sso_api/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/api/sso/sso_api/endpoints/authenticate_internal.py b/api/sso/sso_api/endpoints/authenticate_internal.py index 7f21f13f..657e8c81 100644 --- a/api/sso/sso_api/endpoints/authenticate_internal.py +++ b/api/sso/sso_api/endpoints/authenticate_internal.py @@ -34,6 +34,7 @@ from selene.util.auth import AuthenticationError class AuthenticateInternalEndpoint(SeleneEndpoint): """Sign in a user with an email address and password.""" + def __init__(self): super(AuthenticateInternalEndpoint, self).__init__() self.account: Account = None @@ -44,22 +45,21 @@ class AuthenticateInternalEndpoint(SeleneEndpoint): self._generate_tokens() self._set_token_cookies() - return '', HTTPStatus.NO_CONTENT + return "", HTTPStatus.NO_CONTENT def _authenticate_credentials(self): """Compare credentials in request to credentials in database. :raises AuthenticationError when no match found on database """ - basic_credentials = self.request.headers['authorization'] + basic_credentials = self.request.headers["authorization"] binary_credentials = a2b_base64(basic_credentials[6:]) - email_address, password = binary_credentials.decode().split(':||:') + email_address, password = binary_credentials.decode().split(":||:") acct_repository = AccountRepository(self.db) self.account = acct_repository.get_account_from_credentials( - email_address, - password + email_address, password ) if self.account is None: - raise AuthenticationError('provided credentials not found') + raise AuthenticationError("provided credentials not found") self.access_token.account_id = self.account.id self.refresh_token.account_id = self.account.id diff --git a/api/sso/sso_api/endpoints/github_token.py b/api/sso/sso_api/endpoints/github_token.py index f72e4685..8fd39662 100644 --- a/api/sso/sso_api/endpoints/github_token.py +++ b/api/sso/sso_api/endpoints/github_token.py @@ -26,8 +26,7 @@ from selene.util.auth import get_github_authentication_token class GithubTokenEndpoint(SeleneEndpoint): def get(self): token = get_github_authentication_token( - self.request.args['code'], - self.request.args['state'] + self.request.args["code"], self.request.args["state"] ) return dict(token=token), HTTPStatus.OK diff --git a/api/sso/sso_api/endpoints/password_change.py b/api/sso/sso_api/endpoints/password_change.py index 78710a90..78d1d6ad 100644 --- a/api/sso/sso_api/endpoints/password_change.py +++ b/api/sso/sso_api/endpoints/password_change.py @@ -26,11 +26,11 @@ from selene.data.account import AccountRepository class PasswordChangeEndpoint(SeleneEndpoint): def put(self): - account_id = self.request.json['accountId'] - coded_password = self.request.json['password'] + account_id = self.request.json["accountId"] + coded_password = self.request.json["password"] binary_password = a2b_base64(coded_password) password = binary_password.decode() acct_repository = AccountRepository(self.db) acct_repository.change_password(account_id, password) - return '', HTTPStatus.NO_CONTENT + return "", HTTPStatus.NO_CONTENT diff --git a/api/sso/sso_api/endpoints/password_reset.py b/api/sso/sso_api/endpoints/password_reset.py index 984e8694..45bb6e8b 100644 --- a/api/sso/sso_api/endpoints/password_reset.py +++ b/api/sso/sso_api/endpoints/password_reset.py @@ -37,44 +37,40 @@ class PasswordResetEndpoint(SeleneEndpoint): reset_token = self._generate_reset_token() self._send_reset_email(reset_token) - return '', HTTPStatus.OK + return "", HTTPStatus.OK def _get_account_from_email(self): acct_repository = AccountRepository(self.db) self.account = acct_repository.get_account_by_email( - self.request.json['emailAddress'] + self.request.json["emailAddress"] ) def _generate_reset_token(self): - reset_token = AuthenticationToken( - self.config['RESET_SECRET'], - ONE_HOUR - ) + reset_token = AuthenticationToken(self.config["RESET_SECRET"], ONE_HOUR) reset_token.generate(self.account.id) return reset_token.jwt def _send_reset_email(self, reset_token): - url = '{base_url}/change-password?token={reset_token}'.format( - base_url=os.environ['SSO_BASE_URL'], - reset_token=reset_token + url = "{base_url}/change-password?token={reset_token}".format( + base_url=os.environ["SSO_BASE_URL"], reset_token=reset_token ) email = EmailMessage( - recipient=self.request.json['emailAddress'], - sender='Mycroft AI', - subject='Password Reset Request', - template_file_name='reset_password.html', - template_variables=dict(reset_password_url=url) + recipient=self.request.json["emailAddress"], + sender="Mycroft AI", + subject="Password Reset Request", + template_file_name="reset_password.html", + template_variables=dict(reset_password_url=url), ) mailer = SeleneMailer(email) mailer.send() def _send_account_not_found_email(self): email = EmailMessage( - recipient=self.request.json['emailAddress'], - sender='Mycroft AI', - subject='Password Reset Request', - template_file_name='account_not_found.html' + recipient=self.request.json["emailAddress"], + sender="Mycroft AI", + subject="Password Reset Request", + template_file_name="account_not_found.html", ) mailer = SeleneMailer(email) mailer.send() diff --git a/api/sso/sso_api/endpoints/validate_email.py b/api/sso/sso_api/endpoints/validate_email.py index a2271c0b..a723d444 100644 --- a/api/sso/sso_api/endpoints/validate_email.py +++ b/api/sso/sso_api/endpoints/validate_email.py @@ -25,16 +25,16 @@ from selene.data.account import AccountRepository from selene.util.auth import ( get_facebook_account_email, get_github_account_email, - get_google_account_email + get_google_account_email, ) class ValidateEmailEndpoint(SeleneEndpoint): def get(self): return_data = dict(accountExists=False, noFederatedEmail=False) - if self.request.args['token']: + if self.request.args["token"]: email_address = self._get_email_address() - if self.request.args['platform'] != 'Internal' and not email_address: + if self.request.args["platform"] != "Internal" and not email_address: return_data.update(noFederatedEmail=True) account_repository = AccountRepository(self.db) account = account_repository.get_account_by_email(email_address) @@ -44,20 +44,14 @@ class ValidateEmailEndpoint(SeleneEndpoint): return return_data, HTTPStatus.OK def _get_email_address(self): - if self.request.args['platform'] == 'Google': - email_address = get_google_account_email( - self.request.args['token'] - ) - elif self.request.args['platform'] == 'Facebook': - email_address = get_facebook_account_email( - self.request.args['token'] - ) - elif self.request.args['platform'] == 'GitHub': - email_address = get_github_account_email( - self.request.args['token'] - ) + if self.request.args["platform"] == "Google": + email_address = get_google_account_email(self.request.args["token"]) + elif self.request.args["platform"] == "Facebook": + email_address = get_facebook_account_email(self.request.args["token"]) + elif self.request.args["platform"] == "GitHub": + email_address = get_github_account_email(self.request.args["token"]) else: - coded_email = self.request.args['token'] + coded_email = self.request.args["token"] email_address = a2b_base64(coded_email).decode() return email_address diff --git a/api/sso/sso_api/endpoints/validate_token.py b/api/sso/sso_api/endpoints/validate_token.py index 4330d0ee..0c588a51 100644 --- a/api/sso/sso_api/endpoints/validate_token.py +++ b/api/sso/sso_api/endpoints/validate_token.py @@ -28,15 +28,12 @@ class ValidateTokenEndpoint(SeleneEndpoint): return response_data, HTTPStatus.OK def _validate_token(self): - auth_token = AuthenticationToken( - self.config['RESET_SECRET'], - duration=0 - ) - auth_token.jwt = self.request.json['token'] + auth_token = AuthenticationToken(self.config["RESET_SECRET"], duration=0) + auth_token.jwt = self.request.json["token"] auth_token.validate() return dict( account_id=auth_token.account_id, token_expired=auth_token.is_expired, - token_invalid=not auth_token.is_valid + token_invalid=not auth_token.is_valid, ) diff --git a/batch/job_scheduler/__init__.py b/batch/job_scheduler/__init__.py index 7665c10b..eabab81b 100644 --- a/batch/job_scheduler/__init__.py +++ b/batch/job_scheduler/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/batch/job_scheduler/jobs.py b/batch/job_scheduler/jobs.py index 4a8003bb..7c25154f 100644 --- a/batch/job_scheduler/jobs.py +++ b/batch/job_scheduler/jobs.py @@ -30,13 +30,14 @@ from datetime import date, timedelta import schedule -from selene.util.log import configure_logger +from selene.util.log import configure_selene_logger -_log = configure_logger('selene_job_scheduler') +_log = configure_selene_logger("job_scheduler") class JobRunner(object): """Build the command to run a batch job and run it via subprocess.""" + def __init__(self, script_name: str): self.script_name = script_name self.job_args: str = None @@ -55,17 +56,14 @@ class JobRunner(object): date argument only needs to be specified when it is not current date. """ if self.job_args is None: - self.job_args = '' - date_arg = ' --date ' + str(self.job_date) + self.job_args = "" + date_arg = " --date " + str(self.job_date) self.job_args += date_arg def _build_command(self): """Build the command to run the script.""" - command = ['pipenv', 'run', 'python'] - script_path = os.path.join( - os.environ['SELENE_SCRIPT_DIR'], - self.script_name - ) + command = ["pipenv", "run", "python"] + script_path = os.path.join(os.environ["SELENE_SCRIPT_DIR"], self.script_name) command.append(script_path) if self.job_args is not None: command.extend(self.job_args.split()) @@ -78,31 +76,31 @@ class JobRunner(object): result = subprocess.run(command, capture_output=True) if result.returncode: _log.error( - 'Job {job_name} failed\n' - '\tSTDOUT - {stdout}' - '\tSTDERR - {stderr}'.format( + "Job {job_name} failed\n" + "\tSTDOUT - {stdout}" + "\tSTDERR - {stderr}".format( job_name=self.script_name[:-3], stdout=result.stdout.decode(), - stderr=result.stderr.decode() + stderr=result.stderr.decode(), ) ) else: - log_msg = 'Job {job_name} completed successfully' + log_msg = "Job {job_name} completed successfully" _log.info(log_msg.format(job_name=self.script_name[:-3])) def test_scheduler(): """Run in non-production environments to test scheduler functionality.""" - job_runner = JobRunner('test_scheduler.py') + job_runner = JobRunner("test_scheduler.py") job_runner.job_date = date.today() - timedelta(days=1) - job_runner.job_args = '--arg-with-value test --arg-no-value' + job_runner.job_args = "--arg-with-value test --arg-no-value" job_runner.run_job() def load_skills(version): """Load the json file from the mycroft-skills-data repository to the DB""" - job_runner = JobRunner('load_skill_display_data.py') - job_runner.job_args = '--core-version {}'.format(version) + job_runner = JobRunner("load_skill_display_data.py") + job_runner.job_args = "--core-version {}".format(version) job_runner.job_date = date.today() - timedelta(days=1) job_runner.run_job() @@ -112,7 +110,7 @@ def parse_core_metrics(): Build a de-normalized table that will make latency research easier. """ - job_runner = JobRunner('parse_core_metrics.py') + job_runner = JobRunner("parse_core_metrics.py") job_runner.job_date = date.today() - timedelta(days=1) job_runner.run_job() @@ -123,7 +121,7 @@ def partition_api_metrics(): Build a partition on the metric.api_history table for yesterday's date. Copy yesterday's metric.api table rows to the partition. """ - job_runner = JobRunner('partition_api_metrics.py') + job_runner = JobRunner("partition_api_metrics.py") job_runner.job_date = date.today() - timedelta(days=1) job_runner.run_job() @@ -135,22 +133,22 @@ def update_device_last_contact(): to associate the time of the call with the device. Dump the contents of the Redis data to the device.device table on the Postgres database. """ - job_runner = JobRunner('update_device_last_contact.py') + job_runner = JobRunner("update_device_last_contact.py") job_runner.run_job() # Define the schedule -if os.environ['SELENE_ENVIRONMENT'] != 'prod': +if os.environ["SELENE_ENVIRONMENT"] != "prod": schedule.every(5).minutes.do(test_scheduler) -schedule.every().day.at('00:00').do(partition_api_metrics) -schedule.every().day.at('00:05').do(update_device_last_contact) -schedule.every().day.at('00:10').do(parse_core_metrics) -schedule.every().day.at('00:15').do(load_skills, version='19.02') -schedule.every().day.at('00:20').do(load_skills, version='19.08') -schedule.every().day.at('00:25').do(load_skills, version='20.02') -schedule.every().day.at('00:25').do(load_skills, version='20.08') -schedule.every().day.at('00:25').do(load_skills, version='21.02') +schedule.every().day.at("00:00").do(partition_api_metrics) +schedule.every().day.at("00:05").do(update_device_last_contact) +schedule.every().day.at("00:10").do(parse_core_metrics) +schedule.every().day.at("00:15").do(load_skills, version="19.02") +schedule.every().day.at("00:20").do(load_skills, version="19.08") +schedule.every().day.at("00:25").do(load_skills, version="20.02") +schedule.every().day.at("00:25").do(load_skills, version="20.08") +schedule.every().day.at("00:25").do(load_skills, version="21.02") # Run the schedule while True: diff --git a/batch/script/__init__.py b/batch/script/__init__.py index 7665c10b..eabab81b 100644 --- a/batch/script/__init__.py +++ b/batch/script/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/batch/script/daily_report.py b/batch/script/daily_report.py index f0f386c6..4f16c611 100644 --- a/batch/script/daily_report.py +++ b/batch/script/daily_report.py @@ -30,12 +30,12 @@ from selene.util.db import DatabaseConnectionConfig from selene.util.email import EmailMessage, SeleneMailer mycroft_db = DatabaseConnectionConfig( - host=environ['DB_HOST'], - db_name=environ['DB_NAME'], - user=environ['DB_USER'], - password=environ['DB_PASSWORD'], - port=environ['DB_PORT'], - sslmode=environ['DB_SSL_MODE'] + host=environ["DB_HOST"], + db_name=environ["DB_NAME"], + user=environ["DB_USER"], + password=environ["DB_PASSWORD"], + port=int(environ["DB_PORT"]), + sslmode=environ["DB_SSL_MODE"], ) @@ -43,16 +43,16 @@ class DailyReport(SeleneScript): def __init__(self): super(DailyReport, self).__init__(__file__) self._arg_parser.add_argument( - '--run-mode', - help='If the script should run as a job or just once', - choices=['job', 'once'], + "--run-mode", + help="If the script should run as a job or just once", + choices=["job", "once"], type=str, - default='job' + default="job", ) def _run(self): - if self.args.run_mode == 'job': - schedule.every().day.at('00:00').do(self._build_report) + if self.args.run_mode == "job": + schedule.every().day.at("00:00").do(self._build_report) while True: schedule.run_pending() time.sleep(1) @@ -65,11 +65,11 @@ class DailyReport(SeleneScript): user_metrics = AccountRepository(self.db).daily_report(date) email = EmailMessage( - sender='reports@mycroft.ai', - recipient=os.environ['REPORT_RECIPIENT'], - subject='Mycroft Daily Report - {}'.format(date.strftime('%Y-%m-%d')), - template_file_name='metrics.html', - template_variables=dict(user_metrics=user_metrics) + sender="reports@mycroft.ai", + recipient=os.environ["REPORT_RECIPIENT"], + subject="Mycroft Daily Report - {}".format(date.strftime("%Y-%m-%d")), + template_file_name="metrics.html", + template_variables=dict(user_metrics=user_metrics), ) mailer = SeleneMailer(email) diff --git a/batch/script/load_skill_display_data.py b/batch/script/load_skill_display_data.py index 6564c348..a1a364a8 100644 --- a/batch/script/load_skill_display_data.py +++ b/batch/script/load_skill_display_data.py @@ -30,17 +30,13 @@ import json from os import environ from selene.batch import SeleneScript -from selene.data.skill import ( - SkillDisplay, - SkillDisplayRepository, - SkillRepository -) +from selene.data.skill import SkillDisplay, SkillDisplayRepository, SkillRepository from selene.util.github import download_repository_file, log_into_github -GITHUB_USER = environ['GITHUB_USER'] -GITHUB_PASSWORD = environ['GITHUB_PASSWORD'] -SKILL_DATA_GITHUB_REPO = 'mycroft-skills-data' -SKILL_DATA_FILE_NAME = 'skill-metadata.json' +GITHUB_USER = environ["GITHUB_USER"] +GITHUB_PASSWORD = environ["GITHUB_PASSWORD"] +SKILL_DATA_GITHUB_REPO = "mycroft-skills-data" +SKILL_DATA_FILE_NAME = "skill-metadata.json" class SkillDisplayUpdater(SeleneScript): @@ -52,16 +48,15 @@ class SkillDisplayUpdater(SeleneScript): super(SkillDisplayUpdater, self)._define_args() self._arg_parser.add_argument( "--core-version", - help='Version of Mycroft Core related to skill display data', + help="Version of Mycroft Core related to skill display data", required=True, - type=str + type=str, ) def _run(self): """Make it so.""" self.log.info( - "Updating skill display data for core version " + - self.args.core_version + "Updating skill display data for core version " + self.args.core_version ) self._get_skill_display_data() self._update_skill_display_table() @@ -73,7 +68,7 @@ class SkillDisplayUpdater(SeleneScript): github_api, SKILL_DATA_GITHUB_REPO, self.args.core_version, - SKILL_DATA_FILE_NAME + SKILL_DATA_FILE_NAME, ) self.skill_display_data = json.loads(file_contents) @@ -83,20 +78,18 @@ class SkillDisplayUpdater(SeleneScript): display_repository = SkillDisplayRepository(self.db) for skill_name, skill_metadata in self.skill_display_data.items(): skill_count += 1 - skill_id = skill_repository.ensure_skill_exists( - skill_metadata['skill_gid'] - ) + skill_id = skill_repository.ensure_skill_exists(skill_metadata["skill_gid"]) # add the skill display row display_data = SkillDisplay( skill_id=skill_id, core_version=self.args.core_version, - display_data=json.dumps(skill_metadata) + display_data=json.dumps(skill_metadata), ) display_repository.upsert(display_data) self.log.info("updated {} skills".format(skill_count)) -if __name__ == '__main__': +if __name__ == "__main__": SkillDisplayUpdater().run() diff --git a/batch/script/parse_core_metrics.py b/batch/script/parse_core_metrics.py index 1d48200c..2669aa34 100644 --- a/batch/script/parse_core_metrics.py +++ b/batch/script/parse_core_metrics.py @@ -32,7 +32,7 @@ from decimal import Decimal from selene.batch import SeleneScript from selene.data.metric import CoreMetricRepository, CoreInteraction -SKILL_HANDLERS_TO_SKIP = ('reset', 'notify', 'prime', 'stop_laugh') +SKILL_HANDLERS_TO_SKIP = ("reset", "notify", "prime", "stop_laugh") class CoreMetricsParser(SeleneScript): @@ -47,7 +47,7 @@ class CoreMetricsParser(SeleneScript): def _run(self): last_interaction_id = None for metric in self.core_metric_repo.get_metrics_by_date(self.args.date): - if metric.metric_value['id'] != last_interaction_id: + if metric.metric_value["id"] != last_interaction_id: self._add_interaction_to_db() self._start_new_interaction(metric) last_interaction_id = self.interaction.core_id @@ -57,48 +57,44 @@ class CoreMetricsParser(SeleneScript): def _start_new_interaction(self, metric): """Initialize the interaction object""" self.interaction = CoreInteraction( - core_id=metric.metric_value['id'], + core_id=metric.metric_value["id"], device_id=metric.device_id, - start_ts=datetime.utcfromtimestamp( - metric.metric_value['start_time'] - ), + start_ts=datetime.utcfromtimestamp(metric.metric_value["start_time"]), ) self.stt_start_ts = None self.playback_start_ts = None def _add_metric_to_interaction(self, metric_value): """Combine all the steps of an interaction into a single record""" - duration = Decimal(str(metric_value['time'])) - duration = duration.quantize(Decimal('0.000001')) - if metric_value['system'] == 'stt': - self.interaction.stt_engine = metric_value['stt'] - self.interaction.stt_transcription = metric_value['transcription'] + duration = Decimal(str(metric_value["time"])) + duration = duration.quantize(Decimal("0.000001")) + if metric_value["system"] == "stt": + self.interaction.stt_engine = metric_value["stt"] + self.interaction.stt_transcription = metric_value["transcription"] self.interaction.stt_duration = duration - self.stt_start_ts = metric_value['start_time'] - elif metric_value['system'] == 'intent_service': - self.interaction.intent_type = metric_value['intent_type'] + self.stt_start_ts = metric_value["start_time"] + elif metric_value["system"] == "intent_service": + self.interaction.intent_type = metric_value["intent_type"] self.interaction.intent_duration = duration - elif metric_value['system'] == 'fallback_handler': + elif metric_value["system"] == "fallback_handler": self.interaction.fallback_handler_duration = duration - elif metric_value['system'] == 'skill_handler': - if metric_value['handler'] not in SKILL_HANDLERS_TO_SKIP: - self.interaction.skill_handler = metric_value['handler'] + elif metric_value["system"] == "skill_handler": + if metric_value["handler"] not in SKILL_HANDLERS_TO_SKIP: + self.interaction.skill_handler = metric_value["handler"] self.interaction.skill_duration = duration - elif metric_value['system'] == 'speech': - self.interaction.tts_engine = metric_value['tts'] - self.interaction.tts_utterance = metric_value['utterance'] + elif metric_value["system"] == "speech": + self.interaction.tts_engine = metric_value["tts"] + self.interaction.tts_utterance = metric_value["utterance"] self.interaction.tts_duration = duration - elif metric_value['system'] == 'speech_playback': + elif metric_value["system"] == "speech_playback": self.interaction.speech_playback_duration = duration - self.playback_start_ts = metric_value['start_time'] + self.playback_start_ts = metric_value["start_time"] # The user-experienced latency is the time between when the user # finishes speaking their intent and when the device provides a voice # response. if self.stt_start_ts is not None and self.playback_start_ts is not None: - self.interaction.user_latency = ( - self.playback_start_ts - self.stt_start_ts - ) + self.interaction.user_latency = self.playback_start_ts - self.stt_start_ts def _add_interaction_to_db(self): if self.interaction is not None: @@ -107,5 +103,5 @@ class CoreMetricsParser(SeleneScript): self.core_metric_repo.add_interaction(self.interaction) -if __name__ == '__main__': +if __name__ == "__main__": CoreMetricsParser().run() diff --git a/batch/script/partition_api_metrics.py b/batch/script/partition_api_metrics.py index b375d97e..28558d36 100644 --- a/batch/script/partition_api_metrics.py +++ b/batch/script/partition_api_metrics.py @@ -40,5 +40,5 @@ class PartitionApiMetrics(SeleneScript): api_metrics_repo.remove_by_date(self.args.date) -if __name__ == '__main__': +if __name__ == "__main__": PartitionApiMetrics().run() diff --git a/batch/script/test_scheduler.py b/batch/script/test_scheduler.py index 481c89ca..1b1cda05 100644 --- a/batch/script/test_scheduler.py +++ b/batch/script/test_scheduler.py @@ -42,24 +42,24 @@ class TestScheduler(SeleneScript): super(TestScheduler, self)._define_args() self._arg_parser.add_argument( "--arg-with-value", - help='Argument to test passing a value with an argument', + help="Argument to test passing a value with an argument", required=True, - type=str + type=str, ) self._arg_parser.add_argument( "--arg-no-value", - help='Argument to test passing a value with an argument', - action="store_true" + help="Argument to test passing a value with an argument", + action="store_true", ) def _run(self): - self.log.info('Running the scheduler test job') + self.log.info("Running the scheduler test job") assert self.args.arg_no_value - assert self.args.arg_with_value == 'test' + assert self.args.arg_with_value == "test" # Tests the logic that overrides the default date in the scheduler. assert self.args.date == date.today() - timedelta(days=1) -if __name__ == '__main__': +if __name__ == "__main__": TestScheduler().run() diff --git a/batch/script/update_device_last_contact.py b/batch/script/update_device_last_contact.py index f061a2d4..c88d45ce 100644 --- a/batch/script/update_device_last_contact.py +++ b/batch/script/update_device_last_contact.py @@ -47,21 +47,18 @@ class UpdateDeviceLastContact(SeleneScript): devices_updated += 1 device_repo.update_last_contact_ts(device.id, last_contact_ts) - self.log.info(str(devices_updated) + ' devices were active today') + self.log.info(str(devices_updated) + " devices were active today") def _get_ts_from_cache(self, device_id): last_contact_ts = None cache_key = DEVICE_LAST_CONTACT_KEY.format(device_id=device_id) value = self.cache.get(cache_key) if value is not None: - last_contact_ts = datetime.strptime( - value.decode(), - '%Y-%m-%d %H:%M:%S.%f' - ) + last_contact_ts = datetime.strptime(value.decode(), "%Y-%m-%d %H:%M:%S.%f") self.cache.delete(cache_key) return last_contact_ts -if __name__ == '__main__': +if __name__ == "__main__": UpdateDeviceLastContact().run() diff --git a/db/scripts/bootstrap_mycroft_db.py b/db/scripts/bootstrap_mycroft_db.py index e641d811..f840f113 100644 --- a/db/scripts/bootstrap_mycroft_db.py +++ b/db/scripts/bootstrap_mycroft_db.py @@ -108,7 +108,7 @@ class PostgresDB(object): sslmode=db_ssl_mode, ) self.db.autocommit = True - self.db.set_client_encoding('UTF8') + self.db.set_client_encoding("UTF8") def close_db(self): self.db.close() diff --git a/db/scripts/neo4j-postgres.py b/db/scripts/neo4j-postgres.py index 09efbe41..22391f95 100644 --- a/db/scripts/neo4j-postgres.py +++ b/db/scripts/neo4j-postgres.py @@ -54,216 +54,215 @@ ezra = str(uuid.uuid4()) jarvis = str(uuid.uuid4()) default_wake_words = { - 'hey mycroft': hey_mycroft, - 'christopher': christopher, - 'hey ezra': ezra, - 'hey jarvis': jarvis + "hey mycroft": hey_mycroft, + "christopher": christopher, + "hey ezra": ezra, + "hey jarvis": jarvis, } def load_csv(): - with open('users.csv') as user_csv: + with open("users.csv") as user_csv: user_reader = csv.reader(user_csv) next(user_reader, None) for row in user_reader: # email, password users[row[0]] = {} - users[row[0]]['email'] = row[1] - users[row[0]]['password'] = row[2] - users[row[0]]['terms'] = row[3] - users[row[0]]['privacy'] = row[4] + users[row[0]]["email"] = row[1] + users[row[0]]["password"] = row[2] + users[row[0]]["terms"] = row[3] + users[row[0]]["privacy"] = row[4] - with open('user_settings.csv') as user_setting_csv: + with open("user_settings.csv") as user_setting_csv: user_setting_reader = csv.reader(user_setting_csv) next(user_setting_reader, None) for row in user_setting_reader: user_settings[row[0]] = {} - user_settings[row[0]]['date_format'] = row[1] - user_settings[row[0]]['time_format'] = row[2] - user_settings[row[0]]['measurement_system'] = row[3] - user_settings[row[0]]['tts_type'] = row[4] - user_settings[row[0]]['tts_voice'] = row[5] - user_settings[row[0]]['wake_word'] = row[6] - user_settings[row[0]]['sample_rate'] = row[7] - user_settings[row[0]]['channels'] = row[8] - user_settings[row[0]]['pronunciation'] = row[9] - user_settings[row[0]]['threshold'] = row[10] - user_settings[row[0]]['threshold_multiplier'] = row[11] - user_settings[row[0]]['dynamic_energy_ratio'] = row[12] + user_settings[row[0]]["date_format"] = row[1] + user_settings[row[0]]["time_format"] = row[2] + user_settings[row[0]]["measurement_system"] = row[3] + user_settings[row[0]]["tts_type"] = row[4] + user_settings[row[0]]["tts_voice"] = row[5] + user_settings[row[0]]["wake_word"] = row[6] + user_settings[row[0]]["sample_rate"] = row[7] + user_settings[row[0]]["channels"] = row[8] + user_settings[row[0]]["pronunciation"] = row[9] + user_settings[row[0]]["threshold"] = row[10] + user_settings[row[0]]["threshold_multiplier"] = row[11] + user_settings[row[0]]["dynamic_energy_ratio"] = row[12] - with open('subscription.csv') as subscription_csv: + with open("subscription.csv") as subscription_csv: subscription_reader = csv.reader(subscription_csv) next(subscription_reader, None) for row in subscription_reader: subscription[row[0]] = {} - subscription[row[0]]['stripe_customer_id'] = row[1] - subscription[row[0]]['last_payment_ts'] = row[2] - subscription[row[0]]['type'] = row[3] + subscription[row[0]]["stripe_customer_id"] = row[1] + subscription[row[0]]["last_payment_ts"] = row[2] + subscription[row[0]]["type"] = row[3] - with open('devices.csv') as devices_csv: + with open("devices.csv") as devices_csv: devices_reader = csv.reader(devices_csv) next(devices_reader, None) for row in devices_reader: devices[row[0]] = {} user_uuid = row[1] - devices[row[0]]['user_uuid'] = row[1] - devices[row[0]]['name'] = row[2] - devices[row[0]]['description'] = row[3], - devices[row[0]]['platform'] = row[4], - devices[row[0]]['enclosure_version'] = row[5] - devices[row[0]]['core_version'] = row[6] + devices[row[0]]["user_uuid"] = row[1] + devices[row[0]]["name"] = row[2] + devices[row[0]]["description"] = (row[3],) + devices[row[0]]["platform"] = (row[4],) + devices[row[0]]["enclosure_version"] = row[5] + devices[row[0]]["core_version"] = row[6] if user_uuid in user_devices: user_devices[user_uuid].append((row[0], row[2])) else: user_devices[user_uuid] = [(row[0], row[2])] - with open('skill.csv') as skill_csv: + with open("skill.csv") as skill_csv: skill_reader = csv.reader(skill_csv) next(skill_reader, None) for row in skill_reader: skill = row[0] skills[skill] = {} dev_uuid = row[1] - skills[skill]['device_uuid'] = dev_uuid - skills[skill]['name'] = row[2] - skills[skill]['description'] = row[3] - skills[skill]['identifier'] = row[4] + skills[skill]["device_uuid"] = dev_uuid + skills[skill]["name"] = row[2] + skills[skill]["description"] = row[3] + skills[skill]["identifier"] = row[4] if dev_uuid in device_to_skill: device_to_skill[dev_uuid].add(skill) else: device_to_skill[dev_uuid] = {skill} - with open('skill_section.csv') as skill_section_csv: + with open("skill_section.csv") as skill_section_csv: skill_section_reader = csv.reader(skill_section_csv) next(skill_section_reader, None) for row in skill_section_reader: section_uuid = row[0] skill_sections[section_uuid] = {} skill_uuid = row[1] - skill_sections[section_uuid]['skill_uuid'] = skill_uuid - skill_sections[section_uuid]['section'] = row[2] - skill_sections[section_uuid]['display_order'] = row[3] + skill_sections[section_uuid]["skill_uuid"] = skill_uuid + skill_sections[section_uuid]["section"] = row[2] + skill_sections[section_uuid]["display_order"] = row[3] if skill_uuid in skill_to_section: skill_to_section[skill_uuid].add(section_uuid) else: skill_to_section[skill_uuid] = {section_uuid} - with open('skill_fields.csv') as skill_fields_csv: + with open("skill_fields.csv") as skill_fields_csv: skill_fields_reader = csv.reader(skill_fields_csv) next(skill_fields_reader, None) for row in skill_fields_reader: field_uuid = row[0] skill_fields[field_uuid] = {} section_uuid = row[1] - #skill_fields[field_uuid]['section_uuid'] = section_uuid - skill_fields[field_uuid]['name'] = row[2] - skill_fields[field_uuid]['type'] = row[3] - skill_fields[field_uuid]['label'] = row[4] - skill_fields[field_uuid]['hint'] = row[5] - skill_fields[field_uuid]['placeholder'] = row[6] - skill_fields[field_uuid]['hide'] = row[7] - skill_fields[field_uuid]['options'] = row[8] - #skill_fields[field_uuid]['order'] = row[9] + # skill_fields[field_uuid]['section_uuid'] = section_uuid + skill_fields[field_uuid]["name"] = row[2] + skill_fields[field_uuid]["type"] = row[3] + skill_fields[field_uuid]["label"] = row[4] + skill_fields[field_uuid]["hint"] = row[5] + skill_fields[field_uuid]["placeholder"] = row[6] + skill_fields[field_uuid]["hide"] = row[7] + skill_fields[field_uuid]["options"] = row[8] + # skill_fields[field_uuid]['order'] = row[9] if section_uuid in section_to_field: section_to_field[section_uuid].add(field_uuid) else: section_to_field[section_uuid] = {field_uuid} - with open('skill_fields_values.csv') as skill_field_values_csv: + with open("skill_fields_values.csv") as skill_field_values_csv: skill_field_values_reader = csv.reader(skill_field_values_csv) next(skill_field_values_reader, None) for row in skill_field_values_reader: field_uuid = row[0] skill_field_values[field_uuid] = {} - skill_field_values[field_uuid]['skill_uuid'] = row[1] + skill_field_values[field_uuid]["skill_uuid"] = row[1] device_uuid = row[2] - skill_field_values[field_uuid]['device_uuid'] = device_uuid - skill_field_values[field_uuid]['field_value'] = row[3] + skill_field_values[field_uuid]["device_uuid"] = device_uuid + skill_field_values[field_uuid]["field_value"] = row[3] if device_uuid in device_to_field: device_to_field[device_uuid].add(field_uuid) else: device_to_field[device_uuid] = {field_uuid} - with open('location.csv') as location_csv: + with open("location.csv") as location_csv: location_reader = csv.reader(location_csv) next(location_reader, None) for row in location_reader: location_uuid = row[0] locations[location_uuid] = {} - locations[location_uuid]['timezone'] = row[1] - locations[location_uuid]['city'] = row[2] - locations[location_uuid]['coordinate'] = row[3] + locations[location_uuid]["timezone"] = row[1] + locations[location_uuid]["city"] = row[2] + locations[location_uuid]["coordinate"] = row[3] - with open('timezone.csv') as timezone_csv: + with open("timezone.csv") as timezone_csv: timezone_reader = csv.reader(timezone_csv) next(timezone_reader, None) for row in timezone_reader: timezone_uuid = row[0] timezones[timezone_uuid] = {} - timezones[timezone_uuid]['code'] = row[1] - timezones[timezone_uuid]['name'] = row[2] + timezones[timezone_uuid]["code"] = row[1] + timezones[timezone_uuid]["name"] = row[2] - with open('city.csv') as city_csv: + with open("city.csv") as city_csv: city_reader = csv.reader(city_csv) next(city_reader, None) for row in city_reader: city_uuid = row[0] cities[city_uuid] = {} - cities[city_uuid]['region'] = row[1] - cities[city_uuid]['name'] = row[2] + cities[city_uuid]["region"] = row[1] + cities[city_uuid]["name"] = row[2] - with open('region.csv') as region_csv: + with open("region.csv") as region_csv: region_reader = csv.reader(region_csv) next(region_reader, None) for row in region_reader: region_uuid = row[0] regions[region_uuid] = {} - regions[region_uuid]['country'] = row[1] - regions[region_uuid]['name'] = row[2] - regions[region_uuid]['code'] = row[3] + regions[region_uuid]["country"] = row[1] + regions[region_uuid]["name"] = row[2] + regions[region_uuid]["code"] = row[3] - with open('country.csv') as country_csv: + with open("country.csv") as country_csv: country_reader = csv.reader(country_csv) next(country_reader, None) for row in country_reader: country_uuid = row[0] countries[country_uuid] = {} - countries[country_uuid]['name'] = row[1] - countries[country_uuid]['code'] = row[2] + countries[country_uuid]["name"] = row[1] + countries[country_uuid]["code"] = row[2] - with open('coordinate.csv') as coordinate_csv: + with open("coordinate.csv") as coordinate_csv: coordinate_reader = csv.reader(coordinate_csv) next(coordinate_reader, None) for row in coordinate_reader: coordinate_uuid = row[0] coordinates[coordinate_uuid] = {} - coordinates[coordinate_uuid]['latitude'] = row[1] - coordinates[coordinate_uuid]['longitude'] = row[2] + coordinates[coordinate_uuid]["latitude"] = row[1] + coordinates[coordinate_uuid]["longitude"] = row[2] - with open('device_location.csv') as device_location_csv: + with open("device_location.csv") as device_location_csv: device_location_reader = csv.reader(device_location_csv, None) next(device_location_reader, None) for row in device_location_reader: device_uuid = row[0] if device_uuid in devices: - devices[device_uuid]['location'] = row[1] + devices[device_uuid]["location"] = row[1] def format_date(value): value = int(value) - value = datetime.datetime.fromtimestamp(value//1000) - return f'{value:%Y-%m-%d}' + value = datetime.datetime.fromtimestamp(value // 1000) + return f"{value:%Y-%m-%d}" def format_timestamp(value): value = int(value) - value = datetime.datetime.fromtimestamp(value//1000) - return f'{value:%Y-%m-%d %H:%M:%S}' + value = datetime.datetime.fromtimestamp(value // 1000) + return f"{value:%Y-%m-%d %H:%M:%S}" -db = connect(dbname='mycroft', user='postgres', host='127.0.0.1') - +db = connect(dbname="mycroft", user="postgres", host="127.0.0.1") db.autocommit = True @@ -276,7 +275,9 @@ def get_subscription_uuid(subs): return subscription_uuids[subs] else: cursor = db.cursor() - cursor.execute(f'select id from account.membership s where s.rate_period = \'{subs}\'') + cursor.execute( + f"select id from account.membership s where s.rate_period = '{subs}'" + ) result = cursor.fetchone() subscription_uuids[subs] = result return result @@ -290,60 +291,81 @@ def get_tts_uuid(tts): return tts_uuids[tts] else: cursor = db.cursor() - cursor.execute(f'select id from device.text_to_speech s where s.setting_name = \'{tts}\'') + cursor.execute( + f"select id from device.text_to_speech s where s.setting_name = '{tts}'" + ) result = cursor.fetchone() tts_uuids[tts] = result return result def fill_account_table(): - query = 'insert into account.account(' \ - 'id, ' \ - 'email_address, ' \ - 'password) ' \ - 'values (%s, %s, %s)' + query = ( + "insert into account.account(" + "id, " + "email_address, " + "password) " + "values (%s, %s, %s)" + ) with db.cursor() as cur: - accounts = ((uuid, account['email'], account['password']) for uuid, account in users.items()) + accounts = ( + (uuid, account["email"], account["password"]) + for uuid, account in users.items() + ) execute_batch(cur, query, accounts, page_size=1000) def fill_account_agreement_table(): - query = 'insert into account.account_agreement(account_id, agreement_id, accept_date)' \ - 'values (%s, (select id from account.agreement where agreement = %s), %s)' + query = ( + "insert into account.account_agreement(account_id, agreement_id, accept_date)" + "values (%s, (select id from account.agreement where agreement = %s), %s)" + ) with db.cursor() as cur: - terms = ((uuid, 'Terms of Use', format_timestamp(account['terms'])) for uuid, account in users.items() if account['terms'] != '') - privacy = ((uuid, 'Privacy Policy', format_timestamp(account['privacy'])) for uuid, account in users.items() if account['privacy'] != '') + terms = ( + (uuid, "Terms of Use", format_timestamp(account["terms"])) + for uuid, account in users.items() + if account["terms"] != "" + ) + privacy = ( + (uuid, "Privacy Policy", format_timestamp(account["privacy"])) + for uuid, account in users.items() + if account["privacy"] != "" + ) execute_batch(cur, query, terms, page_size=1000) execute_batch(cur, query, privacy, page_size=1000) def fill_default_wake_word(): - query1 = 'insert into device.wake_word (' \ - 'id,' \ - 'setting_name,' \ - 'display_name,' \ - 'engine)' \ - 'values (%s, %s, %s, %s)' - query2 = 'insert into device.wake_word_settings(' \ - 'wake_word_id,' \ - 'sample_rate,' \ - 'channels,' \ - 'pronunciation,' \ - 'threshold,' \ - 'threshold_multiplier,' \ - 'dynamic_energy_ratio)' \ - 'values (%s, %s, %s, %s, %s, %s, %s)' + query1 = ( + "insert into device.wake_word (" + "id," + "setting_name," + "display_name," + "engine)" + "values (%s, %s, %s, %s)" + ) + query2 = ( + "insert into device.wake_word_settings(" + "wake_word_id," + "sample_rate," + "channels," + "pronunciation," + "threshold," + "threshold_multiplier," + "dynamic_energy_ratio)" + "values (%s, %s, %s, %s, %s, %s, %s)" + ) wake_words = [ - (hey_mycroft, 'Hey Mycroft', 'Hey Mycroft', 'precise'), - (christopher, 'Christopher', 'Christopher', 'precise'), - (ezra, 'Hey Ezra', 'Hey Ezra', 'precise'), - (jarvis, 'Hey Jarvis', 'Hey Jarvis', 'precise') + (hey_mycroft, "Hey Mycroft", "Hey Mycroft", "precise"), + (christopher, "Christopher", "Christopher", "precise"), + (ezra, "Hey Ezra", "Hey Ezra", "precise"), + (jarvis, "Hey Jarvis", "Hey Jarvis", "precise"), ] wake_word_settings = [ - (hey_mycroft, '16000', '1', 'HH EY . M AY K R AO F T', '1e-90', '1', '1.5'), - (christopher, '16000', '1', 'K R IH S T AH F ER .', '1e-25', '1', '1.5'), - (ezra, '16000', '1', 'HH EY . EH Z R AH', '1e-10', '1', '2.5'), - (jarvis, '16000', '1', 'HH EY . JH AA R V AH S', '1e-25', '1', '1.5') + (hey_mycroft, "16000", "1", "HH EY . M AY K R AO F T", "1e-90", "1", "1.5"), + (christopher, "16000", "1", "K R IH S T AH F ER .", "1e-25", "1", "1.5"), + (ezra, "16000", "1", "HH EY . EH Z R AH", "1e-10", "1", "2.5"), + (jarvis, "16000", "1", "HH EY . JH AA R V AH S", "1e-25", "1", "1.5"), ] with db.cursor() as cur: execute_batch(cur, query1, wake_words) @@ -351,134 +373,179 @@ def fill_default_wake_word(): def fill_wake_word_table(): - query = 'insert into device.wake_word (' \ - 'id,' \ - 'setting_name,' \ - 'display_name,' \ - 'engine,' \ - 'account_id)' \ - 'values (%s, %s, %s, %s, %s)' + query = ( + "insert into device.wake_word (" + "id," + "setting_name," + "display_name," + "engine," + "account_id)" + "values (%s, %s, %s, %s, %s)" + ) def map_wake_word(user_id): wake_word_id = str(uuid.uuid4()) - wake_word = user_settings[user_id]['wake_word'].lower() if user_id in user_settings else 'hey mycroft' + wake_word = ( + user_settings[user_id]["wake_word"].lower() + if user_id in user_settings + else "hey mycroft" + ) mycroft_wake_word = default_wake_words.get(wake_word) if mycroft_wake_word is not None: wake_word_id = mycroft_wake_word - users[user_id]['wake_word_id'] = wake_word_id - return wake_word_id, wake_word, wake_word, 'precise', user_id + users[user_id]["wake_word_id"] = wake_word_id + return wake_word_id, wake_word, wake_word, "precise", user_id with db.cursor() as cur: wake_words = (map_wake_word(account_id) for account_id in users) - wake_words = (wk for wk in wake_words if wk[0] not in (hey_mycroft, christopher, ezra, jarvis)) + wake_words = ( + wk + for wk in wake_words + if wk[0] not in (hey_mycroft, christopher, ezra, jarvis) + ) execute_batch(cur, query, wake_words, page_size=1000) def fill_account_preferences_table(): - query = 'insert into device.account_preferences(' \ - 'account_id, ' \ - 'date_format, ' \ - 'time_format, ' \ - 'measurement_system)' \ - 'values (%s, %s, %s, %s)' + query = ( + "insert into device.account_preferences(" + "account_id, " + "date_format, " + "time_format, " + "measurement_system)" + "values (%s, %s, %s, %s)" + ) def map_account_preferences(user_uuid): if user_uuid in user_settings: user_setting = user_settings[user_uuid] - date_format = user_setting['date_format'] - if date_format == 'DMY': - date_format = 'DD/MM/YYYY' + date_format = user_setting["date_format"] + if date_format == "DMY": + date_format = "DD/MM/YYYY" else: - date_format = 'MM/DD/YYYY' - time_format = user_setting['time_format'] - if time_format == 'full': - time_format = '24 Hour' + date_format = "MM/DD/YYYY" + time_format = user_setting["time_format"] + if time_format == "full": + time_format = "24 Hour" else: - time_format = '12 Hour' - measurement_system = user_setting['measurement_system'] - if measurement_system == 'metric': - measurement_system = 'Metric' - elif measurement_system == 'imperial': - measurement_system = 'Imperial' - tts_type = user_setting['tts_type'] - tts_voice = user_setting['tts_voice'] - if tts_type == 'MimicSetting': - if tts_voice == 'ap': - tts = 'ap' - elif tts_voice == 'trinity': - tts = 'amy' + time_format = "12 Hour" + measurement_system = user_setting["measurement_system"] + if measurement_system == "metric": + measurement_system = "Metric" + elif measurement_system == "imperial": + measurement_system = "Imperial" + tts_type = user_setting["tts_type"] + tts_voice = user_setting["tts_voice"] + if tts_type == "MimicSetting": + if tts_voice == "ap": + tts = "ap" + elif tts_voice == "trinity": + tts = "amy" else: - tts = 'ap' - elif tts_type == 'Mimic2Setting': - tts = 'kusal' - elif tts_type == 'GoogleTTSSetting': - tts = 'google' + tts = "ap" + elif tts_type == "Mimic2Setting": + tts = "kusal" + elif tts_type == "GoogleTTSSetting": + tts = "google" else: - tts = 'ap' + tts = "ap" text_to_speech_id = get_tts_uuid(tts) - users[user_uuid]['text_to_speech_id'] = text_to_speech_id + users[user_uuid]["text_to_speech_id"] = text_to_speech_id return user_uuid, date_format, time_format, measurement_system else: - text_to_speech_id = get_tts_uuid('ap') - users[user_uuid]['text_to_speech_id'] = text_to_speech_id - return user_uuid, 'MM/DD/YYYY', '12 Hour', 'Imperial' + text_to_speech_id = get_tts_uuid("ap") + users[user_uuid]["text_to_speech_id"] = text_to_speech_id + return user_uuid, "MM/DD/YYYY", "12 Hour", "Imperial" with db.cursor() as cur: - account_preferences = (map_account_preferences(user_uuid) for user_uuid in users) + account_preferences = ( + map_account_preferences(user_uuid) for user_uuid in users + ) execute_batch(cur, query, account_preferences, page_size=1000) def fill_subscription_table(): - query = 'insert into account.account_membership(' \ - 'account_id, ' \ - 'membership_id, ' \ - 'membership_ts_range, ' \ - 'payment_account_id,' \ - 'payment_method,' \ - 'payment_id) ' \ - 'values (%s, %s, %s, %s, %s, %s)' + query = ( + "insert into account.account_membership(" + "account_id, " + "membership_id, " + "membership_ts_range, " + "payment_account_id," + "payment_method," + "payment_id) " + "values (%s, %s, %s, %s, %s, %s)" + ) def map_subscription(user_uuid): subscr = subscription[user_uuid] - stripe_customer_id = subscr['stripe_customer_id'] - start = format_timestamp(subscr['last_payment_ts']) - subscription_ts_range = '[{},)'.format(start) - subscription_type = subscr['type'] - if subscription_type == 'MonthlyAccount': - subscription_type = 'month' - elif subscription_type == 'YearlyAccount': - subscription_type = 'year' + stripe_customer_id = subscr["stripe_customer_id"] + start = format_timestamp(subscr["last_payment_ts"]) + subscription_ts_range = "[{},)".format(start) + subscription_type = subscr["type"] + if subscription_type == "MonthlyAccount": + subscription_type = "month" + elif subscription_type == "YearlyAccount": + subscription_type = "year" subscription_uuid = get_subscription_uuid(subscription_type) - return user_uuid, subscription_uuid, subscription_ts_range, stripe_customer_id, 'Stripe', 'subscription_id' + return ( + user_uuid, + subscription_uuid, + subscription_ts_range, + stripe_customer_id, + "Stripe", + "subscription_id", + ) + with db.cursor() as cur: - account_subscriptions = (map_subscription(user_uuid) for user_uuid in subscription) + account_subscriptions = ( + map_subscription(user_uuid) for user_uuid in subscription + ) execute_batch(cur, query, account_subscriptions, page_size=1000) def fill_wake_word_settings_table(): - query = 'insert into device.wake_word_settings(' \ - 'wake_word_id,' \ - 'sample_rate,' \ - 'channels,' \ - 'pronunciation,' \ - 'threshold,' \ - 'threshold_multiplier,' \ - 'dynamic_energy_ratio)' \ - 'values (%s, %s, %s, %s, %s, %s, %s)' + query = ( + "insert into device.wake_word_settings(" + "wake_word_id," + "sample_rate," + "channels," + "pronunciation," + "threshold," + "threshold_multiplier," + "dynamic_energy_ratio)" + "values (%s, %s, %s, %s, %s, %s, %s)" + ) def map_wake_word_settings(user_uuid): user_setting = user_settings[user_uuid] - wake_word_id = users[user_uuid]['wake_word_id'] - sample_rate = user_setting['sample_rate'] - channels = user_setting['channels'] - pronunciation = user_setting['pronunciation'] - threshold = user_setting['threshold'] - threshold_multiplier = user_setting['threshold_multiplier'] - dynamic_energy_ratio = user_setting['dynamic_energy_ratio'] - return wake_word_id, sample_rate, channels, pronunciation, threshold, threshold_multiplier, dynamic_energy_ratio + wake_word_id = users[user_uuid]["wake_word_id"] + sample_rate = user_setting["sample_rate"] + channels = user_setting["channels"] + pronunciation = user_setting["pronunciation"] + threshold = user_setting["threshold"] + threshold_multiplier = user_setting["threshold_multiplier"] + dynamic_energy_ratio = user_setting["dynamic_energy_ratio"] + return ( + wake_word_id, + sample_rate, + channels, + pronunciation, + threshold, + threshold_multiplier, + dynamic_energy_ratio, + ) + with db.cursor() as cur: - account_wake_word_settings = (map_wake_word_settings(user_uuid) for user_uuid in users if user_uuid in user_settings) - account_wake_word_settings = (wks for wks in account_wake_word_settings if wks[0] not in (hey_mycroft, christopher, ezra, jarvis)) + account_wake_word_settings = ( + map_wake_word_settings(user_uuid) + for user_uuid in users + if user_uuid in user_settings + ) + account_wake_word_settings = ( + wks + for wks in account_wake_word_settings + if wks[0] not in (hey_mycroft, christopher, ezra, jarvis) + ) execute_batch(cur, query, account_wake_word_settings, page_size=1000) @@ -493,37 +560,43 @@ def change_device_name(): if len(uuids) > 1: count = 1 for uuid in uuids: - devices[uuid]['name'] = '{name}-{uuid}'.format(name=name, uuid=uuid) + devices[uuid]["name"] = "{name}-{uuid}".format( + name=name, uuid=uuid + ) count += 1 def fill_device_table(): - query = 'insert into device.device(' \ - 'id, ' \ - 'account_id, ' \ - 'name, ' \ - 'placement,' \ - 'platform,' \ - 'enclosure_version,' \ - 'core_version,' \ - 'wake_word_id,' \ - 'geography_id,' \ - 'text_to_speech_id) ' \ - 'values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)' - query2 = 'insert into device.geography(' \ - 'id,' \ - 'account_id,' \ - 'country_id,' \ - 'region_id,' \ - 'city_id,' \ - 'timezone_id) ' \ - 'values (%s, %s, %s, %s, %s, %s)' + query = ( + "insert into device.device(" + "id, " + "account_id, " + "name, " + "placement," + "platform," + "enclosure_version," + "core_version," + "wake_word_id," + "geography_id," + "text_to_speech_id) " + "values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" + ) + query2 = ( + "insert into device.geography(" + "id," + "account_id," + "country_id," + "region_id," + "city_id," + "timezone_id) " + "values (%s, %s, %s, %s, %s, %s)" + ) with db.cursor() as cur: query_geography = """ SELECT city.id, region.id, country.id, timezone.id - FROM + FROM geography.city city INNER JOIN geography.region region ON city.region_id = region.id @@ -534,7 +607,7 @@ def fill_device_table(): WHERE city.name = %s and region.name = %s and timezone.name = %s; """ - cur.execute(query_geography, ('Lawrence', 'Kansas', 'America/Chicago')) + cur.execute(query_geography, ("Lawrence", "Kansas", "America/Chicago")) city_default, region_default, country_default, timezone_default = cur.fetchone() def map_geography(account_id, device_id): @@ -543,7 +616,7 @@ def fill_device_table(): query = """ SELECT city.id, region.id, country.id, timezone.id - FROM + FROM geography.city city INNER JOIN geography.region region ON city.region_id = region.id @@ -554,54 +627,73 @@ def fill_device_table(): WHERE city.name = %s and region.name = %s and timezone.name = %s and country.name = %s; """ - location_uuid = devices[device_id].get('location') + location_uuid = devices[device_id].get("location") if location_uuid is not None: location = locations[location_uuid] - timezone_entity = timezones[location['timezone']] - timezone = timezone_entity['code'] - city_entity = cities[location['city']] - city = city_entity['name'] - region_entity = regions[city_entity['region']] - region = region_entity['name'] - country_entity = countries[region_entity['country']] - country = country_entity['name'] + timezone_entity = timezones[location["timezone"]] + timezone = timezone_entity["code"] + city_entity = cities[location["city"]] + city = city_entity["name"] + region_entity = regions[city_entity["region"]] + region = region_entity["name"] + country_entity = countries[region_entity["country"]] + country = country_entity["name"] cur.execute(query, (city, region, timezone, country)) result = cur.fetchone() if result is not None: city, region, country, timezone = result return geography_id, account_id, country, region, city, timezone - return geography_id, account_id, country_default, region_default, city_default, timezone_default + return ( + geography_id, + account_id, + country_default, + region_default, + city_default, + timezone_default, + ) def map_device(device_id): device = devices[device_id] - account_id = device['user_uuid'] - name = device['name'] - placement = device['description'] - platform = device['platform'] - enclosure_version = device['enclosure_version'] - core_version = device['core_version'] - wake_word_id = users[account_id]['wake_word_id'] - geography_id = device['geography_id'] + account_id = device["user_uuid"] + name = device["name"] + placement = device["description"] + platform = device["platform"] + enclosure_version = device["enclosure_version"] + core_version = device["core_version"] + wake_word_id = users[account_id]["wake_word_id"] + geography_id = device["geography_id"] user_setting = user_settings[account_id] - tts_type = user_setting['tts_type'] - tts_voice = user_setting['tts_voice'] - if tts_type == 'MimicSetting': - if tts_voice == 'ap': - tts = 'ap' - elif tts_voice == 'trinity': - tts = 'amy' + tts_type = user_setting["tts_type"] + tts_voice = user_setting["tts_voice"] + if tts_type == "MimicSetting": + if tts_voice == "ap": + tts = "ap" + elif tts_voice == "trinity": + tts = "amy" else: - tts = 'ap' - elif tts_type == 'Mimic2Setting': - tts = 'kusal' - elif tts_type == 'GoogleTTSSetting': - tts = 'google' + tts = "ap" + elif tts_type == "Mimic2Setting": + tts = "kusal" + elif tts_type == "GoogleTTSSetting": + tts = "google" else: - tts = 'ap' + tts = "ap" text_to_speech_id = get_tts_uuid(tts) - return device_id, account_id, name, placement, platform, enclosure_version, core_version, wake_word_id, geography_id, text_to_speech_id + return ( + device_id, + account_id, + name, + placement, + platform, + enclosure_version, + core_version, + wake_word_id, + geography_id, + text_to_speech_id, + ) + with db.cursor() as cur: geography_batch = [] for user in user_devices: @@ -611,9 +703,14 @@ def fill_device_table(): geography = map_geography(user, device_id) geography_batch.append(geography) for device_id, name in aux: - devices[device_id]['geography_id'] = geography[0] + devices[device_id]["geography_id"] = geography[0] execute_batch(cur, query2, geography_batch, page_size=1000) - devices_batch = (map_device(device_id) for user in user_devices if user in users and user in user_settings for device_id, name in user_devices[user]) + devices_batch = ( + map_device(device_id) + for user in user_devices + if user in users and user in user_settings + for device_id, name in user_devices[user] + ) execute_batch(cur, query, devices_batch, page_size=1000) @@ -627,38 +724,50 @@ def fill_skills_table(): if device_uuid in device_to_skill: for skill_uuid in device_to_skill[device_uuid]: skill = skills[skill_uuid] - skill_name = skill['name'] - identifier = skill['identifier'] + skill_name = skill["name"] + identifier = skill["identifier"] sections = [] settings = {} if skill_uuid in skill_to_section: for section_uuid in skill_to_section[skill_uuid]: section = skill_sections[section_uuid] - section_name = section['section'] + section_name = section["section"] fields = [] if section_uuid in section_to_field: for field_uuid in section_to_field[section_uuid]: fields.append(skill_fields[field_uuid]) if field_uuid in skill_field_values: - settings[skill_fields[field_uuid]['name']] = skill_field_values[field_uuid]['field_value'] - sections.append({'name': section_name, 'fields': fields}) + settings[ + skill_fields[field_uuid]["name"] + ] = skill_field_values[field_uuid][ + "field_value" + ] + sections.append( + {"name": section_name, "fields": fields} + ) skill_setting_display = { - 'name': skill_name, - 'identifier': identifier, - 'skillMetadata': {'sections': sections} + "name": skill_name, + "identifier": identifier, + "skillMetadata": {"sections": sections}, } skills_batch.append((skill_uuid, skill_name)) meta_id = str(uuid.uuid4()) - settings_display_batch.append((meta_id, skill_uuid, json.dumps(skill_setting_display))) - device_skill_batch.append((device_uuid, skill_uuid, meta_id, json.dumps(settings))) + settings_display_batch.append( + (meta_id, skill_uuid, json.dumps(skill_setting_display)) + ) + device_skill_batch.append( + (device_uuid, skill_uuid, meta_id, json.dumps(settings)) + ) with db.cursor() as curr: - query = 'insert into skill.skill(id, name) values (%s, %s)' + query = "insert into skill.skill(id, name) values (%s, %s)" execute_batch(curr, query, skills_batch, page_size=1000) - query = 'insert into skill.settings_display(id, skill_id, settings_display) values (%s, %s, %s)' + query = "insert into skill.settings_display(id, skill_id, settings_display) values (%s, %s, %s)" execute_batch(curr, query, settings_display_batch, page_size=1000) - query = 'insert into device.device_skill(device_id, skill_id, skill_settings_display_id, settings) ' \ - 'values (%s, %s, %s, %s)' + query = ( + "insert into device.device_skill(device_id, skill_id, skill_settings_display_id, settings) " + "values (%s, %s, %s, %s)" + ) execute_batch(curr, query, device_skill_batch, page_size=1000) @@ -667,29 +776,43 @@ def analyze_locations(): mismatches = 0 g_mismatches = defaultdict(lambda: defaultdict(list)) for city in cities.values(): - region = regions[city['region']] - country = countries[region['country']] - city_name = city['name'] - region_name = region['name'] - country_name = country['name'] - remove = ['District', 'Region', 'Development', 'Prefecture', 'Community', 'County', 'Province', 'Division', 'Voivodeship', 'State', 'of', 'Governorate'] + region = regions[city["region"]] + country = countries[region["country"]] + city_name = city["name"] + region_name = region["name"] + country_name = country["name"] + remove = [ + "District", + "Region", + "Development", + "Prefecture", + "Community", + "County", + "Province", + "Division", + "Voivodeship", + "State", + "of", + "Governorate", + ] with db.cursor() as curr: original_region_name = region_name - region_name = ' '.join(i for i in region_name.split() if i not in remove) - query = 'select city.name ' \ - 'from geography.city city ' \ - 'inner join geography.region region on city.region_id = region.id ' \ - 'inner join geography.country country on region.country_id = country.id ' \ - 'where ' \ - 'city.name = \'{}\' and ' \ - '(region.name = \'{}\' or region.name = \'{}\') and ' \ - 'country.name = \'{}\''\ - .format( - city_name.replace('\'', '\'\''), - original_region_name.replace('\'', '\'\''), - region_name.replace('\'', '\'\''), - country_name.replace('\'', '\'\'') - ) + region_name = " ".join(i for i in region_name.split() if i not in remove) + query = ( + "select city.name " + "from geography.city city " + "inner join geography.region region on city.region_id = region.id " + "inner join geography.country country on region.country_id = country.id " + "where " + "city.name = '{}' and " + "(region.name = '{}' or region.name = '{}') and " + "country.name = '{}'".format( + city_name.replace("'", "''"), + original_region_name.replace("'", "''"), + region_name.replace("'", "''"), + country_name.replace("'", "''"), + ) + ) curr.execute(query) result = curr.fetchone() if result is None: @@ -701,9 +824,9 @@ def analyze_locations(): for country2, regions2 in g_mismatches.items(): for region2, cities2 in regions2.items(): for city2 in cities2: - print('{} - {} - {}'.format(country2, region2, city2)) + print("{} - {} - {}".format(country2, region2, city2)) - print('Number os mismatches: {}'.format(mismatches)) + print("Number os mismatches: {}".format(mismatches)) def analyze_location_2(): @@ -711,30 +834,32 @@ def analyze_location_2(): locations_from_db = defaultdict(list) with db.cursor() as cur: - cur.execute('select ' - 'c1.id, ' - 'c1.name, ' - 'c1.latitude, ' - 'c1.longitude, ' - 'r.name, ' - 'c2.name, ' - 'c2.iso_code ' - 'from geography.city c1 ' - 'inner join geography.region r on c1.region_id = r.id ' - 'inner join geography.country c2 on r.country_id = c2.id') + cur.execute( + "select " + "c1.id, " + "c1.name, " + "c1.latitude, " + "c1.longitude, " + "r.name, " + "c2.name, " + "c2.iso_code " + "from geography.city c1 " + "inner join geography.region r on c1.region_id = r.id " + "inner join geography.country c2 on r.country_id = c2.id" + ) for c1_id, c1, latitude, longitude, r, c2_name, c2_code in cur: aux[c2_name][r][c1] = c1_id locations_from_db[c2_code].append((c1, latitude, longitude)) for location_uuid, location in locations.items(): coordinate = coordinates[location_uuid] - city = cities[location['city']] - city_name = city['name'] - region = regions[city['region']] - region_name = region['name'] - country = countries[region['country']] - country_code = country['code'] - country_name = country['name'] + city = cities[location["city"]] + city_name = city["name"] + region = regions[city["region"]] + region_name = region["name"] + country = countries[region["country"]] + country_code = country["code"] + country_name = country["name"] res = aux.get(country_name) if res is not None: @@ -742,45 +867,45 @@ def analyze_location_2(): if res is not None: res = res.get(city_name) if res is not None: - print('Match: {}'.format(city_name)) + print("Match: {}".format(city_name)) continue min_dist = None result_name = None for c1_name, latitude, longitude in locations_from_db[country_code]: point1 = (float(latitude), float(longitude)) - point2 = (float(coordinate['latitude']), float(coordinate['longitude'])) + point2 = (float(coordinate["latitude"]), float(coordinate["longitude"])) dist = distance(point1, point2).km if min_dist is None or dist < min_dist: min_dist = dist result_name = c1_name - print('Actual: {}, calculated: {}'.format(city_name, result_name)) + print("Actual: {}, calculated: {}".format(city_name, result_name)) start = time.time() load_csv() end = time.time() -print('Time to load CSVs {}'.format(end - start)) +print("Time to load CSVs {}".format(end - start)) start = time.time() -print('Importing account table') -#fill_account_table() -print('Importing agreements table') -#fill_account_agreement_table() -print('Importing account preferences table') -#fill_account_preferences_table() -print('Importing subscription table') -#fill_subscription_table() -print('Importing wake word table') -#fill_default_wake_word() -#fill_wake_word_table() -print('Importing wake word settings table') -#fill_wake_word_settings_table() -print('Importing device table') -#change_device_name() -#fill_device_table() -print('Importing skills table') -#fill_skills_table() +print("Importing account table") +# fill_account_table() +print("Importing agreements table") +# fill_account_agreement_table() +print("Importing account preferences table") +# fill_account_preferences_table() +print("Importing subscription table") +# fill_subscription_table() +print("Importing wake word table") +# fill_default_wake_word() +# fill_wake_word_table() +print("Importing wake word settings table") +# fill_wake_word_settings_table() +print("Importing device table") +# change_device_name() +# fill_device_table() +print("Importing skills table") +# fill_skills_table() analyze_location_2() end = time.time() -print('Time to import: {}'.format(end-start)) +print("Time to import: {}".format(end - start)) diff --git a/shared/selene/__init__.py b/shared/selene/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/__init__.py +++ b/shared/selene/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/api/base_config.py b/shared/selene/api/base_config.py index b7bf52f2..74704213 100644 --- a/shared/selene/api/base_config.py +++ b/shared/selene/api/base_config.py @@ -48,32 +48,33 @@ class APIConfigError(Exception): class BaseConfig(object): """Base configuration.""" - ACCESS_SECRET = os.environ['JWT_ACCESS_SECRET'] + + ACCESS_SECRET = os.environ["JWT_ACCESS_SECRET"] DB_CONNECTION_POOL = None DEBUG = False - ENV = os.environ['SELENE_ENVIRONMENT'] - REFRESH_SECRET = os.environ['JWT_REFRESH_SECRET'] + ENV = os.environ["SELENE_ENVIRONMENT"] + REFRESH_SECRET = os.environ["JWT_REFRESH_SECRET"] DB_CONNECTION_CONFIG = DatabaseConnectionConfig( - host=os.environ['DB_HOST'], - db_name=os.environ['DB_NAME'], - password=os.environ['DB_PASSWORD'], - port=os.environ.get('DB_PORT', 5432), - user=os.environ['DB_USER'], - sslmode=os.environ.get('DB_SSLMODE') + host=os.environ["DB_HOST"], + db_name=os.environ["DB_NAME"], + password=os.environ["DB_PASSWORD"], + port=os.environ.get("DB_PORT", 5432), + user=os.environ["DB_USER"], + sslmode=os.environ.get("DB_SSLMODE"), ) class DevelopmentConfig(BaseConfig): DEBUG = True - DOMAIN = '.mycroft.test' + DOMAIN = ".mycroft.test" class TestConfig(BaseConfig): - DOMAIN = '.mycroft-test.net' + DOMAIN = ".mycroft-test.net" class ProdConfig(BaseConfig): - DOMAIN = '.mycroft.ai' + DOMAIN = ".mycroft.ai" def get_base_config(): @@ -81,16 +82,12 @@ def get_base_config(): :return: an object containing the configs for the API. """ - environment_configs = dict( - dev=DevelopmentConfig, - test=TestConfig, - prod=ProdConfig - ) + environment_configs = dict(dev=DevelopmentConfig, test=TestConfig, prod=ProdConfig) try: - environment_name = os.environ['SELENE_ENVIRONMENT'] + environment_name = os.environ["SELENE_ENVIRONMENT"] except KeyError: - raise APIConfigError('the SELENE_ENVIRONMENT variable is not set') + raise APIConfigError("the SELENE_ENVIRONMENT variable is not set") try: app_config = environment_configs[environment_name] diff --git a/shared/selene/api/blueprint.py b/shared/selene/api/blueprint.py index c2d492f9..62af7493 100644 --- a/shared/selene/api/blueprint.py +++ b/shared/selene/api/blueprint.py @@ -31,7 +31,7 @@ from selene.util.cache import DEVICE_LAST_CONTACT_KEY from selene.util.db import connect_to_db from selene.util.exceptions import NotModifiedException -selene_api = Blueprint('selene_api', __name__) +selene_api = Blueprint("selene_api", __name__) @selene_api.app_errorhandler(DataError) @@ -46,7 +46,7 @@ def handle_data_error(error): @selene_api.app_errorhandler(NotModifiedException) def handle_not_modified(_): - return '', HTTPStatus.NOT_MODIFIED + return "", HTTPStatus.NOT_MODIFIED @selene_api.before_app_request @@ -67,26 +67,26 @@ def add_api_metric(http_status): api = None # We are not logging metric for the public API until after the socket # implementation to avoid putting millions of rows a day on the table - for api_name in ('account', 'sso', 'market', 'public'): + for api_name in ("account", "sso", "market", "public"): if api_name in current_app.name: api = api_name if api is not None and int(http_status) != 304: - if 'db' not in global_context: + if "db" not in global_context: global_context.db = connect_to_db( - current_app.config['DB_CONNECTION_CONFIG'] + current_app.config["DB_CONNECTION_CONFIG"] ) - if 'account_id' in global_context: + if "account_id" in global_context: account_id = global_context.account_id else: account_id = None - if 'device_id' in global_context: + if "device_id" in global_context: device_id = global_context.device_id else: device_id = None - duration = (datetime.utcnow() - global_context.start_ts) + duration = datetime.utcnow() - global_context.start_ts api_metric = ApiMetric( access_ts=datetime.utcnow(), account_id=account_id, @@ -95,7 +95,7 @@ def add_api_metric(http_status): duration=Decimal(str(duration.total_seconds())), http_method=request.method, http_status=int(http_status), - url=global_context.url + url=global_context.url, ) metric_repository = ApiMetricsRepository(global_context.db) metric_repository.add(api_metric) @@ -107,6 +107,6 @@ def update_device_last_contact(): This should only be done on public API calls because we are tracking device activity only. """ - if 'public' in current_app.name and 'device_id' in global_context: + if "public" in current_app.name and "device_id" in global_context: key = DEVICE_LAST_CONTACT_KEY.format(device_id=global_context.device_id) global_context.cache.set(key, str(datetime.utcnow())) diff --git a/shared/selene/api/endpoints/agreements.py b/shared/selene/api/endpoints/agreements.py index 9f84b6f5..34b3e4e1 100644 --- a/shared/selene/api/endpoints/agreements.py +++ b/shared/selene/api/endpoints/agreements.py @@ -29,20 +29,20 @@ from ..base_endpoint import SeleneEndpoint class AgreementsEndpoint(SeleneEndpoint): authentication_required = False agreement_types = { - 'terms-of-use': 'Terms of Use', - 'privacy-policy': 'Privacy Policy' + "terms-of-use": "Terms of Use", + "privacy-policy": "Privacy Policy", } def get(self, agreement_type): """Process HTTP GET request for an agreement.""" - db = connect_to_db(self.config['DB_CONNECTION_CONFIG']) + db = connect_to_db(self.config["DB_CONNECTION_CONFIG"]) agreement_repository = AgreementRepository(db) agreement = agreement_repository.get_active_for_type( self.agreement_types[agreement_type] ) if agreement is not None: agreement = asdict(agreement) - del(agreement['effective_date']) + del agreement["effective_date"] self.response = agreement, HTTPStatus.OK return self.response diff --git a/shared/selene/api/etag.py b/shared/selene/api/etag.py index 37d9d59a..116c4c22 100644 --- a/shared/selene/api/etag.py +++ b/shared/selene/api/etag.py @@ -24,19 +24,19 @@ from selene.data.device import DeviceRepository from selene.util.cache import SeleneCache, DEVICE_SKILL_ETAG_KEY from selene.util.db import connect_to_db -ETAG_REQUEST_HEADER_KEY = 'If-None-Match' +ETAG_REQUEST_HEADER_KEY = "If-None-Match" def device_etag_key(device_id: str): - return 'device.etag:{uuid}'.format(uuid=device_id) + return "device.etag:{uuid}".format(uuid=device_id) def device_setting_etag_key(device_id: str): - return 'device.setting.etag:{uuid}'.format(uuid=device_id) + return "device.setting.etag:{uuid}".format(uuid=device_id) def device_location_etag_key(device_id: str): - return 'device.location.etag:{uuid}'.format(uuid=device_id) + return "device.location.etag:{uuid}".format(uuid=device_id) class ETagManager(object): @@ -46,7 +46,7 @@ class ETagManager(object): def __init__(self, cache: SeleneCache, config: dict): self.cache: SeleneCache = cache - self.db_connection_config = config['DB_CONNECTION_CONFIG'] + self.db_connection_config = config["DB_CONNECTION_CONFIG"] def get(self, key: str) -> str: """Generate a etag with 32 random chars and store it into a given key @@ -54,14 +54,14 @@ class ETagManager(object): :return etag""" etag = self.cache.get(key) if etag is None: - etag = ''.join(random.choice(self.etag_chars) for _ in range(32)) + etag = "".join(random.choice(self.etag_chars) for _ in range(32)) self.cache.set(key, etag) return etag def expire(self, key): """Expires an existent etag :param key: key where the etag is stored""" - etag = ''.join(random.choice(self.etag_chars) for _ in range(32)) + etag = "".join(random.choice(self.etag_chars) for _ in range(32)) self.cache.set(key, etag) def expire_device_etag_by_device_id(self, device_id: str): diff --git a/shared/selene/api/response.py b/shared/selene/api/response.py index c66cd812..de2261d3 100644 --- a/shared/selene/api/response.py +++ b/shared/selene/api/response.py @@ -22,7 +22,7 @@ import re from flask import jsonify, Response -snake_pattern = re.compile(r'_([a-z])') +snake_pattern = re.compile(r"_([a-z])") def snake_to_camel(name): diff --git a/shared/selene/data/__init__.py b/shared/selene/data/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/__init__.py +++ b/shared/selene/data/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/account/__init__.py b/shared/selene/data/account/__init__.py index a9d3ab1d..a8aff17b 100644 --- a/shared/selene/data/account/__init__.py +++ b/shared/selene/data/account/__init__.py @@ -18,12 +18,7 @@ # along with this program. If not, see . from .entity.account import Account, AccountAgreement, AccountMembership -from .entity.agreement import ( - Agreement, - PRIVACY_POLICY, - TERMS_OF_USE, - OPEN_DATASET -) +from .entity.agreement import Agreement, PRIVACY_POLICY, TERMS_OF_USE, OPEN_DATASET from .entity.membership import Membership from .entity.skill import AccountSkill from .repository.account import AccountRepository @@ -31,6 +26,6 @@ from .repository.agreement import AgreementRepository from .repository.membership import ( MembershipRepository, MONTHLY_MEMBERSHIP, - YEARLY_MEMBERSHIP + YEARLY_MEMBERSHIP, ) from .repository.skill import AccountSkillRepository diff --git a/shared/selene/data/account/entity/__init__.py b/shared/selene/data/account/entity/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/account/entity/__init__.py +++ b/shared/selene/data/account/entity/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/account/entity/agreement.py b/shared/selene/data/account/entity/agreement.py index 82e783b2..1a83192d 100644 --- a/shared/selene/data/account/entity/agreement.py +++ b/shared/selene/data/account/entity/agreement.py @@ -20,9 +20,9 @@ from dataclasses import dataclass from datetime import date -TERMS_OF_USE = 'Terms of Use' -PRIVACY_POLICY = 'Privacy Policy' -OPEN_DATASET = 'Open Dataset' +TERMS_OF_USE = "Terms of Use" +PRIVACY_POLICY = "Privacy Policy" +OPEN_DATASET = "Open Dataset" @dataclass diff --git a/shared/selene/data/account/repository/__init__.py b/shared/selene/data/account/repository/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/account/repository/__init__.py +++ b/shared/selene/data/account/repository/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/account/repository/membership.py b/shared/selene/data/account/repository/membership.py index 62261435..5a45d33d 100644 --- a/shared/selene/data/account/repository/membership.py +++ b/shared/selene/data/account/repository/membership.py @@ -21,8 +21,8 @@ from selene.data.account import AccountMembership from ..entity.membership import Membership from ...repository_base import RepositoryBase -MONTHLY_MEMBERSHIP = 'Monthly Membership' -YEARLY_MEMBERSHIP = 'Yearly Membership' +MONTHLY_MEMBERSHIP = "Monthly Membership" +YEARLY_MEMBERSHIP = "Yearly Membership" class MembershipRepository(RepositoryBase): @@ -30,37 +30,34 @@ class MembershipRepository(RepositoryBase): super(MembershipRepository, self).__init__(db, __file__) def get_membership_types(self): - db_request = self._build_db_request( - sql_file_name='get_membership_types.sql' - ) + db_request = self._build_db_request(sql_file_name="get_membership_types.sql") db_result = self.cursor.select_all(db_request) return [Membership(**row) for row in db_result] def get_membership_by_type(self, membership_type: str): db_request = self._build_db_request( - sql_file_name='get_membership_by_type.sql', - args=dict(type=membership_type) + sql_file_name="get_membership_by_type.sql", args=dict(type=membership_type) ) db_result = self.cursor.select_one(db_request) return Membership(**db_result) def add(self, membership: Membership): db_request = self._build_db_request( - 'add_membership.sql', + "add_membership.sql", args=dict( membership_type=membership.type, rate=membership.rate, - rate_period=membership.rate_period - ) + rate_period=membership.rate_period, + ), ) result = self.cursor.insert_returning(db_request) - return result['id'] + return result["id"] def remove(self, membership: Membership): db_request = self._build_db_request( - sql_file_name='delete_membership.sql', - args=dict(membership_id=membership.id) + sql_file_name="delete_membership.sql", + args=dict(membership_id=membership.id), ) self.cursor.delete(db_request) diff --git a/shared/selene/data/account/repository/skill.py b/shared/selene/data/account/repository/skill.py index c4529e59..6749290b 100644 --- a/shared/selene/data/account/repository/skill.py +++ b/shared/selene/data/account/repository/skill.py @@ -30,8 +30,8 @@ class AccountSkillRepository(RepositoryBase): def get_skills_for_account(self) -> List[AccountSkill]: db_request = self._build_db_request( - sql_file_name='get_account_skills.sql', - args=dict(account_id=self.account_id) + sql_file_name="get_account_skills.sql", + args=dict(account_id=self.account_id), ) db_result = self.cursor.select_all(db_request) diff --git a/shared/selene/data/device/entity/__init__.py b/shared/selene/data/device/entity/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/device/entity/__init__.py +++ b/shared/selene/data/device/entity/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/device/repository/__init__.py b/shared/selene/data/device/repository/__init__.py index 62e7737c..5da52a1a 100644 --- a/shared/selene/data/device/repository/__init__.py +++ b/shared/selene/data/device/repository/__init__.py @@ -17,4 +17,4 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from .device import Device \ No newline at end of file +from .device import Device diff --git a/shared/selene/data/device/repository/device_skill.py b/shared/selene/data/device/repository/device_skill.py index 8aecdae7..69ed7c6a 100644 --- a/shared/selene/data/device/repository/device_skill.py +++ b/shared/selene/data/device/repository/device_skill.py @@ -26,7 +26,7 @@ from selene.data.skill import SettingsDisplay from ..entity.device_skill import ( AccountSkillSettings, DeviceSkillSettings, - ManifestSkill + ManifestSkill, ) from ...repository_base import RepositoryBase @@ -36,19 +36,19 @@ class DeviceSkillRepository(RepositoryBase): super(DeviceSkillRepository, self).__init__(db, __file__) def get_skill_settings_for_account( - self, account_id: str, skill_id: str + self, account_id: str, skill_id: str ) -> List[AccountSkillSettings]: return self._select_all_into_dataclass( AccountSkillSettings, - sql_file_name='get_skill_settings_for_account.sql', - args=dict(account_id=account_id, skill_id=skill_id) + sql_file_name="get_skill_settings_for_account.sql", + args=dict(account_id=account_id, skill_id=skill_id), ) def get_skill_settings_for_device(self, device_id, skill_id=None): device_skills = self._select_all_into_dataclass( DeviceSkillSettings, - sql_file_name='get_skill_settings_for_device.sql', - args=dict(device_id=device_id) + sql_file_name="get_skill_settings_for_device.sql", + args=dict(device_id=device_id), ) if skill_id is None: skill_settings = device_skills @@ -62,23 +62,21 @@ class DeviceSkillRepository(RepositoryBase): return skill_settings def update_skill_settings( - self, account_id: str, device_names: tuple, skill_name: str + self, account_id: str, device_names: tuple, skill_name: str ): db_request = self._build_db_request( - sql_file_name='update_skill_settings.sql', + sql_file_name="update_skill_settings.sql", args=dict( - account_id=account_id, - device_names=device_names, - skill_name=skill_name - ) + account_id=account_id, device_names=device_names, skill_name=skill_name + ), ) self.cursor.update(db_request) def upsert_device_skill_settings( - self, - device_ids: List[str], - settings_display: SettingsDisplay, - settings_values: dict, + self, + device_ids: List[str], + settings_display: SettingsDisplay, + settings_values: dict, ): for device_id in device_ids: if settings_values is None: @@ -86,13 +84,13 @@ class DeviceSkillRepository(RepositoryBase): else: db_settings_values = json.dumps(settings_values) db_request = self._build_db_request( - sql_file_name='upsert_device_skill_settings.sql', + sql_file_name="upsert_device_skill_settings.sql", args=dict( device_id=device_id, skill_id=settings_display.skill_id, settings_values=db_settings_values, - settings_display_id=settings_display.id - ) + settings_display_id=settings_display.id, + ), ) self.cursor.insert(db_request) @@ -103,76 +101,66 @@ class DeviceSkillRepository(RepositoryBase): else: db_settings_values = json.dumps(device_skill.settings_values) db_request = self._build_db_request( - sql_file_name='update_device_skill_settings.sql', + sql_file_name="update_device_skill_settings.sql", args=dict( device_id=device_id, skill_id=device_skill.skill_id, settings_display_id=device_skill.settings_display_id, - settings_values=db_settings_values - ) + settings_values=db_settings_values, + ), ) self.cursor.update(db_request) - def get_skill_manifest_for_device( - self, device_id: str - ) -> List[ManifestSkill]: + def get_skill_manifest_for_device(self, device_id: str) -> List[ManifestSkill]: return self._select_all_into_dataclass( dataclass=ManifestSkill, - sql_file_name='get_device_skill_manifest.sql', - args=dict(device_id=device_id) + sql_file_name="get_device_skill_manifest.sql", + args=dict(device_id=device_id), ) - def get_skill_manifest_for_account( - self, account_id: str - ) -> List[ManifestSkill]: + def get_skill_manifest_for_account(self, account_id: str) -> List[ManifestSkill]: return self._select_all_into_dataclass( dataclass=ManifestSkill, - sql_file_name='get_skill_manifest_for_account.sql', - args=dict(account_id=account_id) + sql_file_name="get_skill_manifest_for_account.sql", + args=dict(account_id=account_id), ) def update_manifest_skill(self, manifest_skill: ManifestSkill): db_request = self._build_db_request( - sql_file_name='update_skill_manifest.sql', - args=asdict(manifest_skill) + sql_file_name="update_skill_manifest.sql", args=asdict(manifest_skill) ) self.cursor.update(db_request) def add_manifest_skill(self, manifest_skill: ManifestSkill): db_request = self._build_db_request( - sql_file_name='add_manifest_skill.sql', - args=asdict(manifest_skill) + sql_file_name="add_manifest_skill.sql", args=asdict(manifest_skill) ) db_result = self.cursor.insert_returning(db_request) - return db_result['id'] + return db_result["id"] def remove_manifest_skill(self, manifest_skill: ManifestSkill): db_request = self._build_db_request( - sql_file_name='remove_manifest_skill.sql', + sql_file_name="remove_manifest_skill.sql", args=dict( - device_id=manifest_skill.device_id, - skill_gid=manifest_skill.skill_gid - ) + device_id=manifest_skill.device_id, skill_gid=manifest_skill.skill_gid + ), ) self.cursor.delete(db_request) def get_settings_display_usage(self, settings_display_id: str) -> int: db_request = self._build_db_request( - sql_file_name='get_settings_display_usage.sql', - args=dict(settings_display_id=settings_display_id) + sql_file_name="get_settings_display_usage.sql", + args=dict(settings_display_id=settings_display_id), ) db_result = self.cursor.select_one(db_request) - return db_result['usage'] + return db_result["usage"] def remove(self, device_id, skill_id): db_request = self._build_db_request( - sql_file_name='delete_device_skill.sql', - args=dict( - device_id=device_id, - skill_id=skill_id - ) + sql_file_name="delete_device_skill.sql", + args=dict(device_id=device_id, skill_id=skill_id), ) self.cursor.delete(db_request) diff --git a/shared/selene/data/device/repository/geography.py b/shared/selene/data/device/repository/geography.py index 65832a32..e251513d 100644 --- a/shared/selene/data/device/repository/geography.py +++ b/shared/selene/data/device/repository/geography.py @@ -28,8 +28,8 @@ class GeographyRepository(RepositoryBase): def get_account_geographies(self): db_request = self._build_db_request( - sql_file_name='get_account_geographies.sql', - args=dict(account_id=self.account_id) + sql_file_name="get_account_geographies.sql", + args=dict(account_id=self.account_id), ) db_response = self.cursor.select_all(db_request) @@ -40,10 +40,10 @@ class GeographyRepository(RepositoryBase): acct_geographies = self.get_account_geographies() for acct_geography in acct_geographies: match = ( - acct_geography.city == geography.city and - acct_geography.country == geography.country and - acct_geography.region == geography.region and - acct_geography.time_zone == geography.time_zone + acct_geography.city == geography.city + and acct_geography.country == geography.country + and acct_geography.region == geography.region + and acct_geography.time_zone == geography.time_zone ) if match: geography_id = acct_geography.id @@ -53,22 +53,22 @@ class GeographyRepository(RepositoryBase): def add(self, geography: Geography): db_request = self._build_db_request( - sql_file_name='add_geography.sql', + sql_file_name="add_geography.sql", args=dict( account_id=self.account_id, city=geography.city, country=geography.country, region=geography.region, - timezone=geography.time_zone - ) + timezone=geography.time_zone, + ), ) db_result = self.cursor.insert_returning(db_request) - return db_result['id'] + return db_result["id"] def get_location_by_device_id(self, device_id): db_request = self._build_db_request( - sql_file_name='get_location_by_device_id.sql', - args=dict(device_id=device_id) + sql_file_name="get_location_by_device_id.sql", + args=dict(device_id=device_id), ) return self.cursor.select_one(db_request) diff --git a/shared/selene/data/device/repository/preference.py b/shared/selene/data/device/repository/preference.py index 45294e85..cd6218a2 100644 --- a/shared/selene/data/device/repository/preference.py +++ b/shared/selene/data/device/repository/preference.py @@ -17,7 +17,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from dataclasses import asdict +from dataclasses import asdict from ..entity.preference import AccountPreferences from ...repository_base import RepositoryBase @@ -30,8 +30,8 @@ class PreferenceRepository(RepositoryBase): def get_account_preferences(self) -> AccountPreferences: db_request = self._build_db_request( - sql_file_name='get_account_preferences.sql', - args=dict(account_id=self.account_id) + sql_file_name="get_account_preferences.sql", + args=dict(account_id=self.account_id), ) db_result = self.cursor.select_one(db_request) @@ -46,7 +46,6 @@ class PreferenceRepository(RepositoryBase): db_request_args = dict(account_id=self.account_id) db_request_args.update(asdict(preferences)) db_request = self._build_db_request( - sql_file_name='upsert_preferences.sql', - args=db_request_args + sql_file_name="upsert_preferences.sql", args=db_request_args ) self.cursor.insert(db_request) diff --git a/shared/selene/data/device/repository/setting.py b/shared/selene/data/device/repository/setting.py index 8d55720e..d0a24bbf 100644 --- a/shared/selene/data/device/repository/setting.py +++ b/shared/selene/data/device/repository/setting.py @@ -21,7 +21,7 @@ from os import path from selene.util.db import get_sql_from_file, Cursor, DatabaseRequest -SQL_DIR = path.join(path.dirname(__file__), 'sql') +SQL_DIR = path.join(path.dirname(__file__), "sql") class SettingRepository(object): @@ -30,36 +30,38 @@ class SettingRepository(object): def get_device_settings_by_device_id(self, device_id): query = DatabaseRequest( - sql=get_sql_from_file(path.join(SQL_DIR, 'get_device_settings_by_device_id.sql')), - args=dict(device_id=device_id) + sql=get_sql_from_file( + path.join(SQL_DIR, "get_device_settings_by_device_id.sql") + ), + args=dict(device_id=device_id), ) return self.cursor.select_one(query) def convert_text_to_speech_setting(self, setting_name, engine) -> (str, str): """Convert the selene representation of TTS into the tartarus representation, for backward compatibility with the API v1""" - if engine == 'mimic': - if setting_name == 'trinity': - return 'mimic', 'trinity' - elif setting_name == 'kusal': - return 'mimic2', 'kusal' + if engine == "mimic": + if setting_name == "trinity": + return "mimic", "trinity" + elif setting_name == "kusal": + return "mimic2", "kusal" else: - return 'mimic', 'ap' + return "mimic", "ap" else: - return 'google', '' + return "google", "" def _format_date_v1(self, date: str): - if date == 'DD/MM/YYYY': - result = 'DMY' + if date == "DD/MM/YYYY": + result = "DMY" else: - result = 'MDY' + result = "MDY" return result def _format_time_v1(self, time: str): - if time == '24 Hour': - result = 'full' + if time == "24 Hour": + result = "full" else: - result = 'half' + result = "half" return result def get_device_settings(self, device_id): @@ -69,22 +71,29 @@ class SettingRepository(object): :return setting entity using the legacy format from the API v1""" response = self.get_device_settings_by_device_id(device_id) if response: - if response['listener_setting']['uuid'] is None: - del response['listener_setting'] - tts_setting = response['tts_settings'] - tts_setting = self.convert_text_to_speech_setting(tts_setting['setting_name'], tts_setting['engine']) - tts_setting = {'module': tts_setting[0], tts_setting[0]: {'voice': tts_setting[1]}} - response['tts_settings'] = tts_setting - response['date_format'] = self._format_date_v1(response['date_format']) - response['time_format'] = self._format_time_v1(response['time_format']) - response['system_unit'] = response['system_unit'].lower() + if response["listener_setting"]["uuid"] is None: + del response["listener_setting"] + tts_setting = response["tts_settings"] + tts_setting = self.convert_text_to_speech_setting( + tts_setting["setting_name"], tts_setting["engine"] + ) + tts_setting = { + "module": tts_setting[0], + tts_setting[0]: {"voice": tts_setting[1]}, + } + response["tts_settings"] = tts_setting + response["date_format"] = self._format_date_v1(response["date_format"]) + response["time_format"] = self._format_time_v1(response["time_format"]) + response["system_unit"] = response["system_unit"].lower() open_dataset = self._get_open_dataset_agreement_by_device_id(device_id) - response['optIn'] = open_dataset is not None + response["optIn"] = open_dataset is not None return response def _get_open_dataset_agreement_by_device_id(self, device_id: str): query = DatabaseRequest( - sql=get_sql_from_file(path.join(SQL_DIR, 'get_open_dataset_agreement_by_device_id.sql')), - args=dict(device_id=device_id) + sql=get_sql_from_file( + path.join(SQL_DIR, "get_open_dataset_agreement_by_device_id.sql") + ), + args=dict(device_id=device_id), ) return self.cursor.select_one(query) diff --git a/shared/selene/data/device/repository/text_to_speech.py b/shared/selene/data/device/repository/text_to_speech.py index dcbfdb07..e3998ef2 100644 --- a/shared/selene/data/device/repository/text_to_speech.py +++ b/shared/selene/data/device/repository/text_to_speech.py @@ -26,7 +26,7 @@ class TextToSpeechRepository(RepositoryBase): super(TextToSpeechRepository, self).__init__(db, __file__) def get_voices(self): - db_request = self._build_db_request(sql_file_name='get_voices.sql') + db_request = self._build_db_request(sql_file_name="get_voices.sql") db_result = self.cursor.select_all(db_request) return [TextToSpeech(**row) for row in db_result] @@ -37,16 +37,16 @@ class TextToSpeechRepository(RepositoryBase): :return wake word id """ db_request = self._build_db_request( - sql_file_name='add_text_to_speech.sql', + sql_file_name="add_text_to_speech.sql", args=dict( wake_word=text_to_speech.setting_name, account_id=text_to_speech.display_name, - engine=text_to_speech.engine - ) + engine=text_to_speech.engine, + ), ) result = self.cursor.insert_returning(db_request) - return result['id'] + return result["id"] # def remove(self, wake_word: WakeWord): # """Delete a wake word from the wake_word table.""" diff --git a/shared/selene/data/geography/entity/__init__.py b/shared/selene/data/geography/entity/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/geography/entity/__init__.py +++ b/shared/selene/data/geography/entity/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/geography/repository/__init__.py b/shared/selene/data/geography/repository/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/geography/repository/__init__.py +++ b/shared/selene/data/geography/repository/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/geography/repository/city.py b/shared/selene/data/geography/repository/city.py index 92e7ef5d..45fc1a66 100644 --- a/shared/selene/data/geography/repository/city.py +++ b/shared/selene/data/geography/repository/city.py @@ -27,8 +27,7 @@ class CityRepository(RepositoryBase): def get_cities_by_region(self, region_id): db_request = self._build_db_request( - sql_file_name='get_cities_by_region.sql', - args=dict(region_id=region_id) + sql_file_name="get_cities_by_region.sql", args=dict(region_id=region_id) ) db_result = self.cursor.select_all(db_request) @@ -39,23 +38,22 @@ class CityRepository(RepositoryBase): city_names = [nm.lower() for nm in possible_city_names] return self._select_all_into_dataclass( GeographicLocation, - sql_file_name='get_geographic_location_by_city.sql', - args=dict(possible_city_names=tuple(city_names)) - + sql_file_name="get_geographic_location_by_city.sql", + args=dict(possible_city_names=tuple(city_names)), ) def get_biggest_city_in_region(self, region_name): """Return the geolocation of the most populous city in a region.""" return self._select_one_into_dataclass( GeographicLocation, - sql_file_name='get_biggest_city_in_region.sql', - args=dict(region=region_name.lower()) + sql_file_name="get_biggest_city_in_region.sql", + args=dict(region=region_name.lower()), ) def get_biggest_city_in_country(self, country_name): """Return the geolocation of the most populous city in a country.""" return self._select_one_into_dataclass( GeographicLocation, - sql_file_name='get_biggest_city_in_country.sql', - args=dict(country=country_name.lower()) + sql_file_name="get_biggest_city_in_country.sql", + args=dict(country=country_name.lower()), ) diff --git a/shared/selene/data/geography/repository/country.py b/shared/selene/data/geography/repository/country.py index e966d58c..87b5cdce 100644 --- a/shared/selene/data/geography/repository/country.py +++ b/shared/selene/data/geography/repository/country.py @@ -26,7 +26,7 @@ class CountryRepository(RepositoryBase): super(CountryRepository, self).__init__(db, __file__) def get_countries(self): - db_request = self._build_db_request(sql_file_name='get_countries.sql') + db_request = self._build_db_request(sql_file_name="get_countries.sql") db_result = self.cursor.select_all(db_request) return [Country(**row) for row in db_result] diff --git a/shared/selene/data/geography/repository/region.py b/shared/selene/data/geography/repository/region.py index 3bd7612a..c3e21571 100644 --- a/shared/selene/data/geography/repository/region.py +++ b/shared/selene/data/geography/repository/region.py @@ -27,8 +27,7 @@ class RegionRepository(RepositoryBase): def get_regions_by_country(self, country_id): db_request = self._build_db_request( - sql_file_name='get_regions_by_country.sql', - args=dict(country_id=country_id) + sql_file_name="get_regions_by_country.sql", args=dict(country_id=country_id) ) db_result = self.cursor.select_all(db_request) diff --git a/shared/selene/data/geography/repository/timezone.py b/shared/selene/data/geography/repository/timezone.py index 66727d72..0024a1d2 100644 --- a/shared/selene/data/geography/repository/timezone.py +++ b/shared/selene/data/geography/repository/timezone.py @@ -27,8 +27,8 @@ class TimezoneRepository(RepositoryBase): def get_timezones_by_country(self, country_id): db_request = self._build_db_request( - sql_file_name='get_timezones_by_country.sql', - args=dict(country_id=country_id) + sql_file_name="get_timezones_by_country.sql", + args=dict(country_id=country_id), ) db_result = self.cursor.select_all(db_request) diff --git a/shared/selene/data/metric/entity/__init__.py b/shared/selene/data/metric/entity/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/metric/entity/__init__.py +++ b/shared/selene/data/metric/entity/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/metric/repository/__init__.py b/shared/selene/data/metric/repository/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/metric/repository/__init__.py +++ b/shared/selene/data/metric/repository/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/metric/repository/api.py b/shared/selene/data/metric/repository/api.py index 6bec83a1..ec6a74e2 100644 --- a/shared/selene/data/metric/repository/api.py +++ b/shared/selene/data/metric/repository/api.py @@ -34,7 +34,7 @@ from datetime import date, datetime, time from ..entity.api import ApiMetric from ...repository_base import RepositoryBase -DUMP_FILE_DIR = '/opt/selene/dump' +DUMP_FILE_DIR = "/opt/selene/dump" class ApiMetricsRepository(RepositoryBase): @@ -43,8 +43,7 @@ class ApiMetricsRepository(RepositoryBase): def add(self, metric: ApiMetric): db_request = self._build_db_request( - sql_file_name='add_api_metric.sql', - args=asdict(metric) + sql_file_name="add_api_metric.sql", args=asdict(metric) ) self.cursor.insert(db_request) @@ -53,27 +52,27 @@ class ApiMetricsRepository(RepositoryBase): start_ts = datetime.combine(partition_date, time.min) end_ts = datetime.combine(partition_date, time.max) db_request = self._build_db_request( - sql_file_name='create_api_metric_partition.sql', + sql_file_name="create_api_metric_partition.sql", args=dict(start_ts=str(start_ts), end_ts=str(end_ts)), - sql_vars=dict(partition=partition_date.strftime('%Y%m%d')) + sql_vars=dict(partition=partition_date.strftime("%Y%m%d")), ) self.cursor.execute(db_request) db_request = self._build_db_request( - sql_file_name='create_api_metric_partition_index.sql', - sql_vars=dict(partition=partition_date.strftime('%Y%m%d')) + sql_file_name="create_api_metric_partition_index.sql", + sql_vars=dict(partition=partition_date.strftime("%Y%m%d")), ) self.cursor.execute(db_request) def copy_to_partition(self, partition_date: date): """Copy rows from metric.api table to metric.api_history.""" - dump_file_name = 'api_metrics_' + str(partition_date) + dump_file_name = "api_metrics_" + str(partition_date) dump_file_path = os.path.join(DUMP_FILE_DIR, dump_file_name) db_request = self._build_db_request( - sql_file_name='get_api_metrics_for_date.sql', - args=dict(metrics_date=partition_date) + sql_file_name="get_api_metrics_for_date.sql", + args=dict(metrics_date=partition_date), ) - table_name = 'metric.api_history_' + partition_date.strftime('%Y%m%d') + table_name = "metric.api_history_" + partition_date.strftime("%Y%m%d") self.cursor.dump_query_result_to_file(db_request, dump_file_path) self.cursor.load_dump_file_to_table(table_name, dump_file_path) os.remove(dump_file_path) @@ -81,7 +80,7 @@ class ApiMetricsRepository(RepositoryBase): def remove_by_date(self, partition_date: date): """Delete from metric.api table after copying to metric.api_history""" db_request = self._build_db_request( - sql_file_name='delete_api_metrics_by_date.sql', - args=dict(delete_date=partition_date) + sql_file_name="delete_api_metrics_by_date.sql", + args=dict(delete_date=partition_date), ) self.cursor.delete(db_request) diff --git a/shared/selene/data/metric/repository/core.py b/shared/selene/data/metric/repository/core.py index bbe6d323..28bf040e 100644 --- a/shared/selene/data/metric/repository/core.py +++ b/shared/selene/data/metric/repository/core.py @@ -31,33 +31,29 @@ class CoreMetricRepository(RepositoryBase): def add(self, metric: CoreMetric): db_request_args = asdict(metric) - db_request_args['metric_value'] = json.dumps( - db_request_args['metric_value'] - ) + db_request_args["metric_value"] = json.dumps(db_request_args["metric_value"]) db_request = self._build_db_request( - sql_file_name='add_core_metric.sql', - args=db_request_args + sql_file_name="add_core_metric.sql", args=db_request_args ) self.cursor.insert(db_request) def get_metrics_by_device(self, device_id): return self._select_all_into_dataclass( CoreMetric, - sql_file_name='get_core_metric_by_device.sql', - args=dict(device_id=device_id) + sql_file_name="get_core_metric_by_device.sql", + args=dict(device_id=device_id), ) def get_metrics_by_date(self, metric_date: date) -> List[CoreMetric]: return self._select_all_into_dataclass( CoreMetric, - sql_file_name='get_core_timing_metrics_by_date.sql', - args=dict(metric_date=metric_date) + sql_file_name="get_core_timing_metrics_by_date.sql", + args=dict(metric_date=metric_date), ) def add_interaction(self, interaction: CoreInteraction) -> str: db_request = self._build_db_request( - sql_file_name='add_core_interaction.sql', - args=asdict(interaction) + sql_file_name="add_core_interaction.sql", args=asdict(interaction) ) db_result = self.cursor.insert_returning(db_request) diff --git a/shared/selene/data/repository_base.py b/shared/selene/data/repository_base.py index c852ea8f..cf70cbcc 100644 --- a/shared/selene/data/repository_base.py +++ b/shared/selene/data/repository_base.py @@ -29,7 +29,7 @@ from selene.util.db import ( Cursor, DatabaseRequest, DatabaseBatchRequest, - get_sql_from_file + get_sql_from_file, ) @@ -52,10 +52,10 @@ class RepositoryBase(object): def __init__(self, db, repository_path): self.db = db self.cursor = Cursor(db) - self.sql_dir = path.join(path.dirname(repository_path), 'sql') + self.sql_dir = path.join(path.dirname(repository_path), "sql") def _build_db_request( - self, sql_file_name: str, args: dict = None, sql_vars: dict = None + self, sql_file_name: str, args: dict = None, sql_vars: dict = None ): """Build a DatabaseRequest object containing a query and args""" sql = get_sql_from_file(path.join(self.sql_dir, sql_file_name)) @@ -67,8 +67,7 @@ class RepositoryBase(object): def _build_db_batch_request(self, sql_file_name: str, args: List[dict]): """Build a DatabaseBatchRequest object containing a query and args""" return DatabaseBatchRequest( - sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)), - args=args + sql=get_sql_from_file(path.join(self.sql_dir, sql_file_name)), args=args ) def _select_one_into_dataclass(self, dataclass, sql_file_name, args=None): diff --git a/shared/selene/data/skill/__init__.py b/shared/selene/data/skill/__init__.py index 73ff17d5..d219fd76 100644 --- a/shared/selene/data/skill/__init__.py +++ b/shared/selene/data/skill/__init__.py @@ -22,7 +22,7 @@ from .entity.skill import Skill from .entity.skill_setting import ( AccountSkillSetting, DeviceSkillSetting, - SettingsDisplay + SettingsDisplay, ) from .repository.display import SkillDisplayRepository from .repository.setting import SkillSettingRepository diff --git a/shared/selene/data/skill/entity/__init__.py b/shared/selene/data/skill/entity/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/skill/entity/__init__.py +++ b/shared/selene/data/skill/entity/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/skill/repository/__init__.py b/shared/selene/data/skill/repository/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/data/skill/repository/__init__.py +++ b/shared/selene/data/skill/repository/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/data/skill/repository/display.py b/shared/selene/data/skill/repository/display.py index f620a8e2..91cd6832 100644 --- a/shared/selene/data/skill/repository/display.py +++ b/shared/selene/data/skill/repository/display.py @@ -26,30 +26,30 @@ class SkillDisplayRepository(RepositoryBase): super(SkillDisplayRepository, self).__init__(db, __file__) # TODO: Change this to a value that can be passed in - self.core_version = '21.02' + self.core_version = "21.02" def get_display_data_for_skills(self): return self._select_all_into_dataclass( dataclass=SkillDisplay, - sql_file_name='get_display_data_for_skills.sql', - args=dict(core_version=self.core_version) + sql_file_name="get_display_data_for_skills.sql", + args=dict(core_version=self.core_version), ) def get_display_data_for_skill(self, skill_display_id) -> SkillDisplay: return self._select_one_into_dataclass( dataclass=SkillDisplay, - sql_file_name='get_display_data_for_skill.sql', - args=dict(skill_display_id=skill_display_id) + sql_file_name="get_display_data_for_skill.sql", + args=dict(skill_display_id=skill_display_id), ) def upsert(self, skill_display: SkillDisplay): db_request = self._build_db_request( - sql_file_name='upsert_skill_display_data.sql', + sql_file_name="upsert_skill_display_data.sql", args=dict( skill_id=skill_display.skill_id, core_version=skill_display.core_version, display_data=skill_display.display_data, - ) + ), ) self.cursor.insert(db_request) diff --git a/shared/selene/data/skill/repository/setting.py b/shared/selene/data/skill/repository/setting.py index d5c8f900..c1e069ed 100644 --- a/shared/selene/data/skill/repository/setting.py +++ b/shared/selene/data/skill/repository/setting.py @@ -32,14 +32,12 @@ class SkillSettingRepository(RepositoryBase): self.db = db def get_family_settings( - self, - account_id: str, - family_name: str + self, account_id: str, family_name: str ) -> List[AccountSkillSetting]: return self._select_all_into_dataclass( AccountSkillSetting, - sql_file_name='get_settings_for_skill_family.sql', - args=dict(family_name=family_name, account_id=account_id) + sql_file_name="get_settings_for_skill_family.sql", + args=dict(family_name=family_name, account_id=account_id), ) def get_installer_settings(self, account_id) -> List[AccountSkillSetting]: @@ -47,39 +45,31 @@ class SkillSettingRepository(RepositoryBase): skills = skill_repo.get_skills_for_account(account_id) installer_skill_id = None for skill in skills: - if skill.display_name == 'Installer': + if skill.display_name == "Installer": installer_skill_id = skill.id skill_settings = None if installer_skill_id is not None: - skill_settings = self.get_family_settings( - account_id, - installer_skill_id - ) + skill_settings = self.get_family_settings(account_id, installer_skill_id) return skill_settings @use_transaction def update_skill_settings( - self, - account_id, - new_skill_settings: AccountSkillSetting, - skill_ids: List[str] + self, account_id, new_skill_settings: AccountSkillSetting, skill_ids: List[str] ): if new_skill_settings.settings_values is None: serialized_settings_values = None else: - serialized_settings_values = json.dumps( - new_skill_settings.settings_values - ) + serialized_settings_values = json.dumps(new_skill_settings.settings_values) db_request = self._build_db_request( - 'update_device_skill_settings.sql', + "update_device_skill_settings.sql", args=dict( account_id=account_id, settings_values=serialized_settings_values, skill_id=tuple(skill_ids), - device_names=tuple(new_skill_settings.device_names) - ) + device_names=tuple(new_skill_settings.device_names), + ), ) self.cursor.update(db_request) @@ -87,6 +77,6 @@ class SkillSettingRepository(RepositoryBase): """Return all skills and their settings for a given device id""" return self._select_all_into_dataclass( DeviceSkillSetting, - sql_file_name='get_skill_setting_by_device.sql', - args=dict(device_id=device_id) + sql_file_name="get_skill_setting_by_device.sql", + args=dict(device_id=device_id), ) diff --git a/shared/selene/data/skill/repository/settings_display.py b/shared/selene/data/skill/repository/settings_display.py index 0f3ee9a5..6b5f89e1 100644 --- a/shared/selene/data/skill/repository/settings_display.py +++ b/shared/selene/data/skill/repository/settings_display.py @@ -24,35 +24,34 @@ from ..entity.skill_setting import SettingsDisplay class SettingsDisplayRepository(RepositoryBase): - def __init__(self, db): super(SettingsDisplayRepository, self).__init__(db, __file__) def add(self, settings_display: SettingsDisplay) -> str: """Add a new row to the skill.settings_display table.""" db_request = self._build_db_request( - sql_file_name='add_settings_display.sql', + sql_file_name="add_settings_display.sql", args=dict( skill_id=settings_display.skill_id, - display_data=json.dumps(settings_display.display_data) - ) + display_data=json.dumps(settings_display.display_data), + ), ) result = self.cursor.insert_returning(db_request) - return result['id'] + return result["id"] def get_settings_display_id(self, settings_display: SettingsDisplay): """Get the ID of a skill's settings definition.""" db_request = self._build_db_request( - sql_file_name='get_settings_display_id.sql', + sql_file_name="get_settings_display_id.sql", args=dict( skill_id=settings_display.skill_id, - display_data=json.dumps(settings_display.display_data) - ) + display_data=json.dumps(settings_display.display_data), + ), ) result = self.cursor.select_one(db_request) - return None if result is None else result['id'] + return None if result is None else result["id"] def get_settings_definitions_by_gid(self, global_id): """Get all matching settings definitions for a global skill ID. @@ -63,14 +62,14 @@ class SettingsDisplayRepository(RepositoryBase): """ return self._select_all_into_dataclass( SettingsDisplay, - sql_file_name='get_settings_definition_by_gid.sql', - args=dict(global_id=global_id) + sql_file_name="get_settings_definition_by_gid.sql", + args=dict(global_id=global_id), ) def remove(self, settings_display_id: str): """Delete a settings definition that is no longer used by any device""" db_request = self._build_db_request( - sql_file_name='delete_settings_display.sql', - args=dict(settings_display_id=settings_display_id) + sql_file_name="delete_settings_display.sql", + args=dict(settings_display_id=settings_display_id), ) self.cursor.delete(db_request) diff --git a/shared/selene/data/skill/repository/skill.py b/shared/selene/data/skill/repository/skill.py index 76e60676..f66df9e6 100644 --- a/shared/selene/data/skill/repository/skill.py +++ b/shared/selene/data/skill/repository/skill.py @@ -24,8 +24,8 @@ from ...repository_base import RepositoryBase def extract_family_from_global_id(skill_gid): - id_parts = skill_gid.split('|') - if id_parts[0].startswith('@'): + id_parts = skill_gid.split("|") + if id_parts[0].startswith("@"): family_name = id_parts[1] else: family_name = id_parts[0] @@ -41,8 +41,7 @@ class SkillRepository(RepositoryBase): def get_skills_for_account(self, account_id) -> List[SkillFamily]: skills = [] db_request = self._build_db_request( - 'get_skills_for_account.sql', - args=dict(account_id=account_id) + "get_skills_for_account.sql", args=dict(account_id=account_id) ) db_result = self.cursor.select_all(db_request) if db_result is not None: @@ -54,23 +53,23 @@ class SkillRepository(RepositoryBase): def get_skill_by_global_id(self, skill_global_id) -> Skill: return self._select_one_into_dataclass( dataclass=Skill, - sql_file_name='get_skill_by_global_id.sql', - args=dict(skill_global_id=skill_global_id) + sql_file_name="get_skill_by_global_id.sql", + args=dict(skill_global_id=skill_global_id), ) @staticmethod def _extract_settings(skill): settings = {} - skill_metadata = skill.get('skillMetadata') + skill_metadata = skill.get("skillMetadata") if skill_metadata: - for section in skill_metadata['sections']: - for field in section['fields']: - if 'name' in field and 'value' in field: - settings[field['name']] = field['value'] - field.pop('value', None) + for section in skill_metadata["sections"]: + for field in section["fields"]: + if "name" in field and "value" in field: + settings[field["name"]] = field["value"] + field.pop("value", None) result = settings, skill else: - result = '', '' + result = "", "" return result def ensure_skill_exists(self, skill_global_id: str) -> str: @@ -85,14 +84,14 @@ class SkillRepository(RepositoryBase): def _add_skill(self, skill_gid: str, name: str) -> str: db_request = self._build_db_request( - sql_file_name='add_skill.sql', - args=dict(skill_gid=skill_gid, family_name=name) + sql_file_name="add_skill.sql", + args=dict(skill_gid=skill_gid, family_name=name), ) db_result = self.cursor.insert_returning(db_request) # handle both dictionary cursors and namedtuple cursors try: - skill_id = db_result['id'] + skill_id = db_result["id"] except TypeError: skill_id = db_result.id @@ -100,7 +99,6 @@ class SkillRepository(RepositoryBase): def remove_by_gid(self, skill_gid): db_request = self._build_db_request( - sql_file_name='remove_skill_by_gid.sql', - args=dict(skill_gid=skill_gid) + sql_file_name="remove_skill_by_gid.sql", args=dict(skill_gid=skill_gid) ) self.cursor.delete(db_request) diff --git a/shared/selene/testing/__init__.py b/shared/selene/testing/__init__.py index 7665c10b..eabab81b 100644 --- a/shared/selene/testing/__init__.py +++ b/shared/selene/testing/__init__.py @@ -16,4 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - diff --git a/shared/selene/testing/account.py b/shared/selene/testing/account.py index 94dde483..1bf815d8 100644 --- a/shared/selene/testing/account.py +++ b/shared/selene/testing/account.py @@ -25,7 +25,7 @@ from selene.data.account import ( AccountRepository, OPEN_DATASET, PRIVACY_POLICY, - TERMS_OF_USE + TERMS_OF_USE, ) @@ -33,19 +33,19 @@ def build_test_account(**overrides): test_agreements = [ AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()), AccountAgreement(type=TERMS_OF_USE, accept_date=date.today()), - AccountAgreement(type=OPEN_DATASET, accept_date=date.today()) + AccountAgreement(type=OPEN_DATASET, accept_date=date.today()), ] return Account( - email_address=overrides.get('email_address') or 'foo@mycroft.ai', - username=overrides.get('username') or 'foobar', - agreements=overrides.get('agreements') or test_agreements + email_address=overrides.get("email_address") or "foo@mycroft.ai", + username=overrides.get("username") or "foobar", + agreements=overrides.get("agreements") or test_agreements, ) def add_account(db, **overrides): acct_repository = AccountRepository(db) account = build_test_account(**overrides) - password = overrides.get('password') or 'test_password' + password = overrides.get("password") or "test_password" account.id = acct_repository.add(account, password) if account.membership is not None: acct_repository.add_membership(account.id, account.membership) @@ -59,13 +59,13 @@ def remove_account(db, account): def build_test_membership(**overrides): - stripe_acct = 'test_stripe_acct_id' + stripe_acct = "test_stripe_acct_id" return AccountMembership( - type=overrides.get('type') or 'Monthly Membership', - start_date=overrides.get('start_date') or date.today(), - payment_method=overrides.get('payment_method') or 'Stripe', - payment_account_id=overrides.get('payment_account_id') or stripe_acct, - payment_id=overrides.get('payment_id') or 'test_stripe_payment_id' + type=overrides.get("type") or "Monthly Membership", + start_date=overrides.get("start_date") or date.today(), + payment_method=overrides.get("payment_method") or "Stripe", + payment_account_id=overrides.get("payment_account_id") or stripe_acct, + payment_id=overrides.get("payment_id") or "test_stripe_payment_id", ) diff --git a/shared/selene/testing/account_geography.py b/shared/selene/testing/account_geography.py index 8806c84d..8147f697 100644 --- a/shared/selene/testing/account_geography.py +++ b/shared/selene/testing/account_geography.py @@ -22,10 +22,10 @@ from selene.data.device import Geography, GeographyRepository def add_account_geography(db, account, **overrides): geography = Geography( - country=overrides.get('country') or 'United States', - region=overrides.get('region') or 'Missouri', - city=overrides.get('city') or 'Kansas City', - time_zone=overrides.get('time_zone') or 'America/Chicago' + country=overrides.get("country") or "United States", + region=overrides.get("region") or "Missouri", + city=overrides.get("city") or "Kansas City", + time_zone=overrides.get("time_zone") or "America/Chicago", ) geo_repository = GeographyRepository(db, account.id) account_geography_id = geo_repository.add(geography) diff --git a/shared/selene/testing/account_preference.py b/shared/selene/testing/account_preference.py index b66a4fda..fd37cde6 100644 --- a/shared/selene/testing/account_preference.py +++ b/shared/selene/testing/account_preference.py @@ -22,9 +22,7 @@ from selene.data.device import AccountPreferences, PreferenceRepository def add_account_preference(db, account_id): account_preferences = AccountPreferences( - date_format='MM/DD/YYYY', - time_format='12 Hour', - measurement_system='Imperial' + date_format="MM/DD/YYYY", time_format="12 Hour", measurement_system="Imperial" ) preference_repo = PreferenceRepository(db, account_id) preference_repo.upsert(account_preferences) diff --git a/shared/selene/testing/agreement.py b/shared/selene/testing/agreement.py index fe9cf8c4..7c6c4656 100644 --- a/shared/selene/testing/agreement.py +++ b/shared/selene/testing/agreement.py @@ -29,44 +29,44 @@ from selene.data.account import ( AgreementRepository, OPEN_DATASET, PRIVACY_POLICY, - TERMS_OF_USE + TERMS_OF_USE, ) def _build_test_terms_of_use(): return Agreement( type=TERMS_OF_USE, - version='Holy Grail', - content='I agree that all the tests I write for this application will ' - 'be in the theme of Monty Python and the Holy Grail. If you ' - 'do not agree with these terms, I will be forced to say "Ni!" ' - 'until such time as you agree', - effective_date=date.today() - timedelta(days=1) -) + version="Holy Grail", + content="I agree that all the tests I write for this application will " + "be in the theme of Monty Python and the Holy Grail. If you " + 'do not agree with these terms, I will be forced to say "Ni!" ' + "until such time as you agree", + effective_date=date.today() - timedelta(days=1), + ) def _build_test_privacy_policy(): return Agreement( type=PRIVACY_POLICY, - version='Holy Grail', - content='First, shalt thou take out the Holy Pin. Then shalt thou ' - 'count to three. No more. No less. Three shalt be the ' - 'number thou shalt count and the number of the counting shall ' - 'be three. Four shalt thou not count, nor either count thou ' - 'two, excepting that thou then proceed to three. Five is ' - 'right out. Once the number three, being the third number, ' - 'be reached, then lobbest thou Holy Hand Grenade of Antioch ' - 'towards thy foe, who, being naughty in My sight, ' - 'shall snuff it.', - effective_date=date.today() - timedelta(days=1) + version="Holy Grail", + content="First, shalt thou take out the Holy Pin. Then shalt thou " + "count to three. No more. No less. Three shalt be the " + "number thou shalt count and the number of the counting shall " + "be three. Four shalt thou not count, nor either count thou " + "two, excepting that thou then proceed to three. Five is " + "right out. Once the number three, being the third number, " + "be reached, then lobbest thou Holy Hand Grenade of Antioch " + "towards thy foe, who, being naughty in My sight, " + "shall snuff it.", + effective_date=date.today() - timedelta(days=1), ) def _build_open_dataset(): return Agreement( type=OPEN_DATASET, - version='Holy Grail', - effective_date=date.today() - timedelta(days=1) + version="Holy Grail", + effective_date=date.today() - timedelta(days=1), ) @@ -93,11 +93,11 @@ def remove_agreements(db, agreements: List[Agreement]): def get_agreements_from_api(context, agreement): """Abstracted so both account and single sign on APIs use in their tests""" if agreement == PRIVACY_POLICY: - url = '/api/agreement/privacy-policy' + url = "/api/agreement/privacy-policy" elif agreement == TERMS_OF_USE: - url = '/api/agreement/terms-of-use' + url = "/api/agreement/terms-of-use" else: - raise ValueError('invalid agreement type') + raise ValueError("invalid agreement type") context.response = context.client.get(url) @@ -109,7 +109,7 @@ def validate_agreement_response(context, agreement): elif agreement == TERMS_OF_USE: expected_response = asdict(context.terms_of_use) else: - raise ValueError('invalid agreement type') + raise ValueError("invalid agreement type") - del(expected_response['effective_date']) + del expected_response["effective_date"] assert_that(response_data, equal_to(expected_response)) diff --git a/shared/selene/testing/api.py b/shared/selene/testing/api.py index 86d01501..a856d130 100644 --- a/shared/selene/testing/api.py +++ b/shared/selene/testing/api.py @@ -24,17 +24,14 @@ from selene.data.account import Account, AccountRepository from selene.util.auth import AuthenticationToken from selene.util.db import connect_to_db -ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess' +ACCESS_TOKEN_COOKIE_KEY = "seleneAccess" ONE_MINUTE = 60 TWO_MINUTES = 120 -REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh' +REFRESH_TOKEN_COOKIE_KEY = "seleneRefresh" def generate_access_token(context, duration=ONE_MINUTE): - access_token = AuthenticationToken( - context.client_config['ACCESS_SECRET'], - duration - ) + access_token = AuthenticationToken(context.client_config["ACCESS_SECRET"], duration) account = context.accounts[context.username] access_token.generate(account.id) @@ -43,17 +40,16 @@ def generate_access_token(context, duration=ONE_MINUTE): def set_access_token_cookie(context, duration=ONE_MINUTE): context.client.set_cookie( - context.client_config['DOMAIN'], + context.client_config["DOMAIN"], ACCESS_TOKEN_COOKIE_KEY, context.access_token.jwt, - max_age=duration + max_age=duration, ) def generate_refresh_token(context, duration=TWO_MINUTES): refresh_token = AuthenticationToken( - context.client_config['REFRESH_SECRET'], - duration + context.client_config["REFRESH_SECRET"], duration ) account = context.accounts[context.username] refresh_token.generate(account.id) @@ -63,38 +59,38 @@ def generate_refresh_token(context, duration=TWO_MINUTES): def set_refresh_token_cookie(context, duration=TWO_MINUTES): context.client.set_cookie( - context.client_config['DOMAIN'], + context.client_config["DOMAIN"], REFRESH_TOKEN_COOKIE_KEY, context.refresh_token.jwt, - max_age=duration + max_age=duration, ) def validate_token_cookies(context, expired=False): - for cookie in context.response.headers.getlist('Set-Cookie'): + for cookie in context.response.headers.getlist("Set-Cookie"): ingredients = _parse_cookie(cookie) ingredient_names = list(ingredients.keys()) if ACCESS_TOKEN_COOKIE_KEY in ingredient_names: context.access_token = ingredients[ACCESS_TOKEN_COOKIE_KEY] elif REFRESH_TOKEN_COOKIE_KEY in ingredient_names: context.refresh_token = ingredients[REFRESH_TOKEN_COOKIE_KEY] - for ingredient_name in ('Domain', 'Expires', 'Max-Age'): + for ingredient_name in ("Domain", "Expires", "Max-Age"): assert_that(ingredient_names, has_item(ingredient_name)) if expired: - assert_that(ingredients['Max-Age'], equal_to('0')) + assert_that(ingredients["Max-Age"], equal_to("0")) - assert hasattr(context, 'access_token'), 'no access token in response' - assert hasattr(context, 'refresh_token'), 'no refresh token in response' + assert hasattr(context, "access_token"), "no access token in response" + assert hasattr(context, "refresh_token"), "no refresh token in response" if expired: - assert_that(context.access_token, equal_to('')) - assert_that(context.refresh_token, equal_to('')) + assert_that(context.access_token, equal_to("")) + assert_that(context.refresh_token, equal_to("")) def _parse_cookie(cookie: str) -> dict: ingredients = {} - for ingredient in cookie.split('; '): - if '=' in ingredient: - key, value = ingredient.split('=') + for ingredient in cookie.split("; "): + if "=" in ingredient: + key, value = ingredient.split("=") ingredients[key] = value else: ingredients[ingredient] = None @@ -103,7 +99,7 @@ def _parse_cookie(cookie: str) -> dict: def get_account(context) -> Account: - db = connect_to_db(context.client['DB_CONNECTION_CONFIG']) + db = connect_to_db(context.client["DB_CONNECTION_CONFIG"]) acct_repository = AccountRepository(db) account = acct_repository.get_account_by_id(context.account.id) @@ -112,21 +108,14 @@ def get_account(context) -> Account: def check_http_success(context): assert_that( - context.response.status_code, - is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT]) + context.response.status_code, is_in([HTTPStatus.OK, HTTPStatus.NO_CONTENT]) ) def check_http_error(context, error_type): - if error_type == 'a bad request': - assert_that( - context.response.status_code, - equal_to(HTTPStatus.BAD_REQUEST) - ) - elif error_type == 'an unauthorized': - assert_that( - context.response.status_code, - equal_to(HTTPStatus.UNAUTHORIZED) - ) + if error_type == "a bad request": + assert_that(context.response.status_code, equal_to(HTTPStatus.BAD_REQUEST)) + elif error_type == "an unauthorized": + assert_that(context.response.status_code, equal_to(HTTPStatus.UNAUTHORIZED)) else: - raise ValueError('unsupported error_type') + raise ValueError("unsupported error_type") diff --git a/shared/selene/testing/device_skill.py b/shared/selene/testing/device_skill.py index d145e59b..6a5ae9f1 100644 --- a/shared/selene/testing/device_skill.py +++ b/shared/selene/testing/device_skill.py @@ -26,12 +26,12 @@ from selene.data.device import DeviceSkillRepository, ManifestSkill def add_device_skill(db, device_id, skill): manifest_skill = ManifestSkill( device_id=device_id, - install_method='test_install_method', - install_status='test_install_status', + install_method="test_install_method", + install_status="test_install_status", skill_id=skill.id, skill_gid=skill.skill_gid, install_ts=datetime.utcnow(), - update_ts=datetime.utcnow() + update_ts=datetime.utcnow(), ) device_skill_repo = DeviceSkillRepository(db) manifest_skill.id = device_skill_repo.add_manifest_skill(manifest_skill) @@ -42,9 +42,7 @@ def add_device_skill(db, device_id, skill): def add_device_skill_settings(db, device_id, settings_display, settings_values): device_skill_repo = DeviceSkillRepository(db) device_skill_repo.upsert_device_skill_settings( - [device_id], - settings_display, - settings_values + [device_id], settings_display, settings_values ) diff --git a/shared/selene/testing/membership.py b/shared/selene/testing/membership.py index 4946b6c4..89b63898 100644 --- a/shared/selene/testing/membership.py +++ b/shared/selene/testing/membership.py @@ -23,19 +23,15 @@ from selene.data.account import ( Membership, MembershipRepository, MONTHLY_MEMBERSHIP, - YEARLY_MEMBERSHIP + YEARLY_MEMBERSHIP, ) monthly_membership = dict( - type=MONTHLY_MEMBERSHIP, - rate=Decimal('1.99'), - rate_period='monthly' + type=MONTHLY_MEMBERSHIP, rate=Decimal("1.99"), rate_period="monthly" ) yearly_membership = dict( - type=YEARLY_MEMBERSHIP, - rate=Decimal('19.99'), - rate_period='yearly' + type=YEARLY_MEMBERSHIP, rate=Decimal("19.99"), rate_period="yearly" ) diff --git a/shared/selene/testing/skill.py b/shared/selene/testing/skill.py index be035345..7452b3ec 100644 --- a/shared/selene/testing/skill.py +++ b/shared/selene/testing/skill.py @@ -21,52 +21,41 @@ from selene.data.skill import ( SettingsDisplay, SettingsDisplayRepository, Skill, - SkillRepository + SkillRepository, ) def build_text_field(): return dict( - name='textfield', - type='text', - label='Text Field', - placeholder='Text Placeholder' + name="textfield", + type="text", + label="Text Field", + placeholder="Text Placeholder", ) def build_checkbox_field(): - return dict( - name='checkboxfield', - type='checkbox', - label='Checkbox Field' - ) + return dict(name="checkboxfield", type="checkbox", label="Checkbox Field") def build_label_field(): - return dict( - type='label', - label='This is a section label.' - ) + return dict(type="label", label="This is a section label.") def _build_display_data(skill_gid, fields): - gid_parts = skill_gid.split('|') + gid_parts = skill_gid.split("|") if len(gid_parts) == 3: skill_name = gid_parts[1] else: skill_name = gid_parts[0] - skill_identifier = skill_name + '-123456' + skill_identifier = skill_name + "-123456" settings_display = dict( - skill_gid=skill_gid, - identifier=skill_identifier, - display_name=skill_name, + skill_gid=skill_gid, identifier=skill_identifier, display_name=skill_name, ) if fields is not None: settings_display.update( - skillMetadata=dict( - sections=[dict(name='Section Name', fields=fields)] - ) + skillMetadata=dict(sections=[dict(name="Section Name", fields=fields)]) ) return settings_display diff --git a/shared/selene/testing/test_db.py b/shared/selene/testing/test_db.py index 04933c62..b65c5c6e 100644 --- a/shared/selene/testing/test_db.py +++ b/shared/selene/testing/test_db.py @@ -20,10 +20,7 @@ from psycopg2 import connect connection_config = dict( - host='127.0.0.1', - dbname='postgres', - user='mycroft', - password='holmes' + host="127.0.0.1", dbname="postgres", user="mycroft", password="holmes" ) @@ -32,12 +29,12 @@ def create_test_db(): db.autocommit = True cursor = db.cursor() cursor.execute( - 'CREATE DATABASE ' - ' mycroft_test ' - 'WITH TEMPLATE ' - ' mycroft_template ' - 'OWNER ' - ' mycroft;' + "CREATE DATABASE " + " mycroft_test " + "WITH TEMPLATE " + " mycroft_template " + "OWNER " + " mycroft;" ) @@ -46,8 +43,8 @@ def drop_test_db(): db.autocommit = True cursor = db.cursor() cursor.execute( - 'SELECT pg_terminate_backend(pid) ' - 'FROM pg_stat_activity ' - 'WHERE datname = \'mycroft_test\';' + "SELECT pg_terminate_backend(pid) " + "FROM pg_stat_activity " + "WHERE datname = 'mycroft_test';" ) - cursor.execute('DROP DATABASE mycroft_test') + cursor.execute("DROP DATABASE mycroft_test") diff --git a/shared/selene/testing/text_to_speech.py b/shared/selene/testing/text_to_speech.py index 73a772e5..46e74126 100644 --- a/shared/selene/testing/text_to_speech.py +++ b/shared/selene/testing/text_to_speech.py @@ -22,9 +22,9 @@ from selene.data.device import DeviceRepository, TextToSpeech def _build_voice(): return TextToSpeech( - setting_name='selene_test_voice', - display_name='Selene Test Voice', - engine='mimic' + setting_name="selene_test_voice", + display_name="Selene Test Voice", + engine="mimic", ) diff --git a/shared/selene/util/__init__.py b/shared/selene/util/__init__.py index bf39084d..eabab81b 100644 --- a/shared/selene/util/__init__.py +++ b/shared/selene/util/__init__.py @@ -16,5 +16,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - - diff --git a/shared/selene/util/db/__init__.py b/shared/selene/util/db/__init__.py index 1cbcebb8..17dc2e5d 100644 --- a/shared/selene/util/db/__init__.py +++ b/shared/selene/util/db/__init__.py @@ -22,12 +22,7 @@ from .connection_pool import ( allocate_db_connection_pool, get_db_connection, get_db_connection_from_pool, - return_db_connection_to_pool -) -from .cursor import ( - Cursor, - DatabaseRequest, - DatabaseBatchRequest, - get_sql_from_file + return_db_connection_to_pool, ) +from .cursor import Cursor, DatabaseRequest, DatabaseBatchRequest, get_sql_from_file from .transaction import use_transaction diff --git a/shared/selene/util/db/transaction.py b/shared/selene/util/db/transaction.py index 2550827b..6864cc75 100644 --- a/shared/selene/util/db/transaction.py +++ b/shared/selene/util/db/transaction.py @@ -30,6 +30,7 @@ def use_transaction(func): :param func: function being decorated :return: decorated function """ + @wraps(func) def execute_in_transaction(*args, **kwargs): instance = args[0] diff --git a/shared/selene/util/exceptions.py b/shared/selene/util/exceptions.py index 6a516801..22251db9 100644 --- a/shared/selene/util/exceptions.py +++ b/shared/selene/util/exceptions.py @@ -25,4 +25,5 @@ class NotModifiedException(Exception): The Flask blueprint will catch this exception and return a HTTP 304 code. """ + pass diff --git a/shared/selene/util/payment/__init__.py b/shared/selene/util/payment/__init__.py index b5f39300..bdf68f7b 100644 --- a/shared/selene/util/payment/__init__.py +++ b/shared/selene/util/payment/__init__.py @@ -20,5 +20,5 @@ from .stripe import ( cancel_stripe_subscription, create_stripe_account, - create_stripe_subscription + create_stripe_subscription, ) diff --git a/shared/selene/util/payment/stripe.py b/shared/selene/util/payment/stripe.py index 02e30bfe..7c8bd5ee 100644 --- a/shared/selene/util/payment/stripe.py +++ b/shared/selene/util/payment/stripe.py @@ -23,22 +23,19 @@ import stripe def create_stripe_account(token: str, email: str): - stripe.api_key = os.environ['STRIPE_PRIVATE_KEY'] + stripe.api_key = os.environ["STRIPE_PRIVATE_KEY"] customer = stripe.Customer.create(source=token, email=email) return customer.id def create_stripe_subscription(customer_id, plan): - stripe.api_key = os.environ['STRIPE_PRIVATE_KEY'] - request = stripe.Subscription.create( - customer=customer_id, - items=[{'plan': plan}] - ) + stripe.api_key = os.environ["STRIPE_PRIVATE_KEY"] + request = stripe.Subscription.create(customer=customer_id, items=[{"plan": plan}]) return request.id def cancel_stripe_subscription(subscription_id): - stripe.api_key = os.environ['STRIPE_PRIVATE_KEY'] + stripe.api_key = os.environ["STRIPE_PRIVATE_KEY"] active_stripe_subscription = stripe.Subscription.retrieve(subscription_id) active_stripe_subscription.delete()