"""View to accept incoming websocket connection.""" from __future__ import annotations import asyncio from collections.abc import Callable from contextlib import suppress import datetime as dt import logging from typing import Any, Final from aiohttp import WSMsgType, web import async_timeout from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_call_later from .auth import AuthPhase, auth_required_message from .const import ( CANCELLATION_ERRORS, DATA_CONNECTIONS, MAX_PENDING_MSG, PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, SIGNAL_WEBSOCKET_CONNECTED, SIGNAL_WEBSOCKET_DISCONNECTED, URL, ) from .error import Disconnect from .messages import message_to_json _WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection") class WebsocketAPIView(HomeAssistantView): """View to serve a websockets endpoint.""" name: str = "websocketapi" url: str = URL requires_auth: bool = False async def get(self, request: web.Request) -> web.WebSocketResponse: """Handle an incoming websocket connection.""" return await WebSocketHandler(request.app["hass"], request).async_handle() class WebSocketAdapter(logging.LoggerAdapter): """Add connection id to websocket messages.""" def process(self, msg: str, kwargs: Any) -> tuple[str, Any]: """Add connid to websocket log messages.""" return f'[{self.extra["connid"]}] {msg}', kwargs class WebSocketHandler: """Handle an active websocket client connection.""" def __init__(self, hass: HomeAssistant, request: web.Request) -> None: """Initialize an active connection.""" self.hass = hass self.request = request self.wsock = web.WebSocketResponse(heartbeat=55) self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) self._handle_task: asyncio.Task | None = None self._writer_task: asyncio.Task | None = None self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) self._peak_checker_unsub: Callable[[], None] | None = None async def _writer(self) -> None: """Write outgoing messages.""" # Exceptions if Socket disconnected or cancelled by connection handler with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): while not self.wsock.closed: if (process := await self._to_write.get()) is None: break if not isinstance(process, str): message: str = process() else: message = process self._logger.debug("Sending %s", message) await self.wsock.send_str(message) # Clean up the peaker checker when we shut down the writer if self._peak_checker_unsub is not None: self._peak_checker_unsub() self._peak_checker_unsub = None @callback def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> None: """Send a message to the client. Closes connection if the client is not reading the messages. Async friendly. """ if isinstance(message, dict): message = message_to_json(message) try: self._to_write.put_nowait(message) except asyncio.QueueFull: self._logger.error( "Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG ) self._cancel() if self._to_write.qsize() < PENDING_MSG_PEAK: if self._peak_checker_unsub: self._peak_checker_unsub() self._peak_checker_unsub = None return if self._peak_checker_unsub is None: self._peak_checker_unsub = async_call_later( self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak ) @callback def _check_write_peak(self, _utc_time: dt.datetime) -> None: """Check that we are no longer above the write peak.""" self._peak_checker_unsub = None if self._to_write.qsize() < PENDING_MSG_PEAK: return self._logger.error( "Client unable to keep up with pending messages. Stayed over %s for %s seconds", PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, ) self._cancel() @callback def _cancel(self) -> None: """Cancel the connection.""" if self._handle_task is not None: self._handle_task.cancel() if self._writer_task is not None: self._writer_task.cancel() async def async_handle(self) -> web.WebSocketResponse: """Handle a websocket response.""" request = self.request wsock = self.wsock try: async with async_timeout.timeout(10): await wsock.prepare(request) except asyncio.TimeoutError: self._logger.warning("Timeout preparing request from %s", request.remote) return wsock self._logger.debug("Connected from %s", request.remote) self._handle_task = asyncio.current_task() @callback def handle_hass_stop(event: Event) -> None: """Cancel this connection.""" self._cancel() unsub_stop = self.hass.bus.async_listen( EVENT_HOMEASSISTANT_STOP, handle_hass_stop ) # As the webserver is now started before the start # event we do not want to block for websocket responses self._writer_task = asyncio.create_task(self._writer()) auth = AuthPhase( self._logger, self.hass, self._send_message, self._cancel, request ) connection = None disconnect_warn = None try: self._send_message(auth_required_message()) # Auth Phase try: async with async_timeout.timeout(10): msg = await wsock.receive() except asyncio.TimeoutError as err: disconnect_warn = "Did not receive auth message within 10 seconds" raise Disconnect from err if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): raise Disconnect if msg.type != WSMsgType.TEXT: disconnect_warn = "Received non-Text message." raise Disconnect try: msg_data = msg.json() except ValueError as err: disconnect_warn = "Received invalid JSON." raise Disconnect from err self._logger.debug("Received %s", msg_data) connection = await auth.async_handle(msg_data) self.hass.data[DATA_CONNECTIONS] = ( self.hass.data.get(DATA_CONNECTIONS, 0) + 1 ) async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_CONNECTED) # Command phase while not wsock.closed: msg = await wsock.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): break if msg.type != WSMsgType.TEXT: disconnect_warn = "Received non-Text message." break try: msg_data = msg.json() except ValueError: disconnect_warn = "Received invalid JSON." break self._logger.debug("Received %s", msg_data) connection.async_handle(msg_data) except asyncio.CancelledError: self._logger.info("Connection closed by client") except Disconnect: pass except Exception: # pylint: disable=broad-except self._logger.exception("Unexpected error inside websocket API") finally: unsub_stop() if connection is not None: connection.async_handle_close() try: self._to_write.put_nowait(None) # Make sure all error messages are written before closing await self._writer_task await wsock.close() except asyncio.QueueFull: # can be raised by put_nowait self._writer_task.cancel() finally: if disconnect_warn is None: self._logger.debug("Disconnected") else: self._logger.warning("Disconnected: %s", disconnect_warn) if connection is not None: self.hass.data[DATA_CONNECTIONS] -= 1 async_dispatcher_send(self.hass, SIGNAL_WEBSOCKET_DISCONNECTED) return wsock