402 lines
12 KiB
Python
402 lines
12 KiB
Python
|
"""Websocket based API for Home Assistant."""
|
||
|
import asyncio
|
||
|
from functools import partial
|
||
|
import json
|
||
|
import logging
|
||
|
|
||
|
from aiohttp import web
|
||
|
import voluptuous as vol
|
||
|
from voluptuous.humanize import humanize_error
|
||
|
|
||
|
from homeassistant.const import (
|
||
|
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
||
|
__version__)
|
||
|
from homeassistant.components import api, frontend
|
||
|
from homeassistant.core import callback
|
||
|
from homeassistant.remote import JSONEncoder
|
||
|
from homeassistant.helpers import config_validation as cv
|
||
|
from homeassistant.components.http import HomeAssistantView
|
||
|
from homeassistant.components.http.auth import validate_password
|
||
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||
|
|
||
|
DOMAIN = 'websocket_api'
|
||
|
|
||
|
URL = "/api/websocket"
|
||
|
DEPENDENCIES = 'http',
|
||
|
|
||
|
ERR_ID_REUSE = 1
|
||
|
ERR_INVALID_FORMAT = 2
|
||
|
ERR_NOT_FOUND = 3
|
||
|
|
||
|
TYPE_AUTH = 'auth'
|
||
|
TYPE_AUTH_OK = 'auth_ok'
|
||
|
TYPE_AUTH_REQUIRED = 'auth_required'
|
||
|
TYPE_AUTH_INVALID = 'auth_invalid'
|
||
|
TYPE_EVENT = 'event'
|
||
|
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
|
||
|
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
|
||
|
TYPE_CALL_SERVICE = 'call_service'
|
||
|
TYPE_GET_STATES = 'get_states'
|
||
|
TYPE_GET_SERVICES = 'get_services'
|
||
|
TYPE_GET_CONFIG = 'get_config'
|
||
|
TYPE_GET_PANELS = 'get_panels'
|
||
|
TYPE_RESULT = 'result'
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
||
|
|
||
|
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('type'): TYPE_AUTH,
|
||
|
vol.Required('api_password'): str,
|
||
|
})
|
||
|
|
||
|
SUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
|
||
|
vol.Optional('event_type', default=MATCH_ALL): str,
|
||
|
})
|
||
|
|
||
|
UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
|
||
|
vol.Required('subscription'): cv.positive_int,
|
||
|
})
|
||
|
|
||
|
CALL_SERVICE_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_CALL_SERVICE,
|
||
|
vol.Required('domain'): str,
|
||
|
vol.Required('service'): str,
|
||
|
vol.Optional('service_data', default=None): dict
|
||
|
})
|
||
|
|
||
|
GET_STATES_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_GET_STATES,
|
||
|
})
|
||
|
|
||
|
GET_SERVICES_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_GET_SERVICES,
|
||
|
})
|
||
|
|
||
|
GET_CONFIG_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_GET_CONFIG,
|
||
|
})
|
||
|
|
||
|
GET_PANELS_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): TYPE_GET_PANELS,
|
||
|
})
|
||
|
|
||
|
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
|
||
|
vol.Required('id'): cv.positive_int,
|
||
|
vol.Required('type'): vol.Any(TYPE_CALL_SERVICE,
|
||
|
TYPE_SUBSCRIBE_EVENTS,
|
||
|
TYPE_UNSUBSCRIBE_EVENTS,
|
||
|
TYPE_GET_STATES,
|
||
|
TYPE_GET_SERVICES,
|
||
|
TYPE_GET_CONFIG,
|
||
|
TYPE_GET_PANELS)
|
||
|
}, extra=vol.ALLOW_EXTRA)
|
||
|
|
||
|
|
||
|
def auth_ok_message():
|
||
|
"""Return an auth_ok message."""
|
||
|
return {
|
||
|
'type': TYPE_AUTH_OK,
|
||
|
'ha_version': __version__,
|
||
|
}
|
||
|
|
||
|
|
||
|
def auth_required_message():
|
||
|
"""Return an auth_required message."""
|
||
|
return {
|
||
|
'type': TYPE_AUTH_REQUIRED,
|
||
|
'ha_version': __version__,
|
||
|
}
|
||
|
|
||
|
|
||
|
def auth_invalid_message(message):
|
||
|
"""Return an auth_invalid message."""
|
||
|
return {
|
||
|
'type': TYPE_AUTH_INVALID,
|
||
|
'message': message,
|
||
|
}
|
||
|
|
||
|
|
||
|
def event_message(iden, event):
|
||
|
"""Return an event message."""
|
||
|
return {
|
||
|
'id': iden,
|
||
|
'type': TYPE_EVENT,
|
||
|
'event': event.as_dict(),
|
||
|
}
|
||
|
|
||
|
|
||
|
def error_message(iden, code, message):
|
||
|
"""Return an error result message."""
|
||
|
return {
|
||
|
'id': iden,
|
||
|
'type': TYPE_RESULT,
|
||
|
'success': False,
|
||
|
'error': {
|
||
|
'code': code,
|
||
|
'message': message,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
|
||
|
def result_message(iden, result=None):
|
||
|
"""Return a success result message."""
|
||
|
return {
|
||
|
'id': iden,
|
||
|
'type': TYPE_RESULT,
|
||
|
'success': True,
|
||
|
'result': result,
|
||
|
}
|
||
|
|
||
|
|
||
|
@asyncio.coroutine
|
||
|
def async_setup(hass, config):
|
||
|
"""Initialize the websocket API."""
|
||
|
hass.http.register_view(WebsocketAPIView)
|
||
|
return True
|
||
|
|
||
|
|
||
|
class WebsocketAPIView(HomeAssistantView):
|
||
|
"""View to serve a websockets endpoint."""
|
||
|
|
||
|
name = "websocketapi"
|
||
|
url = URL
|
||
|
requires_auth = False
|
||
|
|
||
|
@asyncio.coroutine
|
||
|
def get(self, request):
|
||
|
"""Handle an incoming websocket connection."""
|
||
|
# pylint: disable=no-self-use
|
||
|
return ActiveConnection(request.app['hass'], request).handle()
|
||
|
|
||
|
|
||
|
class ActiveConnection:
|
||
|
"""Handle an active websocket client connection."""
|
||
|
|
||
|
def __init__(self, hass, request):
|
||
|
"""Initialize an active connection."""
|
||
|
self.hass = hass
|
||
|
self.request = request
|
||
|
self.wsock = None
|
||
|
self.socket_task = None
|
||
|
self.event_listeners = {}
|
||
|
|
||
|
def debug(self, message1, message2=''):
|
||
|
"""Print a debug message."""
|
||
|
_LOGGER.debug('WS %s: %s %s', id(self.wsock), message1, message2)
|
||
|
|
||
|
def log_error(self, message1, message2=''):
|
||
|
"""Print an error message."""
|
||
|
_LOGGER.error('WS %s: %s %s', id(self.wsock), message1, message2)
|
||
|
|
||
|
def send_message(self, message):
|
||
|
"""Helper method to send messages."""
|
||
|
self.debug('Sending', message)
|
||
|
self.wsock.send_json(message, dumps=JSON_DUMP)
|
||
|
|
||
|
@callback
|
||
|
def _cancel_connection(self, event):
|
||
|
"""Cancel this connection."""
|
||
|
self.socket_task.cancel()
|
||
|
|
||
|
@asyncio.coroutine
|
||
|
def _call_service_helper(self, msg):
|
||
|
"""Helper to call a service and fire complete message."""
|
||
|
yield from self.hass.services.async_call(msg['domain'], msg['service'],
|
||
|
msg['service_data'], True)
|
||
|
try:
|
||
|
self.send_message(result_message(msg['id']))
|
||
|
except RuntimeError:
|
||
|
# Socket has been closed.
|
||
|
pass
|
||
|
|
||
|
@callback
|
||
|
def _forward_event(self, iden, event):
|
||
|
"""Helper to forward events to websocket."""
|
||
|
if event.event_type == EVENT_TIME_CHANGED:
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
self.send_message(event_message(iden, event))
|
||
|
except RuntimeError:
|
||
|
# Socket has been closed.
|
||
|
pass
|
||
|
|
||
|
@asyncio.coroutine
|
||
|
def handle(self):
|
||
|
"""Handle the websocket connection."""
|
||
|
wsock = self.wsock = web.WebSocketResponse()
|
||
|
yield from wsock.prepare(self.request)
|
||
|
|
||
|
# Set up to cancel this connection when Home Assistant shuts down
|
||
|
self.socket_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||
|
self.hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP,
|
||
|
self._cancel_connection)
|
||
|
|
||
|
self.debug('Connected')
|
||
|
|
||
|
msg = None
|
||
|
authenticated = False
|
||
|
|
||
|
try:
|
||
|
if self.request[KEY_AUTHENTICATED]:
|
||
|
authenticated = True
|
||
|
|
||
|
else:
|
||
|
self.send_message(auth_required_message())
|
||
|
msg = yield from wsock.receive_json()
|
||
|
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
if validate_password(self.request, msg['api_password']):
|
||
|
authenticated = True
|
||
|
|
||
|
else:
|
||
|
self.debug('Invalid password')
|
||
|
self.send_message(auth_invalid_message('Invalid password'))
|
||
|
return wsock
|
||
|
|
||
|
if not authenticated:
|
||
|
return wsock
|
||
|
|
||
|
self.send_message(auth_ok_message())
|
||
|
|
||
|
msg = yield from wsock.receive_json()
|
||
|
|
||
|
last_id = 0
|
||
|
|
||
|
while msg:
|
||
|
self.debug('Received', msg)
|
||
|
msg = BASE_COMMAND_MESSAGE_SCHEMA(msg)
|
||
|
cur_id = msg['id']
|
||
|
|
||
|
if cur_id <= last_id:
|
||
|
self.send_message(error_message(
|
||
|
cur_id, ERR_ID_REUSE,
|
||
|
'Identifier values have to increase.'))
|
||
|
|
||
|
else:
|
||
|
handler_name = 'handle_{}'.format(msg['type'])
|
||
|
getattr(self, handler_name)(msg)
|
||
|
|
||
|
last_id = cur_id
|
||
|
msg = yield from wsock.receive_json()
|
||
|
|
||
|
except vol.Invalid as err:
|
||
|
error_msg = 'Message incorrectly formatted: '
|
||
|
if msg:
|
||
|
error_msg += humanize_error(msg, err)
|
||
|
else:
|
||
|
error_msg += str(err)
|
||
|
|
||
|
self.log_error(error_msg)
|
||
|
|
||
|
if not authenticated:
|
||
|
self.send_message(auth_invalid_message(error_msg))
|
||
|
|
||
|
else:
|
||
|
if isinstance(msg, dict):
|
||
|
iden = msg.get('id')
|
||
|
else:
|
||
|
iden = None
|
||
|
|
||
|
self.send_message(error_message(iden, ERR_INVALID_FORMAT,
|
||
|
error_msg))
|
||
|
|
||
|
except TypeError as err:
|
||
|
if wsock.closed:
|
||
|
self.debug('Connection closed by client')
|
||
|
else:
|
||
|
self.log_error('Unexpected TypeError', msg)
|
||
|
|
||
|
except ValueError as err:
|
||
|
msg = 'Received invalid JSON'
|
||
|
value = getattr(err, 'doc', None) # Py3.5+ only
|
||
|
if value:
|
||
|
msg += ': {}'.format(value)
|
||
|
self.log_error(msg)
|
||
|
|
||
|
except asyncio.CancelledError:
|
||
|
self.debug('Connection cancelled by server')
|
||
|
|
||
|
except Exception: # pylint: disable=broad-except
|
||
|
error = 'Unexpected error inside websocket API. '
|
||
|
if msg is not None:
|
||
|
error += str(msg)
|
||
|
_LOGGER.exception(error)
|
||
|
|
||
|
finally:
|
||
|
for unsub in self.event_listeners.values():
|
||
|
unsub()
|
||
|
|
||
|
yield from wsock.close()
|
||
|
self.debug('Closed connection')
|
||
|
|
||
|
return wsock
|
||
|
|
||
|
def handle_subscribe_events(self, msg):
|
||
|
"""Handle subscribe events command."""
|
||
|
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
|
||
|
msg['event_type'], partial(self._forward_event, msg['id']))
|
||
|
|
||
|
self.send_message(result_message(msg['id']))
|
||
|
|
||
|
def handle_unsubscribe_events(self, msg):
|
||
|
"""Handle unsubscribe events command."""
|
||
|
msg = UNSUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
subscription = msg['subscription']
|
||
|
|
||
|
if subscription not in self.event_listeners:
|
||
|
self.send_message(error_message(
|
||
|
msg['id'], ERR_NOT_FOUND,
|
||
|
'Subscription not found.'))
|
||
|
else:
|
||
|
self.event_listeners.pop(subscription)()
|
||
|
self.send_message(result_message(msg['id']))
|
||
|
|
||
|
def handle_call_service(self, msg):
|
||
|
"""Handle call service command."""
|
||
|
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.hass.async_add_job(self._call_service_helper(msg))
|
||
|
|
||
|
def handle_get_states(self, msg):
|
||
|
"""Handle get states command."""
|
||
|
msg = GET_STATES_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.send_message(result_message(msg['id'],
|
||
|
self.hass.states.async_all()))
|
||
|
|
||
|
def handle_get_services(self, msg):
|
||
|
"""Handle get services command."""
|
||
|
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.send_message(result_message(msg['id'],
|
||
|
api.async_services_json(self.hass)))
|
||
|
|
||
|
def handle_get_config(self, msg):
|
||
|
"""Handle get config command."""
|
||
|
msg = GET_CONFIG_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.send_message(result_message(msg['id'],
|
||
|
self.hass.config.as_dict()))
|
||
|
|
||
|
def handle_get_panels(self, msg):
|
||
|
"""Handle get panels command."""
|
||
|
msg = GET_PANELS_MESSAGE_SCHEMA(msg)
|
||
|
|
||
|
self.send_message(result_message(
|
||
|
msg['id'], self.hass.data[frontend.DATA_PANELS]))
|