263 lines
8.6 KiB
Python
263 lines
8.6 KiB
Python
"""View to accept incoming websocket connection."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
import datetime as dt
|
|
import logging
|
|
from typing import Any, Final
|
|
|
|
from aiohttp import WSMsgType, web
|
|
import async_timeout
|
|
|
|
from homeassistant.components.http import HomeAssistantView
|
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|
from homeassistant.core import Event, HomeAssistant, callback
|
|
from homeassistant.helpers.event import async_call_later
|
|
|
|
from .auth import AuthPhase, auth_required_message
|
|
from .const import (
|
|
CANCELLATION_ERRORS,
|
|
DATA_CONNECTIONS,
|
|
MAX_PENDING_MSG,
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
SIGNAL_WEBSOCKET_CONNECTED,
|
|
SIGNAL_WEBSOCKET_DISCONNECTED,
|
|
URL,
|
|
)
|
|
from .error import Disconnect
|
|
from .messages import message_to_json
|
|
|
|
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
|
|
|
|
|
class WebsocketAPIView(HomeAssistantView):
|
|
"""View to serve a websockets endpoint."""
|
|
|
|
name: str = "websocketapi"
|
|
url: str = URL
|
|
requires_auth: bool = False
|
|
|
|
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
|
"""Handle an incoming websocket connection."""
|
|
# pylint: disable=no-self-use
|
|
return await WebSocketHandler(request.app["hass"], request).async_handle()
|
|
|
|
|
|
class WebSocketAdapter(logging.LoggerAdapter):
|
|
"""Add connection id to websocket messages."""
|
|
|
|
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
|
"""Add connid to websocket log messages."""
|
|
return f'[{self.extra["connid"]}] {msg}', kwargs
|
|
|
|
|
|
class WebSocketHandler:
|
|
"""Handle an active websocket client connection."""
|
|
|
|
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
|
"""Initialize an active connection."""
|
|
self.hass = hass
|
|
self.request = request
|
|
self.wsock = web.WebSocketResponse(heartbeat=55)
|
|
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
|
self._handle_task: asyncio.Task | None = None
|
|
self._writer_task: asyncio.Task | None = None
|
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
|
|
|
async def _writer(self) -> None:
|
|
"""Write outgoing messages."""
|
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
|
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
|
while not self.wsock.closed:
|
|
if (message := await self._to_write.get()) is None:
|
|
break
|
|
|
|
self._logger.debug("Sending %s", message)
|
|
await self.wsock.send_str(message)
|
|
|
|
# Clean up the peaker checker when we shut down the writer
|
|
if self._peak_checker_unsub is not None:
|
|
self._peak_checker_unsub()
|
|
self._peak_checker_unsub = None
|
|
|
|
@callback
|
|
def _send_message(self, message: str | dict[str, Any]) -> None:
|
|
"""Send a message to the client.
|
|
|
|
Closes connection if the client is not reading the messages.
|
|
|
|
Async friendly.
|
|
"""
|
|
if not isinstance(message, str):
|
|
message = message_to_json(message)
|
|
|
|
try:
|
|
self._to_write.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
self._logger.error(
|
|
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
|
|
)
|
|
|
|
self._cancel()
|
|
|
|
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
|
if self._peak_checker_unsub:
|
|
self._peak_checker_unsub()
|
|
self._peak_checker_unsub = None
|
|
return
|
|
|
|
if self._peak_checker_unsub is None:
|
|
self._peak_checker_unsub = async_call_later(
|
|
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
|
)
|
|
|
|
@callback
|
|
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
|
|
"""Check that we are no longer above the write peak."""
|
|
self._peak_checker_unsub = None
|
|
|
|
if self._to_write.qsize() < PENDING_MSG_PEAK:
|
|
return
|
|
|
|
self._logger.error(
|
|
"Client unable to keep up with pending messages. Stayed over %s for %s seconds",
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
)
|
|
self._cancel()
|
|
|
|
@callback
|
|
def _cancel(self) -> None:
|
|
"""Cancel the connection."""
|
|
if self._handle_task is not None:
|
|
self._handle_task.cancel()
|
|
if self._writer_task is not None:
|
|
self._writer_task.cancel()
|
|
|
|
async def async_handle(self) -> web.WebSocketResponse:
|
|
"""Handle a websocket response."""
|
|
request = self.request
|
|
wsock = self.wsock
|
|
try:
|
|
async with async_timeout.timeout(10):
|
|
await wsock.prepare(request)
|
|
except asyncio.TimeoutError:
|
|
self._logger.warning("Timeout preparing request from %s", request.remote)
|
|
return wsock
|
|
|
|
self._logger.debug("Connected from %s", request.remote)
|
|
self._handle_task = asyncio.current_task()
|
|
|
|
@callback
|
|
def handle_hass_stop(event: Event) -> None:
|
|
"""Cancel this connection."""
|
|
self._cancel()
|
|
|
|
unsub_stop = self.hass.bus.async_listen(
|
|
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
|
|
)
|
|
|
|
# As the webserver is now started before the start
|
|
# event we do not want to block for websocket responses
|
|
self._writer_task = asyncio.create_task(self._writer())
|
|
|
|
auth = AuthPhase(
|
|
self._logger, self.hass, self._send_message, self._cancel, request
|
|
)
|
|
connection = None
|
|
disconnect_warn = None
|
|
|
|
try:
|
|
self._send_message(auth_required_message())
|
|
|
|
# Auth Phase
|
|
try:
|
|
async with async_timeout.timeout(10):
|
|
msg = await wsock.receive()
|
|
except asyncio.TimeoutError as err:
|
|
disconnect_warn = "Did not receive auth message within 10 seconds"
|
|
raise Disconnect from err
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
|
raise Disconnect
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
disconnect_warn = "Received non-Text message."
|
|
raise Disconnect
|
|
|
|
try:
|
|
msg_data = msg.json()
|
|
except ValueError as err:
|
|
disconnect_warn = "Received invalid JSON."
|
|
raise Disconnect from err
|
|
|
|
self._logger.debug("Received %s", msg_data)
|
|
connection = await auth.async_handle(msg_data)
|
|
self.hass.data[DATA_CONNECTIONS] = (
|
|
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
|
|
)
|
|
self.hass.helpers.dispatcher.async_dispatcher_send(
|
|
SIGNAL_WEBSOCKET_CONNECTED
|
|
)
|
|
|
|
# Command phase
|
|
while not wsock.closed:
|
|
msg = await wsock.receive()
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
|
break
|
|
|
|
if msg.type != WSMsgType.TEXT:
|
|
disconnect_warn = "Received non-Text message."
|
|
break
|
|
|
|
try:
|
|
msg_data = msg.json()
|
|
except ValueError:
|
|
disconnect_warn = "Received invalid JSON."
|
|
break
|
|
|
|
self._logger.debug("Received %s", msg_data)
|
|
connection.async_handle(msg_data)
|
|
|
|
except asyncio.CancelledError:
|
|
self._logger.info("Connection closed by client")
|
|
|
|
except Disconnect:
|
|
pass
|
|
|
|
except Exception: # pylint: disable=broad-except
|
|
self._logger.exception("Unexpected error inside websocket API")
|
|
|
|
finally:
|
|
unsub_stop()
|
|
|
|
if connection is not None:
|
|
connection.async_handle_close()
|
|
|
|
try:
|
|
self._to_write.put_nowait(None)
|
|
# Make sure all error messages are written before closing
|
|
await self._writer_task
|
|
await wsock.close()
|
|
except asyncio.QueueFull: # can be raised by put_nowait
|
|
self._writer_task.cancel()
|
|
|
|
finally:
|
|
if disconnect_warn is None:
|
|
self._logger.debug("Disconnected")
|
|
else:
|
|
self._logger.warning("Disconnected: %s", disconnect_warn)
|
|
|
|
if connection is not None:
|
|
self.hass.data[DATA_CONNECTIONS] -= 1
|
|
self.hass.helpers.dispatcher.async_dispatcher_send(
|
|
SIGNAL_WEBSOCKET_DISCONNECTED
|
|
)
|
|
|
|
return wsock
|