diff --git a/api/account/account_api/endpoints/account_device.py b/api/account/account_api/endpoints/account_device.py index 247f3f66..649cb37a 100644 --- a/api/account/account_api/endpoints/account_device.py +++ b/api/account/account_api/endpoints/account_device.py @@ -42,8 +42,8 @@ class AccountDeviceEndpoint(SeleneEndpoint): def _pair(self, account_id: str, name: str, pairing: dict): """Creates a device and associate it to a pairing session""" with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: - device_id = DeviceRepository(db).add_device(account_id, name) - pairing['uuid'] = device_id + result = DeviceRepository(db).add_device(account_id, name) + pairing['uuid'] = result['id'] return self.cache.set_with_expiration(self._token_key(pairing['token']), json.dumps(pairing), self.device_pairing_time) diff --git a/api/public/public_api/api.py b/api/public/public_api/api.py index e8d1ca6d..13c90ddb 100644 --- a/api/public/public_api/api.py +++ b/api/public/public_api/api.py @@ -13,6 +13,8 @@ from .endpoints.device_subscription import DeviceSubscriptionEndpoint from .endpoints.open_weather_map import OpenWeatherMapEndpoint from .endpoints.wolfram_alpha import WolframAlphaEndpoint from .endpoints.google_stt import GoogleSTTEndpoint +from .endpoints.device_code import DeviceCodeEndpoint +from .endpoints.device_activate import DeviceActivateEndpoint public = Flask(__name__) public.config.from_object(get_base_config()) @@ -65,3 +67,13 @@ public.add_url_rule( view_func=GoogleSTTEndpoint.as_view('google_stt_api'), methods=['POST'] ) # TODO: change this path in the API v2 +public.add_url_rule( + '/device/code', + view_func=DeviceCodeEndpoint.as_view('device_code_api'), + methods=['GET'] +) +public.add_url_rule( + '/device/activate', + view_func=DeviceActivateEndpoint.as_view('device_activate_api'), + methods=['POST'] +) diff --git a/api/public/public_api/endpoints/device_activate.py b/api/public/public_api/endpoints/device_activate.py new file mode 100644 index 00000000..2c2b4dab --- /dev/null +++ b/api/public/public_api/endpoints/device_activate.py @@ -0,0 +1,47 @@ +import json + +from flask_restful import http_status_message +from selene.api import SeleneEndpoint +from selene.data.device import DeviceRepository +from selene.util.cache import SeleneCache +from selene.util.db import get_db_connection + + +class DeviceActivateEndpoint(SeleneEndpoint): + """Endpoint to activate a device and finish the pairing process""" + + def __init__(self): + super(DeviceActivateEndpoint, self).__init__() + self.cache: SeleneCache = self.config.get('SELENE_CACHE') + + def post(self): + device_activate = self.request.get_json() + if device_activate: + pairing = self._get_pairing_session(device_activate) + if pairing: + device_activate['uuid'] = pairing['uuid'] + self._activate(device_activate) + return http_status_message(200) + return http_status_message(204) + return http_status_message(204) + + def _get_pairing_session(self, device_activate: dict): + """Get the pairing session from the cache if device_activate has the same state that + the state stored in the pairing session""" + assert ('token' in device_activate and 'state' in device_activate) + token = device_activate['token'] + pairing = self.cache.get(self._token_key(token)) + if pairing: + pairing = json.loads(pairing) + if device_activate['state'] == pairing['state']: + self.cache.delete(self._token_key(token)) + return pairing + + def _activate(self, device: dict): + """Updates a device in the database with the core version, platform and enclosure_version fields""" + with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: + DeviceRepository(db).update_device(device) + + @staticmethod + def _token_key(token): + return 'pairing.token:{}'.format(token) diff --git a/api/public/endpoints/device_code.py b/api/public/public_api/endpoints/device_code.py similarity index 94% rename from api/public/endpoints/device_code.py rename to api/public/public_api/endpoints/device_code.py index 1ec286ce..26a6daf6 100644 --- a/api/public/endpoints/device_code.py +++ b/api/public/public_api/endpoints/device_code.py @@ -31,19 +31,20 @@ class DeviceCodeEndpoint(SeleneEndpoint): self.sha512.update(bytes(str(uuid.uuid4()), 'utf-8')) token = self.sha512.hexdigest() code = self._pairing_code() - pairing = json.dumps({ + pairing = { 'code': code, 'state': state, 'token': token, 'expiration': self.device_pairing_time - }) + } + pairing_json = json.dumps(pairing) # This is to deal with the case where we generate a pairing code that already exists in the # cache, meaning another device is trying to pairing using the same code. In this case, we should # call the method again to get another random pairing code if self.cache.set_if_not_exists_with_expiration(self._code_key(code), - value=pairing, + value=pairing_json, expiration=self.device_pairing_time): - return code + return pairing else: return self._create(state) diff --git a/shared/selene/data/device/repository/device.py b/shared/selene/data/device/repository/device.py index b5679847..228a6257 100644 --- a/shared/selene/data/device/repository/device.py +++ b/shared/selene/data/device/repository/device.py @@ -61,3 +61,16 @@ class DeviceRepository(object): args=dict(account_id=account_id, name=name) ) return self.cursor.insert_returning(query) + + def update_device(self, device): + """Updates a device in the database""" + query = DatabaseRequest( + sql=get_sql_from_file(path.join(SQL_DIR, 'update_device.sql')), + args=dict( + device_id=device['uuid'], + platform=device.get('platform', 'unknown'), + enclosure_version=device.get('enclosure_version', 'unknown'), + core_version=device.get('core_version', 'unknown') + ) + ) + return self.cursor.insert(query) diff --git a/shared/selene/data/device/repository/sql/update_device.sql b/shared/selene/data/device/repository/sql/update_device.sql new file mode 100644 index 00000000..c470bc84 --- /dev/null +++ b/shared/selene/data/device/repository/sql/update_device.sql @@ -0,0 +1,8 @@ +UPDATE + device.device +SET + platform = %(platform)s, + enclosure_version = %(enclosure_version)s, + core_version = %(core_version)s +WHERE + id = %(device_id)s \ No newline at end of file