core/homeassistant/components/websocket_api.py

559 lines
16 KiB
Python

"""
Websocket based API for Home Assistant.
For more details about this component, please refer to the documentation at
https://home-assistant.io/developers/websocket_api/
"""
import asyncio
from concurrent import futures
from contextlib import suppress
from functools import partial, wraps
import json
import logging
from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import (
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
__version__)
from homeassistant.core import callback
from homeassistant.loader import bind_hass
from homeassistant.remote import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login, \
process_success_login
DOMAIN = 'websocket_api'
URL = '/api/websocket'
DEPENDENCIES = ('http',)
MAX_PENDING_MSG = 512
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_RESULT = 'result'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
_LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Exclusive('api_password', 'auth'): str,
vol.Exclusive('access_token', 'auth'): str,
})
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
})
SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data'): dict
})
SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_STATES,
})
SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_SERVICES,
})
SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_CONFIG,
})
SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_PING,
})
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
def auth_ok_message():
"""Return an auth_ok message."""
return {
'type': TYPE_AUTH_OK,
'ha_version': __version__,
}
def auth_required_message():
"""Return an auth_required message."""
return {
'type': TYPE_AUTH_REQUIRED,
'ha_version': __version__,
}
def auth_invalid_message(message):
"""Return an auth_invalid message."""
return {
'type': TYPE_AUTH_INVALID,
'message': message,
}
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': True,
'result': result,
}
@bind_hass
@callback
def async_register_command(hass, command, handler, schema):
"""Register a websocket command."""
handlers = hass.data.get(DOMAIN)
if handlers is None:
handlers = hass.data[DOMAIN] = {}
handlers[command] = (handler, schema)
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
async def async_setup(hass, config):
"""Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView)
async_register_command(hass, TYPE_SUBSCRIBE_EVENTS,
handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS,
handle_unsubscribe_events,
SCHEMA_UNSUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_CALL_SERVICE,
handle_call_service, SCHEMA_CALL_SERVICE)
async_register_command(hass, TYPE_GET_STATES,
handle_get_states, SCHEMA_GET_STATES)
async_register_command(hass, TYPE_GET_SERVICES,
handle_get_services, SCHEMA_GET_SERVICES)
async_register_command(hass, TYPE_GET_CONFIG,
handle_get_config, SCHEMA_GET_CONFIG)
async_register_command(hass, TYPE_PING,
handle_ping, SCHEMA_PING)
return True
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
async def get(self, request):
"""Handle an incoming websocket connection."""
return await ActiveConnection(request.app['hass'], request).handle()
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = None
self.event_listeners = {}
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
self._handle_task = None
self._writer_task = None
def debug(self, message1, message2=''):
"""Print a debug message."""
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
def log_error(self, message1, message2=''):
"""Print an error message."""
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
async def _writer(self):
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
message = await self.to_write.get()
if message is None:
break
self.debug("Sending", message)
try:
await self.wsock.send_json(message, dumps=JSON_DUMP)
except TypeError as err:
_LOGGER.error('Unable to serialize to JSON: %s\n%s',
err, message)
@callback
def send_message_outside(self, message):
"""Send a message to the client outside of the main task.
Closes connection if the client is not reading the messages.
Async friendly.
"""
try:
self.to_write.put_nowait(message)
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [2]:",
MAX_PENDING_MSG)
self.cancel()
@callback
def cancel(self):
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
async def handle(self):
"""Handle the websocket connection."""
request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
await wsock.prepare(request)
self.debug("Connected")
# Get a reference to current task so we can cancel our connection
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def handle_hass_stop(event):
"""Cancel this connection."""
self.cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
self._writer_task = self.hass.async_add_job(self._writer())
final_message = None
msg = None
authenticated = False
try:
if request[KEY_AUTHENTICATED]:
authenticated = True
else:
self.debug("Request auth")
await self.wsock.send_json(auth_required_message())
msg = await wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if self.hass.auth.active and 'access_token' in msg:
self.debug("Received access_token")
token = self.hass.auth.async_get_access_token(
msg['access_token'])
authenticated = token is not None
if authenticated:
request['hass_user'] = token.refresh_token.user
elif ((not self.hass.auth.active or
self.hass.auth.support_legacy) and
'api_password' in msg):
self.debug("Received api_password")
authenticated = validate_password(
request, msg['api_password'])
if not authenticated:
self.debug("Authorization failed")
await self.wsock.send_json(
auth_invalid_message('Invalid access token or password'))
await process_wrong_login(request)
return wsock
self.debug("Auth OK")
await process_success_login(request)
await self.wsock.send_json(auth_ok_message())
# ---------- AUTH PHASE OVER ----------
msg = await wsock.receive_json()
last_id = 0
handlers = self.hass.data[DOMAIN]
while msg:
self.debug("Received", msg)
msg = MINIMAL_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
if cur_id <= last_id:
self.to_write.put_nowait(error_message(
cur_id, ERR_ID_REUSE,
'Identifier values have to increase.'))
elif msg['type'] not in handlers:
self.log_error(
'Received invalid command: {}'.format(msg['type']))
self.to_write.put_nowait(error_message(
cur_id, ERR_UNKNOWN_COMMAND,
'Unknown command.'))
else:
handler, schema = handlers[msg['type']]
handler(self.hass, self, schema(msg))
last_id = cur_id
msg = await wsock.receive_json()
except vol.Invalid as err:
error_msg = "Message incorrectly formatted: "
if msg:
error_msg += humanize_error(msg, err)
else:
error_msg += str(err)
self.log_error(error_msg)
if not authenticated:
final_message = auth_invalid_message(error_msg)
else:
if isinstance(msg, dict):
iden = msg.get('id')
else:
iden = None
final_message = error_message(
iden, ERR_INVALID_FORMAT, error_msg)
except TypeError as err:
if wsock.closed:
self.debug("Connection closed by client")
else:
_LOGGER.exception("Unexpected TypeError: %s", err)
except ValueError as err:
msg = "Received invalid JSON"
value = getattr(err, 'doc', None) # Py3.5+ only
if value:
msg += ': {}'.format(value)
self.log_error(msg)
self._writer_task.cancel()
except CANCELLATION_ERRORS:
self.debug("Connection cancelled")
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [1]:",
MAX_PENDING_MSG)
self._writer_task.cancel()
except Exception: # pylint: disable=broad-except
error = "Unexpected error inside websocket API. "
if msg is not None:
error += str(msg)
_LOGGER.exception(error)
finally:
unsub_stop()
for unsub in self.event_listeners.values():
unsub()
try:
if final_message is not None:
self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
except asyncio.QueueFull:
self._writer_task.cancel()
await wsock.close()
self.debug("Closed connection")
return wsock
@callback
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command.
Async friendly.
"""
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message_outside(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.to_write.put_nowait(result_message(msg['id']))
@callback
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(result_message(msg['id']))
else:
connection.to_write.put_nowait(error_message(
msg['id'], ERR_NOT_FOUND, 'Subscription not found.'))
@callback
def handle_call_service(hass, connection, msg):
"""Handle call service command.
Async friendly.
"""
async def call_service_helper(msg):
"""Call a service and fire complete message."""
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), True)
connection.send_message_outside(result_message(msg['id']))
hass.async_add_job(call_service_helper(msg))
@callback
def handle_get_states(hass, connection, msg):
"""Handle get states command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.states.async_all()))
@callback
def handle_get_services(hass, connection, msg):
"""Handle get services command.
Async friendly.
"""
async def get_services_helper(msg):
"""Get available services and fire complete message."""
descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside(
result_message(msg['id'], descriptions))
hass.async_add_job(get_services_helper(msg))
@callback
def handle_get_config(hass, connection, msg):
"""Handle get config command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.config.as_dict()))
@callback
def handle_ping(hass, connection, msg):
"""Handle ping command.
Async friendly.
"""
connection.to_write.put_nowait(pong_message(msg['id']))