Break up websocket 2 (#17028)

* Break up websocket 2

* Lint+Test

* Lintttt

* Rename
pull/17030/merge
Paulus Schoutsen 2018-10-01 16:09:31 +02:00 committed by Pascal Vizeli
parent b5e3d8c337
commit 2e6346ca43
27 changed files with 641 additions and 568 deletions

View File

@ -432,7 +432,7 @@ def websocket_current_user(
"""Get current user.""" """Get current user."""
enabled_modules = await hass.auth.async_get_enabled_mfa(user) enabled_modules = await hass.auth.async_get_enabled_mfa(user)
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'], { websocket_api.result_message(msg['id'], {
'id': user.id, 'id': user.id,
'name': user.name, 'name': user.name,
@ -467,7 +467,7 @@ def websocket_create_long_lived_access_token(
access_token = hass.auth.async_create_access_token( access_token = hass.auth.async_create_access_token(
refresh_token) refresh_token)
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'], access_token)) websocket_api.result_message(msg['id'], access_token))
hass.async_create_task( hass.async_create_task(
@ -479,8 +479,8 @@ def websocket_create_long_lived_access_token(
def websocket_refresh_tokens( def websocket_refresh_tokens(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg): hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
"""Return metadata of users refresh tokens.""" """Return metadata of users refresh tokens."""
current_id = connection.request.get('refresh_token_id') current_id = connection.refresh_token_id
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], [{ connection.send_message(websocket_api.result_message(msg['id'], [{
'id': refresh.id, 'id': refresh.id,
'client_id': refresh.client_id, 'client_id': refresh.client_id,
'client_name': refresh.client_name, 'client_name': refresh.client_name,
@ -508,7 +508,7 @@ def websocket_delete_refresh_token(
await hass.auth.async_remove_refresh_token(refresh_token) await hass.auth.async_remove_refresh_token(refresh_token)
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'], {})) websocket_api.result_message(msg['id'], {}))
hass.async_create_task( hass.async_create_task(

View File

@ -64,7 +64,7 @@ def websocket_setup_mfa(
if flow_id is not None: if flow_id is not None:
result = await flow_manager.async_configure( result = await flow_manager.async_configure(
flow_id, msg.get('user_input')) flow_id, msg.get('user_input'))
connection.send_message_outside( connection.send_message(
websocket_api.result_message( websocket_api.result_message(
msg['id'], _prepare_result_json(result))) msg['id'], _prepare_result_json(result)))
return return
@ -72,7 +72,7 @@ def websocket_setup_mfa(
mfa_module_id = msg.get('mfa_module_id') mfa_module_id = msg.get('mfa_module_id')
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id) mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
if mfa_module is None: if mfa_module is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'no_module', msg['id'], 'no_module',
'MFA module {} is not found'.format(mfa_module_id))) 'MFA module {} is not found'.format(mfa_module_id)))
return return
@ -80,7 +80,7 @@ def websocket_setup_mfa(
result = await flow_manager.async_init( result = await flow_manager.async_init(
mfa_module_id, data={'user_id': connection.user.id}) mfa_module_id, data={'user_id': connection.user.id})
connection.send_message_outside( connection.send_message(
websocket_api.result_message( websocket_api.result_message(
msg['id'], _prepare_result_json(result))) msg['id'], _prepare_result_json(result)))
@ -99,13 +99,13 @@ def websocket_depose_mfa(
await hass.auth.async_disable_user_mfa( await hass.auth.async_disable_user_mfa(
connection.user, msg['mfa_module_id']) connection.user, msg['mfa_module_id'])
except ValueError as err: except ValueError as err:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'disable_failed', msg['id'], 'disable_failed',
'Cannot disable MFA Module {}: {}'.format( 'Cannot disable MFA Module {}: {}'.format(
mfa_module_id, err))) mfa_module_id, err)))
return return
connection.send_message_outside( connection.send_message(
websocket_api.result_message( websocket_api.result_message(
msg['id'], 'done')) msg['id'], 'done'))

View File

@ -460,14 +460,14 @@ async def websocket_camera_thumbnail(hass, connection, msg):
""" """
try: try:
image = await async_get_image(hass, msg['entity_id']) image = await async_get_image(hass, msg['entity_id'])
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], { msg['id'], {
'content_type': image.content_type, 'content_type': image.content_type,
'content': base64.b64encode(image.content).decode('utf-8') 'content': base64.b64encode(image.content).decode('utf-8')
} }
)) ))
except HomeAssistantError: except HomeAssistantError:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'image_fetch_failed', 'Unable to fetch image')) msg['id'], 'image_fetch_failed', 'Unable to fetch image'))

View File

@ -231,7 +231,7 @@ def websocket_cloud_status(hass, connection, msg):
Async friendly. Async friendly.
""" """
cloud = hass.data[DOMAIN] cloud = hass.data[DOMAIN]
connection.to_write.put_nowait( connection.send_message(
websocket_api.result_message(msg['id'], _account_data(cloud))) websocket_api.result_message(msg['id'], _account_data(cloud)))
@ -241,7 +241,7 @@ async def websocket_subscription(hass, connection, msg):
cloud = hass.data[DOMAIN] cloud = hass.data[DOMAIN]
if not cloud.is_logged_in: if not cloud.is_logged_in:
connection.to_write.put_nowait(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'not_logged_in', msg['id'], 'not_logged_in',
'You need to be logged in to the cloud.')) 'You need to be logged in to the cloud.'))
return return
@ -250,10 +250,10 @@ async def websocket_subscription(hass, connection, msg):
response = await cloud.fetch_subscription_info() response = await cloud.fetch_subscription_info()
if response.status == 200: if response.status == 200:
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], await response.json())) msg['id'], await response.json()))
else: else:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'request_failed', 'Failed to request subscription')) msg['id'], 'request_failed', 'Failed to request subscription'))
@ -263,7 +263,7 @@ async def websocket_update_prefs(hass, connection, msg):
cloud = hass.data[DOMAIN] cloud = hass.data[DOMAIN]
if not cloud.is_logged_in: if not cloud.is_logged_in:
connection.to_write.put_nowait(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'not_logged_in', msg['id'], 'not_logged_in',
'You need to be logged in to the cloud.')) 'You need to be logged in to the cloud.'))
return return
@ -273,7 +273,7 @@ async def websocket_update_prefs(hass, connection, msg):
changes.pop('type') changes.pop('type')
await cloud.update_preferences(**changes) await cloud.update_preferences(**changes)
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], {'success': True})) msg['id'], {'success': True}))

