"""Test Websocket API http module.""" import asyncio from datetime import timedelta from typing import Any, cast from unittest.mock import patch from aiohttp import ServerDisconnectedError, WSMsgType, web import pytest from homeassistant.components.websocket_api import ( async_register_command, const, http, websocket_command, ) from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.core import HomeAssistant, callback from homeassistant.util.dt import utcnow from tests.common import async_fire_time_changed from tests.typing import MockHAClientWebSocket, WebSocketGenerator @pytest.fixture def mock_low_queue(): """Mock a low queue.""" with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 1): yield @pytest.fixture def mock_low_peak(): """Mock a low queue.""" with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5): yield async def test_pending_msg_overflow( hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket ) -> None: """Test pending messages overflows.""" for idx in range(10): await websocket_client.send_json({"id": idx + 1, "type": "ping"}) msg = await websocket_client.receive() assert msg.type == WSMsgType.close async def test_cleanup_on_cancellation( hass: HomeAssistant, websocket_client: MockHAClientWebSocket ) -> None: """Test cleanup on cancellation.""" subscriptions = None # Register a handler that registers a subscription @callback @websocket_command( { "type": "fake_subscription", } ) def fake_subscription( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: nonlocal subscriptions msg_id: int = msg["id"] connection.subscriptions[msg_id] = callback(lambda: None) connection.send_result(msg_id) subscriptions = connection.subscriptions async_register_command(hass, fake_subscription) # Register a handler that raises on cancel @callback @websocket_command( { "type": "subscription_that_raises_on_cancel", } ) def subscription_that_raises_on_cancel( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: nonlocal subscriptions msg_id: int = msg["id"] @callback def _raise(): raise ValueError() connection.subscriptions[msg_id] = _raise connection.send_result(msg_id) subscriptions = connection.subscriptions async_register_command(hass, subscription_that_raises_on_cancel) # Register a handler that cancels in handler @callback @websocket_command( { "type": "cancel_in_handler", } ) def cancel_in_handler( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: raise asyncio.CancelledError() async_register_command(hass, cancel_in_handler) await websocket_client.send_json({"id": 1, "type": "ping"}) msg = await websocket_client.receive_json() assert msg["id"] == 1 assert msg["type"] == "pong" assert not subscriptions await websocket_client.send_json({"id": 2, "type": "fake_subscription"}) msg = await websocket_client.receive_json() assert msg["id"] == 2 assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert len(subscriptions) == 2 await websocket_client.send_json( {"id": 3, "type": "subscription_that_raises_on_cancel"} ) msg = await websocket_client.receive_json() assert msg["id"] == 3 assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert len(subscriptions) == 3 await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"}) await hass.async_block_till_done() msg = await websocket_client.receive() assert msg.type == WSMsgType.close assert len(subscriptions) == 0 async def test_delayed_response_handler( hass: HomeAssistant, websocket_client: MockHAClientWebSocket, caplog: pytest.LogCaptureFixture, ) -> None: """Test a handler that responds after a connection has already been closed.""" subscriptions = None # Register a handler that responds after it returns @callback @websocket_command( { "type": "late_responder", } ) def async_late_responder( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ) -> None: msg_id: int = msg["id"] nonlocal subscriptions subscriptions = connection.subscriptions connection.subscriptions[msg_id] = lambda: None connection.send_result(msg_id) async def _async_late_send_message(): await asyncio.sleep(0.05) connection.send_event(msg_id, {"event": "any"}) hass.async_create_task(_async_late_send_message()) async_register_command(hass, async_late_responder) await websocket_client.send_json({"id": 1, "type": "ping"}) msg = await websocket_client.receive_json() assert msg["id"] == 1 assert msg["type"] == "pong" assert not subscriptions await websocket_client.send_json({"id": 2, "type": "late_responder"}) msg = await websocket_client.receive_json() assert msg["id"] == 2 assert msg["type"] == "result" assert len(subscriptions) == 2 assert await websocket_client.close() await hass.async_block_till_done() assert len(subscriptions) == 0 assert "Tried to send message" in caplog.text assert "on closed connection" in caplog.text async def test_ensure_disconnect_invalid_json( hass: HomeAssistant, websocket_client: MockHAClientWebSocket, caplog: pytest.LogCaptureFixture, ) -> None: """Test we get disconnected when sending invalid JSON.""" await websocket_client.send_json({"id": 1, "type": "ping"}) msg = await websocket_client.receive_json() assert msg["id"] == 1 assert msg["type"] == "pong" await websocket_client.send_str("[--INVALID-JSON--]") msg = await websocket_client.receive() assert msg.type == WSMsgType.CLOSE async def test_ensure_disconnect_invalid_binary( hass: HomeAssistant, websocket_client: MockHAClientWebSocket, caplog: pytest.LogCaptureFixture, ) -> None: """Test we get disconnected when sending invalid bytes.""" await websocket_client.send_json({"id": 1, "type": "ping"}) msg = await websocket_client.receive_json() assert msg["id"] == 1 assert msg["type"] == "pong" await websocket_client.send_bytes(b"") msg = await websocket_client.receive() assert msg.type == WSMsgType.CLOSE async def test_pending_msg_peak( hass: HomeAssistant, mock_low_peak, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, ) -> None: """Test pending msg overflow command.""" orig_handler = http.WebSocketHandler setup_instance: http.WebSocketHandler | None = None def instantiate_handler(*args): nonlocal setup_instance setup_instance = orig_handler(*args) return setup_instance with patch( "homeassistant.components.websocket_api.http.WebSocketHandler", instantiate_handler, ): websocket_client = await hass_ws_client() instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) # Fill the queue past the allowed peak for _ in range(10): instance._send_message({"overload": "message"}) async_fire_time_changed( hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1) ) msg = await websocket_client.receive() assert msg.type == WSMsgType.close assert "Client unable to keep up with pending messages" in caplog.text assert "Stayed over 5 for 5 seconds" in caplog.text assert "overload" in caplog.text async def test_pending_msg_peak_recovery( hass: HomeAssistant, mock_low_peak, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, ) -> None: """Test pending msg nears the peak but recovers.""" orig_handler = http.WebSocketHandler setup_instance: http.WebSocketHandler | None = None def instantiate_handler(*args): nonlocal setup_instance setup_instance = orig_handler(*args) return setup_instance with patch( "homeassistant.components.websocket_api.http.WebSocketHandler", instantiate_handler, ): websocket_client = await hass_ws_client() instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) # Make sure the call later is started for _ in range(10): instance._send_message({}) for _ in range(10): msg = await websocket_client.receive() assert msg.type == WSMsgType.TEXT instance._send_message({}) msg = await websocket_client.receive() assert msg.type == WSMsgType.TEXT # Cleanly shutdown instance._send_message({}) instance._handle_task.cancel() msg = await websocket_client.receive() assert msg.type == WSMsgType.TEXT msg = await websocket_client.receive() assert msg.type == WSMsgType.close assert "Client unable to keep up with pending messages" not in caplog.text async def test_pending_msg_peak_but_does_not_overflow( hass: HomeAssistant, mock_low_peak, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, ) -> None: """Test pending msg hits the low peak but recovers and does not overflow.""" orig_handler = http.WebSocketHandler setup_instance: http.WebSocketHandler | None = None def instantiate_handler(*args): nonlocal setup_instance setup_instance = orig_handler(*args) return setup_instance with patch( "homeassistant.components.websocket_api.http.WebSocketHandler", instantiate_handler, ): websocket_client = await hass_ws_client() instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance) # Kill writer task and fill queue past peak for _ in range(5): instance._message_queue.append(None) # Trigger the peak check instance._send_message({}) # Clear the queue instance._message_queue.clear() # Trigger the peak clear instance._send_message({}) async_fire_time_changed( hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1) ) msg = await websocket_client.receive() assert msg.type == WSMsgType.TEXT assert "Client unable to keep up with pending messages" not in caplog.text async def test_non_json_message( hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture ) -> None: """Test trying to serialize non JSON objects.""" bad_data = object() hass.states.async_set("test_domain.entity", "testing", {"bad": bad_data}) await websocket_client.send_json({"id": 5, "type": "get_states"}) msg = await websocket_client.receive_json() assert msg["id"] == 5 assert msg["type"] == const.TYPE_RESULT assert msg["success"] assert msg["result"] == [] assert "Unable to serialize to JSON. Bad data found" in caplog.text assert "State: test_domain.entity" in caplog.text assert "bad= None: """Test failing to prepare.""" with patch( "homeassistant.components.websocket_api.http.web.WebSocketResponse.prepare", side_effect=(asyncio.TimeoutError, web.WebSocketResponse.prepare), ), pytest.raises(ServerDisconnectedError): await hass_ws_client(hass) assert "Timeout preparing request" in caplog.text async def test_enable_coalesce( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, caplog: pytest.LogCaptureFixture, ) -> None: """Test enabling coalesce.""" websocket_client = await hass_ws_client(hass) await websocket_client.send_json( { "id": 1, "type": "supported_features", "features": {const.FEATURE_COALESCE_MESSAGES: 1}, } ) msg = await websocket_client.receive_json() assert msg["id"] == 1 assert msg["success"] is True send_tasks: list[asyncio.Future] = [] ids: set[int] = set() start_id = 2 for idx in range(10): id_ = idx + start_id ids.add(id_) send_tasks.append(websocket_client.send_json({"id": id_, "type": "ping"})) await asyncio.gather(*send_tasks) returned_ids: set[int] = set() for _ in range(10): msg = await websocket_client.receive_json() assert msg["type"] == "pong" returned_ids.add(msg["id"]) assert ids == returned_ids # Now close send_tasks_with_close: list[asyncio.Future] = [] start_id = 12 for idx in range(10): id_ = idx + start_id send_tasks_with_close.append( websocket_client.send_json({"id": id_, "type": "ping"}) ) send_tasks_with_close.append(websocket_client.close()) send_tasks_with_close.append(websocket_client.send_json({"id": 50, "type": "ping"})) with pytest.raises(ConnectionResetError): await asyncio.gather(*send_tasks_with_close) async def test_binary_message( hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture ) -> None: """Test binary messages.""" binary_payloads = { 104: ([], asyncio.Future()), 105: ([], asyncio.Future()), } # Register a handler @callback @websocket_command( { "type": "get_binary_message_handler", } ) def get_binary_message_handler( hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] ): unsub = None @callback def binary_message_handler( hass: HomeAssistant, connection: ActiveConnection, payload: bytes ): nonlocal unsub if msg["id"] == 103: raise ValueError("Boom") if payload: binary_payloads[msg["id"]][0].append(payload) else: binary_payloads[msg["id"]][1].set_result( b"".join(binary_payloads[msg["id"]][0]) ) unsub() prefix, unsub = connection.async_register_binary_handler(binary_message_handler) connection.send_result(msg["id"], {"prefix": prefix}) async_register_command(hass, get_binary_message_handler) # Register multiple binary handlers for i in range(101, 106): await websocket_client.send_json( {"id": i, "type": "get_binary_message_handler"} ) result = await websocket_client.receive_json() assert result["id"] == i assert result["type"] == const.TYPE_RESULT assert result["success"] assert result["result"]["prefix"] == i - 100 # Send message to binary await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0") await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3") await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3") await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10") await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4") await websocket_client.send_bytes((4).to_bytes(1, "big") + b"") await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5") await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2") await websocket_client.send_bytes((5).to_bytes(1, "big") + b"") # Verify received assert await binary_payloads[104][1] == b"test4" assert await binary_payloads[105][1] == b"test5test5-2" assert "Error handling binary message" in caplog.text assert "Received binary message for non-existing handler 0" in caplog.text assert "Received binary message for non-existing handler 3" in caplog.text assert "Received binary message for non-existing handler 10" in caplog.text