core/homeassistant/components/websocket_api.py

496 lines
14 KiB
Python

"""
Websocket based API for Home Assistant.
For more details about this component, please refer to the documentation at
https://home-assistant.io/developers/websocket_api/
"""
import asyncio
from contextlib import suppress
from functools import partial
import json
import logging
from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import (
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
__version__)
from homeassistant.components import frontend
from homeassistant.core import callback
from homeassistant.remote import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login
DOMAIN = 'websocket_api'
URL = '/api/websocket'
DEPENDENCIES = 'http',
MAX_PENDING_MSG = 512
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_PANELS = 'get_panels'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_RESULT = 'result'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
_LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Required('api_password'): str,
})
SUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
CALL_SERVICE_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data', default=None): dict
})
GET_STATES_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_STATES,
})
GET_SERVICES_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_SERVICES,
})
GET_CONFIG_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_CONFIG,
})
GET_PANELS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_PANELS,
})
PING_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_PING,
})
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): vol.Any(TYPE_CALL_SERVICE,
TYPE_SUBSCRIBE_EVENTS,
TYPE_UNSUBSCRIBE_EVENTS,
TYPE_GET_STATES,
TYPE_GET_SERVICES,
TYPE_GET_CONFIG,
TYPE_GET_PANELS,
TYPE_PING)
}, extra=vol.ALLOW_EXTRA)
def auth_ok_message():
"""Return an auth_ok message."""
return {
'type': TYPE_AUTH_OK,
'ha_version': __version__,
}
def auth_required_message():
"""Return an auth_required message."""
return {
'type': TYPE_AUTH_REQUIRED,
'ha_version': __version__,
}
def auth_invalid_message(message):
"""Return an auth_invalid message."""
return {
'type': TYPE_AUTH_INVALID,
'message': message,
}
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': True,
'result': result,
}
@asyncio.coroutine
def async_setup(hass, config):
"""Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView)
return True
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
@asyncio.coroutine
def get(self, request):
"""Handle an incoming websocket connection."""
# pylint: disable=no-self-use
return ActiveConnection(request.app['hass']).handle(request)
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, hass):
"""Initialize an active connection."""
self.hass = hass
self.wsock = None
self.event_listeners = {}
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
self._handle_task = None
self._writer_task = None
def debug(self, message1, message2=''):
"""Print a debug message."""
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
def log_error(self, message1, message2=''):
"""Print an error message."""
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
@asyncio.coroutine
def _writer(self):
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, asyncio.CancelledError):
while not self.wsock.closed:
message = yield from self.to_write.get()
if message is None:
break
self.debug("Sending", message)
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
@callback
def send_message_outside(self, message):
"""Send a message to the client outside of the main task.
Closes connection if the client is not reading the messages.
Async friendly.
"""
try:
self.to_write.put_nowait(message)
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [2]:",
MAX_PENDING_MSG)
self.cancel()
@callback
def cancel(self):
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
@asyncio.coroutine
def handle(self, request):
"""Handle the websocket connection."""
wsock = self.wsock = web.WebSocketResponse()
yield from wsock.prepare(request)
self.debug("Connected")
# Get a reference to current task so we can cancel our connection
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def handle_hass_stop(event):
"""Cancel this connection."""
self.cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
self._writer_task = self.hass.async_add_job(self._writer())
final_message = None
msg = None
authenticated = False
try:
if request[KEY_AUTHENTICATED]:
authenticated = True
else:
yield from self.wsock.send_json(auth_required_message())
msg = yield from wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if validate_password(request, msg['api_password']):
authenticated = True
else:
self.debug("Invalid password")
yield from self.wsock.send_json(
auth_invalid_message('Invalid password'))
if not authenticated:
yield from process_wrong_login(request)
return wsock
yield from self.wsock.send_json(auth_ok_message())
# ---------- AUTH PHASE OVER ----------
msg = yield from wsock.receive_json()
last_id = 0
while msg:
self.debug("Received", msg)
msg = BASE_COMMAND_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
if cur_id <= last_id:
self.to_write.put_nowait(error_message(
cur_id, ERR_ID_REUSE,
'Identifier values have to increase.'))
else:
handler_name = 'handle_{}'.format(msg['type'])
getattr(self, handler_name)(msg)
last_id = cur_id
msg = yield from wsock.receive_json()
except vol.Invalid as err:
error_msg = "Message incorrectly formatted: "
if msg:
error_msg += humanize_error(msg, err)
else:
error_msg += str(err)
self.log_error(error_msg)
if not authenticated:
final_message = auth_invalid_message(error_msg)
else:
if isinstance(msg, dict):
iden = msg.get('id')
else:
iden = None
final_message = error_message(
iden, ERR_INVALID_FORMAT, error_msg)
except TypeError as err:
if wsock.closed:
self.debug("Connection closed by client")
else:
self.log_error("Unexpected TypeError", msg)
except ValueError as err:
msg = "Received invalid JSON"
value = getattr(err, 'doc', None) # Py3.5+ only
if value:
msg += ': {}'.format(value)
self.log_error(msg)
self._writer_task.cancel()
except asyncio.CancelledError:
self.debug("Connection cancelled by server")
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [1]:",
MAX_PENDING_MSG)
self._writer_task.cancel()
except Exception: # pylint: disable=broad-except
error = "Unexpected error inside websocket API. "
if msg is not None:
error += str(msg)
_LOGGER.exception(error)
finally:
unsub_stop()
for unsub in self.event_listeners.values():
unsub()
try:
if final_message is not None:
self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None)
# Make sure all error messages are written before closing
yield from self._writer_task
except asyncio.QueueFull:
self._writer_task.cancel()
yield from wsock.close()
self.debug("Closed connection")
return wsock
def handle_subscribe_events(self, msg):
"""Handle subscribe events command.
Async friendly.
"""
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
@asyncio.coroutine
def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
self.send_message_outside(event_message(msg['id'], event))
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
msg['event_type'], forward_events)
self.to_write.put_nowait(result_message(msg['id']))
def handle_unsubscribe_events(self, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
subscription = msg['subscription']
if subscription in self.event_listeners:
self.event_listeners.pop(subscription)()
self.to_write.put_nowait(result_message(msg['id']))
else:
self.to_write.put_nowait(error_message(
msg['id'], ERR_NOT_FOUND,
'Subscription not found.'))
def handle_call_service(self, msg):
"""Handle call service command.
This is a coroutine.
"""
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
@asyncio.coroutine
def call_service_helper(msg):
"""Call a service and fire complete message."""
yield from self.hass.services.async_call(
msg['domain'], msg['service'], msg['service_data'], True)
self.send_message_outside(result_message(msg['id']))
self.hass.async_add_job(call_service_helper(msg))
def handle_get_states(self, msg):
"""Handle get states command.
Async friendly.
"""
msg = GET_STATES_MESSAGE_SCHEMA(msg)
self.to_write.put_nowait(result_message(
msg['id'], self.hass.states.async_all()))
def handle_get_services(self, msg):
"""Handle get services command.
Async friendly.
"""
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
self.to_write.put_nowait(result_message(
msg['id'], self.hass.services.async_services()))
def handle_get_config(self, msg):
"""Handle get config command.
Async friendly.
"""
msg = GET_CONFIG_MESSAGE_SCHEMA(msg)
self.to_write.put_nowait(result_message(
msg['id'], self.hass.config.as_dict()))
def handle_get_panels(self, msg):
"""Handle get panels command.
Async friendly.
"""
msg = GET_PANELS_MESSAGE_SCHEMA(msg)
self.to_write.put_nowait(result_message(
msg['id'], self.hass.data[frontend.DATA_PANELS]))
def handle_ping(self, msg):
"""Handle ping command.
Async friendly.
"""
self.to_write.put_nowait(pong_message(msg['id']))