View File

@ -49,7 +49,7 @@ def websocket_list(hass, connection, msg):
"""Send users.""" """Send users."""
result = [_user_info(u) for u in await hass.auth.async_get_users()] result = [_user_info(u) for u in await hass.auth.async_get_users()]
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'], result)) websocket_api.result_message(msg['id'], result))
hass.async_add_job(send_users()) hass.async_add_job(send_users())
@ -61,8 +61,8 @@ def websocket_delete(hass, connection, msg):
"""Delete a user.""" """Delete a user."""
async def delete_user(): async def delete_user():
"""Delete user.""" """Delete user."""
if msg['user_id'] == connection.request.get('hass_user').id: if msg['user_id'] == connection.user.id:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'no_delete_self', msg['id'], 'no_delete_self',
'Unable to delete your own account')) 'Unable to delete your own account'))
return return
@ -70,13 +70,13 @@ def websocket_delete(hass, connection, msg):
user = await hass.auth.async_get_user(msg['user_id']) user = await hass.auth.async_get_user(msg['user_id'])
if not user: if not user:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'not_found', 'User not found')) msg['id'], 'not_found', 'User not found'))
return return
await hass.auth.async_remove_user(user) await hass.auth.async_remove_user(user)
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'])) websocket_api.result_message(msg['id']))
hass.async_add_job(delete_user()) hass.async_add_job(delete_user())
@ -90,7 +90,7 @@ def websocket_create(hass, connection, msg):
"""Create a user.""" """Create a user."""
user = await hass.auth.async_create_user(msg['name']) user = await hass.auth.async_create_user(msg['name'])
connection.send_message_outside( connection.send_message(
websocket_api.result_message(msg['id'], { websocket_api.result_message(msg['id'], {
'user': _user_info(user) 'user': _user_info(user)
})) }))

View File

@ -2,7 +2,6 @@
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.providers import homeassistant as auth_ha from homeassistant.auth.providers import homeassistant as auth_ha
from homeassistant.core import callback
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_owner from homeassistant.components.websocket_api.decorators import require_owner
@ -55,121 +54,109 @@ def _get_provider(hass):
raise RuntimeError('Provider not found') raise RuntimeError('Provider not found')
@callback
@require_owner @require_owner
def websocket_create(hass, connection, msg): @websocket_api.async_response
async def websocket_create(hass, connection, msg):
"""Create credentials and attach to a user.""" """Create credentials and attach to a user."""
async def create_creds(): provider = _get_provider(hass)
"""Create credentials.""" await provider.async_initialize()
provider = _get_provider(hass)
await provider.async_initialize()
user = await hass.auth.async_get_user(msg['user_id']) user = await hass.auth.async_get_user(msg['user_id'])
if user is None: if user is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'not_found', 'User not found')) msg['id'], 'not_found', 'User not found'))
return return
if user.system_generated: if user.system_generated:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'system_generated', msg['id'], 'system_generated',
'Cannot add credentials to a system generated user.')) 'Cannot add credentials to a system generated user.'))
return return
try:
await hass.async_add_executor_job(
provider.data.add_auth, msg['username'], msg['password'])
except auth_ha.InvalidUser:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'username_exists', 'Username already exists'))
return
credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
await hass.auth.async_link_user(user, credentials)
await provider.data.async_save()
connection.to_write.put_nowait(websocket_api.result_message(msg['id']))
hass.async_add_job(create_creds())
@callback
@require_owner
def websocket_delete(hass, connection, msg):
"""Delete username and related credential."""
async def delete_creds():
"""Delete user credentials."""
provider = _get_provider(hass)
await provider.async_initialize()
credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
# if not new, an existing credential exists.
# Removing the credential will also remove the auth.
if not credentials.is_new:
await hass.auth.async_remove_credentials(credentials)
connection.to_write.put_nowait(
websocket_api.result_message(msg['id']))
return
try:
provider.data.async_remove_auth(msg['username'])
await provider.data.async_save()
except auth_ha.InvalidUser:
connection.to_write.put_nowait(websocket_api.error_message(
msg['id'], 'auth_not_found', 'Given username was not found.'))
return
connection.to_write.put_nowait(
websocket_api.result_message(msg['id']))
hass.async_add_job(delete_creds())
@callback
def websocket_change_password(hass, connection, msg):
"""Change user password."""
async def change_password():
"""Change user password."""
user = connection.request.get('hass_user')
if user is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'user_not_found', 'User not found'))
return
provider = _get_provider(hass)
await provider.async_initialize()
username = None
for credential in user.credentials:
if credential.auth_provider_type == provider.type:
username = credential.data['username']
break
if username is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'credentials_not_found', 'Credentials not found'))
return
try:
await provider.async_validate_login(
username, msg['current_password'])
except auth_ha.InvalidAuth:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_password', 'Invalid password'))
return
try:
await hass.async_add_executor_job( await hass.async_add_executor_job(
provider.data.change_password, username, msg['new_password']) provider.data.add_auth, msg['username'], msg['password'])
await provider.data.async_save() except auth_ha.InvalidUser:
connection.send_message(websocket_api.error_message(
msg['id'], 'username_exists', 'Username already exists'))
return
connection.send_message_outside( credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
await hass.auth.async_link_user(user, credentials)
await provider.data.async_save()
connection.send_message(websocket_api.result_message(msg['id']))
@require_owner
@websocket_api.async_response
async def websocket_delete(hass, connection, msg):
"""Delete username and related credential."""
provider = _get_provider(hass)
await provider.async_initialize()
credentials = await provider.async_get_or_create_credentials({
'username': msg['username']
})
# if not new, an existing credential exists.
# Removing the credential will also remove the auth.
if not credentials.is_new:
await hass.auth.async_remove_credentials(credentials)
connection.send_message(
websocket_api.result_message(msg['id'])) websocket_api.result_message(msg['id']))
return
hass.async_add_job(change_password()) try:
provider.data.async_remove_auth(msg['username'])
await provider.data.async_save()
except auth_ha.InvalidUser:
connection.send_message(websocket_api.error_message(
msg['id'], 'auth_not_found', 'Given username was not found.'))
return
connection.send_message(
websocket_api.result_message(msg['id']))
@websocket_api.async_response
async def websocket_change_password(hass, connection, msg):
"""Change user password."""
user = connection.user
if user is None:
connection.send_message(websocket_api.error_message(
msg['id'], 'user_not_found', 'User not found'))
return
provider = _get_provider(hass)
await provider.async_initialize()
username = None
for credential in user.credentials:
if credential.auth_provider_type == provider.type:
username = credential.data['username']
break
if username is None:
connection.send_message(websocket_api.error_message(
msg['id'], 'credentials_not_found', 'Credentials not found'))
return
try:
await provider.async_validate_login(
username, msg['current_password'])
except auth_ha.InvalidAuth:
connection.send_message(websocket_api.error_message(
msg['id'], 'invalid_password', 'Invalid password'))
return
await hass.async_add_executor_job(
provider.data.change_password, username, msg['new_password'])
await provider.data.async_save()
connection.send_message(
websocket_api.result_message(msg['id']))

