diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index 24c9cd249ce..8d9b47fdd06 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -125,6 +125,7 @@ as part of a config flow. from __future__ import annotations +import asyncio from collections.abc import Callable from datetime import datetime, timedelta from http import HTTPStatus @@ -168,6 +169,8 @@ type RetrieveResultType = Callable[[str, str], Credentials | None] CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) +DELETE_CURRENT_TOKEN_DELAY = 2 + @bind_hass def create_auth_code( @@ -644,11 +647,34 @@ def websocket_delete_all_refresh_tokens( else: connection.send_result(msg["id"], {}) + async def _delete_current_token_soon() -> None: + """Delete the current token after a delay. + + We do not want to delete the current token immediately as it will + close the connection. + + This is implemented as a tracked task to ensure the token + is still deleted if Home Assistant is shut down during + the delay. + + It should not be refactored to use a call_later as that + would not be tracked and the token would not be deleted + if Home Assistant was shut down during the delay. + """ + try: + await asyncio.sleep(DELETE_CURRENT_TOKEN_DELAY) + finally: + # If the task is cancelled because we are shutting down, delete + # the token right away. + hass.auth.async_remove_refresh_token(current_refresh_token) + if delete_current_token and ( not limit_token_types or current_refresh_token.token_type == token_type ): - # This will close the connection so we need to send the result first. - hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token) + # Deleting the token will close the connection so we need + # to do it with a delay in a tracked task to ensure it still + # happens if Home Assistant is shutting down. + hass.async_create_task(_delete_current_token_soon()) @websocket_api.websocket_command( diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 3a81508addc..a0d031834ae 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -25,8 +25,15 @@ PENDING_MSG_PEAK_TIME: Final = 5 # Maximum number of messages that can be pending at any given time. # This is effectively the upper limit of the number of entities # that can fire state changes within ~1 second. +# Ideally we would use homeassistant.const.MAX_EXPECTED_ENTITY_IDS +# but since chrome will lock up with too many messages we need to +# limit it to a lower number. MAX_PENDING_MSG: Final = 4096 +# Maximum number of messages that are pending before we force +# resolve the ready future. +PENDING_MSG_MAX_FORCE_READY: Final = 256 + ERR_ID_REUSE: Final = "id_reuse" ERR_INVALID_FORMAT: Final = "invalid_format" ERR_NOT_ALLOWED: Final = "not_allowed" diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index ef5b010171a..c65c4c65988 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -24,6 +24,7 @@ from .auth import AUTH_REQUIRED_MESSAGE, AuthPhase from .const import ( DATA_CONNECTIONS, MAX_PENDING_MSG, + PENDING_MSG_MAX_FORCE_READY, PENDING_MSG_PEAK, PENDING_MSG_PEAK_TIME, SIGNAL_WEBSOCKET_CONNECTED, @@ -67,6 +68,7 @@ class WebSocketHandler: __slots__ = ( "_hass", + "_loop", "_request", "_wsock", "_handle_task", @@ -78,11 +80,13 @@ class WebSocketHandler: "_connection", "_message_queue", "_ready_future", + "_release_ready_queue_size", ) def __init__(self, hass: HomeAssistant, request: web.Request) -> None: """Initialize an active connection.""" self._hass = hass + self._loop = hass.loop self._request: web.Request = request self._wsock = web.WebSocketResponse(heartbeat=55) self._handle_task: asyncio.Task | None = None @@ -97,8 +101,9 @@ class WebSocketHandler: # to where messages are queued. This allows the implementation # to use a deque and an asyncio.Future to avoid the overhead of # an asyncio.Queue. - self._message_queue: deque[bytes | None] = deque() - self._ready_future: asyncio.Future[None] | None = None + self._message_queue: deque[bytes] = deque() + self._ready_future: asyncio.Future[int] | None = None + self._release_ready_queue_size: int = 0 def __repr__(self) -> str: """Return the representation.""" @@ -126,45 +131,35 @@ class WebSocketHandler: message_queue = self._message_queue logger = self._logger wsock = self._wsock - loop = self._hass.loop + loop = self._loop + is_debug_log_enabled = partial(logger.isEnabledFor, logging.DEBUG) debug = logger.debug - is_enabled_for = logger.isEnabledFor - logging_debug = logging.DEBUG + can_coalesce = self._connection and self._connection.can_coalesce + ready_message_count = len(message_queue) # Exceptions if Socket disconnected or cancelled by connection handler try: while not wsock.closed: - if (messages_remaining := len(message_queue)) == 0: + if not message_queue: self._ready_future = loop.create_future() - await self._ready_future - messages_remaining = len(message_queue) + ready_message_count = await self._ready_future - # A None message is used to signal the end of the connection - if (message := message_queue.popleft()) is None: + if self._closing: return - debug_enabled = is_enabled_for(logging_debug) - messages_remaining -= 1 + if not can_coalesce: + # coalesce may be enabled later in the connection + can_coalesce = self._connection and self._connection.can_coalesce - if ( - not messages_remaining - or not (connection := self._connection) - or not connection.can_coalesce - ): - if debug_enabled: + if not can_coalesce or ready_message_count == 1: + message = message_queue.popleft() + if is_debug_log_enabled(): debug("%s: Sending %s", self.description, message) await send_bytes_text(message) continue - messages: list[bytes] = [message] - while messages_remaining: - # A None message is used to signal the end of the connection - if (message := message_queue.popleft()) is None: - return - messages.append(message) - messages_remaining -= 1 - - coalesced_messages = b"".join((b"[", b",".join(messages), b"]")) - if debug_enabled: + coalesced_messages = b"".join((b"[", b",".join(message_queue), b"]")) + message_queue.clear() + if is_debug_log_enabled(): debug("%s: Sending %s", self.description, coalesced_messages) await send_bytes_text(coalesced_messages) except asyncio.CancelledError: @@ -197,14 +192,15 @@ class WebSocketHandler: # max pending messages. return - if isinstance(message, dict): - message = message_to_json_bytes(message) - elif isinstance(message, str): - message = message.encode("utf-8") + if type(message) is not bytes: # noqa: E721 + if isinstance(message, dict): + message = message_to_json_bytes(message) + elif isinstance(message, str): + message = message.encode("utf-8") message_queue = self._message_queue - queue_size_before_add = len(message_queue) - if queue_size_before_add >= MAX_PENDING_MSG: + message_queue.append(message) + if (queue_size_after_add := len(message_queue)) >= MAX_PENDING_MSG: self._logger.error( ( "%s: Client unable to keep up with pending messages. Reached %s pending" @@ -218,14 +214,14 @@ class WebSocketHandler: self._cancel() return - message_queue.append(message) - ready_future = self._ready_future - if ready_future and not ready_future.done(): - ready_future.set_result(None) + if self._release_ready_queue_size == 0: + # Try to coalesce more messages to reduce the number of writes + self._release_ready_queue_size = queue_size_after_add + self._loop.call_soon(self._release_ready_future_or_reschedule) peak_checker_active = self._peak_checker_unsub is not None - if queue_size_before_add <= PENDING_MSG_PEAK: + if queue_size_after_add <= PENDING_MSG_PEAK: if peak_checker_active: self._cancel_peak_checker() return @@ -235,6 +231,32 @@ class WebSocketHandler: self._hass, PENDING_MSG_PEAK_TIME, self._check_write_peak ) + @callback + def _release_ready_future_or_reschedule(self) -> None: + """Release the ready future or reschedule. + + We will release the ready future if the queue did not grow since the + last time we tried to release the ready future. + + If we reach PENDING_MSG_MAX_FORCE_READY, we will release the ready future + immediately so avoid the coalesced messages from growing too large. + """ + if not (ready_future := self._ready_future) or not ( + queue_size := len(self._message_queue) + ): + self._release_ready_queue_size = 0 + return + # If we are below the max pending to force ready, and there are new messages + # in the queue since the last time we tried to release the ready future, we + # try again later so we can coalesce more messages. + if queue_size > self._release_ready_queue_size < PENDING_MSG_MAX_FORCE_READY: + self._release_ready_queue_size = queue_size + self._loop.call_soon(self._release_ready_future_or_reschedule) + return + self._release_ready_queue_size = 0 + if not ready_future.done(): + ready_future.set_result(queue_size) + @callback def _check_write_peak(self, _utc_time: dt.datetime) -> None: """Check that we are no longer above the write peak.""" @@ -440,10 +462,8 @@ class WebSocketHandler: connection.async_handle_close() self._closing = True - - self._message_queue.append(None) if self._ready_future and not self._ready_future.done(): - self._ready_future.set_result(None) + self._ready_future.set_result(len(self._message_queue)) # If the writer gets canceled we still need to close the websocket # so we have another finally block to make sure we close the websocket diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index c6f03f8bd64..09079337e07 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -546,20 +546,21 @@ async def test_ws_delete_all_refresh_tokens_error( tokens = result["result"] - await ws_client.send_json( - { - "id": 6, - "type": "auth/delete_all_refresh_tokens", - } - ) + with patch("homeassistant.components.auth.DELETE_CURRENT_TOKEN_DELAY", 0.001): + await ws_client.send_json( + { + "id": 6, + "type": "auth/delete_all_refresh_tokens", + } + ) - caplog.clear() - result = await ws_client.receive_json() - assert result, result["success"] is False - assert result["error"] == { - "code": "token_removing_error", - "message": "During removal, an error was raised.", - } + caplog.clear() + result = await ws_client.receive_json() + assert result, result["success"] is False + assert result["error"] == { + "code": "token_removing_error", + "message": "During removal, an error was raised.", + } records = [ record @@ -571,6 +572,7 @@ async def test_ws_delete_all_refresh_tokens_error( assert records[0].exc_info and str(records[0].exc_info[1]) == "I'm bad" assert records[0].name == "homeassistant.components.auth" + await hass.async_block_till_done() for token in tokens: refresh_token = hass.auth.async_get_refresh_token(token["id"]) assert refresh_token is None @@ -629,18 +631,20 @@ async def test_ws_delete_all_refresh_tokens( result = await ws_client.receive_json() assert result["success"], result - await ws_client.send_json( - { - "id": 6, - "type": "auth/delete_all_refresh_tokens", - **delete_token_type, - **delete_current_token, - } - ) + with patch("homeassistant.components.auth.DELETE_CURRENT_TOKEN_DELAY", 0.001): + await ws_client.send_json( + { + "id": 6, + "type": "auth/delete_all_refresh_tokens", + **delete_token_type, + **delete_current_token, + } + ) - result = await ws_client.receive_json() - assert result, result["success"] + result = await ws_client.receive_json() + assert result, result["success"] + await hass.async_block_till_done() # We need to enumerate the user since we may remove the token # that is used to authenticate the user which will prevent the websocket # connection from working diff --git a/tests/components/websocket_api/test_http.py b/tests/components/websocket_api/test_http.py index 6ce46a5d9fe..794dd410661 100644 --- a/tests/components/websocket_api/test_http.py +++ b/tests/components/websocket_api/test_http.py @@ -294,8 +294,6 @@ async def test_pending_msg_peak_recovery( instance._send_message({}) instance._handle_task.cancel() - msg = await websocket_client.receive() - assert msg.type == WSMsgType.TEXT msg = await websocket_client.receive() assert msg.type is WSMsgType.CLOSE assert "Client unable to keep up with pending messages" not in caplog.text