core/homeassistant/components/websocket_api.py

402 lines
12 KiB
Python
Raw Normal View History

"""Websocket based API for Home Assistant."""
import asyncio
from functools import partial
import json
import logging
from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import (
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
__version__)
from homeassistant.components import api, frontend
from homeassistant.core import callback
from homeassistant.remote import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
DOMAIN = 'websocket_api'
URL = "/api/websocket"
DEPENDENCIES = 'http',
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
TYPE_AUTH = 'auth'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_EVENT = 'event'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
TYPE_CALL_SERVICE = 'call_service'
TYPE_GET_STATES = 'get_states'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_PANELS = 'get_panels'
TYPE_RESULT = 'result'
_LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Required('api_password'): str,
})
SUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
CALL_SERVICE_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data', default=None): dict
})
GET_STATES_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_STATES,
})
GET_SERVICES_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_SERVICES,
})
GET_CONFIG_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_CONFIG,
})
GET_PANELS_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): TYPE_GET_PANELS,
})
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): vol.Any(TYPE_CALL_SERVICE,
TYPE_SUBSCRIBE_EVENTS,
TYPE_UNSUBSCRIBE_EVENTS,
TYPE_GET_STATES,
TYPE_GET_SERVICES,
TYPE_GET_CONFIG,
TYPE_GET_PANELS)
}, extra=vol.ALLOW_EXTRA)
def auth_ok_message():
"""Return an auth_ok message."""
return {
'type': TYPE_AUTH_OK,
'ha_version': __version__,
}
def auth_required_message():
"""Return an auth_required message."""
return {
'type': TYPE_AUTH_REQUIRED,
'ha_version': __version__,
}
def auth_invalid_message(message):
"""Return an auth_invalid message."""
return {
'type': TYPE_AUTH_INVALID,
'message': message,
}
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': True,
'result': result,
}
@asyncio.coroutine
def async_setup(hass, config):
"""Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView)
return True
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
@asyncio.coroutine
def get(self, request):
"""Handle an incoming websocket connection."""
# pylint: disable=no-self-use
return ActiveConnection(request.app['hass'], request).handle()
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = None
self.socket_task = None
self.event_listeners = {}
def debug(self, message1, message2=''):
"""Print a debug message."""
_LOGGER.debug('WS %s: %s %s', id(self.wsock), message1, message2)
def log_error(self, message1, message2=''):
"""Print an error message."""
_LOGGER.error('WS %s: %s %s', id(self.wsock), message1, message2)
def send_message(self, message):
"""Helper method to send messages."""
self.debug('Sending', message)
self.wsock.send_json(message, dumps=JSON_DUMP)
@callback
def _cancel_connection(self, event):
"""Cancel this connection."""
self.socket_task.cancel()
@asyncio.coroutine
def _call_service_helper(self, msg):
"""Helper to call a service and fire complete message."""
yield from self.hass.services.async_call(msg['domain'], msg['service'],
msg['service_data'], True)
try:
self.send_message(result_message(msg['id']))
except RuntimeError:
# Socket has been closed.
pass
@callback
def _forward_event(self, iden, event):
"""Helper to forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
try:
self.send_message(event_message(iden, event))
except RuntimeError:
# Socket has been closed.
pass
@asyncio.coroutine
def handle(self):
"""Handle the websocket connection."""
wsock = self.wsock = web.WebSocketResponse()
yield from wsock.prepare(self.request)
# Set up to cancel this connection when Home Assistant shuts down
self.socket_task = asyncio.Task.current_task(loop=self.hass.loop)
self.hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP,
self._cancel_connection)
self.debug('Connected')
msg = None
authenticated = False
try:
if self.request[KEY_AUTHENTICATED]:
authenticated = True
else:
self.send_message(auth_required_message())
msg = yield from wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if validate_password(self.request, msg['api_password']):
authenticated = True
else:
self.debug('Invalid password')
self.send_message(auth_invalid_message('Invalid password'))
return wsock
if not authenticated:
return wsock
self.send_message(auth_ok_message())
msg = yield from wsock.receive_json()
last_id = 0
while msg:
self.debug('Received', msg)
msg = BASE_COMMAND_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
if cur_id <= last_id:
self.send_message(error_message(
cur_id, ERR_ID_REUSE,
'Identifier values have to increase.'))
else:
handler_name = 'handle_{}'.format(msg['type'])
getattr(self, handler_name)(msg)
last_id = cur_id
msg = yield from wsock.receive_json()
except vol.Invalid as err:
error_msg = 'Message incorrectly formatted: '
if msg:
error_msg += humanize_error(msg, err)
else:
error_msg += str(err)
self.log_error(error_msg)
if not authenticated:
self.send_message(auth_invalid_message(error_msg))
else:
if isinstance(msg, dict):
iden = msg.get('id')
else:
iden = None
self.send_message(error_message(iden, ERR_INVALID_FORMAT,
error_msg))
except TypeError as err:
if wsock.closed:
self.debug('Connection closed by client')
else:
self.log_error('Unexpected TypeError', msg)
except ValueError as err:
msg = 'Received invalid JSON'
value = getattr(err, 'doc', None) # Py3.5+ only
if value:
msg += ': {}'.format(value)
self.log_error(msg)
except asyncio.CancelledError:
self.debug('Connection cancelled by server')
except Exception: # pylint: disable=broad-except
error = 'Unexpected error inside websocket API. '
if msg is not None:
error += str(msg)
_LOGGER.exception(error)
finally:
for unsub in self.event_listeners.values():
unsub()
yield from wsock.close()
self.debug('Closed connection')
return wsock
def handle_subscribe_events(self, msg):
"""Handle subscribe events command."""
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
msg['event_type'], partial(self._forward_event, msg['id']))
self.send_message(result_message(msg['id']))
def handle_unsubscribe_events(self, msg):
"""Handle unsubscribe events command."""
msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
subscription = msg['subscription']
if subscription not in self.event_listeners:
self.send_message(error_message(
msg['id'], ERR_NOT_FOUND,
'Subscription not found.'))
else:
self.event_listeners.pop(subscription)()
self.send_message(result_message(msg['id']))
def handle_call_service(self, msg):
"""Handle call service command."""
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
self.hass.async_add_job(self._call_service_helper(msg))
def handle_get_states(self, msg):
"""Handle get states command."""
msg = GET_STATES_MESSAGE_SCHEMA(msg)
self.send_message(result_message(msg['id'],
self.hass.states.async_all()))
def handle_get_services(self, msg):
"""Handle get services command."""
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
self.send_message(result_message(msg['id'],
api.async_services_json(self.hass)))
def handle_get_config(self, msg):
"""Handle get config command."""
msg = GET_CONFIG_MESSAGE_SCHEMA(msg)
self.send_message(result_message(msg['id'],
self.hass.config.as_dict()))
def handle_get_panels(self, msg):
"""Handle get panels command."""
msg = GET_PANELS_MESSAGE_SCHEMA(msg)
self.send_message(result_message(
msg['id'], self.hass.data[frontend.DATA_PANELS]))