View File

@ -31,7 +31,7 @@ def websocket_list_devices(hass, connection, msg):
async def retrieve_entities(): async def retrieve_entities():
"""Get devices from registry.""" """Get devices from registry."""
registry = await async_get_registry(hass) registry = await async_get_registry(hass)
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], [{ msg['id'], [{
'config_entries': list(entry.config_entries), 'config_entries': list(entry.config_entries),
'connections': list(entry.connections), 'connections': list(entry.connections),

View File

@ -55,7 +55,7 @@ async def websocket_list_entities(hass, connection, msg):
Async friendly. Async friendly.
""" """
registry = await async_get_registry(hass) registry = await async_get_registry(hass)
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], [{ msg['id'], [{
'config_entry_id': entry.config_entry_id, 'config_entry_id': entry.config_entry_id,
'device_id': entry.device_id, 'device_id': entry.device_id,
@ -77,11 +77,11 @@ async def websocket_get_entity(hass, connection, msg):
entry = registry.entities.get(msg['entity_id']) entry = registry.entities.get(msg['entity_id'])
if entry is None: if entry is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], ERR_NOT_FOUND, 'Entity not found')) msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return return
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], _entry_dict(entry) msg['id'], _entry_dict(entry)
)) ))
@ -95,7 +95,7 @@ async def websocket_update_entity(hass, connection, msg):
registry = await async_get_registry(hass) registry = await async_get_registry(hass)
if msg['entity_id'] not in registry.entities: if msg['entity_id'] not in registry.entities:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], ERR_NOT_FOUND, 'Entity not found')) msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return return
@ -112,11 +112,11 @@ async def websocket_update_entity(hass, connection, msg):
entry = registry.async_update_entity( entry = registry.async_update_entity(
msg['entity_id'], **changes) msg['entity_id'], **changes)
except ValueError as err: except ValueError as err:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'invalid_info', str(err) msg['id'], 'invalid_info', str(err)
)) ))
else: else:
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], _entry_dict(entry) msg['id'], _entry_dict(entry)
)) ))

View File

@ -145,7 +145,7 @@ class Panel:
index_view.get) index_view.get)
@callback @callback
def to_response(self, hass, request): def to_response(self):
"""Panel as dictionary.""" """Panel as dictionary."""
return { return {
'component_name': self.component_name, 'component_name': self.component_name,
@ -485,12 +485,10 @@ def websocket_get_panels(hass, connection, msg):
Async friendly. Async friendly.
""" """
panels = { panels = {
panel: panel: connection.hass.data[DATA_PANELS][panel].to_response()
connection.hass.data[DATA_PANELS][panel].to_response(
connection.hass, connection.request)
for panel in connection.hass.data[DATA_PANELS]} for panel in connection.hass.data[DATA_PANELS]}
connection.to_write.put_nowait(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], panels)) msg['id'], panels))
@ -500,25 +498,21 @@ def websocket_get_themes(hass, connection, msg):
Async friendly. Async friendly.
""" """
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], { connection.send_message(websocket_api.result_message(msg['id'], {
'themes': hass.data[DATA_THEMES], 'themes': hass.data[DATA_THEMES],
'default_theme': hass.data[DATA_DEFAULT_THEME], 'default_theme': hass.data[DATA_DEFAULT_THEME],
})) }))
@callback @websocket_api.async_response
def websocket_get_translations(hass, connection, msg): async def websocket_get_translations(hass, connection, msg):
"""Handle get translations command. """Handle get translations command.
Async friendly. Async friendly.
""" """
async def send_translations(): resources = await async_get_translations(hass, msg['language'])
"""Send a translation.""" connection.send_message(websocket_api.result_message(
resources = await async_get_translations(hass, msg['language']) msg['id'], {
connection.send_message_outside(websocket_api.result_message( 'resources': resources,
msg['id'], { }
'resources': resources, ))
}
))
hass.async_add_job(send_translations())

View File

@ -112,6 +112,7 @@ async def async_validate_auth_header(request, api_password=None):
if refresh_token is None: if refresh_token is None:
return False return False
request['hass_refresh_token'] = refresh_token
request['hass_user'] = refresh_token.user request['hass_user'] = refresh_token.user
return True return True

View File

@ -48,4 +48,4 @@ async def websocket_lovelace_config(hass, connection, msg):
if error is not None: if error is not None:
message = websocket_api.error_message(msg['id'], *error) message = websocket_api.error_message(msg['id'], *error)
connection.send_message_outside(message) connection.send_message(message)

View File

