Break up websocket 2 (#17028)
* Break up websocket 2 * Lint+Test * Lintttt * Renamepull/17030/merge
parent
b5e3d8c337
commit
2e6346ca43
|
@ -432,7 +432,7 @@ def websocket_current_user(
|
||||||
"""Get current user."""
|
"""Get current user."""
|
||||||
enabled_modules = await hass.auth.async_get_enabled_mfa(user)
|
enabled_modules = await hass.auth.async_get_enabled_mfa(user)
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], {
|
websocket_api.result_message(msg['id'], {
|
||||||
'id': user.id,
|
'id': user.id,
|
||||||
'name': user.name,
|
'name': user.name,
|
||||||
|
@ -467,7 +467,7 @@ def websocket_create_long_lived_access_token(
|
||||||
access_token = hass.auth.async_create_access_token(
|
access_token = hass.auth.async_create_access_token(
|
||||||
refresh_token)
|
refresh_token)
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], access_token))
|
websocket_api.result_message(msg['id'], access_token))
|
||||||
|
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
|
@ -479,8 +479,8 @@ def websocket_create_long_lived_access_token(
|
||||||
def websocket_refresh_tokens(
|
def websocket_refresh_tokens(
|
||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
||||||
"""Return metadata of users refresh tokens."""
|
"""Return metadata of users refresh tokens."""
|
||||||
current_id = connection.request.get('refresh_token_id')
|
current_id = connection.refresh_token_id
|
||||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], [{
|
connection.send_message(websocket_api.result_message(msg['id'], [{
|
||||||
'id': refresh.id,
|
'id': refresh.id,
|
||||||
'client_id': refresh.client_id,
|
'client_id': refresh.client_id,
|
||||||
'client_name': refresh.client_name,
|
'client_name': refresh.client_name,
|
||||||
|
@ -508,7 +508,7 @@ def websocket_delete_refresh_token(
|
||||||
|
|
||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], {}))
|
websocket_api.result_message(msg['id'], {}))
|
||||||
|
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
|
|
|
@ -64,7 +64,7 @@ def websocket_setup_mfa(
|
||||||
if flow_id is not None:
|
if flow_id is not None:
|
||||||
result = await flow_manager.async_configure(
|
result = await flow_manager.async_configure(
|
||||||
flow_id, msg.get('user_input'))
|
flow_id, msg.get('user_input'))
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(
|
websocket_api.result_message(
|
||||||
msg['id'], _prepare_result_json(result)))
|
msg['id'], _prepare_result_json(result)))
|
||||||
return
|
return
|
||||||
|
@ -72,7 +72,7 @@ def websocket_setup_mfa(
|
||||||
mfa_module_id = msg.get('mfa_module_id')
|
mfa_module_id = msg.get('mfa_module_id')
|
||||||
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
|
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
|
||||||
if mfa_module is None:
|
if mfa_module is None:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'no_module',
|
msg['id'], 'no_module',
|
||||||
'MFA module {} is not found'.format(mfa_module_id)))
|
'MFA module {} is not found'.format(mfa_module_id)))
|
||||||
return
|
return
|
||||||
|
@ -80,7 +80,7 @@ def websocket_setup_mfa(
|
||||||
result = await flow_manager.async_init(
|
result = await flow_manager.async_init(
|
||||||
mfa_module_id, data={'user_id': connection.user.id})
|
mfa_module_id, data={'user_id': connection.user.id})
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(
|
websocket_api.result_message(
|
||||||
msg['id'], _prepare_result_json(result)))
|
msg['id'], _prepare_result_json(result)))
|
||||||
|
|
||||||
|
@ -99,13 +99,13 @@ def websocket_depose_mfa(
|
||||||
await hass.auth.async_disable_user_mfa(
|
await hass.auth.async_disable_user_mfa(
|
||||||
connection.user, msg['mfa_module_id'])
|
connection.user, msg['mfa_module_id'])
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'disable_failed',
|
msg['id'], 'disable_failed',
|
||||||
'Cannot disable MFA Module {}: {}'.format(
|
'Cannot disable MFA Module {}: {}'.format(
|
||||||
mfa_module_id, err)))
|
mfa_module_id, err)))
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(
|
websocket_api.result_message(
|
||||||
msg['id'], 'done'))
|
msg['id'], 'done'))
|
||||||
|
|
||||||
|
|
|
@ -460,14 +460,14 @@ async def websocket_camera_thumbnail(hass, connection, msg):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
image = await async_get_image(hass, msg['entity_id'])
|
image = await async_get_image(hass, msg['entity_id'])
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], {
|
msg['id'], {
|
||||||
'content_type': image.content_type,
|
'content_type': image.content_type,
|
||||||
'content': base64.b64encode(image.content).decode('utf-8')
|
'content': base64.b64encode(image.content).decode('utf-8')
|
||||||
}
|
}
|
||||||
))
|
))
|
||||||
except HomeAssistantError:
|
except HomeAssistantError:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
|
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -231,7 +231,7 @@ def websocket_cloud_status(hass, connection, msg):
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
cloud = hass.data[DOMAIN]
|
cloud = hass.data[DOMAIN]
|
||||||
connection.to_write.put_nowait(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], _account_data(cloud)))
|
websocket_api.result_message(msg['id'], _account_data(cloud)))
|
||||||
|
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ async def websocket_subscription(hass, connection, msg):
|
||||||
cloud = hass.data[DOMAIN]
|
cloud = hass.data[DOMAIN]
|
||||||
|
|
||||||
if not cloud.is_logged_in:
|
if not cloud.is_logged_in:
|
||||||
connection.to_write.put_nowait(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'not_logged_in',
|
msg['id'], 'not_logged_in',
|
||||||
'You need to be logged in to the cloud.'))
|
'You need to be logged in to the cloud.'))
|
||||||
return
|
return
|
||||||
|
@ -250,10 +250,10 @@ async def websocket_subscription(hass, connection, msg):
|
||||||
response = await cloud.fetch_subscription_info()
|
response = await cloud.fetch_subscription_info()
|
||||||
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], await response.json()))
|
msg['id'], await response.json()))
|
||||||
else:
|
else:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'request_failed', 'Failed to request subscription'))
|
msg['id'], 'request_failed', 'Failed to request subscription'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ async def websocket_update_prefs(hass, connection, msg):
|
||||||
cloud = hass.data[DOMAIN]
|
cloud = hass.data[DOMAIN]
|
||||||
|
|
||||||
if not cloud.is_logged_in:
|
if not cloud.is_logged_in:
|
||||||
connection.to_write.put_nowait(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'not_logged_in',
|
msg['id'], 'not_logged_in',
|
||||||
'You need to be logged in to the cloud.'))
|
'You need to be logged in to the cloud.'))
|
||||||
return
|
return
|
||||||
|
@ -273,7 +273,7 @@ async def websocket_update_prefs(hass, connection, msg):
|
||||||
changes.pop('type')
|
changes.pop('type')
|
||||||
await cloud.update_preferences(**changes)
|
await cloud.update_preferences(**changes)
|
||||||
|
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], {'success': True}))
|
msg['id'], {'success': True}))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ def websocket_list(hass, connection, msg):
|
||||||
"""Send users."""
|
"""Send users."""
|
||||||
result = [_user_info(u) for u in await hass.auth.async_get_users()]
|
result = [_user_info(u) for u in await hass.auth.async_get_users()]
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], result))
|
websocket_api.result_message(msg['id'], result))
|
||||||
|
|
||||||
hass.async_add_job(send_users())
|
hass.async_add_job(send_users())
|
||||||
|
@ -61,8 +61,8 @@ def websocket_delete(hass, connection, msg):
|
||||||
"""Delete a user."""
|
"""Delete a user."""
|
||||||
async def delete_user():
|
async def delete_user():
|
||||||
"""Delete user."""
|
"""Delete user."""
|
||||||
if msg['user_id'] == connection.request.get('hass_user').id:
|
if msg['user_id'] == connection.user.id:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'no_delete_self',
|
msg['id'], 'no_delete_self',
|
||||||
'Unable to delete your own account'))
|
'Unable to delete your own account'))
|
||||||
return
|
return
|
||||||
|
@ -70,13 +70,13 @@ def websocket_delete(hass, connection, msg):
|
||||||
user = await hass.auth.async_get_user(msg['user_id'])
|
user = await hass.auth.async_get_user(msg['user_id'])
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'not_found', 'User not found'))
|
msg['id'], 'not_found', 'User not found'))
|
||||||
return
|
return
|
||||||
|
|
||||||
await hass.auth.async_remove_user(user)
|
await hass.auth.async_remove_user(user)
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id']))
|
websocket_api.result_message(msg['id']))
|
||||||
|
|
||||||
hass.async_add_job(delete_user())
|
hass.async_add_job(delete_user())
|
||||||
|
@ -90,7 +90,7 @@ def websocket_create(hass, connection, msg):
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
user = await hass.auth.async_create_user(msg['name'])
|
user = await hass.auth.async_create_user(msg['name'])
|
||||||
|
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], {
|
websocket_api.result_message(msg['id'], {
|
||||||
'user': _user_info(user)
|
'user': _user_info(user)
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.auth.providers import homeassistant as auth_ha
|
from homeassistant.auth.providers import homeassistant as auth_ha
|
||||||
from homeassistant.core import callback
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.websocket_api.decorators import require_owner
|
from homeassistant.components.websocket_api.decorators import require_owner
|
||||||
|
|
||||||
|
@ -55,121 +54,109 @@ def _get_provider(hass):
|
||||||
raise RuntimeError('Provider not found')
|
raise RuntimeError('Provider not found')
|
||||||
|
|
||||||
|
|
||||||
@callback
|
|
||||||
@require_owner
|
@require_owner
|
||||||
def websocket_create(hass, connection, msg):
|
@websocket_api.async_response
|
||||||
|
async def websocket_create(hass, connection, msg):
|
||||||
"""Create credentials and attach to a user."""
|
"""Create credentials and attach to a user."""
|
||||||
async def create_creds():
|
provider = _get_provider(hass)
|
||||||
"""Create credentials."""
|
await provider.async_initialize()
|
||||||
provider = _get_provider(hass)
|
|
||||||
await provider.async_initialize()
|
|
||||||
|
|
||||||
user = await hass.auth.async_get_user(msg['user_id'])
|
user = await hass.auth.async_get_user(msg['user_id'])
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'not_found', 'User not found'))
|
msg['id'], 'not_found', 'User not found'))
|
||||||
return
|
return
|
||||||
|
|
||||||
if user.system_generated:
|
if user.system_generated:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'system_generated',
|
msg['id'], 'system_generated',
|
||||||
'Cannot add credentials to a system generated user.'))
|
'Cannot add credentials to a system generated user.'))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
|
||||||
await hass.async_add_executor_job(
|
|
||||||
provider.data.add_auth, msg['username'], msg['password'])
|
|
||||||
except auth_ha.InvalidUser:
|
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
|
||||||
msg['id'], 'username_exists', 'Username already exists'))
|
|
||||||
return
|
|
||||||
|
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
|
||||||
'username': msg['username']
|
|
||||||
})
|
|
||||||
await hass.auth.async_link_user(user, credentials)
|
|
||||||
|
|
||||||
await provider.data.async_save()
|
|
||||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id']))
|
|
||||||
|
|
||||||
hass.async_add_job(create_creds())
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
|
||||||
@require_owner
|
|
||||||
def websocket_delete(hass, connection, msg):
|
|
||||||
"""Delete username and related credential."""
|
|
||||||
async def delete_creds():
|
|
||||||
"""Delete user credentials."""
|
|
||||||
provider = _get_provider(hass)
|
|
||||||
await provider.async_initialize()
|
|
||||||
|
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
|
||||||
'username': msg['username']
|
|
||||||
})
|
|
||||||
|
|
||||||
# if not new, an existing credential exists.
|
|
||||||
# Removing the credential will also remove the auth.
|
|
||||||
if not credentials.is_new:
|
|
||||||
await hass.auth.async_remove_credentials(credentials)
|
|
||||||
|
|
||||||
connection.to_write.put_nowait(
|
|
||||||
websocket_api.result_message(msg['id']))
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider.data.async_remove_auth(msg['username'])
|
|
||||||
await provider.data.async_save()
|
|
||||||
except auth_ha.InvalidUser:
|
|
||||||
connection.to_write.put_nowait(websocket_api.error_message(
|
|
||||||
msg['id'], 'auth_not_found', 'Given username was not found.'))
|
|
||||||
return
|
|
||||||
|
|
||||||
connection.to_write.put_nowait(
|
|
||||||
websocket_api.result_message(msg['id']))
|
|
||||||
|
|
||||||
hass.async_add_job(delete_creds())
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def websocket_change_password(hass, connection, msg):
|
|
||||||
"""Change user password."""
|
|
||||||
async def change_password():
|
|
||||||
"""Change user password."""
|
|
||||||
user = connection.request.get('hass_user')
|
|
||||||
if user is None:
|
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
|
||||||
msg['id'], 'user_not_found', 'User not found'))
|
|
||||||
return
|
|
||||||
|
|
||||||
provider = _get_provider(hass)
|
|
||||||
await provider.async_initialize()
|
|
||||||
|
|
||||||
username = None
|
|
||||||
for credential in user.credentials:
|
|
||||||
if credential.auth_provider_type == provider.type:
|
|
||||||
username = credential.data['username']
|
|
||||||
break
|
|
||||||
|
|
||||||
if username is None:
|
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
|
||||||
msg['id'], 'credentials_not_found', 'Credentials not found'))
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await provider.async_validate_login(
|
|
||||||
username, msg['current_password'])
|
|
||||||
except auth_ha.InvalidAuth:
|
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
|
||||||
msg['id'], 'invalid_password', 'Invalid password'))
|
|
||||||
return
|
|
||||||
|
|
||||||
|
try:
|
||||||
await hass.async_add_executor_job(
|
await hass.async_add_executor_job(
|
||||||
provider.data.change_password, username, msg['new_password'])
|
provider.data.add_auth, msg['username'], msg['password'])
|
||||||
await provider.data.async_save()
|
except auth_ha.InvalidUser:
|
||||||
|
connection.send_message(websocket_api.error_message(
|
||||||
|
msg['id'], 'username_exists', 'Username already exists'))
|
||||||
|
return
|
||||||
|
|
||||||
connection.send_message_outside(
|
credentials = await provider.async_get_or_create_credentials({
|
||||||
|
'username': msg['username']
|
||||||
|
})
|
||||||
|
await hass.auth.async_link_user(user, credentials)
|
||||||
|
|
||||||
|
await provider.data.async_save()
|
||||||
|
connection.send_message(websocket_api.result_message(msg['id']))
|
||||||
|
|
||||||
|
|
||||||
|
@require_owner
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_delete(hass, connection, msg):
|
||||||
|
"""Delete username and related credential."""
|
||||||
|
provider = _get_provider(hass)
|
||||||
|
await provider.async_initialize()
|
||||||
|
|
||||||
|
credentials = await provider.async_get_or_create_credentials({
|
||||||
|
'username': msg['username']
|
||||||
|
})
|
||||||
|
|
||||||
|
# if not new, an existing credential exists.
|
||||||
|
# Removing the credential will also remove the auth.
|
||||||
|
if not credentials.is_new:
|
||||||
|
await hass.auth.async_remove_credentials(credentials)
|
||||||
|
|
||||||
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id']))
|
websocket_api.result_message(msg['id']))
|
||||||
|
return
|
||||||
|
|
||||||
hass.async_add_job(change_password())
|
try:
|
||||||
|
provider.data.async_remove_auth(msg['username'])
|
||||||
|
await provider.data.async_save()
|
||||||
|
except auth_ha.InvalidUser:
|
||||||
|
connection.send_message(websocket_api.error_message(
|
||||||
|
msg['id'], 'auth_not_found', 'Given username was not found.'))
|
||||||
|
return
|
||||||
|
|
||||||
|
connection.send_message(
|
||||||
|
websocket_api.result_message(msg['id']))
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_change_password(hass, connection, msg):
|
||||||
|
"""Change user password."""
|
||||||
|
user = connection.user
|
||||||
|
if user is None:
|
||||||
|
connection.send_message(websocket_api.error_message(
|
||||||
|
msg['id'], 'user_not_found', 'User not found'))
|
||||||
|
return
|
||||||
|
|
||||||
|
provider = _get_provider(hass)
|
||||||
|
await provider.async_initialize()
|
||||||
|
|
||||||
|
username = None
|
||||||
|
for credential in user.credentials:
|
||||||
|
if credential.auth_provider_type == provider.type:
|
||||||
|
username = credential.data['username']
|
||||||
|
break
|
||||||
|
|
||||||
|
if username is None:
|
||||||
|
connection.send_message(websocket_api.error_message(
|
||||||
|
msg['id'], 'credentials_not_found', 'Credentials not found'))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await provider.async_validate_login(
|
||||||
|
username, msg['current_password'])
|
||||||
|
except auth_ha.InvalidAuth:
|
||||||
|
connection.send_message(websocket_api.error_message(
|
||||||
|
msg['id'], 'invalid_password', 'Invalid password'))
|
||||||
|
return
|
||||||
|
|
||||||
|
await hass.async_add_executor_job(
|
||||||
|
provider.data.change_password, username, msg['new_password'])
|
||||||
|
await provider.data.async_save()
|
||||||
|
|
||||||
|
connection.send_message(
|
||||||
|
websocket_api.result_message(msg['id']))
|
||||||
|
|
|
@ -31,7 +31,7 @@ def websocket_list_devices(hass, connection, msg):
|
||||||
async def retrieve_entities():
|
async def retrieve_entities():
|
||||||
"""Get devices from registry."""
|
"""Get devices from registry."""
|
||||||
registry = await async_get_registry(hass)
|
registry = await async_get_registry(hass)
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], [{
|
msg['id'], [{
|
||||||
'config_entries': list(entry.config_entries),
|
'config_entries': list(entry.config_entries),
|
||||||
'connections': list(entry.connections),
|
'connections': list(entry.connections),
|
||||||
|
|
|
@ -55,7 +55,7 @@ async def websocket_list_entities(hass, connection, msg):
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
registry = await async_get_registry(hass)
|
registry = await async_get_registry(hass)
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], [{
|
msg['id'], [{
|
||||||
'config_entry_id': entry.config_entry_id,
|
'config_entry_id': entry.config_entry_id,
|
||||||
'device_id': entry.device_id,
|
'device_id': entry.device_id,
|
||||||
|
@ -77,11 +77,11 @@ async def websocket_get_entity(hass, connection, msg):
|
||||||
entry = registry.entities.get(msg['entity_id'])
|
entry = registry.entities.get(msg['entity_id'])
|
||||||
|
|
||||||
if entry is None:
|
if entry is None:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], _entry_dict(entry)
|
msg['id'], _entry_dict(entry)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ async def websocket_update_entity(hass, connection, msg):
|
||||||
registry = await async_get_registry(hass)
|
registry = await async_get_registry(hass)
|
||||||
|
|
||||||
if msg['entity_id'] not in registry.entities:
|
if msg['entity_id'] not in registry.entities:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -112,11 +112,11 @@ async def websocket_update_entity(hass, connection, msg):
|
||||||
entry = registry.async_update_entity(
|
entry = registry.async_update_entity(
|
||||||
msg['entity_id'], **changes)
|
msg['entity_id'], **changes)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'invalid_info', str(err)
|
msg['id'], 'invalid_info', str(err)
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], _entry_dict(entry)
|
msg['id'], _entry_dict(entry)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
@ -145,7 +145,7 @@ class Panel:
|
||||||
index_view.get)
|
index_view.get)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def to_response(self, hass, request):
|
def to_response(self):
|
||||||
"""Panel as dictionary."""
|
"""Panel as dictionary."""
|
||||||
return {
|
return {
|
||||||
'component_name': self.component_name,
|
'component_name': self.component_name,
|
||||||
|
@ -485,12 +485,10 @@ def websocket_get_panels(hass, connection, msg):
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
panels = {
|
panels = {
|
||||||
panel:
|
panel: connection.hass.data[DATA_PANELS][panel].to_response()
|
||||||
connection.hass.data[DATA_PANELS][panel].to_response(
|
|
||||||
connection.hass, connection.request)
|
|
||||||
for panel in connection.hass.data[DATA_PANELS]}
|
for panel in connection.hass.data[DATA_PANELS]}
|
||||||
|
|
||||||
connection.to_write.put_nowait(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], panels))
|
msg['id'], panels))
|
||||||
|
|
||||||
|
|
||||||
|
@ -500,25 +498,21 @@ def websocket_get_themes(hass, connection, msg):
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
connection.to_write.put_nowait(websocket_api.result_message(msg['id'], {
|
connection.send_message(websocket_api.result_message(msg['id'], {
|
||||||
'themes': hass.data[DATA_THEMES],
|
'themes': hass.data[DATA_THEMES],
|
||||||
'default_theme': hass.data[DATA_DEFAULT_THEME],
|
'default_theme': hass.data[DATA_DEFAULT_THEME],
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@websocket_api.async_response
|
||||||
def websocket_get_translations(hass, connection, msg):
|
async def websocket_get_translations(hass, connection, msg):
|
||||||
"""Handle get translations command.
|
"""Handle get translations command.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
async def send_translations():
|
resources = await async_get_translations(hass, msg['language'])
|
||||||
"""Send a translation."""
|
connection.send_message(websocket_api.result_message(
|
||||||
resources = await async_get_translations(hass, msg['language'])
|
msg['id'], {
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
'resources': resources,
|
||||||
msg['id'], {
|
}
|
||||||
'resources': resources,
|
))
|
||||||
}
|
|
||||||
))
|
|
||||||
|
|
||||||
hass.async_add_job(send_translations())
|
|
||||||
|
|
|
@ -112,6 +112,7 @@ async def async_validate_auth_header(request, api_password=None):
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
request['hass_refresh_token'] = refresh_token
|
||||||
request['hass_user'] = refresh_token.user
|
request['hass_user'] = refresh_token.user
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -48,4 +48,4 @@ async def websocket_lovelace_config(hass, connection, msg):
|
||||||
if error is not None:
|
if error is not None:
|
||||||
message = websocket_api.error_message(msg['id'], *error)
|
message = websocket_api.error_message(msg['id'], *error)
|
||||||
|
|
||||||
connection.send_message_outside(message)
|
connection.send_message(message)
|
||||||
|
|
|
@ -874,19 +874,19 @@ async def websocket_handle_thumbnail(hass, connection, msg):
|
||||||
player = component.get_entity(msg['entity_id'])
|
player = component.get_entity(msg['entity_id'])
|
||||||
|
|
||||||
if player is None:
|
if player is None:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'entity_not_found', 'Entity not found'))
|
msg['id'], 'entity_not_found', 'Entity not found'))
|
||||||
return
|
return
|
||||||
|
|
||||||
data, content_type = await player.async_get_media_image()
|
data, content_type = await player.async_get_media_image()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
connection.send_message_outside(websocket_api.error_message(
|
connection.send_message(websocket_api.error_message(
|
||||||
msg['id'], 'thumbnail_fetch_failed',
|
msg['id'], 'thumbnail_fetch_failed',
|
||||||
'Failed to fetch thumbnail'))
|
'Failed to fetch thumbnail'))
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message_outside(websocket_api.result_message(
|
connection.send_message(websocket_api.result_message(
|
||||||
msg['id'], {
|
msg['id'], {
|
||||||
'content_type': content_type,
|
'content_type': content_type,
|
||||||
'content': base64.b64encode(data).decode('utf-8')
|
'content': base64.b64encode(data).decode('utf-8')
|
||||||
|
|
|
@ -199,7 +199,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
|
||||||
def websocket_get_notifications(
|
def websocket_get_notifications(
|
||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
||||||
"""Return a list of persistent_notifications."""
|
"""Return a list of persistent_notifications."""
|
||||||
connection.to_write.put_nowait(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg['id'], [
|
websocket_api.result_message(msg['id'], [
|
||||||
{
|
{
|
||||||
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,
|
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,
|
||||||
|
|
|
@ -4,49 +4,18 @@ Websocket based API for Home Assistant.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://developers.home-assistant.io/docs/external_api_websocket.html
|
https://developers.home-assistant.io/docs/external_api_websocket.html
|
||||||
"""
|
"""
|
||||||
import asyncio
|
from homeassistant.core import callback
|
||||||
from concurrent import futures
|
|
||||||
from contextlib import suppress
|
|
||||||
from functools import partial
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
import voluptuous as vol
|
|
||||||
from voluptuous.humanize import humanize_error
|
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
|
|
||||||
from homeassistant.core import Context, callback
|
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.helpers.json import JSONEncoder
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
|
||||||
from homeassistant.components.http.auth import validate_password
|
|
||||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
|
||||||
from homeassistant.components.http.ban import process_wrong_login, \
|
|
||||||
process_success_login
|
|
||||||
|
|
||||||
from . import commands, const, decorators, messages
|
from . import commands, connection, const, decorators, http, messages
|
||||||
|
|
||||||
DOMAIN = 'websocket_api'
|
DOMAIN = const.DOMAIN
|
||||||
|
|
||||||
URL = '/api/websocket'
|
|
||||||
DEPENDENCIES = ('http',)
|
DEPENDENCIES = ('http',)
|
||||||
|
|
||||||
MAX_PENDING_MSG = 512
|
# Backwards compat / Make it easier to integrate
|
||||||
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
|
||||||
|
|
||||||
TYPE_AUTH = 'auth'
|
|
||||||
TYPE_AUTH_INVALID = 'auth_invalid'
|
|
||||||
TYPE_AUTH_OK = 'auth_ok'
|
|
||||||
TYPE_AUTH_REQUIRED = 'auth_required'
|
|
||||||
|
|
||||||
|
|
||||||
# Backwards compat
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
ActiveConnection = connection.ActiveConnection
|
||||||
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
||||||
error_message = messages.error_message
|
error_message = messages.error_message
|
||||||
result_message = messages.result_message
|
result_message = messages.result_message
|
||||||
|
@ -54,42 +23,6 @@ async_response = decorators.async_response
|
||||||
ws_require_user = decorators.ws_require_user
|
ws_require_user = decorators.ws_require_user
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
|
|
||||||
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
|
||||||
vol.Required('type'): TYPE_AUTH,
|
|
||||||
vol.Exclusive('api_password', 'auth'): str,
|
|
||||||
vol.Exclusive('access_token', 'auth'): str,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# Define the possible errors that occur when connections are cancelled.
|
|
||||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
|
||||||
# that futures.CancelledErrors can also occur in some situations.
|
|
||||||
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
|
||||||
|
|
||||||
|
|
||||||
def auth_ok_message():
|
|
||||||
"""Return an auth_ok message."""
|
|
||||||
return {
|
|
||||||
'type': TYPE_AUTH_OK,
|
|
||||||
'ha_version': __version__,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def auth_required_message():
|
|
||||||
"""Return an auth_required message."""
|
|
||||||
return {
|
|
||||||
'type': TYPE_AUTH_REQUIRED,
|
|
||||||
'ha_version': __version__,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def auth_invalid_message(message):
|
|
||||||
"""Return an auth_invalid message."""
|
|
||||||
return {
|
|
||||||
'type': TYPE_AUTH_INVALID,
|
|
||||||
'message': message,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@callback
|
@callback
|
||||||
|
@ -103,255 +36,6 @@ def async_register_command(hass, command, handler, schema):
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Initialize the websocket API."""
|
"""Initialize the websocket API."""
|
||||||
hass.http.register_view(WebsocketAPIView)
|
hass.http.register_view(http.WebsocketAPIView)
|
||||||
commands.async_register_commands(hass)
|
commands.async_register_commands(hass)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class WebsocketAPIView(HomeAssistantView):
|
|
||||||
"""View to serve a websockets endpoint."""
|
|
||||||
|
|
||||||
name = "websocketapi"
|
|
||||||
url = URL
|
|
||||||
requires_auth = False
|
|
||||||
|
|
||||||
async def get(self, request):
|
|
||||||
"""Handle an incoming websocket connection."""
|
|
||||||
return await ActiveConnection(request.app['hass'], request).handle()
|
|
||||||
|
|
||||||
|
|
||||||
class ActiveConnection:
|
|
||||||
"""Handle an active websocket client connection."""
|
|
||||||
|
|
||||||
def __init__(self, hass, request):
|
|
||||||
"""Initialize an active connection."""
|
|
||||||
self.hass = hass
|
|
||||||
self.request = request
|
|
||||||
self.wsock = None
|
|
||||||
self.event_listeners = {}
|
|
||||||
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
|
||||||
self._handle_task = None
|
|
||||||
self._writer_task = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def user(self):
|
|
||||||
"""Return the user associated with the connection."""
|
|
||||||
return self.request.get('hass_user')
|
|
||||||
|
|
||||||
def context(self, msg):
|
|
||||||
"""Return a context."""
|
|
||||||
user = self.user
|
|
||||||
if user is None:
|
|
||||||
return Context()
|
|
||||||
return Context(user_id=user.id)
|
|
||||||
|
|
||||||
def debug(self, message1, message2=''):
|
|
||||||
"""Print a debug message."""
|
|
||||||
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
|
|
||||||
|
|
||||||
def log_error(self, message1, message2=''):
|
|
||||||
"""Print an error message."""
|
|
||||||
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
|
|
||||||
|
|
||||||
async def _writer(self):
|
|
||||||
"""Write outgoing messages."""
|
|
||||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
|
||||||
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
|
||||||
while not self.wsock.closed:
|
|
||||||
message = await self.to_write.get()
|
|
||||||
if message is None:
|
|
||||||
break
|
|
||||||
self.debug("Sending", message)
|
|
||||||
try:
|
|
||||||
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
|
||||||
except TypeError as err:
|
|
||||||
_LOGGER.error('Unable to serialize to JSON: %s\n%s',
|
|
||||||
err, message)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def send_message_outside(self, message):
|
|
||||||
"""Send a message to the client.
|
|
||||||
|
|
||||||
Closes connection if the client is not reading the messages.
|
|
||||||
|
|
||||||
Async friendly.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.to_write.put_nowait(message)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
self.log_error("Client exceeded max pending messages [2]:",
|
|
||||||
MAX_PENDING_MSG)
|
|
||||||
self.cancel()
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def cancel(self):
|
|
||||||
"""Cancel the connection."""
|
|
||||||
self._handle_task.cancel()
|
|
||||||
self._writer_task.cancel()
|
|
||||||
|
|
||||||
async def handle(self):
|
|
||||||
"""Handle the websocket connection."""
|
|
||||||
request = self.request
|
|
||||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
|
||||||
await wsock.prepare(request)
|
|
||||||
self.debug("Connected")
|
|
||||||
|
|
||||||
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def handle_hass_stop(event):
|
|
||||||
"""Cancel this connection."""
|
|
||||||
self.cancel()
|
|
||||||
|
|
||||||
unsub_stop = self.hass.bus.async_listen(
|
|
||||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
|
||||||
self._writer_task = self.hass.async_add_job(self._writer())
|
|
||||||
final_message = None
|
|
||||||
msg = None
|
|
||||||
authenticated = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if request[KEY_AUTHENTICATED]:
|
|
||||||
authenticated = True
|
|
||||||
|
|
||||||
# always request auth when auth is active
|
|
||||||
# even request passed pre-authentication (trusted networks)
|
|
||||||
# or when using legacy api_password
|
|
||||||
if self.hass.auth.active or not authenticated:
|
|
||||||
self.debug("Request auth")
|
|
||||||
await self.wsock.send_json(auth_required_message())
|
|
||||||
msg = await wsock.receive_json()
|
|
||||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
|
||||||
|
|
||||||
if self.hass.auth.active and 'access_token' in msg:
|
|
||||||
self.debug("Received access_token")
|
|
||||||
refresh_token = \
|
|
||||||
await self.hass.auth.async_validate_access_token(
|
|
||||||
msg['access_token'])
|
|
||||||
authenticated = refresh_token is not None
|
|
||||||
if authenticated:
|
|
||||||
request['hass_user'] = refresh_token.user
|
|
||||||
request['refresh_token_id'] = refresh_token.id
|
|
||||||
|
|
||||||
elif ((not self.hass.auth.active or
|
|
||||||
self.hass.auth.support_legacy) and
|
|
||||||
'api_password' in msg):
|
|
||||||
self.debug("Received api_password")
|
|
||||||
authenticated = validate_password(
|
|
||||||
request, msg['api_password'])
|
|
||||||
|
|
||||||
if not authenticated:
|
|
||||||
self.debug("Authorization failed")
|
|
||||||
await self.wsock.send_json(
|
|
||||||
auth_invalid_message('Invalid access token or password'))
|
|
||||||
await process_wrong_login(request)
|
|
||||||
return wsock
|
|
||||||
|
|
||||||
self.debug("Auth OK")
|
|
||||||
await process_success_login(request)
|
|
||||||
await self.wsock.send_json(auth_ok_message())
|
|
||||||
|
|
||||||
# ---------- AUTH PHASE OVER ----------
|
|
||||||
|
|
||||||
msg = await wsock.receive_json()
|
|
||||||
last_id = 0
|
|
||||||
handlers = self.hass.data[DOMAIN]
|
|
||||||
|
|
||||||
while msg:
|
|
||||||
self.debug("Received", msg)
|
|
||||||
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
|
|
||||||
cur_id = msg['id']
|
|
||||||
|
|
||||||
if cur_id <= last_id:
|
|
||||||
self.to_write.put_nowait(messages.error_message(
|
|
||||||
cur_id, const.ERR_ID_REUSE,
|
|
||||||
'Identifier values have to increase.'))
|
|
||||||
|
|
||||||
elif msg['type'] not in handlers:
|
|
||||||
self.log_error(
|
|
||||||
'Received invalid command: {}'.format(msg['type']))
|
|
||||||
self.to_write.put_nowait(messages.error_message(
|
|
||||||
cur_id, const.ERR_UNKNOWN_COMMAND,
|
|
||||||
'Unknown command.'))
|
|
||||||
|
|
||||||
else:
|
|
||||||
handler, schema = handlers[msg['type']]
|
|
||||||
try:
|
|
||||||
handler(self.hass, self, schema(msg))
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
_LOGGER.exception('Error handling message: %s', msg)
|
|
||||||
self.to_write.put_nowait(messages.error_message(
|
|
||||||
cur_id, const.ERR_UNKNOWN_ERROR,
|
|
||||||
'Unknown error.'))
|
|
||||||
|
|
||||||
last_id = cur_id
|
|
||||||
msg = await wsock.receive_json()
|
|
||||||
|
|
||||||
except vol.Invalid as err:
|
|
||||||
error_msg = "Message incorrectly formatted: "
|
|
||||||
if msg:
|
|
||||||
error_msg += humanize_error(msg, err)
|
|
||||||
else:
|
|
||||||
error_msg += str(err)
|
|
||||||
|
|
||||||
self.log_error(error_msg)
|
|
||||||
|
|
||||||
if not authenticated:
|
|
||||||
final_message = auth_invalid_message(error_msg)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
iden = msg.get('id')
|
|
||||||
else:
|
|
||||||
iden = None
|
|
||||||
|
|
||||||
final_message = messages.error_message(
|
|
||||||
iden, const.ERR_INVALID_FORMAT, error_msg)
|
|
||||||
|
|
||||||
except TypeError as err:
|
|
||||||
if wsock.closed:
|
|
||||||
self.debug("Connection closed by client")
|
|
||||||
else:
|
|
||||||
_LOGGER.exception("Unexpected TypeError: %s", err)
|
|
||||||
|
|
||||||
except ValueError as err:
|
|
||||||
msg = "Received invalid JSON"
|
|
||||||
value = getattr(err, 'doc', None) # Py3.5+ only
|
|
||||||
if value:
|
|
||||||
msg += ': {}'.format(value)
|
|
||||||
self.log_error(msg)
|
|
||||||
self._writer_task.cancel()
|
|
||||||
|
|
||||||
except CANCELLATION_ERRORS:
|
|
||||||
self.debug("Connection cancelled")
|
|
||||||
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
self.log_error("Client exceeded max pending messages [1]:",
|
|
||||||
MAX_PENDING_MSG)
|
|
||||||
self._writer_task.cancel()
|
|
||||||
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
error = "Unexpected error inside websocket API. "
|
|
||||||
if msg is not None:
|
|
||||||
error += str(msg)
|
|
||||||
_LOGGER.exception(error)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
unsub_stop()
|
|
||||||
|
|
||||||
for unsub in self.event_listeners.values():
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if final_message is not None:
|
|
||||||
self.to_write.put_nowait(final_message)
|
|
||||||
self.to_write.put_nowait(None)
|
|
||||||
# Make sure all error messages are written before closing
|
|
||||||
await self._writer_task
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
self._writer_task.cancel()
|
|
||||||
|
|
||||||
await wsock.close()
|
|
||||||
self.debug("Closed connection")
|
|
||||||
|
|
||||||
return wsock
|
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
"""Handle the auth of a connection."""
|
||||||
|
import voluptuous as vol
|
||||||
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
|
from homeassistant.const import __version__
|
||||||
|
from homeassistant.components.http.auth import validate_password
|
||||||
|
from homeassistant.components.http.ban import process_wrong_login, \
|
||||||
|
process_success_login
|
||||||
|
|
||||||
|
from .connection import ActiveConnection
|
||||||
|
from .error import Disconnect
|
||||||
|
|
||||||
|
TYPE_AUTH = 'auth'
|
||||||
|
TYPE_AUTH_INVALID = 'auth_invalid'
|
||||||
|
TYPE_AUTH_OK = 'auth_ok'
|
||||||
|
TYPE_AUTH_REQUIRED = 'auth_required'
|
||||||
|
|
||||||
|
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||||||
|
vol.Required('type'): TYPE_AUTH,
|
||||||
|
vol.Exclusive('api_password', 'auth'): str,
|
||||||
|
vol.Exclusive('access_token', 'auth'): str,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def auth_ok_message():
|
||||||
|
"""Return an auth_ok message."""
|
||||||
|
return {
|
||||||
|
'type': TYPE_AUTH_OK,
|
||||||
|
'ha_version': __version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def auth_required_message():
|
||||||
|
"""Return an auth_required message."""
|
||||||
|
return {
|
||||||
|
'type': TYPE_AUTH_REQUIRED,
|
||||||
|
'ha_version': __version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def auth_invalid_message(message):
|
||||||
|
"""Return an auth_invalid message."""
|
||||||
|
return {
|
||||||
|
'type': TYPE_AUTH_INVALID,
|
||||||
|
'message': message,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AuthPhase:
|
||||||
|
"""Connection that requires client to authenticate first."""
|
||||||
|
|
||||||
|
def __init__(self, logger, hass, send_message, request):
|
||||||
|
"""Initialize the authentiated connection."""
|
||||||
|
self._hass = hass
|
||||||
|
self._send_message = send_message
|
||||||
|
self._logger = logger
|
||||||
|
self._request = request
|
||||||
|
self._authenticated = False
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
async def async_handle(self, msg):
|
||||||
|
"""Handle authentication."""
|
||||||
|
try:
|
||||||
|
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||||
|
except vol.Invalid as err:
|
||||||
|
error_msg = 'Auth message incorrectly formatted: {}'.format(
|
||||||
|
humanize_error(msg, err))
|
||||||
|
self._logger.warning(error_msg)
|
||||||
|
self._send_message(auth_invalid_message(error_msg))
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
if self._hass.auth.active and 'access_token' in msg:
|
||||||
|
self._logger.debug("Received access_token")
|
||||||
|
refresh_token = \
|
||||||
|
await self._hass.auth.async_validate_access_token(
|
||||||
|
msg['access_token'])
|
||||||
|
if refresh_token is not None:
|
||||||
|
return await self._async_finish_auth(
|
||||||
|
refresh_token.user, refresh_token)
|
||||||
|
|
||||||
|
elif ((not self._hass.auth.active or self._hass.auth.support_legacy)
|
||||||
|
and 'api_password' in msg):
|
||||||
|
self._logger.debug("Received api_password")
|
||||||
|
if validate_password(self._request, msg['api_password']):
|
||||||
|
return await self._async_finish_auth(None, None)
|
||||||
|
|
||||||
|
self._send_message(auth_invalid_message(
|
||||||
|
'Invalid access token or password'))
|
||||||
|
await process_wrong_login(self._request)
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
async def _async_finish_auth(self, user, refresh_token) \
|
||||||
|
-> ActiveConnection:
|
||||||
|
"""Create an active connection."""
|
||||||
|
self._logger.debug("Auth OK")
|
||||||
|
await process_success_login(self._request)
|
||||||
|
self._send_message(auth_ok_message())
|
||||||
|
return ActiveConnection(
|
||||||
|
self._logger, self._hass, self._send_message, user, refresh_token)
|
|
@ -103,12 +103,12 @@ def handle_subscribe_events(hass, connection, msg):
|
||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
return
|
return
|
||||||
|
|
||||||
connection.send_message_outside(event_message(msg['id'], event))
|
connection.send_message(event_message(msg['id'], event))
|
||||||
|
|
||||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
||||||
msg['event_type'], forward_events)
|
msg['event_type'], forward_events)
|
||||||
|
|
||||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
connection.send_message(messages.result_message(msg['id']))
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -121,9 +121,9 @@ def handle_unsubscribe_events(hass, connection, msg):
|
||||||
|
|
||||||
if subscription in connection.event_listeners:
|
if subscription in connection.event_listeners:
|
||||||
connection.event_listeners.pop(subscription)()
|
connection.event_listeners.pop(subscription)()
|
||||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
connection.send_message(messages.result_message(msg['id']))
|
||||||
else:
|
else:
|
||||||
connection.to_write.put_nowait(messages.error_message(
|
connection.send_message(messages.error_message(
|
||||||
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
|
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ async def handle_call_service(hass, connection, msg):
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
||||||
connection.context(msg))
|
connection.context(msg))
|
||||||
connection.send_message_outside(messages.result_message(msg['id']))
|
connection.send_message(messages.result_message(msg['id']))
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -149,7 +149,7 @@ def handle_get_states(hass, connection, msg):
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
connection.to_write.put_nowait(messages.result_message(
|
connection.send_message(messages.result_message(
|
||||||
msg['id'], hass.states.async_all()))
|
msg['id'], hass.states.async_all()))
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ async def handle_get_services(hass, connection, msg):
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
descriptions = await async_get_all_descriptions(hass)
|
descriptions = await async_get_all_descriptions(hass)
|
||||||
connection.send_message_outside(
|
connection.send_message(
|
||||||
messages.result_message(msg['id'], descriptions))
|
messages.result_message(msg['id'], descriptions))
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ def handle_get_config(hass, connection, msg):
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
connection.to_write.put_nowait(messages.result_message(
|
connection.send_message(messages.result_message(
|
||||||
msg['id'], hass.config.as_dict()))
|
msg['id'], hass.config.as_dict()))
|
||||||
|
|
||||||
|
|
||||||
|
@ -180,4 +180,4 @@ def handle_ping(hass, connection, msg):
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
connection.to_write.put_nowait(pong_message(msg['id']))
|
connection.send_message(pong_message(msg['id']))
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Connection session."""
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.core import callback, Context
|
||||||
|
|
||||||
|
from . import const, messages
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveConnection:
|
||||||
|
"""Handle an active websocket client connection."""
|
||||||
|
|
||||||
|
def __init__(self, logger, hass, send_message, user, refresh_token):
|
||||||
|
"""Initialize an active connection."""
|
||||||
|
self.logger = logger
|
||||||
|
self.hass = hass
|
||||||
|
self.send_message = send_message
|
||||||
|
self.user = user
|
||||||
|
if refresh_token:
|
||||||
|
self.refresh_token_id = refresh_token.id
|
||||||
|
else:
|
||||||
|
self.refresh_token_id = None
|
||||||
|
|
||||||
|
self.event_listeners = {}
|
||||||
|
self.last_id = 0
|
||||||
|
|
||||||
|
def context(self, msg):
|
||||||
|
"""Return a context."""
|
||||||
|
user = self.user
|
||||||
|
if user is None:
|
||||||
|
return Context()
|
||||||
|
return Context(user_id=user.id)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_handle(self, msg):
|
||||||
|
"""Handle a single incoming message."""
|
||||||
|
handlers = self.hass.data[const.DOMAIN]
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
|
||||||
|
cur_id = msg['id']
|
||||||
|
except vol.Invalid:
|
||||||
|
self.logger.error('Received invalid command', msg)
|
||||||
|
self.send_message(messages.error_message(
|
||||||
|
msg.get('id'), const.ERR_INVALID_FORMAT,
|
||||||
|
'Message incorrectly formatted.'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if cur_id <= self.last_id:
|
||||||
|
self.send_message(messages.error_message(
|
||||||
|
cur_id, const.ERR_ID_REUSE,
|
||||||
|
'Identifier values have to increase.'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if msg['type'] not in handlers:
|
||||||
|
self.logger.error(
|
||||||
|
'Received invalid command: {}'.format(msg['type']))
|
||||||
|
self.send_message(messages.error_message(
|
||||||
|
cur_id, const.ERR_UNKNOWN_COMMAND,
|
||||||
|
'Unknown command.'))
|
||||||
|
return
|
||||||
|
|
||||||
|
handler, schema = handlers[msg['type']]
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler(self.hass, self, schema(msg))
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
self.logger.exception('Error handling message: %s', msg)
|
||||||
|
self.send_message(messages.error_message(
|
||||||
|
cur_id, const.ERR_UNKNOWN_ERROR,
|
||||||
|
'Unknown error.'))
|
||||||
|
|
||||||
|
self.last_id = cur_id
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_close(self):
|
||||||
|
"""Close down connection."""
|
||||||
|
for unsub in self.event_listeners.values():
|
||||||
|
unsub()
|
|
@ -1,4 +1,11 @@
|
||||||
"""Websocket constants."""
|
"""Websocket constants."""
|
||||||
|
import asyncio
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
DOMAIN = 'websocket_api'
|
||||||
|
URL = '/api/websocket'
|
||||||
|
MAX_PENDING_MSG = 512
|
||||||
|
|
||||||
ERR_ID_REUSE = 1
|
ERR_ID_REUSE = 1
|
||||||
ERR_INVALID_FORMAT = 2
|
ERR_INVALID_FORMAT = 2
|
||||||
ERR_NOT_FOUND = 3
|
ERR_NOT_FOUND = 3
|
||||||
|
@ -6,3 +13,8 @@ ERR_UNKNOWN_COMMAND = 4
|
||||||
ERR_UNKNOWN_ERROR = 5
|
ERR_UNKNOWN_ERROR = 5
|
||||||
|
|
||||||
TYPE_RESULT = 'result'
|
TYPE_RESULT = 'result'
|
||||||
|
|
||||||
|
# Define the possible errors that occur when connections are cancelled.
|
||||||
|
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||||
|
# that futures.CancelledErrors can also occur in some situations.
|
||||||
|
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
||||||
|
|
|
@ -18,7 +18,7 @@ def async_response(func):
|
||||||
await func(hass, connection, msg)
|
await func(hass, connection, msg)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception("Unexpected exception")
|
_LOGGER.exception("Unexpected exception")
|
||||||
connection.send_message_outside(messages.error_message(
|
connection.send_message(messages.error_message(
|
||||||
msg['id'], 'unknown', 'Unexpected error occurred'))
|
msg['id'], 'unknown', 'Unexpected error occurred'))
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -35,10 +35,10 @@ def require_owner(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def with_owner(hass, connection, msg):
|
def with_owner(hass, connection, msg):
|
||||||
"""Check owner and call function."""
|
"""Check owner and call function."""
|
||||||
user = connection.request.get('hass_user')
|
user = connection.user
|
||||||
|
|
||||||
if user is None or not user.is_owner:
|
if user is None or not user.is_owner:
|
||||||
connection.to_write.put_nowait(messages.error_message(
|
connection.send_message(messages.error_message(
|
||||||
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ def ws_require_user(
|
||||||
"""Check current user."""
|
"""Check current user."""
|
||||||
def output_error(message_id, message):
|
def output_error(message_id, message):
|
||||||
"""Output error message."""
|
"""Output error message."""
|
||||||
connection.send_message_outside(messages.error_message(
|
connection.send_message(messages.error_message(
|
||||||
msg['id'], message_id, message))
|
msg['id'], message_id, message))
|
||||||
|
|
||||||
if connection.user is None:
|
if connection.user is None:
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
"""WebSocket API related errors."""
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
|
||||||
|
class Disconnect(HomeAssistantError):
|
||||||
|
"""Disconnect the current session."""
|
||||||
|
|
||||||
|
pass
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""View to accept incoming websocket connection."""
|
||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp import web, WSMsgType
|
||||||
|
import async_timeout
|
||||||
|
|
||||||
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.helpers.json import JSONEncoder
|
||||||
|
|
||||||
|
from .const import MAX_PENDING_MSG, CANCELLATION_ERRORS, URL
|
||||||
|
from .auth import AuthPhase, auth_required_message
|
||||||
|
from .error import Disconnect
|
||||||
|
|
||||||
|
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketAPIView(HomeAssistantView):
|
||||||
|
"""View to serve a websockets endpoint."""
|
||||||
|
|
||||||
|
name = "websocketapi"
|
||||||
|
url = URL
|
||||||
|
requires_auth = False
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
"""Handle an incoming websocket connection."""
|
||||||
|
return await WebSocketHandler(
|
||||||
|
request.app['hass'], request).async_handle()
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketHandler:
|
||||||
|
"""Handle an active websocket client connection."""
|
||||||
|
|
||||||
|
def __init__(self, hass, request):
|
||||||
|
"""Initialize an active connection."""
|
||||||
|
self.hass = hass
|
||||||
|
self.request = request
|
||||||
|
self.wsock = None
|
||||||
|
self._to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
||||||
|
self._handle_task = None
|
||||||
|
self._writer_task = None
|
||||||
|
self._logger = logging.getLogger(
|
||||||
|
"{}.connection.{}".format(__name__, id(self)))
|
||||||
|
|
||||||
|
async def _writer(self):
|
||||||
|
"""Write outgoing messages."""
|
||||||
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
|
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
||||||
|
while not self.wsock.closed:
|
||||||
|
message = await self._to_write.get()
|
||||||
|
if message is None:
|
||||||
|
break
|
||||||
|
self._logger.debug("Sending %s", message)
|
||||||
|
try:
|
||||||
|
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||||
|
except TypeError as err:
|
||||||
|
self._logger.error('Unable to serialize to JSON: %s\n%s',
|
||||||
|
err, message)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _send_message(self, message):
|
||||||
|
"""Send a message to the client.
|
||||||
|
|
||||||
|
Closes connection if the client is not reading the messages.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._to_write.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
self._logger.error("Client exceeded max pending messages [2]: %s",
|
||||||
|
MAX_PENDING_MSG)
|
||||||
|
self._cancel()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _cancel(self):
|
||||||
|
"""Cancel the connection."""
|
||||||
|
self._handle_task.cancel()
|
||||||
|
self._writer_task.cancel()
|
||||||
|
|
||||||
|
async def async_handle(self):
|
||||||
|
"""Handle a websocket response."""
|
||||||
|
request = self.request
|
||||||
|
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||||
|
await wsock.prepare(request)
|
||||||
|
self._logger.debug("Connected")
|
||||||
|
|
||||||
|
# Py3.7+
|
||||||
|
if hasattr(asyncio, 'current_task'):
|
||||||
|
# pylint: disable=no-member
|
||||||
|
self._handle_task = asyncio.current_task()
|
||||||
|
else:
|
||||||
|
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_hass_stop(event):
|
||||||
|
"""Cancel this connection."""
|
||||||
|
self._cancel()
|
||||||
|
|
||||||
|
unsub_stop = self.hass.bus.async_listen(
|
||||||
|
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||||
|
|
||||||
|
self._writer_task = self.hass.async_create_task(self._writer())
|
||||||
|
|
||||||
|
auth = AuthPhase(self._logger, self.hass, self._send_message, request)
|
||||||
|
connection = None
|
||||||
|
disconnect_warn = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._send_message(auth_required_message())
|
||||||
|
|
||||||
|
# Auth Phase
|
||||||
|
try:
|
||||||
|
with async_timeout.timeout(10):
|
||||||
|
msg = await wsock.receive()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
disconnect_warn = \
|
||||||
|
'Did not receive auth message within 10 seconds'
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
elif msg.type != WSMsgType.TEXT:
|
||||||
|
disconnect_warn = 'Received non-Text message.'
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = msg.json()
|
||||||
|
except ValueError:
|
||||||
|
disconnect_warn = 'Received invalid JSON.'
|
||||||
|
raise Disconnect
|
||||||
|
|
||||||
|
self._logger.debug("Received %s", msg)
|
||||||
|
connection = await auth.async_handle(msg)
|
||||||
|
|
||||||
|
# Command phase
|
||||||
|
while not wsock.closed:
|
||||||
|
msg = await wsock.receive()
|
||||||
|
|
||||||
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
||||||
|
break
|
||||||
|
|
||||||
|
elif msg.type != WSMsgType.TEXT:
|
||||||
|
disconnect_warn = 'Received non-Text message.'
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = msg.json()
|
||||||
|
except ValueError:
|
||||||
|
disconnect_warn = 'Received invalid JSON.'
|
||||||
|
break
|
||||||
|
|
||||||
|
self._logger.debug("Received %s", msg)
|
||||||
|
connection.async_handle(msg)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self._logger.info("Connection closed by client")
|
||||||
|
|
||||||
|
except Disconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
self._logger.exception("Unexpected error inside websocket API")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
unsub_stop()
|
||||||
|
|
||||||
|
if connection is not None:
|
||||||
|
connection.async_close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._to_write.put_nowait(None)
|
||||||
|
# Make sure all error messages are written before closing
|
||||||
|
await self._writer_task
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
self._writer_task.cancel()
|
||||||
|
|
||||||
|
await wsock.close()
|
||||||
|
|
||||||
|
if disconnect_warn is None:
|
||||||
|
self._logger.debug("Disconnected")
|
||||||
|
else:
|
||||||
|
self._logger.warning("Disconnected: %s", disconnect_warn)
|
|
@ -4,7 +4,9 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components.websocket_api.http import URL
|
||||||
|
from homeassistant.components.websocket_api.auth import (
|
||||||
|
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
|
||||||
|
|
||||||
from tests.common import MockUser, CLIENT_ID
|
from tests.common import MockUser, CLIENT_ID
|
||||||
|
|
||||||
|
@ -14,41 +16,52 @@ def hass_ws_client(aiohttp_client):
|
||||||
"""Websocket client fixture connected to websocket server."""
|
"""Websocket client fixture connected to websocket server."""
|
||||||
async def create_client(hass, access_token=None):
|
async def create_client(hass, access_token=None):
|
||||||
"""Create a websocket client."""
|
"""Create a websocket client."""
|
||||||
wapi = hass.components.websocket_api
|
|
||||||
assert await async_setup_component(hass, 'websocket_api')
|
assert await async_setup_component(hass, 'websocket_api')
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
patching = None
|
patches = []
|
||||||
|
|
||||||
if access_token is not None:
|
if access_token is None:
|
||||||
patching = patch('homeassistant.auth.AuthManager.active',
|
patches.append(patch(
|
||||||
return_value=True)
|
'homeassistant.auth.AuthManager.active', return_value=False))
|
||||||
patching.start()
|
patches.append(patch(
|
||||||
|
'homeassistant.auth.AuthManager.support_legacy',
|
||||||
|
return_value=True))
|
||||||
|
patches.append(patch(
|
||||||
|
'homeassistant.components.websocket_api.auth.'
|
||||||
|
'validate_password', return_value=True))
|
||||||
|
else:
|
||||||
|
patches.append(patch(
|
||||||
|
'homeassistant.auth.AuthManager.active', return_value=True))
|
||||||
|
patches.append(patch(
|
||||||
|
'homeassistant.components.http.auth.setup_auth'))
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
p.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
websocket = await client.ws_connect(wapi.URL)
|
websocket = await client.ws_connect(URL)
|
||||||
auth_resp = await websocket.receive_json()
|
auth_resp = await websocket.receive_json()
|
||||||
|
assert auth_resp['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
if auth_resp['type'] == wapi.TYPE_AUTH_OK:
|
if access_token is None:
|
||||||
assert access_token is None, \
|
await websocket.send_json({
|
||||||
'Access token given but no auth required'
|
'type': TYPE_AUTH,
|
||||||
return websocket
|
'api_password': 'bla'
|
||||||
|
})
|
||||||
assert access_token is not None, \
|
else:
|
||||||
'Access token required for fixture'
|
await websocket.send_json({
|
||||||
|
'type': TYPE_AUTH,
|
||||||
await websocket.send_json({
|
'access_token': access_token
|
||||||
'type': websocket_api.TYPE_AUTH,
|
})
|
||||||
'access_token': access_token
|
|
||||||
})
|
|
||||||
|
|
||||||
auth_ok = await websocket.receive_json()
|
auth_ok = await websocket.receive_json()
|
||||||
assert auth_ok['type'] == wapi.TYPE_AUTH_OK
|
assert auth_ok['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if patching is not None:
|
for p in patches:
|
||||||
patching.stop()
|
p.stop()
|
||||||
|
|
||||||
# wrap in client
|
# wrap in client
|
||||||
websocket.client = client
|
websocket.client = client
|
||||||
|
|
|
@ -62,7 +62,7 @@ class TestPanelIframe(unittest.TestCase):
|
||||||
|
|
||||||
panels = self.hass.data[frontend.DATA_PANELS]
|
panels = self.hass.data[frontend.DATA_PANELS]
|
||||||
|
|
||||||
assert panels.get('router').to_response(self.hass, None) == {
|
assert panels.get('router').to_response() == {
|
||||||
'component_name': 'iframe',
|
'component_name': 'iframe',
|
||||||
'config': {'url': 'http://192.168.1.1'},
|
'config': {'url': 'http://192.168.1.1'},
|
||||||
'icon': 'mdi:network-wireless',
|
'icon': 'mdi:network-wireless',
|
||||||
|
@ -70,7 +70,7 @@ class TestPanelIframe(unittest.TestCase):
|
||||||
'url_path': 'router'
|
'url_path': 'router'
|
||||||
}
|
}
|
||||||
|
|
||||||
assert panels.get('weather').to_response(self.hass, None) == {
|
assert panels.get('weather').to_response() == {
|
||||||
'component_name': 'iframe',
|
'component_name': 'iframe',
|
||||||
'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'},
|
'config': {'url': 'https://www.wunderground.com/us/ca/san-diego'},
|
||||||
'icon': 'mdi:weather',
|
'icon': 'mdi:weather',
|
||||||
|
@ -78,7 +78,7 @@ class TestPanelIframe(unittest.TestCase):
|
||||||
'url_path': 'weather',
|
'url_path': 'weather',
|
||||||
}
|
}
|
||||||
|
|
||||||
assert panels.get('api').to_response(self.hass, None) == {
|
assert panels.get('api').to_response() == {
|
||||||
'component_name': 'iframe',
|
'component_name': 'iframe',
|
||||||
'config': {'url': '/api'},
|
'config': {'url': '/api'},
|
||||||
'icon': 'mdi:weather',
|
'icon': 'mdi:weather',
|
||||||
|
@ -86,7 +86,7 @@ class TestPanelIframe(unittest.TestCase):
|
||||||
'url_path': 'api',
|
'url_path': 'api',
|
||||||
}
|
}
|
||||||
|
|
||||||
assert panels.get('ftp').to_response(self.hass, None) == {
|
assert panels.get('ftp').to_response() == {
|
||||||
'component_name': 'iframe',
|
'component_name': 'iframe',
|
||||||
'config': {'url': 'ftp://some/ftp'},
|
'config': {'url': 'ftp://some/ftp'},
|
||||||
'icon': 'mdi:weather',
|
'icon': 'mdi:weather',
|
||||||
|
|
|
@ -2,7 +2,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components import websocket_api as wapi
|
from homeassistant.components.websocket_api.http import URL
|
||||||
|
from homeassistant.components.websocket_api.auth import TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
from . import API_PASSWORD
|
from . import API_PASSWORD
|
||||||
|
|
||||||
|
@ -24,10 +25,10 @@ def no_auth_websocket_client(hass, loop, aiohttp_client):
|
||||||
}))
|
}))
|
||||||
|
|
||||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
ws = loop.run_until_complete(client.ws_connect(URL))
|
||||||
|
|
||||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||||
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_ok['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
yield ws
|
yield ws
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
"""Test auth of websocket API."""
|
"""Test auth of websocket API."""
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.components import websocket_api as wapi
|
from homeassistant.components.websocket_api.const import URL
|
||||||
|
from homeassistant.components.websocket_api.auth import (
|
||||||
|
TYPE_AUTH, TYPE_AUTH_INVALID, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED)
|
||||||
|
|
||||||
from homeassistant.components.websocket_api import commands
|
from homeassistant.components.websocket_api import commands
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
@ -13,28 +16,29 @@ from . import API_PASSWORD
|
||||||
async def test_auth_via_msg(no_auth_websocket_client):
|
async def test_auth_via_msg(no_auth_websocket_client):
|
||||||
"""Test authenticating."""
|
"""Test authenticating."""
|
||||||
await no_auth_websocket_client.send_json({
|
await no_auth_websocket_client.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
})
|
})
|
||||||
|
|
||||||
msg = await no_auth_websocket_client.receive_json()
|
msg = await no_auth_websocket_client.receive_json()
|
||||||
|
|
||||||
assert msg['type'] == wapi.TYPE_AUTH_OK
|
assert msg['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
|
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
|
||||||
"""Test authenticating."""
|
"""Test authenticating."""
|
||||||
with patch('homeassistant.components.websocket_api.process_wrong_login',
|
with patch('homeassistant.components.websocket_api.auth.'
|
||||||
return_value=mock_coro()) as mock_process_wrong_login:
|
'process_wrong_login', return_value=mock_coro()) \
|
||||||
|
as mock_process_wrong_login:
|
||||||
await no_auth_websocket_client.send_json({
|
await no_auth_websocket_client.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'api_password': API_PASSWORD + 'wrong'
|
'api_password': API_PASSWORD + 'wrong'
|
||||||
})
|
})
|
||||||
|
|
||||||
msg = await no_auth_websocket_client.receive_json()
|
msg = await no_auth_websocket_client.receive_json()
|
||||||
|
|
||||||
assert mock_process_wrong_login.called
|
assert mock_process_wrong_login.called
|
||||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert msg['type'] == TYPE_AUTH_INVALID
|
||||||
assert msg['message'] == 'Invalid access token or password'
|
assert msg['message'] == 'Invalid access token or password'
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,8 +55,8 @@ async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
|
||||||
|
|
||||||
msg = await no_auth_websocket_client.receive_json()
|
msg = await no_auth_websocket_client.receive_json()
|
||||||
|
|
||||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert msg['type'] == TYPE_AUTH_INVALID
|
||||||
assert msg['message'].startswith('Message incorrectly formatted')
|
assert msg['message'].startswith('Auth message incorrectly formatted')
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||||
|
@ -65,19 +69,19 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||||
auth_active.return_value = True
|
auth_active.return_value = True
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'access_token': hass_access_token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||||
|
@ -94,19 +98,19 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||||
auth_active.return_value = True
|
auth_active.return_value = True
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'access_token': hass_access_token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
||||||
|
@ -119,19 +123,19 @@ async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active',
|
with patch('homeassistant.auth.AuthManager.active',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
||||||
|
@ -144,21 +148,21 @@ async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active',
|
with patch('homeassistant.auth.AuthManager.active',
|
||||||
return_value=True),\
|
return_value=True),\
|
||||||
patch('homeassistant.auth.AuthManager.support_legacy',
|
patch('homeassistant.auth.AuthManager.support_legacy',
|
||||||
return_value=True):
|
return_value=True):
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||||
|
@ -171,16 +175,16 @@ async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||||
|
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||||
auth_active.return_value = True
|
auth_active.return_value = True
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'access_token': 'incorrect'
|
'access_token': 'incorrect'
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert auth_msg['type'] == TYPE_AUTH_INVALID
|
||||||
|
|
|
@ -4,7 +4,10 @@ from unittest.mock import patch
|
||||||
from async_timeout import timeout
|
from async_timeout import timeout
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.components import websocket_api as wapi
|
from homeassistant.components.websocket_api.const import URL
|
||||||
|
from homeassistant.components.websocket_api.auth import (
|
||||||
|
TYPE_AUTH, TYPE_AUTH_OK, TYPE_AUTH_REQUIRED
|
||||||
|
)
|
||||||
from homeassistant.components.websocket_api import const, commands
|
from homeassistant.components.websocket_api import const, commands
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
@ -178,19 +181,19 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||||
auth_active.return_value = True
|
auth_active.return_value = True
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'access_token': hass_access_token
|
'access_token': hass_access_token
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
@ -227,17 +230,17 @@ async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||||
client = await aiohttp_client(hass.http.app)
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
async with client.ws_connect(wapi.URL) as ws:
|
async with client.ws_connect(URL) as ws:
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
assert auth_msg['type'] == TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'type': wapi.TYPE_AUTH,
|
'type': TYPE_AUTH,
|
||||||
'api_password': API_PASSWORD
|
'api_password': API_PASSWORD
|
||||||
})
|
})
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
assert auth_msg['type'] == TYPE_AUTH_OK
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
'id': 5,
|
'id': 5,
|
||||||
|
|
|
@ -5,14 +5,14 @@ from unittest.mock import patch, Mock
|
||||||
from aiohttp import WSMsgType
|
from aiohttp import WSMsgType
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import websocket_api as wapi
|
|
||||||
from homeassistant.components.websocket_api import const, commands, messages
|
from homeassistant.components.websocket_api import const, commands, messages
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_low_queue():
|
def mock_low_queue():
|
||||||
"""Mock a low queue."""
|
"""Mock a low queue."""
|
||||||
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
with patch('homeassistant.components.websocket_api.http.MAX_PENDING_MSG',
|
||||||
|
5):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue