""" Websocket based API for Home Assistant. For more details about this component, please refer to the documentation at https://home-assistant.io/developers/websocket_api/ """ import asyncio from concurrent import futures from contextlib import suppress from functools import partial, wraps import json import logging from aiohttp import web import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant.const import ( MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, __version__) from homeassistant.core import Context, callback from homeassistant.loader import bind_hass from homeassistant.remote import JSONEncoder from homeassistant.helpers import config_validation as cv from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.auth import validate_password from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.ban import process_wrong_login, \ process_success_login DOMAIN = 'websocket_api' URL = '/api/websocket' DEPENDENCIES = ('http',) MAX_PENDING_MSG = 512 ERR_ID_REUSE = 1 ERR_INVALID_FORMAT = 2 ERR_NOT_FOUND = 3 ERR_UNKNOWN_COMMAND = 4 TYPE_AUTH = 'auth' TYPE_AUTH_INVALID = 'auth_invalid' TYPE_AUTH_OK = 'auth_ok' TYPE_AUTH_REQUIRED = 'auth_required' TYPE_CALL_SERVICE = 'call_service' TYPE_EVENT = 'event' TYPE_GET_CONFIG = 'get_config' TYPE_GET_SERVICES = 'get_services' TYPE_GET_STATES = 'get_states' TYPE_PING = 'ping' TYPE_PONG = 'pong' TYPE_RESULT = 'result' TYPE_SUBSCRIBE_EVENTS = 'subscribe_events' TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events' _LOGGER = logging.getLogger(__name__) JSON_DUMP = partial(json.dumps, cls=JSONEncoder) AUTH_MESSAGE_SCHEMA = vol.Schema({ vol.Required('type'): TYPE_AUTH, vol.Exclusive('api_password', 'auth'): str, vol.Exclusive('access_token', 'auth'): str, }) # Minimal requirements of a message MINIMAL_MESSAGE_SCHEMA = vol.Schema({ vol.Required('id'): cv.positive_int, vol.Required('type'): cv.string, }, extra=vol.ALLOW_EXTRA) # Base schema to extend by message handlers BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({ vol.Required('id'): cv.positive_int, }) SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_SUBSCRIBE_EVENTS, vol.Optional('event_type', default=MATCH_ALL): str, }) SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS, vol.Required('subscription'): cv.positive_int, }) SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_CALL_SERVICE, vol.Required('domain'): str, vol.Required('service'): str, vol.Optional('service_data'): dict }) SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_STATES, }) SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_SERVICES, }) SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_GET_CONFIG, }) SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({ vol.Required('type'): TYPE_PING, }) # Define the possible errors that occur when connections are cancelled. # Originally, this was just asyncio.CancelledError, but issue #9546 showed # that futures.CancelledErrors can also occur in some situations. CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError) def auth_ok_message(): """Return an auth_ok message.""" return { 'type': TYPE_AUTH_OK, 'ha_version': __version__, } def auth_required_message(): """Return an auth_required message.""" return { 'type': TYPE_AUTH_REQUIRED, 'ha_version': __version__, } def auth_invalid_message(message): """Return an auth_invalid message.""" return { 'type': TYPE_AUTH_INVALID, 'message': message, } def event_message(iden, event): """Return an event message.""" return { 'id': iden, 'type': TYPE_EVENT, 'event': event.as_dict(), } def error_message(iden, code, message): """Return an error result message.""" return { 'id': iden, 'type': TYPE_RESULT, 'success': False, 'error': { 'code': code, 'message': message, }, } def pong_message(iden): """Return a pong message.""" return { 'id': iden, 'type': TYPE_PONG, } def result_message(iden, result=None): """Return a success result message.""" return { 'id': iden, 'type': TYPE_RESULT, 'success': True, 'result': result, } @bind_hass @callback def async_register_command(hass, command, handler, schema): """Register a websocket command.""" handlers = hass.data.get(DOMAIN) if handlers is None: handlers = hass.data[DOMAIN] = {} handlers[command] = (handler, schema) def require_owner(func): """Websocket decorator to require user to be an owner.""" @wraps(func) def with_owner(hass, connection, msg): """Check owner and call function.""" user = connection.request.get('hass_user') if user is None or not user.is_owner: connection.to_write.put_nowait(error_message( msg['id'], 'unauthorized', 'This command is for owners only.')) return func(hass, connection, msg) return with_owner async def async_setup(hass, config): """Initialize the websocket API.""" hass.http.register_view(WebsocketAPIView) async_register_command(hass, TYPE_SUBSCRIBE_EVENTS, handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS) async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS, handle_unsubscribe_events, SCHEMA_UNSUBSCRIBE_EVENTS) async_register_command(hass, TYPE_CALL_SERVICE, handle_call_service, SCHEMA_CALL_SERVICE) async_register_command(hass, TYPE_GET_STATES, handle_get_states, SCHEMA_GET_STATES) async_register_command(hass, TYPE_GET_SERVICES, handle_get_services, SCHEMA_GET_SERVICES) async_register_command(hass, TYPE_GET_CONFIG, handle_get_config, SCHEMA_GET_CONFIG) async_register_command(hass, TYPE_PING, handle_ping, SCHEMA_PING) return True class WebsocketAPIView(HomeAssistantView): """View to serve a websockets endpoint.""" name = "websocketapi" url = URL requires_auth = False async def get(self, request): """Handle an incoming websocket connection.""" return await ActiveConnection(request.app['hass'], request).handle() class ActiveConnection: """Handle an active websocket client connection.""" def __init__(self, hass, request): """Initialize an active connection.""" self.hass = hass self.request = request self.wsock = None self.event_listeners = {} self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop) self._handle_task = None self._writer_task = None @property def user(self): """Return the user associated with the connection.""" return self.request.get('hass_user') def context(self, msg): """Return a context.""" user = self.user if user is None: return Context() return Context(user_id=user.id) def debug(self, message1, message2=''): """Print a debug message.""" _LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2) def log_error(self, message1, message2=''): """Print an error message.""" _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2) async def _writer(self): """Write outgoing messages.""" # Exceptions if Socket disconnected or cancelled by connection handler with suppress(RuntimeError, *CANCELLATION_ERRORS): while not self.wsock.closed: message = await self.to_write.get() if message is None: break self.debug("Sending", message) try: await self.wsock.send_json(message, dumps=JSON_DUMP) except TypeError as err: _LOGGER.error('Unable to serialize to JSON: %s\n%s', err, message) @callback def send_message_outside(self, message): """Send a message to the client. Closes connection if the client is not reading the messages. Async friendly. """ try: self.to_write.put_nowait(message) except asyncio.QueueFull: self.log_error("Client exceeded max pending messages [2]:", MAX_PENDING_MSG) self.cancel() @callback def cancel(self): """Cancel the connection.""" self._handle_task.cancel() self._writer_task.cancel() async def handle(self): """Handle the websocket connection.""" request = self.request wsock = self.wsock = web.WebSocketResponse(heartbeat=55) await wsock.prepare(request) self.debug("Connected") # Get a reference to current task so we can cancel our connection self._handle_task = asyncio.Task.current_task(loop=self.hass.loop) @callback def handle_hass_stop(event): """Cancel this connection.""" self.cancel() unsub_stop = self.hass.bus.async_listen( EVENT_HOMEASSISTANT_STOP, handle_hass_stop) self._writer_task = self.hass.async_add_job(self._writer()) final_message = None msg = None authenticated = False try: if request[KEY_AUTHENTICATED]: authenticated = True else: self.debug("Request auth") await self.wsock.send_json(auth_required_message()) msg = await wsock.receive_json() msg = AUTH_MESSAGE_SCHEMA(msg) if self.hass.auth.active and 'access_token' in msg: self.debug("Received access_token") token = self.hass.auth.async_get_access_token( msg['access_token']) authenticated = token is not None if authenticated: request['hass_user'] = token.refresh_token.user elif ((not self.hass.auth.active or self.hass.auth.support_legacy) and 'api_password' in msg): self.debug("Received api_password") authenticated = validate_password( request, msg['api_password']) if not authenticated: self.debug("Authorization failed") await self.wsock.send_json( auth_invalid_message('Invalid access token or password')) await process_wrong_login(request) return wsock self.debug("Auth OK") await process_success_login(request) await self.wsock.send_json(auth_ok_message()) # ---------- AUTH PHASE OVER ---------- msg = await wsock.receive_json() last_id = 0 handlers = self.hass.data[DOMAIN] while msg: self.debug("Received", msg) msg = MINIMAL_MESSAGE_SCHEMA(msg) cur_id = msg['id'] if cur_id <= last_id: self.to_write.put_nowait(error_message( cur_id, ERR_ID_REUSE, 'Identifier values have to increase.')) elif msg['type'] not in handlers: self.log_error( 'Received invalid command: {}'.format(msg['type'])) self.to_write.put_nowait(error_message( cur_id, ERR_UNKNOWN_COMMAND, 'Unknown command.')) else: handler, schema = handlers[msg['type']] handler(self.hass, self, schema(msg)) last_id = cur_id msg = await wsock.receive_json() except vol.Invalid as err: error_msg = "Message incorrectly formatted: " if msg: error_msg += humanize_error(msg, err) else: error_msg += str(err) self.log_error(error_msg) if not authenticated: final_message = auth_invalid_message(error_msg) else: if isinstance(msg, dict): iden = msg.get('id') else: iden = None final_message = error_message( iden, ERR_INVALID_FORMAT, error_msg) except TypeError as err: if wsock.closed: self.debug("Connection closed by client") else: _LOGGER.exception("Unexpected TypeError: %s", err) except ValueError as err: msg = "Received invalid JSON" value = getattr(err, 'doc', None) # Py3.5+ only if value: msg += ': {}'.format(value) self.log_error(msg) self._writer_task.cancel() except CANCELLATION_ERRORS: self.debug("Connection cancelled") except asyncio.QueueFull: self.log_error("Client exceeded max pending messages [1]:", MAX_PENDING_MSG) self._writer_task.cancel() except Exception: # pylint: disable=broad-except error = "Unexpected error inside websocket API. " if msg is not None: error += str(msg) _LOGGER.exception(error) finally: unsub_stop() for unsub in self.event_listeners.values(): unsub() try: if final_message is not None: self.to_write.put_nowait(final_message) self.to_write.put_nowait(None) # Make sure all error messages are written before closing await self._writer_task except asyncio.QueueFull: self._writer_task.cancel() await wsock.close() self.debug("Closed connection") return wsock @callback def handle_subscribe_events(hass, connection, msg): """Handle subscribe events command. Async friendly. """ async def forward_events(event): """Forward events to websocket.""" if event.event_type == EVENT_TIME_CHANGED: return connection.send_message_outside(event_message(msg['id'], event)) connection.event_listeners[msg['id']] = hass.bus.async_listen( msg['event_type'], forward_events) connection.to_write.put_nowait(result_message(msg['id'])) @callback def handle_unsubscribe_events(hass, connection, msg): """Handle unsubscribe events command. Async friendly. """ subscription = msg['subscription'] if subscription in connection.event_listeners: connection.event_listeners.pop(subscription)() connection.to_write.put_nowait(result_message(msg['id'])) else: connection.to_write.put_nowait(error_message( msg['id'], ERR_NOT_FOUND, 'Subscription not found.')) @callback def handle_call_service(hass, connection, msg): """Handle call service command. Async friendly. """ async def call_service_helper(msg): """Call a service and fire complete message.""" await hass.services.async_call( msg['domain'], msg['service'], msg.get('service_data'), True, connection.context(msg)) connection.send_message_outside(result_message(msg['id'])) hass.async_add_job(call_service_helper(msg)) @callback def handle_get_states(hass, connection, msg): """Handle get states command. Async friendly. """ connection.to_write.put_nowait(result_message( msg['id'], hass.states.async_all())) @callback def handle_get_services(hass, connection, msg): """Handle get services command. Async friendly. """ async def get_services_helper(msg): """Get available services and fire complete message.""" descriptions = await async_get_all_descriptions(hass) connection.send_message_outside( result_message(msg['id'], descriptions)) hass.async_add_job(get_services_helper(msg)) @callback def handle_get_config(hass, connection, msg): """Handle get config command. Async friendly. """ connection.to_write.put_nowait(result_message( msg['id'], hass.config.as_dict())) @callback def handle_ping(hass, connection, msg): """Handle ping command. Async friendly. """ connection.to_write.put_nowait(pong_message(msg['id']))