@ -874,19 +874,19 @@ async def websocket_handle_thumbnail(hass, connection, msg):
player = component.get_entity(msg['entity_id']) player = component.get_entity(msg['entity_id'])
if player is None: if player is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'entity_not_found', 'Entity not found')) msg['id'], 'entity_not_found', 'Entity not found'))
return return
data, content_type = await player.async_get_media_image() data, content_type = await player.async_get_media_image()
if data is None: if data is None:
connection.send_message_outside(websocket_api.error_message( connection.send_message(websocket_api.error_message(
msg['id'], 'thumbnail_fetch_failed', msg['id'], 'thumbnail_fetch_failed',
'Failed to fetch thumbnail')) 'Failed to fetch thumbnail'))
return return
connection.send_message_outside(websocket_api.result_message( connection.send_message(websocket_api.result_message(
msg['id'], { msg['id'], {
'content_type': content_type, 'content_type': content_type,
'content': base64.b64encode(data).decode('utf-8') 'content': base64.b64encode(data).decode('utf-8')

View File

@ -199,7 +199,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
def websocket_get_notifications( def websocket_get_notifications(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg): hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
"""Return a list of persistent_notifications.""" """Return a list of persistent_notifications."""
connection.to_write.put_nowait( connection.send_message(
websocket_api.result_message(msg['id'], [ websocket_api.result_message(msg['id'], [
{ {
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE, key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,

View File

@ -4,49 +4,18 @@ Websocket based API for Home Assistant.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://developers.home-assistant.io/docs/external_api_websocket.html https://developers.home-assistant.io/docs/external_api_websocket.html
""" """
import asyncio from homeassistant.core import callback
from concurrent import futures
from contextlib import suppress
from functools import partial
import json
import logging
from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
from homeassistant.core import Context, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.helpers.json import JSONEncoder
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login, \
process_success_login
from . import commands, const, decorators, messages from . import commands, connection, const, decorators, http, messages
DOMAIN = 'websocket_api' DOMAIN = const.DOMAIN
URL = '/api/websocket'
DEPENDENCIES = ('http',) DEPENDENCIES = ('http',)
MAX_PENDING_MSG = 512 # Backwards compat / Make it easier to integrate
_LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
# Backwards compat
# pylint: disable=invalid-name # pylint: disable=invalid-name
ActiveConnection = connection.ActiveConnection
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
error_message = messages.error_message error_message = messages.error_message
result_message = messages.result_message result_message = messages.result_message
@ -54,42 +23,6 @@ async_response = decorators.async_response
ws_require_user = decorators.ws_require_user ws_require_user = decorators.ws_require_user
# pylint: enable=invalid-name # pylint: enable=invalid-name
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Exclusive('api_password', 'auth'): str,
vol.Exclusive('access_token', 'auth'): str,
})
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
def auth_ok_message():
"""Return an auth_ok message."""
return {
'type': TYPE_AUTH_OK,
'ha_version': __version__,
}
def auth_required_message():
"""Return an auth_required message."""
return {
'type': TYPE_AUTH_REQUIRED,
'ha_version': __version__,
}
def auth_invalid_message(message):
"""Return an auth_invalid message."""
return {
'type': TYPE_AUTH_INVALID,
'message': message,
}
@bind_hass @bind_hass
@callback @callback
@ -103,255 +36,6 @@ def async_register_command(hass, command, handler, schema):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Initialize the websocket API.""" """Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView) hass.http.register_view(http.WebsocketAPIView)
commands.async_register_commands(hass) commands.async_register_commands(hass)
return True return True
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
async def get(self, request):
"""Handle an incoming websocket connection."""
return await ActiveConnection(request.app['hass'], request).handle()
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = None
self.event_listeners = {}
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
self._handle_task = None
self._writer_task = None
@property
def user(self):
"""Return the user associated with the connection."""
return self.request.get('hass_user')
def context(self, msg):
"""Return a context."""
user = self.user
if user is None:
return Context()
return Context(user_id=user.id)
def debug(self, message1, message2=''):
"""Print a debug message."""
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
def log_error(self, message1, message2=''):
"""Print an error message."""
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
async def _writer(self):
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
message = await self.to_write.get()
if message is None:
break
self.debug("Sending", message)
try:
await self.wsock.send_json(message, dumps=JSON_DUMP)
except TypeError as err:
_LOGGER.error('Unable to serialize to JSON: %s\n%s',
err, message)
@callback
def send_message_outside(self, message):
"""Send a message to the client.
Closes connection if the client is not reading the messages.
Async friendly.
"""
try:
self.to_write.put_nowait(message)
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [2]:",
MAX_PENDING_MSG)
self.cancel()
@callback
def cancel(self):
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
async def handle(self):
"""Handle the websocket connection."""
request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
await wsock.prepare(request)
self.debug("Connected")
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def handle_hass_stop(event):
"""Cancel this connection."""
self.cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
self._writer_task = self.hass.async_add_job(self._writer())
final_message = None
msg = None
authenticated = False
try:
if request[KEY_AUTHENTICATED]:
authenticated = True
# always request auth when auth is active
# even request passed pre-authentication (trusted networks)
# or when using legacy api_password
if self.hass.auth.active or not authenticated:
self.debug("Request auth")
await self.wsock.send_json(auth_required_message())
msg = await wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if self.hass.auth.active and 'access_token' in msg:
self.debug("Received access_token")
refresh_token = \
await self.hass.auth.async_validate_access_token(
msg['access_token'])
authenticated = refresh_token is not None
if authenticated:
request['hass_user'] = refresh_token.user
request['refresh_token_id'] = refresh_token.id
elif ((not self.hass.auth.active or
self.hass.auth.support_legacy) and
'api_password' in msg):
self.debug("Received api_password")
authenticated = validate_password(
request, msg['api_password'])
if not authenticated:
self.debug("Authorization failed")
await self.wsock.send_json(
auth_invalid_message('Invalid access token or password'))
await process_wrong_login(request)
return wsock
self.debug("Auth OK")
await process_success_login(request)
await self.wsock.send_json(auth_ok_message())
# ---------- AUTH PHASE OVER ----------
msg = await wsock.receive_json()
last_id = 0
handlers = self.hass.data[DOMAIN]
while msg:
self.debug("Received", msg)
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
if cur_id <= last_id:
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_ID_REUSE,
'Identifier values have to increase.'))
elif msg['type'] not in handlers:
self.log_error(
'Received invalid command: {}'.format(msg['type']))
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_UNKNOWN_COMMAND,
'Unknown command.'))
else:
handler, schema = handlers[msg['type']]
try:
handler(self.hass, self, schema(msg))
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error handling message: %s', msg)
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_UNKNOWN_ERROR,
'Unknown error.'))
last_id = cur_id
msg = await wsock.receive_json()
except vol.Invalid as err:
error_msg = "Message incorrectly formatted: "
if msg:
error_msg += humanize_error(msg, err)
else:
error_msg += str(err)
self.log_error(error_msg)
if not authenticated:
final_message = auth_invalid_message(error_msg)
else:
if isinstance(msg, dict):
iden = msg.get('id')
else:
iden = None
final_message = messages.error_message(
iden, const.ERR_INVALID_FORMAT, error_msg)
except TypeError as err:
if wsock.closed:
self.debug("Connection closed by client")
else:
_LOGGER.exception("Unexpected TypeError: %s", err)
except ValueError as err:
msg = "Received invalid JSON"
value = getattr(err, 'doc', None) # Py3.5+ only
if value:
msg += ': {}'.format(value)
self.log_error(msg)
self._writer_task.cancel()
except CANCELLATION_ERRORS:
self.debug("Connection cancelled")
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [1]:",
MAX_PENDING_MSG)
self._writer_task.cancel()
except Exception: # pylint: disable=broad-except
error = "Unexpected error inside websocket API. "
if msg is not None:
error += str(msg)
_LOGGER.exception(error)
finally:
unsub_stop()
for unsub in self.event_listeners.values():
unsub()
try:
if final_message is not None:
self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
except asyncio.QueueFull:
self._writer_task.cancel()
await wsock.close()
self.debug("Closed connection")
return wsock

View File

@ -0,0 +1,99 @@
"""Handle the auth of a connection."""
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import __version__
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.ban import process_wrong_login, \
process_success_login
from .connection import ActiveConnection
from .error import Disconnect
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Exclusive('api_password', 'auth'): str,
vol.Exclusive('access_token', 'auth'): str,
})
def auth_ok_message():
"""Return an auth_ok message."""
return {
'type': TYPE_AUTH_OK,
'ha_version': __version__,
}
def auth_required_message():
"""Return an auth_required message."""
return {
'type': TYPE_AUTH_REQUIRED,
'ha_version': __version__,
}
def auth_invalid_message(message):
"""Return an auth_invalid message."""
return {
'type': TYPE_AUTH_INVALID,
'message': message,
}
class AuthPhase:
"""Connection that requires client to authenticate first."""
def __init__(self, logger, hass, send_message, request):
"""Initialize the authentiated connection."""
self._hass = hass
self._send_message = send_message
self._logger = logger
self._request = request
self._authenticated = False
self._connection = None
async def async_handle(self, msg):
"""Handle authentication."""
try:
msg = AUTH_MESSAGE_SCHEMA(msg)
except vol.Invalid as err:
error_msg = 'Auth message incorrectly formatted: {}'.format(
humanize_error(msg, err))
self._logger.warning(error_msg)
self._send_message(auth_invalid_message(error_msg))
raise Disconnect
if self._hass.auth.active and 'access_token' in msg:
self._logger.debug("Received access_token")
refresh_token = \
await self._hass.auth.async_validate_access_token(
msg['access_token'])
if refresh_token is not None:
return await self._async_finish_auth(
refresh_token.user, refresh_token)
elif ((not self._hass.auth.active or self._hass.auth.support_legacy)
and 'api_password' in msg):
self._logger.debug("Received api_password")
if validate_password(self._request, msg['api_password']):
return await self._async_finish_auth(None, None)
self._send_message(auth_invalid_message(
'Invalid access token or password'))
await process_wrong_login(self._request)
raise Disconnect
async def _async_finish_auth(self, user, refresh_token) \
-> ActiveConnection:
"""Create an active connection."""
self._logger.debug("Auth OK")
await process_success_login(self._request)
self._send_message(auth_ok_message())
return ActiveConnection(
self._logger, self._hass, self._send_message, user, refresh_token)

View File

@ -103,12 +103,12 @@ def handle_subscribe_events(hass, connection, msg):
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
return return
connection.send_message_outside(event_message(msg['id'], event)) connection.send_message(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen( connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events) msg['event_type'], forward_events)
connection.to_write.put_nowait(messages.result_message(msg['id'])) connection.send_message(messages.result_message(msg['id']))
@callback @callback
@ -121,9 +121,9 @@ def handle_unsubscribe_events(hass, connection, msg):
if subscription in connection.event_listeners: if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)() connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(messages.result_message(msg['id'])) connection.send_message(messages.result_message(msg['id']))
else: else:
connection.to_write.put_nowait(messages.error_message( connection.send_message(messages.error_message(
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.')) msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
@ -140,7 +140,7 @@ async def handle_call_service(hass, connection, msg):
await hass.services.async_call( await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking, msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg)) connection.context(msg))
connection.send_message_outside(messages.result_message(msg['id'])) connection.send_message(messages.result_message(msg['id']))
@callback @callback
@ -149,7 +149,7 @@ def handle_get_states(hass, connection, msg):
Async friendly. Async friendly.
""" """
connection.to_write.put_nowait(messages.result_message( connection.send_message(messages.result_message(
msg['id'], hass.states.async_all())) msg['id'], hass.states.async_all()))
@ -160,7 +160,7 @@ async def handle_get_services(hass, connection, msg):
Async friendly. Async friendly.
""" """
descriptions = await async_get_all_descriptions(hass) descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside( connection.send_message(
messages.result_message(msg['id'], descriptions)) messages.result_message(msg['id'], descriptions))
@ -170,7 +170,7 @@ def handle_get_config(hass, connection, msg):
Async friendly. Async friendly.
""" """
connection.to_write.put_nowait(messages.result_message( connection.send_message(messages.result_message(
msg['id'], hass.config.as_dict())) msg['id'], hass.config.as_dict()))
@ -180,4 +180,4 @@ def handle_ping(hass, connection, msg):
Async friendly. Async friendly.
""" """
connection.to_write.put_nowait(pong_message(msg['id'])) connection.send_message(pong_message(msg['id']))

View File

@ -0,0 +1,78 @@
"""Connection session."""
import voluptuous as vol
from homeassistant.core import callback, Context
from . import const, messages
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, logger, hass, send_message, user, refresh_token):
"""Initialize an active connection."""
self.logger = logger
self.hass = hass
self.send_message = send_message
self.user = user
if refresh_token:
self.refresh_token_id = refresh_token.id
else:
self.refresh_token_id = None
self.event_listeners = {}
self.last_id = 0
def context(self, msg):
"""Return a context."""
user = self.user
if user is None:
return Context()
return Context(user_id=user.id)
@callback
def async_handle(self, msg):
"""Handle a single incoming message."""
handlers = self.hass.data[const.DOMAIN]
try:
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
except vol.Invalid:
self.logger.error('Received invalid command', msg)
self.send_message(messages.error_message(
msg.get('id'), const.ERR_INVALID_FORMAT,
'Message incorrectly formatted.'))
return
if cur_id <= self.last_id:
self.send_message(messages.error_message(
cur_id, const.ERR_ID_REUSE,
'Identifier values have to increase.'))
return
if msg['type'] not in handlers:
self.logger.error(
'Received invalid command: {}'.format(msg['type']))
self.send_message(messages.error_message(
cur_id, const.ERR_UNKNOWN_COMMAND,
'Unknown command.'))
return
handler, schema = handlers[msg['type']]
try:
handler(self.hass, self, schema(msg))
except Exception: # pylint: disable=broad-except
self.logger.exception('Error handling message: %s', msg)
self.send_message(messages.error_message(
cur_id, const.ERR_UNKNOWN_ERROR,
'Unknown error.'))
self.last_id = cur_id
@callback
def async_close(self):
"""Close down connection."""
for unsub in self.event_listeners.values():
unsub()

View File

@ -1,4 +1,11 @@
"""Websocket constants.""" """Websocket constants."""
import asyncio
from concurrent import futures
DOMAIN = 'websocket_api'
URL = '/api/websocket'
MAX_PENDING_MSG = 512
ERR_ID_REUSE = 1 ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2 ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3 ERR_NOT_FOUND = 3
@ -6,3 +13,8 @@ ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5 ERR_UNKNOWN_ERROR = 5
TYPE_RESULT = 'result' TYPE_RESULT = 'result'
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)

View File

@ -18,7 +18,7 @@ def async_response(func):
await func(hass, connection, msg) await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")
connection.send_message_outside(messages.error_message( connection.send_message(messages.error_message(
msg['id'], 'unknown', 'Unexpected error occurred')) msg['id'], 'unknown', 'Unexpected error occurred'))
@callback @callback
@ -35,10 +35,10 @@ def require_owner(func):
@wraps(func) @wraps(func)
def with_owner(hass, connection, msg): def with_owner(hass, connection, msg):
"""Check owner and call function.""" """Check owner and call function."""
user = connection.request.get('hass_user') user = connection.user
if user is None or not user.is_owner: if user is None or not user.is_owner:
connection.to_write.put_nowait(messages.error_message( connection.send_message(messages.error_message(
msg['id'], 'unauthorized', 'This command is for owners only.')) msg['id'], 'unauthorized', 'This command is for owners only.'))
return return
@ -61,7 +61,7 @@ def ws_require_user(
"""Check current user.""" """Check current user."""
def output_error(message_id, message): def output_error(message_id, message):
"""Output error message.""" """Output error message."""
connection.send_message_outside(messages.error_message( connection.send_message(messages.error_message(
msg['id'], message_id, message)) msg['id'], message_id, message))
if connection.user is None: if connection.user is None:

View File

@ -0,0 +1,8 @@
"""WebSocket API related errors."""
from homeassistant.exceptions import HomeAssistantError
class Disconnect(HomeAssistantError):
"""Disconnect the current session."""
pass

View File

@ -0,0 +1,189 @@
"""View to accept incoming websocket connection."""
import asyncio
from contextlib import suppress
from functools import partial
import json
import logging
from aiohttp import web, WSMsgType
import async_timeout
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
from homeassistant.components.http import HomeAssistantView
from homeassistant.helpers.json import JSONEncoder
from .const import MAX_PENDING_MSG, CANCELLATION_ERRORS, URL
from .auth import AuthPhase, auth_required_message
from .error import Disconnect
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
async def get(self, request):
"""Handle an incoming websocket connection."""
return await WebSocketHandler(
request.app['hass'], request).async_handle()
class WebSocketHandler:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = None
self._to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
self._handle_task = None
self._writer_task = None
self._logger = logging.getLogger(
"{}.connection.{}".format(__name__, id(self)))
async def _writer(self):
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
message = await self._to_write.get()
if message is None:
break
self._logger.debug("Sending %s", message)
try:
await self.wsock.send_json(message, dumps=JSON_DUMP)
except TypeError as err:
self._logger.error('Unable to serialize to JSON: %s\n%s',
err, message)
@callback
def _send_message(self, message):
"""Send a message to the client.
Closes connection if the client is not reading the messages.
Async friendly.
"""
try:
self._to_write.put_nowait(message)
except asyncio.QueueFull:
self._logger.error("Client exceeded max pending messages [2]: %s",
MAX_PENDING_MSG)
self._cancel()
@callback
def _cancel(self):
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
async def async_handle(self):
"""Handle a websocket response."""
request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
await wsock.prepare(request)
self._logger.debug("Connected")
# Py3.7+
if hasattr(asyncio, 'current_task'):
# pylint: disable=no-member
self._handle_task = asyncio.current_task()
else:
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def handle_hass_stop(event):
"""Cancel this connection."""
self._cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
self._writer_task = self.hass.async_create_task(self._writer())
auth = AuthPhase(self._logger, self.hass, self._send_message, request)
connection = None
disconnect_warn = None
try:
self._send_message(auth_required_message())
# Auth Phase
try:
with async_timeout.timeout(10):
msg = await wsock.receive()
except asyncio.TimeoutError:
disconnect_warn = \
'Did not receive auth message within 10 seconds'
raise Disconnect
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
raise Disconnect
elif msg.type != WSMsgType.TEXT:
disconnect_warn = 'Received non-Text message.'
raise Disconnect
try:
msg = msg.json()
except ValueError:
disconnect_warn = 'Received invalid JSON.'
raise Disconnect
self._logger.debug("Received %s", msg)
connection = await auth.async_handle(msg)
# Command phase
while not wsock.closed:
msg = await wsock.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
break
elif msg.type != WSMsgType.TEXT:
disconnect_warn = 'Received non-Text message.'
break
try:
msg = msg.json()
except ValueError:
disconnect_warn = 'Received invalid JSON.'
break
self._logger.debug("Received %s", msg)
connection.async_handle(msg)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
except Disconnect:
pass
except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error inside websocket API")
finally:
unsub_stop()
if connection is not None:
connection.async_close()
try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
except asyncio.QueueFull:
self._writer_task.cancel()
await wsock.close()
if disconnect_warn is None:
self._logger.debug("Disconnected")
else:
self._logger.warning("Disconnected: %s", disconnect_warn)

View File

@ -4,7 +4,9 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api from homeassistant.components.websocket_api.http import URL
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
from tests.common import MockUser, CLIENT_ID from tests.common import MockUser, CLIENT_ID
@ -14,41 +16,52 @@ def hass_ws_client(aiohttp_client):
"""Websocket client fixture connected to websocket server.""" """Websocket client fixture connected to websocket server."""
async def create_client(hass, access_token=None): async def create_client(hass, access_token=None):
"""Create a websocket client.""" """Create a websocket client."""
wapi = hass.components.websocket_api
assert await async_setup_component(hass, 'websocket_api') assert await async_setup_component(hass, 'websocket_api')
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
patching = None patches = []
if access_token is not None: if access_token is None:
patching = patch('homeassistant.auth.AuthManager.active', patches.append(patch(
return_value=True) 'homeassistant.auth.AuthManager.active', return_value=False))
patching.start() patches.append(patch(
'homeassistant.auth.AuthManager.support_legacy',
return_value=True))
patches.append(patch(
'homeassistant.components.websocket_api.auth.'
'validate_password', return_value=True))
else:
patches.append(patch(
'homeassistant.auth.AuthManager.active', return_value=True))
patches.append(patch(
'homeassistant.components.http.auth.setup_auth'))
for p in patches:
p.start()
try: try:
websocket = await client.ws_connect(wapi.URL) websocket = await client.ws_connect(URL)
auth_resp = await websocket.receive_json() auth_resp = await websocket.receive_json()
assert auth_resp['type'] == TYPE_AUTH_REQUIRED
if auth_resp['type'] == wapi.TYPE_AUTH_OK: if access_token is None:
assert access_token is None, \ await websocket.send_json({
'Access token given but no auth required' 'type': TYPE_AUTH,
return websocket 'api_password': 'bla'
})
assert access_token is not None, \ else:
'Access token required for fixture' await websocket.send_json({
'type': TYPE_AUTH,
await websocket.send_json({ 'access_token': access_token
'type': websocket_api.TYPE_AUTH, })
'access_token': access_token
})
auth_ok = await websocket.receive_json() auth_ok = await websocket.receive_json()
assert auth_ok['type'] == wapi.TYPE_AUTH_OK assert auth_ok['type'] == TYPE_AUTH_OK
finally: finally:
if patching is not None: for p in patches:
patching.stop() p.stop()
# wrap in client # wrap in client
websocket.client = client websocket.client = client

View File

@ -62,7 +62,7 @@ class TestPanelIframe(unittest.TestCase):
panels = self.hass.data[frontend.DATA_PANELS] panels = self.hass.data[frontend.DATA_PANELS]
assert panels.get('router').to_response(self.hass, None) == { assert panels.get('router').to_response() == {
'component_name': 'iframe', 'component_name': 'iframe',
'config': {'url': 'http://192.168.1.1'}, 'config': {'url': 'http://192.168.1.1'},
'icon': 'mdi:network-wireless', 'icon': 'mdi:network-wireless',
@ -70,7 +70,7 @@ class TestPanelIframe(unittest.TestCase):
'url_path': 'router' 'url_path': 'router'
} }
assert panels.get('weather').to_response(self.hass, None) == { assert panels.get('weather').to_response() == {
'component_name': 'iframe', 'component_name': 'iframe',
'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'}, 'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'},
'icon': 'mdi:weather', 'icon': 'mdi:weather',
@ -78,7 +78,7 @@ class TestPanelIframe(unittest.TestCase):
'url_path': 'weather', 'url_path': 'weather',
} }
assert panels.get('api').to_response(self.hass, None) == { assert panels.get('api').to_response() == {
'component_name': 'iframe', 'component_name': 'iframe',
'config': {'url': '/api'}, 'config': {'url': '/api'},
'icon': 'mdi:weather', 'icon': 'mdi:weather',
@ -86,7 +86,7 @@ class TestPanelIframe(unittest.TestCase):
'url_path': 'api', 'url_path': 'api',
} }
assert panels.get('ftp').to_response(self.hass, None) == { assert panels.get('ftp').to_response() == {
'component_name': 'iframe', 'component_name': 'iframe',
'config': {'url': 'ftp://some/ftp'}, 'config': {'url': 'ftp://some/ftp'},
'icon': 'mdi:weather', 'icon': 'mdi:weather',

View File

@ -2,7 +2,8 @@
import pytest import pytest
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api.http import URL
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
from . import API_PASSWORD from . import API_PASSWORD
@ -24,10 +25,10 @@ def no_auth_websocket_client(hass, loop, aiohttp_client):
})) }))
client = loop.run_until_complete(aiohttp_client(hass.http.app)) client = loop.run_until_complete(aiohttp_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL)) ws = loop.run_until_complete(client.ws_connect(URL))
auth_ok = loop.run_until_complete(ws.receive_json()) auth_ok = loop.run_until_complete(ws.receive_json())
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_ok['type'] == TYPE_AUTH_REQUIRED
yield ws yield ws

View File

@ -1,7 +1,10 @@
"""Test auth of websocket API.""" """Test auth of websocket API."""
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api.const import URL
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH, TYPE_AUTH_INVALID, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
from homeassistant.components.websocket_api import commands from homeassistant.components.websocket_api import commands
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -13,28 +16,29 @@ from . import API_PASSWORD
async def test_auth_via_msg(no_auth_websocket_client): async def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating.""" """Test authenticating."""
await no_auth_websocket_client.send_json({ await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'api_password': API_PASSWORD 'api_password': API_PASSWORD
}) })
msg = await no_auth_websocket_client.receive_json() msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_OK assert msg['type'] == TYPE_AUTH_OK
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client): async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
"""Test authenticating.""" """Test authenticating."""
with patch('homeassistant.components.websocket_api.process_wrong_login', with patch('homeassistant.components.websocket_api.auth.'
return_value=mock_coro()) as mock_process_wrong_login: 'process_wrong_login', return_value=mock_coro()) \
as mock_process_wrong_login:
await no_auth_websocket_client.send_json({ await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'api_password': API_PASSWORD + 'wrong' 'api_password': API_PASSWORD + 'wrong'
}) })
msg = await no_auth_websocket_client.receive_json() msg = await no_auth_websocket_client.receive_json()
assert mock_process_wrong_login.called assert mock_process_wrong_login.called
assert msg['type'] == wapi.TYPE_AUTH_INVALID assert msg['type'] == TYPE_AUTH_INVALID
assert msg['message'] == 'Invalid access token or password' assert msg['message'] == 'Invalid access token or password'
@ -51,8 +55,8 @@ async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
msg = await no_auth_websocket_client.receive_json() msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_INVALID assert msg['type'] == TYPE_AUTH_INVALID
assert msg['message'].startswith('Message incorrectly formatted') assert msg['message'].startswith('Auth message incorrectly formatted')
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
@ -65,19 +69,19 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active: with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True auth_active.return_value = True
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'access_token': hass_access_token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK assert auth_msg['type'] == TYPE_AUTH_OK
async def test_auth_active_user_inactive(hass, aiohttp_client, async def test_auth_active_user_inactive(hass, aiohttp_client,
@ -94,19 +98,19 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active: with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True auth_active.return_value = True
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'access_token': hass_access_token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID assert auth_msg['type'] == TYPE_AUTH_INVALID
async def test_auth_active_with_password_not_allow(hass, aiohttp_client): async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
@ -119,19 +123,19 @@ async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active', with patch('homeassistant.auth.AuthManager.active',
return_value=True): return_value=True):
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'api_password': API_PASSWORD 'api_password': API_PASSWORD
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID assert auth_msg['type'] == TYPE_AUTH_INVALID
async def test_auth_legacy_support_with_password(hass, aiohttp_client): async def test_auth_legacy_support_with_password(hass, aiohttp_client):
@ -144,21 +148,21 @@ async def test_auth_legacy_support_with_password(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active', with patch('homeassistant.auth.AuthManager.active',
return_value=True),\ return_value=True),\
patch('homeassistant.auth.AuthManager.support_legacy', patch('homeassistant.auth.AuthManager.support_legacy',
return_value=True): return_value=True):
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'api_password': API_PASSWORD 'api_password': API_PASSWORD
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK assert auth_msg['type'] == TYPE_AUTH_OK
async def test_auth_with_invalid_token(hass, aiohttp_client): async def test_auth_with_invalid_token(hass, aiohttp_client):
@ -171,16 +175,16 @@ async def test_auth_with_invalid_token(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active: with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True auth_active.return_value = True
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'access_token': 'incorrect' 'access_token': 'incorrect'
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID assert auth_msg['type'] == TYPE_AUTH_INVALID

View File

@ -4,7 +4,10 @@ from unittest.mock import patch
from async_timeout import timeout from async_timeout import timeout
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi from homeassistant.components.websocket_api.const import URL
from homeassistant.components.websocket_api.auth import (
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED
)
from homeassistant.components.websocket_api import const, commands from homeassistant.components.websocket_api import const, commands
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -178,19 +181,19 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
calls = async_mock_service(hass, 'domain_test', 'test_service') calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active: with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True auth_active.return_value = True
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'access_token': hass_access_token 'access_token': hass_access_token
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK assert auth_msg['type'] == TYPE_AUTH_OK
await ws.send_json({ await ws.send_json({
'id': 5, 'id': 5,
@ -227,17 +230,17 @@ async def test_call_service_context_no_user(hass, aiohttp_client):
calls = async_mock_service(hass, 'domain_test', 'test_service') calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app) client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws: async with client.ws_connect(URL) as ws:
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED assert auth_msg['type'] == TYPE_AUTH_REQUIRED
await ws.send_json({ await ws.send_json({
'type': wapi.TYPE_AUTH, 'type': TYPE_AUTH,
'api_password': API_PASSWORD 'api_password': API_PASSWORD
}) })
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK assert auth_msg['type'] == TYPE_AUTH_OK
await ws.send_json({ await ws.send_json({
'id': 5, 'id': 5,

View File

@ -5,14 +5,14 @@ from unittest.mock import patch, Mock
from aiohttp import WSMsgType from aiohttp import WSMsgType
import pytest import pytest
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import const, commands, messages from homeassistant.components.websocket_api import const, commands, messages
@pytest.fixture @pytest.fixture
def mock_low_queue(): def mock_low_queue():
"""Mock a low queue.""" """Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5): with patch('homeassistant.components.websocket_api.http.MAX_PENDING_MSG',
5):
yield yield