From 2e6346ca433aa308d1fb26c81152613faef9b2aa Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 1 Oct 2018 16:09:31 +0200 Subject: [PATCH] Break up websocket 2 (#17028) * Break up websocket 2 * Lint+Test * Lintttt * Rename --- homeassistant/components/auth/__init__.py | 10 +- .../components/auth/mfa_setup_flow.py | 10 +- homeassistant/components/camera/__init__.py | 4 +- homeassistant/components/cloud/http_api.py | 12 +- homeassistant/components/config/auth.py | 12 +- .../config/auth_provider_homeassistant.py | 205 +++++------ .../components/config/device_registry.py | 2 +- .../components/config/entity_registry.py | 12 +- homeassistant/components/frontend/__init__.py | 30 +- homeassistant/components/http/auth.py | 1 + homeassistant/components/lovelace/__init__.py | 2 +- .../components/media_player/__init__.py | 6 +- .../persistent_notification/__init__.py | 2 +- .../components/websocket_api/__init__.py | 328 +----------------- .../components/websocket_api/auth.py | 99 ++++++ .../components/websocket_api/commands.py | 18 +- .../components/websocket_api/connection.py | 78 +++++ .../components/websocket_api/const.py | 12 + .../components/websocket_api/decorators.py | 8 +- .../components/websocket_api/error.py | 8 + .../components/websocket_api/http.py | 189 ++++++++++ tests/components/conftest.py | 59 ++-- tests/components/test_panel_iframe.py | 8 +- tests/components/websocket_api/conftest.py | 7 +- tests/components/websocket_api/test_auth.py | 62 ++-- .../components/websocket_api/test_commands.py | 21 +- tests/components/websocket_api/test_init.py | 4 +- 27 files changed, 641 insertions(+), 568 deletions(-) create mode 100644 homeassistant/components/websocket_api/auth.py create mode 100644 homeassistant/components/websocket_api/connection.py create mode 100644 homeassistant/components/websocket_api/error.py create mode 100644 homeassistant/components/websocket_api/http.py diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index c0027fac820..58be53d4122 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -432,7 +432,7 @@ def websocket_current_user( """Get current user.""" enabled_modules = await hass.auth.async_get_enabled_mfa(user) - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'], { 'id': user.id, 'name': user.name, @@ -467,7 +467,7 @@ def websocket_create_long_lived_access_token( access_token = hass.auth.async_create_access_token( refresh_token) - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'], access_token)) hass.async_create_task( @@ -479,8 +479,8 @@ def websocket_create_long_lived_access_token( def websocket_refresh_tokens( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg): """Return metadata of users refresh tokens.""" - current_id = connection.request.get('refresh_token_id') - connection.to_write.put_nowait(websocket_api.result_message(msg['id'], [{ + current_id = connection.refresh_token_id + connection.send_message(websocket_api.result_message(msg['id'], [{ 'id': refresh.id, 'client_id': refresh.client_id, 'client_name': refresh.client_name, @@ -508,7 +508,7 @@ def websocket_delete_refresh_token( await hass.auth.async_remove_refresh_token(refresh_token) - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'], {})) hass.async_create_task( diff --git a/homeassistant/components/auth/mfa_setup_flow.py b/homeassistant/components/auth/mfa_setup_flow.py index 82eb913d890..121d95aede3 100644 --- a/homeassistant/components/auth/mfa_setup_flow.py +++ b/homeassistant/components/auth/mfa_setup_flow.py @@ -64,7 +64,7 @@ def websocket_setup_mfa( if flow_id is not None: result = await flow_manager.async_configure( flow_id, msg.get('user_input')) - connection.send_message_outside( + connection.send_message( websocket_api.result_message( msg['id'], _prepare_result_json(result))) return @@ -72,7 +72,7 @@ def websocket_setup_mfa( mfa_module_id = msg.get('mfa_module_id') mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id) if mfa_module is None: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'no_module', 'MFA module {} is not found'.format(mfa_module_id))) return @@ -80,7 +80,7 @@ def websocket_setup_mfa( result = await flow_manager.async_init( mfa_module_id, data={'user_id': connection.user.id}) - connection.send_message_outside( + connection.send_message( websocket_api.result_message( msg['id'], _prepare_result_json(result))) @@ -99,13 +99,13 @@ def websocket_depose_mfa( await hass.auth.async_disable_user_mfa( connection.user, msg['mfa_module_id']) except ValueError as err: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'disable_failed', 'Cannot disable MFA Module {}: {}'.format( mfa_module_id, err))) return - connection.send_message_outside( + connection.send_message( websocket_api.result_message( msg['id'], 'done')) diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 95f0cddf320..2cf23e0d60c 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -460,14 +460,14 @@ async def websocket_camera_thumbnail(hass, connection, msg): """ try: image = await async_get_image(hass, msg['entity_id']) - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], { 'content_type': image.content_type, 'content': base64.b64encode(image.content).decode('utf-8') } )) except HomeAssistantError: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'image_fetch_failed', 'Unable to fetch image')) diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index c81ec38bace..720ca00cf52 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -231,7 +231,7 @@ def websocket_cloud_status(hass, connection, msg): Async friendly. """ cloud = hass.data[DOMAIN] - connection.to_write.put_nowait( + connection.send_message( websocket_api.result_message(msg['id'], _account_data(cloud))) @@ -241,7 +241,7 @@ async def websocket_subscription(hass, connection, msg): cloud = hass.data[DOMAIN] if not cloud.is_logged_in: - connection.to_write.put_nowait(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'not_logged_in', 'You need to be logged in to the cloud.')) return @@ -250,10 +250,10 @@ async def websocket_subscription(hass, connection, msg): response = await cloud.fetch_subscription_info() if response.status == 200: - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], await response.json())) else: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'request_failed', 'Failed to request subscription')) @@ -263,7 +263,7 @@ async def websocket_update_prefs(hass, connection, msg): cloud = hass.data[DOMAIN] if not cloud.is_logged_in: - connection.to_write.put_nowait(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'not_logged_in', 'You need to be logged in to the cloud.')) return @@ -273,7 +273,7 @@ async def websocket_update_prefs(hass, connection, msg): changes.pop('type') await cloud.update_preferences(**changes) - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], {'success': True})) diff --git a/homeassistant/components/config/auth.py b/homeassistant/components/config/auth.py index 17dd132d4b4..f2af6589f11 100644 --- a/homeassistant/components/config/auth.py +++ b/homeassistant/components/config/auth.py @@ -49,7 +49,7 @@ def websocket_list(hass, connection, msg): """Send users.""" result = [_user_info(u) for u in await hass.auth.async_get_users()] - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'], result)) hass.async_add_job(send_users()) @@ -61,8 +61,8 @@ def websocket_delete(hass, connection, msg): """Delete a user.""" async def delete_user(): """Delete user.""" - if msg['user_id'] == connection.request.get('hass_user').id: - connection.send_message_outside(websocket_api.error_message( + if msg['user_id'] == connection.user.id: + connection.send_message(websocket_api.error_message( msg['id'], 'no_delete_self', 'Unable to delete your own account')) return @@ -70,13 +70,13 @@ def websocket_delete(hass, connection, msg): user = await hass.auth.async_get_user(msg['user_id']) if not user: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'not_found', 'User not found')) return await hass.auth.async_remove_user(user) - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'])) hass.async_add_job(delete_user()) @@ -90,7 +90,7 @@ def websocket_create(hass, connection, msg): """Create a user.""" user = await hass.auth.async_create_user(msg['name']) - connection.send_message_outside( + connection.send_message( websocket_api.result_message(msg['id'], { 'user': _user_info(user) })) diff --git a/homeassistant/components/config/auth_provider_homeassistant.py b/homeassistant/components/config/auth_provider_homeassistant.py index 8f0c969a808..3495a959f49 100644 --- a/homeassistant/components/config/auth_provider_homeassistant.py +++ b/homeassistant/components/config/auth_provider_homeassistant.py @@ -2,7 +2,6 @@ import voluptuous as vol from homeassistant.auth.providers import homeassistant as auth_ha -from homeassistant.core import callback from homeassistant.components import websocket_api from homeassistant.components.websocket_api.decorators import require_owner @@ -55,121 +54,109 @@ def _get_provider(hass): raise RuntimeError('Provider not found') -@callback @require_owner -def websocket_create(hass, connection, msg): +@websocket_api.async_response +async def websocket_create(hass, connection, msg): """Create credentials and attach to a user.""" - async def create_creds(): - """Create credentials.""" - provider = _get_provider(hass) - await provider.async_initialize() + provider = _get_provider(hass) + await provider.async_initialize() - user = await hass.auth.async_get_user(msg['user_id']) + user = await hass.auth.async_get_user(msg['user_id']) - if user is None: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'not_found', 'User not found')) - return + if user is None: + connection.send_message(websocket_api.error_message( + msg['id'], 'not_found', 'User not found')) + return - if user.system_generated: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'system_generated', - 'Cannot add credentials to a system generated user.')) - return - - try: - await hass.async_add_executor_job( - provider.data.add_auth, msg['username'], msg['password']) - except auth_ha.InvalidUser: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'username_exists', 'Username already exists')) - return - - credentials = await provider.async_get_or_create_credentials({ - 'username': msg['username'] - }) - await hass.auth.async_link_user(user, credentials) - - await provider.data.async_save() - connection.to_write.put_nowait(websocket_api.result_message(msg['id'])) - - hass.async_add_job(create_creds()) - - -@callback -@require_owner -def websocket_delete(hass, connection, msg): - """Delete username and related credential.""" - async def delete_creds(): - """Delete user credentials.""" - provider = _get_provider(hass) - await provider.async_initialize() - - credentials = await provider.async_get_or_create_credentials({ - 'username': msg['username'] - }) - - # if not new, an existing credential exists. - # Removing the credential will also remove the auth. - if not credentials.is_new: - await hass.auth.async_remove_credentials(credentials) - - connection.to_write.put_nowait( - websocket_api.result_message(msg['id'])) - return - - try: - provider.data.async_remove_auth(msg['username']) - await provider.data.async_save() - except auth_ha.InvalidUser: - connection.to_write.put_nowait(websocket_api.error_message( - msg['id'], 'auth_not_found', 'Given username was not found.')) - return - - connection.to_write.put_nowait( - websocket_api.result_message(msg['id'])) - - hass.async_add_job(delete_creds()) - - -@callback -def websocket_change_password(hass, connection, msg): - """Change user password.""" - async def change_password(): - """Change user password.""" - user = connection.request.get('hass_user') - if user is None: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'user_not_found', 'User not found')) - return - - provider = _get_provider(hass) - await provider.async_initialize() - - username = None - for credential in user.credentials: - if credential.auth_provider_type == provider.type: - username = credential.data['username'] - break - - if username is None: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'credentials_not_found', 'Credentials not found')) - return - - try: - await provider.async_validate_login( - username, msg['current_password']) - except auth_ha.InvalidAuth: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'invalid_password', 'Invalid password')) - return + if user.system_generated: + connection.send_message(websocket_api.error_message( + msg['id'], 'system_generated', + 'Cannot add credentials to a system generated user.')) + return + try: await hass.async_add_executor_job( - provider.data.change_password, username, msg['new_password']) - await provider.data.async_save() + provider.data.add_auth, msg['username'], msg['password']) + except auth_ha.InvalidUser: + connection.send_message(websocket_api.error_message( + msg['id'], 'username_exists', 'Username already exists')) + return - connection.send_message_outside( + credentials = await provider.async_get_or_create_credentials({ + 'username': msg['username'] + }) + await hass.auth.async_link_user(user, credentials) + + await provider.data.async_save() + connection.send_message(websocket_api.result_message(msg['id'])) + + +@require_owner +@websocket_api.async_response +async def websocket_delete(hass, connection, msg): + """Delete username and related credential.""" + provider = _get_provider(hass) + await provider.async_initialize() + + credentials = await provider.async_get_or_create_credentials({ + 'username': msg['username'] + }) + + # if not new, an existing credential exists. + # Removing the credential will also remove the auth. + if not credentials.is_new: + await hass.auth.async_remove_credentials(credentials) + + connection.send_message( websocket_api.result_message(msg['id'])) + return - hass.async_add_job(change_password()) + try: + provider.data.async_remove_auth(msg['username']) + await provider.data.async_save() + except auth_ha.InvalidUser: + connection.send_message(websocket_api.error_message( + msg['id'], 'auth_not_found', 'Given username was not found.')) + return + + connection.send_message( + websocket_api.result_message(msg['id'])) + + +@websocket_api.async_response +async def websocket_change_password(hass, connection, msg): + """Change user password.""" + user = connection.user + if user is None: + connection.send_message(websocket_api.error_message( + msg['id'], 'user_not_found', 'User not found')) + return + + provider = _get_provider(hass) + await provider.async_initialize() + + username = None + for credential in user.credentials: + if credential.auth_provider_type == provider.type: + username = credential.data['username'] + break + + if username is None: + connection.send_message(websocket_api.error_message( + msg['id'], 'credentials_not_found', 'Credentials not found')) + return + + try: + await provider.async_validate_login( + username, msg['current_password']) + except auth_ha.InvalidAuth: + connection.send_message(websocket_api.error_message( + msg['id'], 'invalid_password', 'Invalid password')) + return + + await hass.async_add_executor_job( + provider.data.change_password, username, msg['new_password']) + await provider.data.async_save() + + connection.send_message( + websocket_api.result_message(msg['id'])) diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index 88aa5727a97..54396f8956c 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -31,7 +31,7 @@ def websocket_list_devices(hass, connection, msg): async def retrieve_entities(): """Get devices from registry.""" registry = await async_get_registry(hass) - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], [{ 'config_entries': list(entry.config_entries), 'connections': list(entry.connections), diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 18d66ec623a..1ede76d0fd8 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -55,7 +55,7 @@ async def websocket_list_entities(hass, connection, msg): Async friendly. """ registry = await async_get_registry(hass) - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], [{ 'config_entry_id': entry.config_entry_id, 'device_id': entry.device_id, @@ -77,11 +77,11 @@ async def websocket_get_entity(hass, connection, msg): entry = registry.entities.get(msg['entity_id']) if entry is None: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], ERR_NOT_FOUND, 'Entity not found')) return - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], _entry_dict(entry) )) @@ -95,7 +95,7 @@ async def websocket_update_entity(hass, connection, msg): registry = await async_get_registry(hass) if msg['entity_id'] not in registry.entities: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], ERR_NOT_FOUND, 'Entity not found')) return @@ -112,11 +112,11 @@ async def websocket_update_entity(hass, connection, msg): entry = registry.async_update_entity( msg['entity_id'], **changes) except ValueError as err: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'invalid_info', str(err) )) else: - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], _entry_dict(entry) )) diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 023e75aac85..083f1a5f0d5 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -145,7 +145,7 @@ class Panel: index_view.get) @callback - def to_response(self, hass, request): + def to_response(self): """Panel as dictionary.""" return { 'component_name': self.component_name, @@ -485,12 +485,10 @@ def websocket_get_panels(hass, connection, msg): Async friendly. """ panels = { - panel: - connection.hass.data[DATA_PANELS][panel].to_response( - connection.hass, connection.request) + panel: connection.hass.data[DATA_PANELS][panel].to_response() for panel in connection.hass.data[DATA_PANELS]} - connection.to_write.put_nowait(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], panels)) @@ -500,25 +498,21 @@ def websocket_get_themes(hass, connection, msg): Async friendly. """ - connection.to_write.put_nowait(websocket_api.result_message(msg['id'], { + connection.send_message(websocket_api.result_message(msg['id'], { 'themes': hass.data[DATA_THEMES], 'default_theme': hass.data[DATA_DEFAULT_THEME], })) -@callback -def websocket_get_translations(hass, connection, msg): +@websocket_api.async_response +async def websocket_get_translations(hass, connection, msg): """Handle get translations command. Async friendly. """ - async def send_translations(): - """Send a translation.""" - resources = await async_get_translations(hass, msg['language']) - connection.send_message_outside(websocket_api.result_message( - msg['id'], { - 'resources': resources, - } - )) - - hass.async_add_job(send_translations()) + resources = await async_get_translations(hass, msg['language']) + connection.send_message(websocket_api.result_message( + msg['id'], { + 'resources': resources, + } + )) diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index a18b4de7a10..bcc86b36dbe 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -112,6 +112,7 @@ async def async_validate_auth_header(request, api_password=None): if refresh_token is None: return False + request['hass_refresh_token'] = refresh_token request['hass_user'] = refresh_token.user return True diff --git a/homeassistant/components/lovelace/__init__.py b/homeassistant/components/lovelace/__init__.py index eba69159048..a24c8eb9e91 100644 --- a/homeassistant/components/lovelace/__init__.py +++ b/homeassistant/components/lovelace/__init__.py @@ -48,4 +48,4 @@ async def websocket_lovelace_config(hass, connection, msg): if error is not None: message = websocket_api.error_message(msg['id'], *error) - connection.send_message_outside(message) + connection.send_message(message) diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index 85016df7262..8530a01d3e6 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -874,19 +874,19 @@ async def websocket_handle_thumbnail(hass, connection, msg): player = component.get_entity(msg['entity_id']) if player is None: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'entity_not_found', 'Entity not found')) return data, content_type = await player.async_get_media_image() if data is None: - connection.send_message_outside(websocket_api.error_message( + connection.send_message(websocket_api.error_message( msg['id'], 'thumbnail_fetch_failed', 'Failed to fetch thumbnail')) return - connection.send_message_outside(websocket_api.result_message( + connection.send_message(websocket_api.result_message( msg['id'], { 'content_type': content_type, 'content': base64.b64encode(data).decode('utf-8') diff --git a/homeassistant/components/persistent_notification/__init__.py b/homeassistant/components/persistent_notification/__init__.py index 066afe1fe22..a0f5cdae24d 100644 --- a/homeassistant/components/persistent_notification/__init__.py +++ b/homeassistant/components/persistent_notification/__init__.py @@ -199,7 +199,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]: def websocket_get_notifications( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg): """Return a list of persistent_notifications.""" - connection.to_write.put_nowait( + connection.send_message( websocket_api.result_message(msg['id'], [ { key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE, diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index 448256e31fd..41d0efaf3aa 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -4,49 +4,18 @@ Websocket based API for Home Assistant. For more details about this component, please refer to the documentation at https://developers.home-assistant.io/docs/external_api_websocket.html """ -import asyncio -from concurrent import futures -from contextlib import suppress -from functools import partial -import json -import logging - -from aiohttp import web -import voluptuous as vol -from voluptuous.humanize import humanize_error - -from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__ -from homeassistant.core import Context, callback +from homeassistant.core import callback from homeassistant.loader import bind_hass -from homeassistant.helpers.json import JSONEncoder -from homeassistant.components.http import HomeAssistantView -from homeassistant.components.http.auth import validate_password -from homeassistant.components.http.const import KEY_AUTHENTICATED -from homeassistant.components.http.ban import process_wrong_login, \ - process_success_login -from . import commands, const, decorators, messages +from . import commands, connection, const, decorators, http, messages -DOMAIN = 'websocket_api' +DOMAIN = const.DOMAIN -URL = '/api/websocket' DEPENDENCIES = ('http',) -MAX_PENDING_MSG = 512 - - -_LOGGER = logging.getLogger(__name__) - -JSON_DUMP = partial(json.dumps, cls=JSONEncoder) - -TYPE_AUTH = 'auth' -TYPE_AUTH_INVALID = 'auth_invalid' -TYPE_AUTH_OK = 'auth_ok' -TYPE_AUTH_REQUIRED = 'auth_required' - - -# Backwards compat +# Backwards compat / Make it easier to integrate # pylint: disable=invalid-name +ActiveConnection = connection.ActiveConnection BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA error_message = messages.error_message result_message = messages.result_message @@ -54,42 +23,6 @@ async_response = decorators.async_response ws_require_user = decorators.ws_require_user # pylint: enable=invalid-name -AUTH_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('type'): TYPE_AUTH, - vol.Exclusive('api_password', 'auth'): str, - vol.Exclusive('access_token', 'auth'): str, -}) - - -# Define the possible errors that occur when connections are cancelled. -# Originally, this was just asyncio.CancelledError, but issue #9546 showed -# that futures.CancelledErrors can also occur in some situations. -CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError) - - -def auth_ok_message(): - """Return an auth_ok message.""" - return { - 'type': TYPE_AUTH_OK, - 'ha_version': __version__, - } - - -def auth_required_message(): - """Return an auth_required message.""" - return { - 'type': TYPE_AUTH_REQUIRED, - 'ha_version': __version__, - } - - -def auth_invalid_message(message): - """Return an auth_invalid message.""" - return { - 'type': TYPE_AUTH_INVALID, - 'message': message, - } - @bind_hass @callback @@ -103,255 +36,6 @@ def async_register_command(hass, command, handler, schema): async def async_setup(hass, config): """Initialize the websocket API.""" - hass.http.register_view(WebsocketAPIView) + hass.http.register_view(http.WebsocketAPIView) commands.async_register_commands(hass) return True - - -class WebsocketAPIView(HomeAssistantView): - """View to serve a websockets endpoint.""" - - name = "websocketapi" - url = URL - requires_auth = False - - async def get(self, request): - """Handle an incoming websocket connection.""" - return await ActiveConnection(request.app['hass'], request).handle() - - -class ActiveConnection: - """Handle an active websocket client connection.""" - - def __init__(self, hass, request): - """Initialize an active connection.""" - self.hass = hass - self.request = request - self.wsock = None - self.event_listeners = {} - self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop) - self._handle_task = None - self._writer_task = None - - @property - def user(self): - """Return the user associated with the connection.""" - return self.request.get('hass_user') - - def context(self, msg): - """Return a context.""" - user = self.user - if user is None: - return Context() - return Context(user_id=user.id) - - def debug(self, message1, message2=''): - """Print a debug message.""" - _LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2) - - def log_error(self, message1, message2=''): - """Print an error message.""" - _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2) - - async def _writer(self): - """Write outgoing messages.""" - # Exceptions if Socket disconnected or cancelled by connection handler - with suppress(RuntimeError, *CANCELLATION_ERRORS): - while not self.wsock.closed: - message = await self.to_write.get() - if message is None: - break - self.debug("Sending", message) - try: - await self.wsock.send_json(message, dumps=JSON_DUMP) - except TypeError as err: - _LOGGER.error('Unable to serialize to JSON: %s\n%s', - err, message) - - @callback - def send_message_outside(self, message): - """Send a message to the client. - - Closes connection if the client is not reading the messages. - - Async friendly. - """ - try: - self.to_write.put_nowait(message) - except asyncio.QueueFull: - self.log_error("Client exceeded max pending messages [2]:", - MAX_PENDING_MSG) - self.cancel() - - @callback - def cancel(self): - """Cancel the connection.""" - self._handle_task.cancel() - self._writer_task.cancel() - - async def handle(self): - """Handle the websocket connection.""" - request = self.request - wsock = self.wsock = web.WebSocketResponse(heartbeat=55) - await wsock.prepare(request) - self.debug("Connected") - - self._handle_task = asyncio.Task.current_task(loop=self.hass.loop) - - @callback - def handle_hass_stop(event): - """Cancel this connection.""" - self.cancel() - - unsub_stop = self.hass.bus.async_listen( - EVENT_HOMEASSISTANT_STOP, handle_hass_stop) - self._writer_task = self.hass.async_add_job(self._writer()) - final_message = None - msg = None - authenticated = False - - try: - if request[KEY_AUTHENTICATED]: - authenticated = True - - # always request auth when auth is active - # even request passed pre-authentication (trusted networks) - # or when using legacy api_password - if self.hass.auth.active or not authenticated: - self.debug("Request auth") - await self.wsock.send_json(auth_required_message()) - msg = await wsock.receive_json() - msg = AUTH_MESSAGE_SCHEMA(msg) - - if self.hass.auth.active and 'access_token' in msg: - self.debug("Received access_token") - refresh_token = \ - await self.hass.auth.async_validate_access_token( - msg['access_token']) - authenticated = refresh_token is not None - if authenticated: - request['hass_user'] = refresh_token.user - request['refresh_token_id'] = refresh_token.id - - elif ((not self.hass.auth.active or - self.hass.auth.support_legacy) and - 'api_password' in msg): - self.debug("Received api_password") - authenticated = validate_password( - request, msg['api_password']) - - if not authenticated: - self.debug("Authorization failed") - await self.wsock.send_json( - auth_invalid_message('Invalid access token or password')) - await process_wrong_login(request) - return wsock - - self.debug("Auth OK") - await process_success_login(request) - await self.wsock.send_json(auth_ok_message()) - - # ---------- AUTH PHASE OVER ---------- - - msg = await wsock.receive_json() - last_id = 0 - handlers = self.hass.data[DOMAIN] - - while msg: - self.debug("Received", msg) - msg = messages.MINIMAL_MESSAGE_SCHEMA(msg) - cur_id = msg['id'] - - if cur_id <= last_id: - self.to_write.put_nowait(messages.error_message( - cur_id, const.ERR_ID_REUSE, - 'Identifier values have to increase.')) - - elif msg['type'] not in handlers: - self.log_error( - 'Received invalid command: {}'.format(msg['type'])) - self.to_write.put_nowait(messages.error_message( - cur_id, const.ERR_UNKNOWN_COMMAND, - 'Unknown command.')) - - else: - handler, schema = handlers[msg['type']] - try: - handler(self.hass, self, schema(msg)) - except Exception: # pylint: disable=broad-except - _LOGGER.exception('Error handling message: %s', msg) - self.to_write.put_nowait(messages.error_message( - cur_id, const.ERR_UNKNOWN_ERROR, - 'Unknown error.')) - - last_id = cur_id - msg = await wsock.receive_json() - - except vol.Invalid as err: - error_msg = "Message incorrectly formatted: " - if msg: - error_msg += humanize_error(msg, err) - else: - error_msg += str(err) - - self.log_error(error_msg) - - if not authenticated: - final_message = auth_invalid_message(error_msg) - - else: - if isinstance(msg, dict): - iden = msg.get('id') - else: - iden = None - - final_message = messages.error_message( - iden, const.ERR_INVALID_FORMAT, error_msg) - - except TypeError as err: - if wsock.closed: - self.debug("Connection closed by client") - else: - _LOGGER.exception("Unexpected TypeError: %s", err) - - except ValueError as err: - msg = "Received invalid JSON" - value = getattr(err, 'doc', None) # Py3.5+ only - if value: - msg += ': {}'.format(value) - self.log_error(msg) - self._writer_task.cancel() - - except CANCELLATION_ERRORS: - self.debug("Connection cancelled") - - except asyncio.QueueFull: - self.log_error("Client exceeded max pending messages [1]:", - MAX_PENDING_MSG) - self._writer_task.cancel() - - except Exception: # pylint: disable=broad-except - error = "Unexpected error inside websocket API. " - if msg is not None: - error += str(msg) - _LOGGER.exception(error) - - finally: - unsub_stop() - - for unsub in self.event_listeners.values(): - unsub() - - try: - if final_message is not None: - self.to_write.put_nowait(final_message) - self.to_write.put_nowait(None) - # Make sure all error messages are written before closing - await self._writer_task - except asyncio.QueueFull: - self._writer_task.cancel() - - await wsock.close() - self.debug("Closed connection") - - return wsock diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py new file mode 100644 index 00000000000..db41f3df06d --- /dev/null +++ b/homeassistant/components/websocket_api/auth.py @@ -0,0 +1,99 @@ +"""Handle the auth of a connection.""" +import voluptuous as vol +from voluptuous.humanize import humanize_error + +from homeassistant.const import __version__ +from homeassistant.components.http.auth import validate_password +from homeassistant.components.http.ban import process_wrong_login, \ + process_success_login + +from .connection import ActiveConnection +from .error import Disconnect + +TYPE_AUTH = 'auth' +TYPE_AUTH_INVALID = 'auth_invalid' +TYPE_AUTH_OK = 'auth_ok' +TYPE_AUTH_REQUIRED = 'auth_required' + +AUTH_MESSAGE_SCHEMA = vol.Schema({ + vol.Required('type'): TYPE_AUTH, + vol.Exclusive('api_password', 'auth'): str, + vol.Exclusive('access_token', 'auth'): str, +}) + + +def auth_ok_message(): + """Return an auth_ok message.""" + return { + 'type': TYPE_AUTH_OK, + 'ha_version': __version__, + } + + +def auth_required_message(): + """Return an auth_required message.""" + return { + 'type': TYPE_AUTH_REQUIRED, + 'ha_version': __version__, + } + + +def auth_invalid_message(message): + """Return an auth_invalid message.""" + return { + 'type': TYPE_AUTH_INVALID, + 'message': message, + } + + +class AuthPhase: + """Connection that requires client to authenticate first.""" + + def __init__(self, logger, hass, send_message, request): + """Initialize the authentiated connection.""" + self._hass = hass + self._send_message = send_message + self._logger = logger + self._request = request + self._authenticated = False + self._connection = None + + async def async_handle(self, msg): + """Handle authentication.""" + try: + msg = AUTH_MESSAGE_SCHEMA(msg) + except vol.Invalid as err: + error_msg = 'Auth message incorrectly formatted: {}'.format( + humanize_error(msg, err)) + self._logger.warning(error_msg) + self._send_message(auth_invalid_message(error_msg)) + raise Disconnect + + if self._hass.auth.active and 'access_token' in msg: + self._logger.debug("Received access_token") + refresh_token = \ + await self._hass.auth.async_validate_access_token( + msg['access_token']) + if refresh_token is not None: + return await self._async_finish_auth( + refresh_token.user, refresh_token) + + elif ((not self._hass.auth.active or self._hass.auth.support_legacy) + and 'api_password' in msg): + self._logger.debug("Received api_password") + if validate_password(self._request, msg['api_password']): + return await self._async_finish_auth(None, None) + + self._send_message(auth_invalid_message( + 'Invalid access token or password')) + await process_wrong_login(self._request) + raise Disconnect + + async def _async_finish_auth(self, user, refresh_token) \ + -> ActiveConnection: + """Create an active connection.""" + self._logger.debug("Auth OK") + await process_success_login(self._request) + self._send_message(auth_ok_message()) + return ActiveConnection( + self._logger, self._hass, self._send_message, user, refresh_token) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index c9808f3a692..8e1dac4af8e 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -103,12 +103,12 @@ def handle_subscribe_events(hass, connection, msg): if event.event_type == EVENT_TIME_CHANGED: return - connection.send_message_outside(event_message(msg['id'], event)) + connection.send_message(event_message(msg['id'], event)) connection.event_listeners[msg['id']] = hass.bus.async_listen( msg['event_type'], forward_events) - connection.to_write.put_nowait(messages.result_message(msg['id'])) + connection.send_message(messages.result_message(msg['id'])) @callback @@ -121,9 +121,9 @@ def handle_unsubscribe_events(hass, connection, msg): if subscription in connection.event_listeners: connection.event_listeners.pop(subscription)() - connection.to_write.put_nowait(messages.result_message(msg['id'])) + connection.send_message(messages.result_message(msg['id'])) else: - connection.to_write.put_nowait(messages.error_message( + connection.send_message(messages.error_message( msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.')) @@ -140,7 +140,7 @@ async def handle_call_service(hass, connection, msg): await hass.services.async_call( msg['domain'], msg['service'], msg.get('service_data'), blocking, connection.context(msg)) - connection.send_message_outside(messages.result_message(msg['id'])) + connection.send_message(messages.result_message(msg['id'])) @callback @@ -149,7 +149,7 @@ def handle_get_states(hass, connection, msg): Async friendly. """ - connection.to_write.put_nowait(messages.result_message( + connection.send_message(messages.result_message( msg['id'], hass.states.async_all())) @@ -160,7 +160,7 @@ async def handle_get_services(hass, connection, msg): Async friendly. """ descriptions = await async_get_all_descriptions(hass) - connection.send_message_outside( + connection.send_message( messages.result_message(msg['id'], descriptions)) @@ -170,7 +170,7 @@ def handle_get_config(hass, connection, msg): Async friendly. """ - connection.to_write.put_nowait(messages.result_message( + connection.send_message(messages.result_message( msg['id'], hass.config.as_dict())) @@ -180,4 +180,4 @@ def handle_ping(hass, connection, msg): Async friendly. """ - connection.to_write.put_nowait(pong_message(msg['id'])) + connection.send_message(pong_message(msg['id'])) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py new file mode 100644 index 00000000000..1cb58591a0a --- /dev/null +++ b/homeassistant/components/websocket_api/connection.py @@ -0,0 +1,78 @@ +"""Connection session.""" +import voluptuous as vol + +from homeassistant.core import callback, Context + +from . import const, messages + + +class ActiveConnection: + """Handle an active websocket client connection.""" + + def __init__(self, logger, hass, send_message, user, refresh_token): + """Initialize an active connection.""" + self.logger = logger + self.hass = hass + self.send_message = send_message + self.user = user + if refresh_token: + self.refresh_token_id = refresh_token.id + else: + self.refresh_token_id = None + + self.event_listeners = {} + self.last_id = 0 + + def context(self, msg): + """Return a context.""" + user = self.user + if user is None: + return Context() + return Context(user_id=user.id) + + @callback + def async_handle(self, msg): + """Handle a single incoming message.""" + handlers = self.hass.data[const.DOMAIN] + + try: + msg = messages.MINIMAL_MESSAGE_SCHEMA(msg) + cur_id = msg['id'] + except vol.Invalid: + self.logger.error('Received invalid command', msg) + self.send_message(messages.error_message( + msg.get('id'), const.ERR_INVALID_FORMAT, + 'Message incorrectly formatted.')) + return + + if cur_id <= self.last_id: + self.send_message(messages.error_message( + cur_id, const.ERR_ID_REUSE, + 'Identifier values have to increase.')) + return + + if msg['type'] not in handlers: + self.logger.error( + 'Received invalid command: {}'.format(msg['type'])) + self.send_message(messages.error_message( + cur_id, const.ERR_UNKNOWN_COMMAND, + 'Unknown command.')) + return + + handler, schema = handlers[msg['type']] + + try: + handler(self.hass, self, schema(msg)) + except Exception: # pylint: disable=broad-except + self.logger.exception('Error handling message: %s', msg) + self.send_message(messages.error_message( + cur_id, const.ERR_UNKNOWN_ERROR, + 'Unknown error.')) + + self.last_id = cur_id + + @callback + def async_close(self): + """Close down connection.""" + for unsub in self.event_listeners.values(): + unsub() diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index cbc56b168c6..8d452959ca5 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -1,4 +1,11 @@ """Websocket constants.""" +import asyncio +from concurrent import futures + +DOMAIN = 'websocket_api' +URL = '/api/websocket' +MAX_PENDING_MSG = 512 + ERR_ID_REUSE = 1 ERR_INVALID_FORMAT = 2 ERR_NOT_FOUND = 3 @@ -6,3 +13,8 @@ ERR_UNKNOWN_COMMAND = 4 ERR_UNKNOWN_ERROR = 5 TYPE_RESULT = 'result' + +# Define the possible errors that occur when connections are cancelled. +# Originally, this was just asyncio.CancelledError, but issue #9546 showed +# that futures.CancelledErrors can also occur in some situations. +CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError) diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index df32dd06d2b..aaa054e4054 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -18,7 +18,7 @@ def async_response(func): await func(hass, connection, msg) except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") - connection.send_message_outside(messages.error_message( + connection.send_message(messages.error_message( msg['id'], 'unknown', 'Unexpected error occurred')) @callback @@ -35,10 +35,10 @@ def require_owner(func): @wraps(func) def with_owner(hass, connection, msg): """Check owner and call function.""" - user = connection.request.get('hass_user') + user = connection.user if user is None or not user.is_owner: - connection.to_write.put_nowait(messages.error_message( + connection.send_message(messages.error_message( msg['id'], 'unauthorized', 'This command is for owners only.')) return @@ -61,7 +61,7 @@ def ws_require_user( """Check current user.""" def output_error(message_id, message): """Output error message.""" - connection.send_message_outside(messages.error_message( + connection.send_message(messages.error_message( msg['id'], message_id, message)) if connection.user is None: diff --git a/homeassistant/components/websocket_api/error.py b/homeassistant/components/websocket_api/error.py new file mode 100644 index 00000000000..c0b7ea04554 --- /dev/null +++ b/homeassistant/components/websocket_api/error.py @@ -0,0 +1,8 @@ +"""WebSocket API related errors.""" +from homeassistant.exceptions import HomeAssistantError + + +class Disconnect(HomeAssistantError): + """Disconnect the current session.""" + + pass diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py new file mode 100644 index 00000000000..87f25c9b3ef --- /dev/null +++ b/homeassistant/components/websocket_api/http.py @@ -0,0 +1,189 @@ +"""View to accept incoming websocket connection.""" +import asyncio +from contextlib import suppress +from functools import partial +import json +import logging + +from aiohttp import web, WSMsgType +import async_timeout + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import callback +from homeassistant.components.http import HomeAssistantView +from homeassistant.helpers.json import JSONEncoder + +from .const import MAX_PENDING_MSG, CANCELLATION_ERRORS, URL +from .auth import AuthPhase, auth_required_message +from .error import Disconnect + +JSON_DUMP = partial(json.dumps, cls=JSONEncoder) + + +class WebsocketAPIView(HomeAssistantView): + """View to serve a websockets endpoint.""" + + name = "websocketapi" + url = URL + requires_auth = False + + async def get(self, request): + """Handle an incoming websocket connection.""" + return await WebSocketHandler( + request.app['hass'], request).async_handle() + + +class WebSocketHandler: + """Handle an active websocket client connection.""" + + def __init__(self, hass, request): + """Initialize an active connection.""" + self.hass = hass + self.request = request + self.wsock = None + self._to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop) + self._handle_task = None + self._writer_task = None + self._logger = logging.getLogger( + "{}.connection.{}".format(__name__, id(self))) + + async def _writer(self): + """Write outgoing messages.""" + # Exceptions if Socket disconnected or cancelled by connection handler + with suppress(RuntimeError, *CANCELLATION_ERRORS): + while not self.wsock.closed: + message = await self._to_write.get() + if message is None: + break + self._logger.debug("Sending %s", message) + try: + await self.wsock.send_json(message, dumps=JSON_DUMP) + except TypeError as err: + self._logger.error('Unable to serialize to JSON: %s\n%s', + err, message) + + @callback + def _send_message(self, message): + """Send a message to the client. + + Closes connection if the client is not reading the messages. + + Async friendly. + """ + try: + self._to_write.put_nowait(message) + except asyncio.QueueFull: + self._logger.error("Client exceeded max pending messages [2]: %s", + MAX_PENDING_MSG) + self._cancel() + + @callback + def _cancel(self): + """Cancel the connection.""" + self._handle_task.cancel() + self._writer_task.cancel() + + async def async_handle(self): + """Handle a websocket response.""" + request = self.request + wsock = self.wsock = web.WebSocketResponse(heartbeat=55) + await wsock.prepare(request) + self._logger.debug("Connected") + + # Py3.7+ + if hasattr(asyncio, 'current_task'): + # pylint: disable=no-member + self._handle_task = asyncio.current_task() + else: + self._handle_task = asyncio.Task.current_task(loop=self.hass.loop) + + @callback + def handle_hass_stop(event): + """Cancel this connection.""" + self._cancel() + + unsub_stop = self.hass.bus.async_listen( + EVENT_HOMEASSISTANT_STOP, handle_hass_stop) + + self._writer_task = self.hass.async_create_task(self._writer()) + + auth = AuthPhase(self._logger, self.hass, self._send_message, request) + connection = None + disconnect_warn = None + + try: + self._send_message(auth_required_message()) + + # Auth Phase + try: + with async_timeout.timeout(10): + msg = await wsock.receive() + except asyncio.TimeoutError: + disconnect_warn = \ + 'Did not receive auth message within 10 seconds' + raise Disconnect + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): + raise Disconnect + + elif msg.type != WSMsgType.TEXT: + disconnect_warn = 'Received non-Text message.' + raise Disconnect + + try: + msg = msg.json() + except ValueError: + disconnect_warn = 'Received invalid JSON.' + raise Disconnect + + self._logger.debug("Received %s", msg) + connection = await auth.async_handle(msg) + + # Command phase + while not wsock.closed: + msg = await wsock.receive() + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): + break + + elif msg.type != WSMsgType.TEXT: + disconnect_warn = 'Received non-Text message.' + break + + try: + msg = msg.json() + except ValueError: + disconnect_warn = 'Received invalid JSON.' + break + + self._logger.debug("Received %s", msg) + connection.async_handle(msg) + + except asyncio.CancelledError: + self._logger.info("Connection closed by client") + + except Disconnect: + pass + + except Exception: # pylint: disable=broad-except + self._logger.exception("Unexpected error inside websocket API") + + finally: + unsub_stop() + + if connection is not None: + connection.async_close() + + try: + self._to_write.put_nowait(None) + # Make sure all error messages are written before closing + await self._writer_task + except asyncio.QueueFull: + self._writer_task.cancel() + + await wsock.close() + + if disconnect_warn is None: + self._logger.debug("Disconnected") + else: + self._logger.warning("Disconnected: %s", disconnect_warn) diff --git a/tests/components/conftest.py b/tests/components/conftest.py index 232405a632c..252d0b1d872 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -4,7 +4,9 @@ from unittest.mock import patch import pytest from homeassistant.setup import async_setup_component -from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.http import URL +from homeassistant.components.websocket_api.auth import ( + TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED) from tests.common import MockUser, CLIENT_ID @@ -14,41 +16,52 @@ def hass_ws_client(aiohttp_client): """Websocket client fixture connected to websocket server.""" async def create_client(hass, access_token=None): """Create a websocket client.""" - wapi = hass.components.websocket_api assert await async_setup_component(hass, 'websocket_api') client = await aiohttp_client(hass.http.app) - patching = None + patches = [] - if access_token is not None: - patching = patch('homeassistant.auth.AuthManager.active', - return_value=True) - patching.start() + if access_token is None: + patches.append(patch( + 'homeassistant.auth.AuthManager.active', return_value=False)) + patches.append(patch( + 'homeassistant.auth.AuthManager.support_legacy', + return_value=True)) + patches.append(patch( + 'homeassistant.components.websocket_api.auth.' + 'validate_password', return_value=True)) + else: + patches.append(patch( + 'homeassistant.auth.AuthManager.active', return_value=True)) + patches.append(patch( + 'homeassistant.components.http.auth.setup_auth')) + + for p in patches: + p.start() try: - websocket = await client.ws_connect(wapi.URL) + websocket = await client.ws_connect(URL) auth_resp = await websocket.receive_json() + assert auth_resp['type'] == TYPE_AUTH_REQUIRED - if auth_resp['type'] == wapi.TYPE_AUTH_OK: - assert access_token is None, \ - 'Access token given but no auth required' - return websocket - - assert access_token is not None, \ - 'Access token required for fixture' - - await websocket.send_json({ - 'type': websocket_api.TYPE_AUTH, - 'access_token': access_token - }) + if access_token is None: + await websocket.send_json({ + 'type': TYPE_AUTH, + 'api_password': 'bla' + }) + else: + await websocket.send_json({ + 'type': TYPE_AUTH, + 'access_token': access_token + }) auth_ok = await websocket.receive_json() - assert auth_ok['type'] == wapi.TYPE_AUTH_OK + assert auth_ok['type'] == TYPE_AUTH_OK finally: - if patching is not None: - patching.stop() + for p in patches: + p.stop() # wrap in client websocket.client = client diff --git a/tests/components/test_panel_iframe.py b/tests/components/test_panel_iframe.py index 3ac06c09a26..cb868f64b58 100644 --- a/tests/components/test_panel_iframe.py +++ b/tests/components/test_panel_iframe.py @@ -62,7 +62,7 @@ class TestPanelIframe(unittest.TestCase): panels = self.hass.data[frontend.DATA_PANELS] - assert panels.get('router').to_response(self.hass, None) == { + assert panels.get('router').to_response() == { 'component_name': 'iframe', 'config': {'url': 'http://192.168.1.1'}, 'icon': 'mdi:network-wireless', @@ -70,7 +70,7 @@ class TestPanelIframe(unittest.TestCase): 'url_path': 'router' } - assert panels.get('weather').to_response(self.hass, None) == { + assert panels.get('weather').to_response() == { 'component_name': 'iframe', 'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'}, 'icon': 'mdi:weather', @@ -78,7 +78,7 @@ class TestPanelIframe(unittest.TestCase): 'url_path': 'weather', } - assert panels.get('api').to_response(self.hass, None) == { + assert panels.get('api').to_response() == { 'component_name': 'iframe', 'config': {'url': '/api'}, 'icon': 'mdi:weather', @@ -86,7 +86,7 @@ class TestPanelIframe(unittest.TestCase): 'url_path': 'api', } - assert panels.get('ftp').to_response(self.hass, None) == { + assert panels.get('ftp').to_response() == { 'component_name': 'iframe', 'config': {'url': 'ftp://some/ftp'}, 'icon': 'mdi:weather', diff --git a/tests/components/websocket_api/conftest.py b/tests/components/websocket_api/conftest.py index 063e0b43d1b..b7825600cb1 100644 --- a/tests/components/websocket_api/conftest.py +++ b/tests/components/websocket_api/conftest.py @@ -2,7 +2,8 @@ import pytest from homeassistant.setup import async_setup_component -from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api.http import URL +from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED from . import API_PASSWORD @@ -24,10 +25,10 @@ def no_auth_websocket_client(hass, loop, aiohttp_client): })) client = loop.run_until_complete(aiohttp_client(hass.http.app)) - ws = loop.run_until_complete(client.ws_connect(wapi.URL)) + ws = loop.run_until_complete(client.ws_connect(URL)) auth_ok = loop.run_until_complete(ws.receive_json()) - assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_ok['type'] == TYPE_AUTH_REQUIRED yield ws diff --git a/tests/components/websocket_api/test_auth.py b/tests/components/websocket_api/test_auth.py index ee1de906fa1..ed54b509aaa 100644 --- a/tests/components/websocket_api/test_auth.py +++ b/tests/components/websocket_api/test_auth.py @@ -1,7 +1,10 @@ """Test auth of websocket API.""" from unittest.mock import patch -from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api.const import URL +from homeassistant.components.websocket_api.auth import ( + TYPE_AUTH, TYPE_AUTH_INVALID, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED) + from homeassistant.components.websocket_api import commands from homeassistant.setup import async_setup_component @@ -13,28 +16,29 @@ from . import API_PASSWORD async def test_auth_via_msg(no_auth_websocket_client): """Test authenticating.""" await no_auth_websocket_client.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'api_password': API_PASSWORD }) msg = await no_auth_websocket_client.receive_json() - assert msg['type'] == wapi.TYPE_AUTH_OK + assert msg['type'] == TYPE_AUTH_OK async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client): """Test authenticating.""" - with patch('homeassistant.components.websocket_api.process_wrong_login', - return_value=mock_coro()) as mock_process_wrong_login: + with patch('homeassistant.components.websocket_api.auth.' + 'process_wrong_login', return_value=mock_coro()) \ + as mock_process_wrong_login: await no_auth_websocket_client.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'api_password': API_PASSWORD + 'wrong' }) msg = await no_auth_websocket_client.receive_json() assert mock_process_wrong_login.called - assert msg['type'] == wapi.TYPE_AUTH_INVALID + assert msg['type'] == TYPE_AUTH_INVALID assert msg['message'] == 'Invalid access token or password' @@ -51,8 +55,8 @@ async def test_pre_auth_only_auth_allowed(no_auth_websocket_client): msg = await no_auth_websocket_client.receive_json() - assert msg['type'] == wapi.TYPE_AUTH_INVALID - assert msg['message'].startswith('Message incorrectly formatted') + assert msg['type'] == TYPE_AUTH_INVALID + assert msg['message'].startswith('Auth message incorrectly formatted') async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): @@ -65,19 +69,19 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active') as auth_active: auth_active.return_value = True auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'access_token': hass_access_token }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK + assert auth_msg['type'] == TYPE_AUTH_OK async def test_auth_active_user_inactive(hass, aiohttp_client, @@ -94,19 +98,19 @@ async def test_auth_active_user_inactive(hass, aiohttp_client, client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active') as auth_active: auth_active.return_value = True auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'access_token': hass_access_token }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + assert auth_msg['type'] == TYPE_AUTH_INVALID async def test_auth_active_with_password_not_allow(hass, aiohttp_client): @@ -119,19 +123,19 @@ async def test_auth_active_with_password_not_allow(hass, aiohttp_client): client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active', return_value=True): auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'api_password': API_PASSWORD }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + assert auth_msg['type'] == TYPE_AUTH_INVALID async def test_auth_legacy_support_with_password(hass, aiohttp_client): @@ -144,21 +148,21 @@ async def test_auth_legacy_support_with_password(hass, aiohttp_client): client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active', return_value=True),\ patch('homeassistant.auth.AuthManager.support_legacy', return_value=True): auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'api_password': API_PASSWORD }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK + assert auth_msg['type'] == TYPE_AUTH_OK async def test_auth_with_invalid_token(hass, aiohttp_client): @@ -171,16 +175,16 @@ async def test_auth_with_invalid_token(hass, aiohttp_client): client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active') as auth_active: auth_active.return_value = True auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'access_token': 'incorrect' }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + assert auth_msg['type'] == TYPE_AUTH_INVALID diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 0eaf215afaa..84c29533859 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -4,7 +4,10 @@ from unittest.mock import patch from async_timeout import timeout from homeassistant.core import callback -from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api.const import URL +from homeassistant.components.websocket_api.auth import ( + TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED +) from homeassistant.components.websocket_api import const, commands from homeassistant.setup import async_setup_component @@ -178,19 +181,19 @@ async def test_call_service_context_with_user(hass, aiohttp_client, calls = async_mock_service(hass, 'domain_test', 'test_service') client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: with patch('homeassistant.auth.AuthManager.active') as auth_active: auth_active.return_value = True auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'access_token': hass_access_token }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK + assert auth_msg['type'] == TYPE_AUTH_OK await ws.send_json({ 'id': 5, @@ -227,17 +230,17 @@ async def test_call_service_context_no_user(hass, aiohttp_client): calls = async_mock_service(hass, 'domain_test', 'test_service') client = await aiohttp_client(hass.http.app) - async with client.ws_connect(wapi.URL) as ws: + async with client.ws_connect(URL) as ws: auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + assert auth_msg['type'] == TYPE_AUTH_REQUIRED await ws.send_json({ - 'type': wapi.TYPE_AUTH, + 'type': TYPE_AUTH, 'api_password': API_PASSWORD }) auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK + assert auth_msg['type'] == TYPE_AUTH_OK await ws.send_json({ 'id': 5, diff --git a/tests/components/websocket_api/test_init.py b/tests/components/websocket_api/test_init.py index 97acc1210fc..a7e54e8146a 100644 --- a/tests/components/websocket_api/test_init.py +++ b/tests/components/websocket_api/test_init.py @@ -5,14 +5,14 @@ from unittest.mock import patch, Mock from aiohttp import WSMsgType import pytest -from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api import const, commands, messages @pytest.fixture def mock_low_queue(): """Mock a low queue.""" - with patch.object(wapi, 'MAX_PENDING_MSG', 5): + with patch('homeassistant.components.websocket_api.http.MAX_PENDING_MSG', + 5): yield