540 lines
20 KiB
Python
540 lines
20 KiB
Python
"""View to accept incoming websocket connection."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import deque
|
|
from collections.abc import Callable, Coroutine
|
|
import datetime as dt
|
|
from functools import partial
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, Final
|
|
|
|
from aiohttp import WSMsgType, web
|
|
from aiohttp.http_websocket import WebSocketWriter
|
|
|
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|
from homeassistant.core import Event, HomeAssistant, callback
|
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
from homeassistant.helpers.event import async_call_later
|
|
from homeassistant.util.async_ import create_eager_task
|
|
from homeassistant.util.json import json_loads
|
|
|
|
from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase
|
|
from .const import (
|
|
DATA_CONNECTIONS,
|
|
MAX_PENDING_MSG,
|
|
PENDING_MSG_MAX_FORCE_READY,
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
SIGNAL_WEBSOCKET_CONNECTED,
|
|
SIGNAL_WEBSOCKET_DISCONNECTED,
|
|
URL,
|
|
)
|
|
from .error import Disconnect
|
|
from .messages import message_to_json_bytes
|
|
from .util import describe_request
|
|
|
|
CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING}
|
|
|
|
if TYPE_CHECKING:
|
|
from .connection import ActiveConnection
|
|
|
|
|
|
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
|
|
|
|
|
class WebsocketAPIView(HomeAssistantView):
|
|
"""View to serve a websockets endpoint."""
|
|
|
|
name: str = "websocketapi"
|
|
url: str = URL
|
|
requires_auth: bool = False
|
|
|
|
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
|
"""Handle an incoming websocket connection."""
|
|
return await WebSocketHandler(request.app[KEY_HASS], request).async_handle()
|
|
|
|
|
|
class WebSocketAdapter(logging.LoggerAdapter):
|
|
"""Add connection id to websocket messages."""
|
|
|
|
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
|
"""Add connid to websocket log messages."""
|
|
assert self.extra is not None
|
|
return f'[{self.extra["connid"]}] {msg}', kwargs
|
|
|
|
|
|
class WebSocketHandler:
|
|
"""Handle an active websocket client connection."""
|
|
|
|
__slots__ = (
|
|
"_hass",
|
|
"_loop",
|
|
"_request",
|
|
"_wsock",
|
|
"_handle_task",
|
|
"_writer_task",
|
|
"_closing",
|
|
"_authenticated",
|
|
"_logger",
|
|
"_peak_checker_unsub",
|
|
"_connection",
|
|
"_message_queue",
|
|
"_ready_future",
|
|
"_release_ready_queue_size",
|
|
)
|
|
|
|
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
|
"""Initialize an active connection."""
|
|
self._hass = hass
|
|
self._loop = hass.loop
|
|
self._request: web.Request = request
|
|
self._wsock = web.WebSocketResponse(heartbeat=55)
|
|
self._handle_task: asyncio.Task | None = None
|
|
self._writer_task: asyncio.Task | None = None
|
|
self._closing: bool = False
|
|
self._authenticated: bool = False
|
|
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
|
self._peak_checker_unsub: Callable[[], None] | None = None
|
|
self._connection: ActiveConnection | None = None
|
|
|
|
# The WebSocketHandler has a single consumer and path
|
|
# to where messages are queued. This allows the implementation
|
|
# to use a deque and an asyncio.Future to avoid the overhead of
|
|
# an asyncio.Queue.
|
|
self._message_queue: deque[bytes] = deque()
|
|
self._ready_future: asyncio.Future[int] | None = None
|
|
self._release_ready_queue_size: int = 0
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return the representation."""
|
|
return (
|
|
"<WebSocketHandler "
|
|
f"closing={self._closing} "
|
|
f"authenticated={self._authenticated} "
|
|
f"description={self.description}>"
|
|
)
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
"""Return a description of the connection."""
|
|
if connection := self._connection:
|
|
return connection.get_description(self._request)
|
|
if request := self._request:
|
|
return describe_request(request)
|
|
return "finished connection"
|
|
|
|
async def _writer(
|
|
self,
|
|
connection: ActiveConnection,
|
|
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
|
) -> None:
|
|
"""Write outgoing messages."""
|
|
# Variables are set locally to avoid lookups in the loop
|
|
message_queue = self._message_queue
|
|
logger = self._logger
|
|
wsock = self._wsock
|
|
loop = self._loop
|
|
is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG)
|
|
debug = logger.debug
|
|
can_coalesce = connection.can_coalesce
|
|
ready_message_count = len(message_queue)
|
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
|
try:
|
|
while not wsock.closed:
|
|
if not message_queue:
|
|
self._ready_future = loop.create_future()
|
|
ready_message_count = await self._ready_future
|
|
|
|
if self._closing:
|
|
return
|
|
|
|
if not can_coalesce:
|
|
# coalesce may be enabled later in the connection
|
|
can_coalesce = connection.can_coalesce
|
|
|
|
if not can_coalesce or ready_message_count == 1:
|
|
message = message_queue.popleft()
|
|
if is_debug_log_enabled():
|
|
debug("%s: Sending %s", self.description, message)
|
|
await send_bytes_text(message)
|
|
continue
|
|
|
|
coalesced_messages = b"".join((b"[", b",".join(message_queue), b"]"))
|
|
message_queue.clear()
|
|
if is_debug_log_enabled():
|
|
debug("%s: Sending %s", self.description, coalesced_messages)
|
|
await send_bytes_text(coalesced_messages)
|
|
except asyncio.CancelledError:
|
|
debug("%s: Writer cancelled", self.description)
|
|
raise
|
|
except (RuntimeError, ConnectionResetError) as ex:
|
|
debug("%s: Unexpected error in writer: %s", self.description, ex)
|
|
finally:
|
|
debug("%s: Writer done", self.description)
|
|
# Clean up the peak checker when we shut down the writer
|
|
self._cancel_peak_checker()
|
|
|
|
@callback
|
|
def _cancel_peak_checker(self) -> None:
|
|
"""Cancel the peak checker."""
|
|
if self._peak_checker_unsub is not None:
|
|
self._peak_checker_unsub()
|
|
self._peak_checker_unsub = None
|
|
|
|
@callback
|
|
def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
|
|
"""Queue sending a message to the client.
|
|
|
|
Closes connection if the client is not reading the messages.
|
|
|
|
Async friendly.
|
|
"""
|
|
if self._closing:
|
|
# Connection is cancelled, don't flood logs about exceeding
|
|
# max pending messages.
|
|
return
|
|
|
|
if type(message) is not bytes: # noqa: E721
|
|
if isinstance(message, dict):
|
|
message = message_to_json_bytes(message)
|
|
elif isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
|
|
message_queue = self._message_queue
|
|
message_queue.append(message)
|
|
if (queue_size_after_add := len(message_queue)) >= MAX_PENDING_MSG:
|
|
self._logger.error(
|
|
(
|
|
"%s: Client unable to keep up with pending messages. Reached %s pending"
|
|
" messages. The system's load is too high or an integration is"
|
|
" misbehaving; Last message was: %s"
|
|
),
|
|
self.description,
|
|
MAX_PENDING_MSG,
|
|
message,
|
|
)
|
|
self._cancel()
|
|
return
|
|
|
|
if self._release_ready_queue_size == 0:
|
|
# Try to coalesce more messages to reduce the number of writes
|
|
self._release_ready_queue_size = queue_size_after_add
|
|
self._loop.call_soon(self._release_ready_future_or_reschedule)
|
|
|
|
peak_checker_active = self._peak_checker_unsub is not None
|
|
|
|
if queue_size_after_add <= PENDING_MSG_PEAK:
|
|
if peak_checker_active:
|
|
self._cancel_peak_checker()
|
|
return
|
|
|
|
if not peak_checker_active:
|
|
self._peak_checker_unsub = async_call_later(
|
|
self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
|
|
)
|
|
|
|
@callback
|
|
def _release_ready_future_or_reschedule(self) -> None:
|
|
"""Release the ready future or reschedule.
|
|
|
|
We will release the ready future if the queue did not grow since the
|
|
last time we tried to release the ready future.
|
|
|
|
If we reach PENDING_MSG_MAX_FORCE_READY, we will release the ready future
|
|
immediately so avoid the coalesced messages from growing too large.
|
|
"""
|
|
if not (ready_future := self._ready_future) or not (
|
|
queue_size := len(self._message_queue)
|
|
):
|
|
self._release_ready_queue_size = 0
|
|
return
|
|
# If we are below the max pending to force ready, and there are new messages
|
|
# in the queue since the last time we tried to release the ready future, we
|
|
# try again later so we can coalesce more messages.
|
|
if queue_size > self._release_ready_queue_size < PENDING_MSG_MAX_FORCE_READY:
|
|
self._release_ready_queue_size = queue_size
|
|
self._loop.call_soon(self._release_ready_future_or_reschedule)
|
|
return
|
|
self._release_ready_queue_size = 0
|
|
if not ready_future.done():
|
|
ready_future.set_result(queue_size)
|
|
|
|
@callback
|
|
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
|
|
"""Check that we are no longer above the write peak."""
|
|
self._peak_checker_unsub = None
|
|
|
|
if len(self._message_queue) < PENDING_MSG_PEAK:
|
|
return
|
|
|
|
self._logger.error(
|
|
(
|
|
"%s: Client unable to keep up with pending messages. Stayed over %s for %s"
|
|
" seconds. The system's load is too high or an integration is"
|
|
" misbehaving; Last message was: %s"
|
|
),
|
|
self.description,
|
|
PENDING_MSG_PEAK,
|
|
PENDING_MSG_PEAK_TIME,
|
|
self._message_queue[-1],
|
|
)
|
|
self._cancel()
|
|
|
|
@callback
|
|
def _cancel(self) -> None:
|
|
"""Cancel the connection."""
|
|
self._closing = True
|
|
self._cancel_peak_checker()
|
|
if self._handle_task is not None:
|
|
self._handle_task.cancel()
|
|
if self._writer_task is not None:
|
|
self._writer_task.cancel()
|
|
|
|
@callback
|
|
def _async_handle_hass_stop(self, event: Event) -> None:
|
|
"""Cancel this connection."""
|
|
self._cancel()
|
|
|
|
async def async_handle(self) -> web.WebSocketResponse:
|
|
"""Handle a websocket response."""
|
|
request = self._request
|
|
wsock = self._wsock
|
|
logger = self._logger
|
|
hass = self._hass
|
|
|
|
try:
|
|
async with asyncio.timeout(10):
|
|
await wsock.prepare(request)
|
|
except ConnectionResetError:
|
|
# Likely the client disconnected before we prepared the websocket
|
|
logger.debug(
|
|
"%s: Connection reset by peer while preparing WebSocket",
|
|
self.description,
|
|
)
|
|
return wsock
|
|
except TimeoutError:
|
|
logger.warning("Timeout preparing request from %s", request.remote)
|
|
return wsock
|
|
|
|
logger.debug("%s: Connected from %s", self.description, request.remote)
|
|
self._handle_task = asyncio.current_task()
|
|
|
|
unsub_stop = hass.bus.async_listen(
|
|
EVENT_HOMEASSISTANT_STOP, self._async_handle_hass_stop
|
|
)
|
|
|
|
writer = wsock._writer # noqa: SLF001
|
|
if TYPE_CHECKING:
|
|
assert writer is not None
|
|
|
|
send_bytes_text = partial(writer.send_frame, opcode=WSMsgType.TEXT)
|
|
auth = AuthPhase(
|
|
logger, hass, self._send_message, self._cancel, request, send_bytes_text
|
|
)
|
|
connection: ActiveConnection | None = None
|
|
disconnect_warn: str | None = None
|
|
|
|
try:
|
|
connection = await self._async_handle_auth_phase(auth, send_bytes_text)
|
|
self._async_increase_writer_limit(writer)
|
|
await self._async_websocket_command_phase(connection)
|
|
except asyncio.CancelledError:
|
|
logger.debug("%s: Connection cancelled", self.description)
|
|
raise
|
|
except Disconnect as ex:
|
|
if disconnect_msg := str(ex):
|
|
disconnect_warn = disconnect_msg
|
|
|
|
logger.debug("%s: Connection closed by client: %s", self.description, ex)
|
|
except Exception:
|
|
logger.exception(
|
|
"%s: Unexpected error inside websocket API", self.description
|
|
)
|
|
finally:
|
|
unsub_stop()
|
|
|
|
self._cancel_peak_checker()
|
|
|
|
if connection is not None:
|
|
connection.async_handle_close()
|
|
|
|
self._closing = True
|
|
if self._ready_future and not self._ready_future.done():
|
|
self._ready_future.set_result(len(self._message_queue))
|
|
|
|
await self._async_cleanup_writer_and_close(disconnect_warn, connection)
|
|
|
|
return wsock
|
|
|
|
async def _async_handle_auth_phase(
|
|
self,
|
|
auth: AuthPhase,
|
|
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
|
) -> ActiveConnection:
|
|
"""Handle the auth phase of the websocket connection."""
|
|
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
|
|
|
# Auth Phase
|
|
try:
|
|
msg = await self._wsock.receive(10)
|
|
except TimeoutError as err:
|
|
raise Disconnect("Did not receive auth message within 10 seconds") from err
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
|
raise Disconnect("Received close message during auth phase")
|
|
|
|
if msg.type is not WSMsgType.TEXT:
|
|
raise Disconnect("Received non-Text message during auth phase")
|
|
|
|
try:
|
|
auth_msg_data = json_loads(msg.data)
|
|
except ValueError as err:
|
|
raise Disconnect("Received invalid JSON during auth phase") from err
|
|
|
|
if self._logger.isEnabledFor(logging.DEBUG):
|
|
self._logger.debug("%s: Received %s", self.description, auth_msg_data)
|
|
connection = await auth.async_handle(auth_msg_data)
|
|
# As the webserver is now started before the start
|
|
# event we do not want to block for websocket responses
|
|
#
|
|
# We only start the writer queue after the auth phase is completed
|
|
# since there is no need to queue messages before the auth phase
|
|
self._connection = connection
|
|
self._writer_task = create_eager_task(self._writer(connection, send_bytes_text))
|
|
self._hass.data[DATA_CONNECTIONS] = self._hass.data.get(DATA_CONNECTIONS, 0) + 1
|
|
async_dispatcher_send(self._hass, SIGNAL_WEBSOCKET_CONNECTED)
|
|
|
|
self._authenticated = True
|
|
return connection
|
|
|
|
@callback
|
|
def _async_increase_writer_limit(self, writer: WebSocketWriter) -> None:
|
|
#
|
|
#
|
|
# Our websocket implementation is backed by a deque
|
|
#
|
|
# As back-pressure builds, the queue will back up and use more memory
|
|
# until we disconnect the client when the queue size reaches
|
|
# MAX_PENDING_MSG. When we are generating a high volume of websocket messages,
|
|
# we hit a bottleneck in aiohttp where it will wait for
|
|
# the buffer to drain before sending the next message and messages
|
|
# start backing up in the queue.
|
|
#
|
|
# https://github.com/aio-libs/aiohttp/issues/1367 added drains
|
|
# to the websocket writer to handle malicious clients and network issues.
|
|
# The drain causes multiple problems for us since the buffer cannot be
|
|
# drained fast enough when we deliver a high volume or large messages:
|
|
#
|
|
# - We end up disconnecting the client. The client will then reconnect,
|
|
# and the cycle repeats itself, which results in a significant amount of
|
|
# CPU usage.
|
|
#
|
|
# - Messages latency increases because messages cannot be moved into
|
|
# the TCP buffer because it is blocked waiting for the drain to happen because
|
|
# of the low default limit of 16KiB. By increasing the limit, we instead
|
|
# rely on the underlying TCP buffer and stack to deliver the messages which
|
|
# can typically happen much faster.
|
|
#
|
|
# After the auth phase is completed, and we are not concerned about
|
|
# the user being a malicious client, we set the limit to force a drain
|
|
# to 1MiB. 1MiB is the maximum expected size of the serialized entity
|
|
# registry, which is the largest message we usually send.
|
|
#
|
|
# https://github.com/aio-libs/aiohttp/commit/b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99
|
|
# added a way to set the limit, but there is no way to actually
|
|
# reach the code to set the limit, so we have to set it directly.
|
|
#
|
|
writer._limit = 2**20 # noqa: SLF001
|
|
|
|
async def _async_websocket_command_phase(
|
|
self, connection: ActiveConnection
|
|
) -> None:
|
|
"""Handle the command phase of the websocket connection."""
|
|
wsock = self._wsock
|
|
async_handle_str = connection.async_handle
|
|
async_handle_binary = connection.async_handle_binary
|
|
_debug_enabled = partial(self._logger.isEnabledFor, logging.DEBUG)
|
|
|
|
# Command phase
|
|
while not wsock.closed:
|
|
msg = await wsock.receive()
|
|
msg_type = msg.type
|
|
msg_data = msg.data
|
|
|
|
if msg_type in CLOSE_MSG_TYPES:
|
|
break
|
|
|
|
if msg_type is WSMsgType.BINARY:
|
|
if len(msg_data) < 1:
|
|
raise Disconnect("Received invalid binary message.")
|
|
|
|
handler = msg_data[0]
|
|
payload = msg_data[1:]
|
|
async_handle_binary(handler, payload)
|
|
continue
|
|
|
|
if msg_type is not WSMsgType.TEXT:
|
|
raise Disconnect("Received non-Text message.")
|
|
|
|
try:
|
|
command_msg_data = json_loads(msg_data)
|
|
except ValueError as ex:
|
|
raise Disconnect("Received invalid JSON.") from ex
|
|
|
|
if _debug_enabled():
|
|
self._logger.debug(
|
|
"%s: Received %s", self.description, command_msg_data
|
|
)
|
|
|
|
# command_msg_data is always deserialized from JSON as a list
|
|
if type(command_msg_data) is not list: # noqa: E721
|
|
async_handle_str(command_msg_data)
|
|
continue
|
|
|
|
for split_msg in command_msg_data:
|
|
async_handle_str(split_msg)
|
|
|
|
async def _async_cleanup_writer_and_close(
|
|
self, disconnect_warn: str | None, connection: ActiveConnection | None
|
|
) -> None:
|
|
"""Cleanup the writer and close the websocket."""
|
|
# If the writer gets canceled we still need to close the websocket
|
|
# so we have another finally block to make sure we close the websocket
|
|
# if the writer gets canceled.
|
|
wsock = self._wsock
|
|
hass = self._hass
|
|
logger = self._logger
|
|
try:
|
|
if self._writer_task:
|
|
await self._writer_task
|
|
finally:
|
|
try:
|
|
# Make sure all error messages are written before closing
|
|
await wsock.close()
|
|
finally:
|
|
if disconnect_warn is None:
|
|
logger.debug("%s: Disconnected", self.description)
|
|
else:
|
|
logger.warning(
|
|
"%s: Disconnected: %s", self.description, disconnect_warn
|
|
)
|
|
|
|
if connection is not None:
|
|
hass.data[DATA_CONNECTIONS] -= 1
|
|
self._connection = None
|
|
|
|
async_dispatcher_send(hass, SIGNAL_WEBSOCKET_DISCONNECTED)
|
|
|
|
# Break reference cycles to make sure GC can happen sooner
|
|
self._wsock = None # type: ignore[assignment]
|
|
self._request = None # type: ignore[assignment]
|
|
self._hass = None # type: ignore[assignment]
|
|
self._logger = None # type: ignore[assignment]
|
|
self._message_queue = None # type: ignore[assignment]
|
|
self._handle_task = None
|
|
self._writer_task = None
|
|
self._ready_future = None
|