Break up websocket 2 (#17028)
* Break up websocket 2 * Lint+Test * Lintttt * Renamepull/17030/merge
parent
b5e3d8c337
commit
2e6346ca43
|
@ -432,7 +432,7 @@ def websocket_current_user(
|
|||
"""Get current user."""
|
||||
enabled_modules = await hass.auth.async_get_enabled_mfa(user)
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], {
|
||||
'id': user.id,
|
||||
'name': user.name,
|
||||
|
@ -467,7 +467,7 @@ def websocket_create_long_lived_access_token(
|
|||
access_token = hass.auth.async_create_access_token(
|
||||
refresh_token)
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], access_token))
|
||||
|
||||
hass.async_create_task(
|
||||
|
@ -479,8 +479,8 @@ def websocket_create_long_lived_access_token(
|
|||
def websocket_refresh_tokens(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
||||
"""Return metadata of users refresh tokens."""
|
||||
current_id = connection.request.get('refresh_token_id')
|
||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], [{
|
||||
current_id = connection.refresh_token_id
|
||||
connection.send_message(websocket_api.result_message(msg['id'], [{
|
||||
'id': refresh.id,
|
||||
'client_id': refresh.client_id,
|
||||
'client_name': refresh.client_name,
|
||||
|
@ -508,7 +508,7 @@ def websocket_delete_refresh_token(
|
|||
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], {}))
|
||||
|
||||
hass.async_create_task(
|
||||
|
|
|
@ -64,7 +64,7 @@ def websocket_setup_mfa(
|
|||
if flow_id is not None:
|
||||
result = await flow_manager.async_configure(
|
||||
flow_id, msg.get('user_input'))
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(
|
||||
msg['id'], _prepare_result_json(result)))
|
||||
return
|
||||
|
@ -72,7 +72,7 @@ def websocket_setup_mfa(
|
|||
mfa_module_id = msg.get('mfa_module_id')
|
||||
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
|
||||
if mfa_module is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'no_module',
|
||||
'MFA module {} is not found'.format(mfa_module_id)))
|
||||
return
|
||||
|
@ -80,7 +80,7 @@ def websocket_setup_mfa(
|
|||
result = await flow_manager.async_init(
|
||||
mfa_module_id, data={'user_id': connection.user.id})
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(
|
||||
msg['id'], _prepare_result_json(result)))
|
||||
|
||||
|
@ -99,13 +99,13 @@ def websocket_depose_mfa(
|
|||
await hass.auth.async_disable_user_mfa(
|
||||
connection.user, msg['mfa_module_id'])
|
||||
except ValueError as err:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'disable_failed',
|
||||
'Cannot disable MFA Module {}: {}'.format(
|
||||
mfa_module_id, err)))
|
||||
return
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(
|
||||
msg['id'], 'done'))
|
||||
|
||||
|
|
|
@ -460,14 +460,14 @@ async def websocket_camera_thumbnail(hass, connection, msg):
|
|||
"""
|
||||
try:
|
||||
image = await async_get_image(hass, msg['entity_id'])
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': image.content_type,
|
||||
'content': base64.b64encode(image.content).decode('utf-8')
|
||||
}
|
||||
))
|
||||
except HomeAssistantError:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
|
||||
|
||||
|
||||
|
|
|
@ -231,7 +231,7 @@ def websocket_cloud_status(hass, connection, msg):
|
|||
Async friendly.
|
||||
"""
|
||||
cloud = hass.data[DOMAIN]
|
||||
connection.to_write.put_nowait(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], _account_data(cloud)))
|
||||
|
||||
|
||||
|
@ -241,7 +241,7 @@ async def websocket_subscription(hass, connection, msg):
|
|||
cloud = hass.data[DOMAIN]
|
||||
|
||||
if not cloud.is_logged_in:
|
||||
connection.to_write.put_nowait(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'not_logged_in',
|
||||
'You need to be logged in to the cloud.'))
|
||||
return
|
||||
|
@ -250,10 +250,10 @@ async def websocket_subscription(hass, connection, msg):
|
|||
response = await cloud.fetch_subscription_info()
|
||||
|
||||
if response.status == 200:
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], await response.json()))
|
||||
else:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'request_failed', 'Failed to request subscription'))
|
||||
|
||||
|
||||
|
@ -263,7 +263,7 @@ async def websocket_update_prefs(hass, connection, msg):
|
|||
cloud = hass.data[DOMAIN]
|
||||
|
||||
if not cloud.is_logged_in:
|
||||
connection.to_write.put_nowait(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'not_logged_in',
|
||||
'You need to be logged in to the cloud.'))
|
||||
return
|
||||
|
@ -273,7 +273,7 @@ async def websocket_update_prefs(hass, connection, msg):
|
|||
changes.pop('type')
|
||||
await cloud.update_preferences(**changes)
|
||||
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], {'success': True}))
|
||||
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ def websocket_list(hass, connection, msg):
|
|||
"""Send users."""
|
||||
result = [_user_info(u) for u in await hass.auth.async_get_users()]
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], result))
|
||||
|
||||
hass.async_add_job(send_users())
|
||||
|
@ -61,8 +61,8 @@ def websocket_delete(hass, connection, msg):
|
|||
"""Delete a user."""
|
||||
async def delete_user():
|
||||
"""Delete user."""
|
||||
if msg['user_id'] == connection.request.get('hass_user').id:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
if msg['user_id'] == connection.user.id:
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'no_delete_self',
|
||||
'Unable to delete your own account'))
|
||||
return
|
||||
|
@ -70,13 +70,13 @@ def websocket_delete(hass, connection, msg):
|
|||
user = await hass.auth.async_get_user(msg['user_id'])
|
||||
|
||||
if not user:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'not_found', 'User not found'))
|
||||
return
|
||||
|
||||
await hass.auth.async_remove_user(user)
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(delete_user())
|
||||
|
@ -90,7 +90,7 @@ def websocket_create(hass, connection, msg):
|
|||
"""Create a user."""
|
||||
user = await hass.auth.async_create_user(msg['name'])
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], {
|
||||
'user': _user_info(user)
|
||||
}))
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.providers import homeassistant as auth_ha
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.decorators import require_owner
|
||||
|
||||
|
@ -55,24 +54,22 @@ def _get_provider(hass):
|
|||
raise RuntimeError('Provider not found')
|
||||
|
||||
|
||||
@callback
|
||||
@require_owner
|
||||
def websocket_create(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_create(hass, connection, msg):
|
||||
"""Create credentials and attach to a user."""
|
||||
async def create_creds():
|
||||
"""Create credentials."""
|
||||
provider = _get_provider(hass)
|
||||
await provider.async_initialize()
|
||||
|
||||
user = await hass.auth.async_get_user(msg['user_id'])
|
||||
|
||||
if user is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'not_found', 'User not found'))
|
||||
return
|
||||
|
||||
if user.system_generated:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'system_generated',
|
||||
'Cannot add credentials to a system generated user.'))
|
||||
return
|
||||
|
@ -81,7 +78,7 @@ def websocket_create(hass, connection, msg):
|
|||
await hass.async_add_executor_job(
|
||||
provider.data.add_auth, msg['username'], msg['password'])
|
||||
except auth_ha.InvalidUser:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'username_exists', 'Username already exists'))
|
||||
return
|
||||
|
||||
|
@ -91,17 +88,13 @@ def websocket_create(hass, connection, msg):
|
|||
await hass.auth.async_link_user(user, credentials)
|
||||
|
||||
await provider.data.async_save()
|
||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(create_creds())
|
||||
connection.send_message(websocket_api.result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
@require_owner
|
||||
def websocket_delete(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_delete(hass, connection, msg):
|
||||
"""Delete username and related credential."""
|
||||
async def delete_creds():
|
||||
"""Delete user credentials."""
|
||||
provider = _get_provider(hass)
|
||||
await provider.async_initialize()
|
||||
|
||||
|
@ -114,7 +107,7 @@ def websocket_delete(hass, connection, msg):
|
|||
if not credentials.is_new:
|
||||
await hass.auth.async_remove_credentials(credentials)
|
||||
|
||||
connection.to_write.put_nowait(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id']))
|
||||
return
|
||||
|
||||
|
@ -122,24 +115,20 @@ def websocket_delete(hass, connection, msg):
|
|||
provider.data.async_remove_auth(msg['username'])
|
||||
await provider.data.async_save()
|
||||
except auth_ha.InvalidUser:
|
||||
connection.to_write.put_nowait(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'auth_not_found', 'Given username was not found.'))
|
||||
return
|
||||
|
||||
connection.to_write.put_nowait(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(delete_creds())
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_change_password(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_change_password(hass, connection, msg):
|
||||
"""Change user password."""
|
||||
async def change_password():
|
||||
"""Change user password."""
|
||||
user = connection.request.get('hass_user')
|
||||
user = connection.user
|
||||
if user is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'user_not_found', 'User not found'))
|
||||
return
|
||||
|
||||
|
@ -153,7 +142,7 @@ def websocket_change_password(hass, connection, msg):
|
|||
break
|
||||
|
||||
if username is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'credentials_not_found', 'Credentials not found'))
|
||||
return
|
||||
|
||||
|
@ -161,7 +150,7 @@ def websocket_change_password(hass, connection, msg):
|
|||
await provider.async_validate_login(
|
||||
username, msg['current_password'])
|
||||
except auth_ha.InvalidAuth:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'invalid_password', 'Invalid password'))
|
||||
return
|
||||
|
||||
|
@ -169,7 +158,5 @@ def websocket_change_password(hass, connection, msg):
|
|||
provider.data.change_password, username, msg['new_password'])
|
||||
await provider.data.async_save()
|
||||
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id']))
|
||||
|
||||
hass.async_add_job(change_password())
|
||||
|
|
|
@ -31,7 +31,7 @@ def websocket_list_devices(hass, connection, msg):
|
|||
async def retrieve_entities():
|
||||
"""Get devices from registry."""
|
||||
registry = await async_get_registry(hass)
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], [{
|
||||
'config_entries': list(entry.config_entries),
|
||||
'connections': list(entry.connections),
|
||||
|
|
|
@ -55,7 +55,7 @@ async def websocket_list_entities(hass, connection, msg):
|
|||
Async friendly.
|
||||
"""
|
||||
registry = await async_get_registry(hass)
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], [{
|
||||
'config_entry_id': entry.config_entry_id,
|
||||
'device_id': entry.device_id,
|
||||
|
@ -77,11 +77,11 @@ async def websocket_get_entity(hass, connection, msg):
|
|||
entry = registry.entities.get(msg['entity_id'])
|
||||
|
||||
if entry is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
|
@ -95,7 +95,7 @@ async def websocket_update_entity(hass, connection, msg):
|
|||
registry = await async_get_registry(hass)
|
||||
|
||||
if msg['entity_id'] not in registry.entities:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
|
||||
|
@ -112,11 +112,11 @@ async def websocket_update_entity(hass, connection, msg):
|
|||
entry = registry.async_update_entity(
|
||||
msg['entity_id'], **changes)
|
||||
except ValueError as err:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'invalid_info', str(err)
|
||||
))
|
||||
else:
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
|
|
|
@ -145,7 +145,7 @@ class Panel:
|
|||
index_view.get)
|
||||
|
||||
@callback
|
||||
def to_response(self, hass, request):
|
||||
def to_response(self):
|
||||
"""Panel as dictionary."""
|
||||
return {
|
||||
'component_name': self.component_name,
|
||||
|
@ -485,12 +485,10 @@ def websocket_get_panels(hass, connection, msg):
|
|||
Async friendly.
|
||||
"""
|
||||
panels = {
|
||||
panel:
|
||||
connection.hass.data[DATA_PANELS][panel].to_response(
|
||||
connection.hass, connection.request)
|
||||
panel: connection.hass.data[DATA_PANELS][panel].to_response()
|
||||
for panel in connection.hass.data[DATA_PANELS]}
|
||||
|
||||
connection.to_write.put_nowait(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], panels))
|
||||
|
||||
|
||||
|
@ -500,25 +498,21 @@ def websocket_get_themes(hass, connection, msg):
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], {
|
||||
connection.send_message(websocket_api.result_message(msg['id'], {
|
||||
'themes': hass.data[DATA_THEMES],
|
||||
'default_theme': hass.data[DATA_DEFAULT_THEME],
|
||||
}))
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_get_translations(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_get_translations(hass, connection, msg):
|
||||
"""Handle get translations command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def send_translations():
|
||||
"""Send a translation."""
|
||||
resources = await async_get_translations(hass, msg['language'])
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'resources': resources,
|
||||
}
|
||||
))
|
||||
|
||||
hass.async_add_job(send_translations())
|
||||
|
|
|
@ -112,6 +112,7 @@ async def async_validate_auth_header(request, api_password=None):
|
|||
if refresh_token is None:
|
||||
return False
|
||||
|
||||
request['hass_refresh_token'] = refresh_token
|
||||
request['hass_user'] = refresh_token.user
|
||||
return True
|
||||
|
||||
|
|
|
@ -48,4 +48,4 @@ async def websocket_lovelace_config(hass, connection, msg):
|
|||
if error is not None:
|
||||
message = websocket_api.error_message(msg['id'], *error)
|
||||
|
||||
connection.send_message_outside(message)
|
||||
connection.send_message(message)
|
||||
|
|
|
@ -874,19 +874,19 @@ async def websocket_handle_thumbnail(hass, connection, msg):
|
|||
player = component.get_entity(msg['entity_id'])
|
||||
|
||||
if player is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'entity_not_found', 'Entity not found'))
|
||||
return
|
||||
|
||||
data, content_type = await player.async_get_media_image()
|
||||
|
||||
if data is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
connection.send_message(websocket_api.error_message(
|
||||
msg['id'], 'thumbnail_fetch_failed',
|
||||
'Failed to fetch thumbnail'))
|
||||
return
|
||||
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
connection.send_message(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': content_type,
|
||||
'content': base64.b64encode(data).decode('utf-8')
|
||||
|
|
|
@ -199,7 +199,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
|
|||
def websocket_get_notifications(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
||||
"""Return a list of persistent_notifications."""
|
||||
connection.to_write.put_nowait(
|
||||
connection.send_message(
|
||||
websocket_api.result_message(msg['id'], [
|
||||
{
|
||||
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,
|
||||
|
|
|
@ -4,49 +4,18 @@ Websocket based API for Home Assistant.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://developers.home-assistant.io/docs/external_api_websocket.html
|
||||
"""
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.auth import validate_password
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.ban import process_wrong_login, \
|
||||
process_success_login
|
||||
|
||||
from . import commands, const, decorators, messages
|
||||
from . import commands, connection, const, decorators, http, messages
|
||||
|
||||
DOMAIN = 'websocket_api'
|
||||
DOMAIN = const.DOMAIN
|
||||
|
||||
URL = '/api/websocket'
|
||||
DEPENDENCIES = ('http',)
|
||||
|
||||
MAX_PENDING_MSG = 512
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
||||
|
||||
TYPE_AUTH = 'auth'
|
||||
TYPE_AUTH_INVALID = 'auth_invalid'
|
||||
TYPE_AUTH_OK = 'auth_ok'
|
||||
TYPE_AUTH_REQUIRED = 'auth_required'
|
||||
|
||||
|
||||
# Backwards compat
|
||||
# Backwards compat / Make it easier to integrate
|
||||
# pylint: disable=invalid-name
|
||||
ActiveConnection = connection.ActiveConnection
|
||||
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
||||
error_message = messages.error_message
|
||||
result_message = messages.result_message
|
||||
|
@ -54,42 +23,6 @@ async_response = decorators.async_response
|
|||
ws_require_user = decorators.ws_require_user
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('type'): TYPE_AUTH,
|
||||
vol.Exclusive('api_password', 'auth'): str,
|
||||
vol.Exclusive('access_token', 'auth'): str,
|
||||
})
|
||||
|
||||
|
||||
# Define the possible errors that occur when connections are cancelled.
|
||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||
# that futures.CancelledErrors can also occur in some situations.
|
||||
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
||||
|
||||
|
||||
def auth_ok_message():
|
||||
"""Return an auth_ok message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_OK,
|
||||
'ha_version': __version__,
|
||||
}
|
||||
|
||||
|
||||
def auth_required_message():
|
||||
"""Return an auth_required message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_REQUIRED,
|
||||
'ha_version': __version__,
|
||||
}
|
||||
|
||||
|
||||
def auth_invalid_message(message):
|
||||
"""Return an auth_invalid message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_INVALID,
|
||||
'message': message,
|
||||
}
|
||||
|
||||
|
||||
@bind_hass
|
||||
@callback
|
||||
|
@ -103,255 +36,6 @@ def async_register_command(hass, command, handler, schema):
|
|||
|
||||
async def async_setup(hass, config):
|
||||
"""Initialize the websocket API."""
|
||||
hass.http.register_view(WebsocketAPIView)
|
||||
hass.http.register_view(http.WebsocketAPIView)
|
||||
commands.async_register_commands(hass)
|
||||
return True
|
||||
|
||||
|
||||
class WebsocketAPIView(HomeAssistantView):
|
||||
"""View to serve a websockets endpoint."""
|
||||
|
||||
name = "websocketapi"
|
||||
url = URL
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request):
|
||||
"""Handle an incoming websocket connection."""
|
||||
return await ActiveConnection(request.app['hass'], request).handle()
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, hass, request):
|
||||
"""Initialize an active connection."""
|
||||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock = None
|
||||
self.event_listeners = {}
|
||||
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
"""Return the user associated with the connection."""
|
||||
return self.request.get('hass_user')
|
||||
|
||||
def context(self, msg):
|
||||
"""Return a context."""
|
||||
user = self.user
|
||||
if user is None:
|
||||
return Context()
|
||||
return Context(user_id=user.id)
|
||||
|
||||
def debug(self, message1, message2=''):
|
||||
"""Print a debug message."""
|
||||
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
|
||||
|
||||
def log_error(self, message1, message2=''):
|
||||
"""Print an error message."""
|
||||
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
|
||||
|
||||
async def _writer(self):
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
message = await self.to_write.get()
|
||||
if message is None:
|
||||
break
|
||||
self.debug("Sending", message)
|
||||
try:
|
||||
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||
except TypeError as err:
|
||||
_LOGGER.error('Unable to serialize to JSON: %s\n%s',
|
||||
err, message)
|
||||
|
||||
@callback
|
||||
def send_message_outside(self, message):
|
||||
"""Send a message to the client.
|
||||
|
||||
Closes connection if the client is not reading the messages.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
try:
|
||||
self.to_write.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self.log_error("Client exceeded max pending messages [2]:",
|
||||
MAX_PENDING_MSG)
|
||||
self.cancel()
|
||||
|
||||
@callback
|
||||
def cancel(self):
|
||||
"""Cancel the connection."""
|
||||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
async def handle(self):
|
||||
"""Handle the websocket connection."""
|
||||
request = self.request
|
||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
await wsock.prepare(request)
|
||||
self.debug("Connected")
|
||||
|
||||
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event):
|
||||
"""Cancel this connection."""
|
||||
self.cancel()
|
||||
|
||||
unsub_stop = self.hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||
self._writer_task = self.hass.async_add_job(self._writer())
|
||||
final_message = None
|
||||
msg = None
|
||||
authenticated = False
|
||||
|
||||
try:
|
||||
if request[KEY_AUTHENTICATED]:
|
||||
authenticated = True
|
||||
|
||||
# always request auth when auth is active
|
||||
# even request passed pre-authentication (trusted networks)
|
||||
# or when using legacy api_password
|
||||
if self.hass.auth.active or not authenticated:
|
||||
self.debug("Request auth")
|
||||
await self.wsock.send_json(auth_required_message())
|
||||
msg = await wsock.receive_json()
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
|
||||
if self.hass.auth.active and 'access_token' in msg:
|
||||
self.debug("Received access_token")
|
||||
refresh_token = \
|
||||
await self.hass.auth.async_validate_access_token(
|
||||
msg['access_token'])
|
||||
authenticated = refresh_token is not None
|
||||
if authenticated:
|
||||
request['hass_user'] = refresh_token.user
|
||||
request['refresh_token_id'] = refresh_token.id
|
||||
|
||||
elif ((not self.hass.auth.active or
|
||||
self.hass.auth.support_legacy) and
|
||||
'api_password' in msg):
|
||||
self.debug("Received api_password")
|
||||
authenticated = validate_password(
|
||||
request, msg['api_password'])
|
||||
|
||||
if not authenticated:
|
||||
self.debug("Authorization failed")
|
||||
await self.wsock.send_json(
|
||||
auth_invalid_message('Invalid access token or password'))
|
||||
await process_wrong_login(request)
|
||||
return wsock
|
||||
|
||||
self.debug("Auth OK")
|
||||
await process_success_login(request)
|
||||
await self.wsock.send_json(auth_ok_message())
|
||||
|
||||
# ---------- AUTH PHASE OVER ----------
|
||||
|
||||
msg = await wsock.receive_json()
|
||||
last_id = 0
|
||||
handlers = self.hass.data[DOMAIN]
|
||||
|
||||
while msg:
|
||||
self.debug("Received", msg)
|
||||
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
|
||||
cur_id = msg['id']
|
||||
|
||||
if cur_id <= last_id:
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_ID_REUSE,
|
||||
'Identifier values have to increase.'))
|
||||
|
||||
elif msg['type'] not in handlers:
|
||||
self.log_error(
|
||||
'Received invalid command: {}'.format(msg['type']))
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_COMMAND,
|
||||
'Unknown command.'))
|
||||
|
||||
else:
|
||||
handler, schema = handlers[msg['type']]
|
||||
try:
|
||||
handler(self.hass, self, schema(msg))
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error handling message: %s', msg)
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_ERROR,
|
||||
'Unknown error.'))
|
||||
|
||||
last_id = cur_id
|
||||
msg = await wsock.receive_json()
|
||||
|
||||
except vol.Invalid as err:
|
||||
error_msg = "Message incorrectly formatted: "
|
||||
if msg:
|
||||
error_msg += humanize_error(msg, err)
|
||||
else:
|
||||
error_msg += str(err)
|
||||
|
||||
self.log_error(error_msg)
|
||||
|
||||
if not authenticated:
|
||||
final_message = auth_invalid_message(error_msg)
|
||||
|
||||
else:
|
||||
if isinstance(msg, dict):
|
||||
iden = msg.get('id')
|
||||
else:
|
||||
iden = None
|
||||
|
||||
final_message = messages.error_message(
|
||||
iden, const.ERR_INVALID_FORMAT, error_msg)
|
||||
|
||||
except TypeError as err:
|
||||
if wsock.closed:
|
||||
self.debug("Connection closed by client")
|
||||
else:
|
||||
_LOGGER.exception("Unexpected TypeError: %s", err)
|
||||
|
||||
except ValueError as err:
|
||||
msg = "Received invalid JSON"
|
||||
value = getattr(err, 'doc', None) # Py3.5+ only
|
||||
if value:
|
||||
msg += ': {}'.format(value)
|
||||
self.log_error(msg)
|
||||
self._writer_task.cancel()
|
||||
|
||||
except CANCELLATION_ERRORS:
|
||||
self.debug("Connection cancelled")
|
||||
|
||||
except asyncio.QueueFull:
|
||||
self.log_error("Client exceeded max pending messages [1]:",
|
||||
MAX_PENDING_MSG)
|
||||
self._writer_task.cancel()
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
error = "Unexpected error inside websocket API. "
|
||||
if msg is not None:
|
||||
error += str(msg)
|
||||
_LOGGER.exception(error)
|
||||
|
||||
finally:
|
||||
unsub_stop()
|
||||
|
||||
for unsub in self.event_listeners.values():
|
||||
unsub()
|
||||
|
||||
try:
|
||||
if final_message is not None:
|
||||
self.to_write.put_nowait(final_message)
|
||||
self.to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
await self._writer_task
|
||||
except asyncio.QueueFull:
|
||||
self._writer_task.cancel()
|
||||
|
||||
await wsock.close()
|
||||
self.debug("Closed connection")
|
||||
|
||||
return wsock
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
"""Handle the auth of a connection."""
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.components.http.auth import validate_password
|
||||
from homeassistant.components.http.ban import process_wrong_login, \
|
||||
process_success_login
|
||||
|
||||
from .connection import ActiveConnection
|
||||
from .error import Disconnect
|
||||
|
||||
TYPE_AUTH = 'auth'
|
||||
TYPE_AUTH_INVALID = 'auth_invalid'
|
||||
TYPE_AUTH_OK = 'auth_ok'
|
||||
TYPE_AUTH_REQUIRED = 'auth_required'
|
||||
|
||||
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('type'): TYPE_AUTH,
|
||||
vol.Exclusive('api_password', 'auth'): str,
|
||||
vol.Exclusive('access_token', 'auth'): str,
|
||||
})
|
||||
|
||||
|
||||
def auth_ok_message():
|
||||
"""Return an auth_ok message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_OK,
|
||||
'ha_version': __version__,
|
||||
}
|
||||
|
||||
|
||||
def auth_required_message():
|
||||
"""Return an auth_required message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_REQUIRED,
|
||||
'ha_version': __version__,
|
||||
}
|
||||
|
||||
|
||||
def auth_invalid_message(message):
|
||||
"""Return an auth_invalid message."""
|
||||
return {
|
||||
'type': TYPE_AUTH_INVALID,
|
||||
'message': message,
|
||||
}
|
||||
|
||||
|
||||
class AuthPhase:
|
||||
"""Connection that requires client to authenticate first."""
|
||||
|
||||
def __init__(self, logger, hass, send_message, request):
|
||||
"""Initialize the authentiated connection."""
|
||||
self._hass = hass
|
||||
self._send_message = send_message
|
||||
self._logger = logger
|
||||
self._request = request
|
||||
self._authenticated = False
|
||||
self._connection = None
|
||||
|
||||
async def async_handle(self, msg):
|
||||
"""Handle authentication."""
|
||||
try:
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
except vol.Invalid as err:
|
||||
error_msg = 'Auth message incorrectly formatted: {}'.format(
|
||||
humanize_error(msg, err))
|
||||
self._logger.warning(error_msg)
|
||||
self._send_message(auth_invalid_message(error_msg))
|
||||
raise Disconnect
|
||||
|
||||
if self._hass.auth.active and 'access_token' in msg:
|
||||
self._logger.debug("Received access_token")
|
||||
refresh_token = \
|
||||
await self._hass.auth.async_validate_access_token(
|
||||
msg['access_token'])
|
||||
if refresh_token is not None:
|
||||
return await self._async_finish_auth(
|
||||
refresh_token.user, refresh_token)
|
||||
|
||||
elif ((not self._hass.auth.active or self._hass.auth.support_legacy)
|
||||
and 'api_password' in msg):
|
||||
self._logger.debug("Received api_password")
|
||||
if validate_password(self._request, msg['api_password']):
|
||||
return await self._async_finish_auth(None, None)
|
||||
|
||||
self._send_message(auth_invalid_message(
|
||||
'Invalid access token or password'))
|
||||
await process_wrong_login(self._request)
|
||||
raise Disconnect
|
||||
|
||||
async def _async_finish_auth(self, user, refresh_token) \
|
||||
-> ActiveConnection:
|
||||
"""Create an active connection."""
|
||||
self._logger.debug("Auth OK")
|
||||
await process_success_login(self._request)
|
||||
self._send_message(auth_ok_message())
|
||||
return ActiveConnection(
|
||||
self._logger, self._hass, self._send_message, user, refresh_token)
|
|
@ -103,12 +103,12 @@ def handle_subscribe_events(hass, connection, msg):
|
|||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
connection.send_message_outside(event_message(msg['id'], event))
|
||||
connection.send_message(event_message(msg['id'], event))
|
||||
|
||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
||||
msg['event_type'], forward_events)
|
||||
|
||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -121,9 +121,9 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||
|
||||
if subscription in connection.event_listeners:
|
||||
connection.event_listeners.pop(subscription)()
|
||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
else:
|
||||
connection.to_write.put_nowait(messages.error_message(
|
||||
connection.send_message(messages.error_message(
|
||||
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
|
||||
|
||||
|
||||
|
@ -140,7 +140,7 @@ async def handle_call_service(hass, connection, msg):
|
|||
await hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
||||
connection.context(msg))
|
||||
connection.send_message_outside(messages.result_message(msg['id']))
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -149,7 +149,7 @@ def handle_get_states(hass, connection, msg):
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(messages.result_message(
|
||||
connection.send_message(messages.result_message(
|
||||
msg['id'], hass.states.async_all()))
|
||||
|
||||
|
||||
|
@ -160,7 +160,7 @@ async def handle_get_services(hass, connection, msg):
|
|||
Async friendly.
|
||||
"""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
connection.send_message_outside(
|
||||
connection.send_message(
|
||||
messages.result_message(msg['id'], descriptions))
|
||||
|
||||
|
||||
|
@ -170,7 +170,7 @@ def handle_get_config(hass, connection, msg):
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(messages.result_message(
|
||||
connection.send_message(messages.result_message(
|
||||
msg['id'], hass.config.as_dict()))
|
||||
|
||||
|
||||
|
@ -180,4 +180,4 @@ def handle_ping(hass, connection, msg):
|
|||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(pong_message(msg['id']))
|
||||
connection.send_message(pong_message(msg['id']))
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
"""Connection session."""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback, Context
|
||||
|
||||
from . import const, messages
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, logger, hass, send_message, user, refresh_token):
|
||||
"""Initialize an active connection."""
|
||||
self.logger = logger
|
||||
self.hass = hass
|
||||
self.send_message = send_message
|
||||
self.user = user
|
||||
if refresh_token:
|
||||
self.refresh_token_id = refresh_token.id
|
||||
else:
|
||||
self.refresh_token_id = None
|
||||
|
||||
self.event_listeners = {}
|
||||
self.last_id = 0
|
||||
|
||||
def context(self, msg):
|
||||
"""Return a context."""
|
||||
user = self.user
|
||||
if user is None:
|
||||
return Context()
|
||||
return Context(user_id=user.id)
|
||||
|
||||
@callback
|
||||
def async_handle(self, msg):
|
||||
"""Handle a single incoming message."""
|
||||
handlers = self.hass.data[const.DOMAIN]
|
||||
|
||||
try:
|
||||
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
|
||||
cur_id = msg['id']
|
||||
except vol.Invalid:
|
||||
self.logger.error('Received invalid command', msg)
|
||||
self.send_message(messages.error_message(
|
||||
msg.get('id'), const.ERR_INVALID_FORMAT,
|
||||
'Message incorrectly formatted.'))
|
||||
return
|
||||
|
||||
if cur_id <= self.last_id:
|
||||
self.send_message(messages.error_message(
|
||||
cur_id, const.ERR_ID_REUSE,
|
||||
'Identifier values have to increase.'))
|
||||
return
|
||||
|
||||
if msg['type'] not in handlers:
|
||||
self.logger.error(
|
||||
'Received invalid command: {}'.format(msg['type']))
|
||||
self.send_message(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_COMMAND,
|
||||
'Unknown command.'))
|
||||
return
|
||||
|
||||
handler, schema = handlers[msg['type']]
|
||||
|
||||
try:
|
||||
handler(self.hass, self, schema(msg))
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception('Error handling message: %s', msg)
|
||||
self.send_message(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_ERROR,
|
||||
'Unknown error.'))
|
||||
|
||||
self.last_id = cur_id
|
||||
|
||||
@callback
|
||||
def async_close(self):
|
||||
"""Close down connection."""
|
||||
for unsub in self.event_listeners.values():
|
||||
unsub()
|
|
@ -1,4 +1,11 @@
|
|||
"""Websocket constants."""
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
|
||||
DOMAIN = 'websocket_api'
|
||||
URL = '/api/websocket'
|
||||
MAX_PENDING_MSG = 512
|
||||
|
||||
ERR_ID_REUSE = 1
|
||||
ERR_INVALID_FORMAT = 2
|
||||
ERR_NOT_FOUND = 3
|
||||
|
@ -6,3 +13,8 @@ ERR_UNKNOWN_COMMAND = 4
|
|||
ERR_UNKNOWN_ERROR = 5
|
||||
|
||||
TYPE_RESULT = 'result'
|
||||
|
||||
# Define the possible errors that occur when connections are cancelled.
|
||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||
# that futures.CancelledErrors can also occur in some situations.
|
||||
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
||||
|
|
|
@ -18,7 +18,7 @@ def async_response(func):
|
|||
await func(hass, connection, msg)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
connection.send_message_outside(messages.error_message(
|
||||
connection.send_message(messages.error_message(
|
||||
msg['id'], 'unknown', 'Unexpected error occurred'))
|
||||
|
||||
@callback
|
||||
|
@ -35,10 +35,10 @@ def require_owner(func):
|
|||
@wraps(func)
|
||||
def with_owner(hass, connection, msg):
|
||||
"""Check owner and call function."""
|
||||
user = connection.request.get('hass_user')
|
||||
user = connection.user
|
||||
|
||||
if user is None or not user.is_owner:
|
||||
connection.to_write.put_nowait(messages.error_message(
|
||||
connection.send_message(messages.error_message(
|
||||
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
||||
return
|
||||
|
||||
|
@ -61,7 +61,7 @@ def ws_require_user(
|
|||
"""Check current user."""
|
||||
def output_error(message_id, message):
|
||||
"""Output error message."""
|
||||
connection.send_message_outside(messages.error_message(
|
||||
connection.send_message(messages.error_message(
|
||||
msg['id'], message_id, message))
|
||||
|
||||
if connection.user is None:
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
"""WebSocket API related errors."""
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
|
||||
class Disconnect(HomeAssistantError):
|
||||
"""Disconnect the current session."""
|
||||
|
||||
pass
|
|
@ -0,0 +1,189 @@
|
|||
"""View to accept incoming websocket connection."""
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web, WSMsgType
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
|
||||
from .const import MAX_PENDING_MSG, CANCELLATION_ERRORS, URL
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
from .error import Disconnect
|
||||
|
||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
||||
|
||||
|
||||
class WebsocketAPIView(HomeAssistantView):
|
||||
"""View to serve a websockets endpoint."""
|
||||
|
||||
name = "websocketapi"
|
||||
url = URL
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request):
|
||||
"""Handle an incoming websocket connection."""
|
||||
return await WebSocketHandler(
|
||||
request.app['hass'], request).async_handle()
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, hass, request):
|
||||
"""Initialize an active connection."""
|
||||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock = None
|
||||
self._to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._logger = logging.getLogger(
|
||||
"{}.connection.{}".format(__name__, id(self)))
|
||||
|
||||
async def _writer(self):
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
message = await self._to_write.get()
|
||||
if message is None:
|
||||
break
|
||||
self._logger.debug("Sending %s", message)
|
||||
try:
|
||||
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||
except TypeError as err:
|
||||
self._logger.error('Unable to serialize to JSON: %s\n%s',
|
||||
err, message)
|
||||
|
||||
@callback
|
||||
def _send_message(self, message):
|
||||
"""Send a message to the client.
|
||||
|
||||
Closes connection if the client is not reading the messages.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
try:
|
||||
self._to_write.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self._logger.error("Client exceeded max pending messages [2]: %s",
|
||||
MAX_PENDING_MSG)
|
||||
self._cancel()
|
||||
|
||||
@callback
|
||||
def _cancel(self):
|
||||
"""Cancel the connection."""
|
||||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
async def async_handle(self):
|
||||
"""Handle a websocket response."""
|
||||
request = self.request
|
||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
await wsock.prepare(request)
|
||||
self._logger.debug("Connected")
|
||||
|
||||
# Py3.7+
|
||||
if hasattr(asyncio, 'current_task'):
|
||||
# pylint: disable=no-member
|
||||
self._handle_task = asyncio.current_task()
|
||||
else:
|
||||
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event):
|
||||
"""Cancel this connection."""
|
||||
self._cancel()
|
||||
|
||||
unsub_stop = self.hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||
|
||||
self._writer_task = self.hass.async_create_task(self._writer())
|
||||
|
||||
auth = AuthPhase(self._logger, self.hass, self._send_message, request)
|
||||
connection = None
|
||||
disconnect_warn = None
|
||||
|
||||
try:
|
||||
self._send_message(auth_required_message())
|
||||
|
||||
# Auth Phase
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
msg = await wsock.receive()
|
||||
except asyncio.TimeoutError:
|
||||
disconnect_warn = \
|
||||
'Did not receive auth message within 10 seconds'
|
||||
raise Disconnect
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
||||
raise Disconnect
|
||||
|
||||
elif msg.type != WSMsgType.TEXT:
|
||||
disconnect_warn = 'Received non-Text message.'
|
||||
raise Disconnect
|
||||
|
||||
try:
|
||||
msg = msg.json()
|
||||
except ValueError:
|
||||
disconnect_warn = 'Received invalid JSON.'
|
||||
raise Disconnect
|
||||
|
||||
self._logger.debug("Received %s", msg)
|
||||
connection = await auth.async_handle(msg)
|
||||
|
||||
# Command phase
|
||||
while not wsock.closed:
|
||||
msg = await wsock.receive()
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
||||
break
|
||||
|
||||
elif msg.type != WSMsgType.TEXT:
|
||||
disconnect_warn = 'Received non-Text message.'
|
||||
break
|
||||
|
||||
try:
|
||||
msg = msg.json()
|
||||
except ValueError:
|
||||
disconnect_warn = 'Received invalid JSON.'
|
||||
break
|
||||
|
||||
self._logger.debug("Received %s", msg)
|
||||
connection.async_handle(msg)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self._logger.info("Connection closed by client")
|
||||
|
||||
except Disconnect:
|
||||
pass
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self._logger.exception("Unexpected error inside websocket API")
|
||||
|
||||
finally:
|
||||
unsub_stop()
|
||||
|
||||
if connection is not None:
|
||||
connection.async_close()
|
||||
|
||||
try:
|
||||
self._to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
await self._writer_task
|
||||
except asyncio.QueueFull:
|
||||
self._writer_task.cancel()
|
||||
|
||||
await wsock.close()
|
||||
|
||||
if disconnect_warn is None:
|
||||
self._logger.debug("Disconnected")
|
||||
else:
|
||||
self._logger.warning("Disconnected: %s", disconnect_warn)
|
|
@ -4,7 +4,9 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
|
||||
|
||||
from tests.common import MockUser, CLIENT_ID
|
||||
|
||||
|
@ -14,41 +16,52 @@ def hass_ws_client(aiohttp_client):
|
|||
"""Websocket client fixture connected to websocket server."""
|
||||
async def create_client(hass, access_token=None):
|
||||
"""Create a websocket client."""
|
||||
wapi = hass.components.websocket_api
|
||||
assert await async_setup_component(hass, 'websocket_api')
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
patching = None
|
||||
patches = []
|
||||
|
||||
if access_token is not None:
|
||||
patching = patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True)
|
||||
patching.start()
|
||||
if access_token is None:
|
||||
patches.append(patch(
|
||||
'homeassistant.auth.AuthManager.active', return_value=False))
|
||||
patches.append(patch(
|
||||
'homeassistant.auth.AuthManager.support_legacy',
|
||||
return_value=True))
|
||||
patches.append(patch(
|
||||
'homeassistant.components.websocket_api.auth.'
|
||||
'validate_password', return_value=True))
|
||||
else:
|
||||
patches.append(patch(
|
||||
'homeassistant.auth.AuthManager.active', return_value=True))
|
||||
patches.append(patch(
|
||||
'homeassistant.components.http.auth.setup_auth'))
|
||||
|
||||
for p in patches:
|
||||
p.start()
|
||||
|
||||
try:
|
||||
websocket = await client.ws_connect(wapi.URL)
|
||||
websocket = await client.ws_connect(URL)
|
||||
auth_resp = await websocket.receive_json()
|
||||
assert auth_resp['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
if auth_resp['type'] == wapi.TYPE_AUTH_OK:
|
||||
assert access_token is None, \
|
||||
'Access token given but no auth required'
|
||||
return websocket
|
||||
|
||||
assert access_token is not None, \
|
||||
'Access token required for fixture'
|
||||
|
||||
if access_token is None:
|
||||
await websocket.send_json({
|
||||
'type': websocket_api.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': 'bla'
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
'type': TYPE_AUTH,
|
||||
'access_token': access_token
|
||||
})
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
assert auth_ok['type'] == wapi.TYPE_AUTH_OK
|
||||
assert auth_ok['type'] == TYPE_AUTH_OK
|
||||
|
||||
finally:
|
||||
if patching is not None:
|
||||
patching.stop()
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
# wrap in client
|
||||
websocket.client = client
|
||||
|
|
|
@ -62,7 +62,7 @@ class TestPanelIframe(unittest.TestCase):
|
|||
|
||||
panels = self.hass.data[frontend.DATA_PANELS]
|
||||
|
||||
assert panels.get('router').to_response(self.hass, None) == {
|
||||
assert panels.get('router').to_response() == {
|
||||
'component_name': 'iframe',
|
||||
'config': {'url': 'http://192.168.1.1'},
|
||||
'icon': 'mdi:network-wireless',
|
||||
|
@ -70,7 +70,7 @@ class TestPanelIframe(unittest.TestCase):
|
|||
'url_path': 'router'
|
||||
}
|
||||
|
||||
assert panels.get('weather').to_response(self.hass, None) == {
|
||||
assert panels.get('weather').to_response() == {
|
||||
'component_name': 'iframe',
|
||||
'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'},
|
||||
'icon': 'mdi:weather',
|
||||
|
@ -78,7 +78,7 @@ class TestPanelIframe(unittest.TestCase):
|
|||
'url_path': 'weather',
|
||||
}
|
||||
|
||||
assert panels.get('api').to_response(self.hass, None) == {
|
||||
assert panels.get('api').to_response() == {
|
||||
'component_name': 'iframe',
|
||||
'config': {'url': '/api'},
|
||||
'icon': 'mdi:weather',
|
||||
|
@ -86,7 +86,7 @@ class TestPanelIframe(unittest.TestCase):
|
|||
'url_path': 'api',
|
||||
}
|
||||
|
||||
assert panels.get('ftp').to_response(self.hass, None) == {
|
||||
assert panels.get('ftp').to_response() == {
|
||||
'component_name': 'iframe',
|
||||
'config': {'url': 'ftp://some/ftp'},
|
||||
'icon': 'mdi:weather',
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api.http import URL
|
||||
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
|
||||
|
||||
from . import API_PASSWORD
|
||||
|
||||
|
@ -24,10 +25,10 @@ def no_auth_websocket_client(hass, loop, aiohttp_client):
|
|||
}))
|
||||
|
||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||
ws = loop.run_until_complete(client.ws_connect(URL))
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_ok['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
yield ws
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
"""Test auth of websocket API."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api.const import URL
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH, TYPE_AUTH_INVALID, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
|
||||
|
||||
from homeassistant.components.websocket_api import commands
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -13,28 +16,29 @@ from . import API_PASSWORD
|
|||
async def test_auth_via_msg(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
await no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_OK
|
||||
assert msg['type'] == TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
with patch('homeassistant.components.websocket_api.process_wrong_login',
|
||||
return_value=mock_coro()) as mock_process_wrong_login:
|
||||
with patch('homeassistant.components.websocket_api.auth.'
|
||||
'process_wrong_login', return_value=mock_coro()) \
|
||||
as mock_process_wrong_login:
|
||||
await no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': API_PASSWORD + 'wrong'
|
||||
})
|
||||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert mock_process_wrong_login.called
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['type'] == TYPE_AUTH_INVALID
|
||||
assert msg['message'] == 'Invalid access token or password'
|
||||
|
||||
|
||||
|
@ -51,8 +55,8 @@ async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
|
|||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['message'].startswith('Message incorrectly formatted')
|
||||
assert msg['type'] == TYPE_AUTH_INVALID
|
||||
assert msg['message'].startswith('Auth message incorrectly formatted')
|
||||
|
||||
|
||||
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||
|
@ -65,19 +69,19 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
|||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||
|
@ -94,19 +98,19 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
|
|||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
||||
|
@ -119,19 +123,19 @@ async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
|||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
||||
|
@ -144,21 +148,21 @@ async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
|||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True),\
|
||||
patch('homeassistant.auth.AuthManager.support_legacy',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||
|
@ -171,16 +175,16 @@ async def test_auth_with_invalid_token(hass, aiohttp_client):
|
|||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'access_token': 'incorrect'
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||
|
|
|
@ -4,7 +4,10 @@ from unittest.mock import patch
|
|||
from async_timeout import timeout
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api.const import URL
|
||||
from homeassistant.components.websocket_api.auth import (
|
||||
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED
|
||||
)
|
||||
from homeassistant.components.websocket_api import const, commands
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -178,19 +181,19 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
|||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
|
@ -227,17 +230,17 @@ async def test_call_service_context_no_user(hass, aiohttp_client):
|
|||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
async with client.ws_connect(URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'type': TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
|
|
|
@ -5,14 +5,14 @@ from unittest.mock import patch, Mock
|
|||
from aiohttp import WSMsgType
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api import const, commands, messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
||||
with patch('homeassistant.components.websocket_api.http.MAX_PENDING_MSG',
|
||||
5):
|
||||
yield
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue