diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py index 3f322beef29..6932fa3180c 100644 --- a/homeassistant/components/assist_satellite/__init__.py +++ b/homeassistant/components/assist_satellite/__init__.py @@ -10,7 +10,13 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import ConfigType -from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature +from .connection_test import ConnectionTestView +from .const import ( + CONNECTION_TEST_DATA, + DOMAIN, + DOMAIN_DATA, + AssistSatelliteEntityFeature, +) from .entity import ( AssistSatelliteAnnouncement, AssistSatelliteConfiguration, @@ -57,7 +63,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: "async_internal_announce", [AssistSatelliteEntityFeature.ANNOUNCE], ) + hass.data[CONNECTION_TEST_DATA] = {} async_register_websocket_api(hass) + hass.http.register_view(ConnectionTestView()) return True diff --git a/homeassistant/components/assist_satellite/connection_test.mp3 b/homeassistant/components/assist_satellite/connection_test.mp3 new file mode 100755 index 00000000000..5fd79ce8609 Binary files /dev/null and b/homeassistant/components/assist_satellite/connection_test.mp3 differ diff --git a/homeassistant/components/assist_satellite/connection_test.py b/homeassistant/components/assist_satellite/connection_test.py new file mode 100644 index 00000000000..956542dacf3 --- /dev/null +++ b/homeassistant/components/assist_satellite/connection_test.py @@ -0,0 +1,43 @@ +"""Assist satellite connection test.""" + +import logging +from pathlib import Path + +from aiohttp import web + +from homeassistant.components.http import KEY_HASS, HomeAssistantView + +from .const import CONNECTION_TEST_DATA + +_LOGGER = logging.getLogger(__name__) + +CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg" +CONNECTION_TEST_FILENAME = "connection_test.mp3" +CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test" + + +class ConnectionTestView(HomeAssistantView): + """View to serve an audio sample for connection test.""" + + requires_auth = False + url = f"{CONNECTION_TEST_URL_BASE}/{{connection_id}}" + name = "api:assist_satellite_connection_test" + + async def get(self, request: web.Request, connection_id: str) -> web.Response: + """Start a get request.""" + _LOGGER.debug("Request for connection test with id %s", connection_id) + + hass = request.app[KEY_HASS] + connection_test_data = hass.data[CONNECTION_TEST_DATA] + + connection_test_event = connection_test_data.pop(connection_id, None) + + if connection_test_event is None: + return web.Response(status=404) + + connection_test_event.set() + + audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME + audio_data = await hass.async_add_executor_job(audio_path.read_bytes) + + return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE) diff --git a/homeassistant/components/assist_satellite/const.py b/homeassistant/components/assist_satellite/const.py index bd5453e06de..73bc126f7ba 100644 --- a/homeassistant/components/assist_satellite/const.py +++ b/homeassistant/components/assist_satellite/const.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from enum import IntFlag from typing import TYPE_CHECKING @@ -15,6 +16,9 @@ if TYPE_CHECKING: DOMAIN = "assist_satellite" DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN) +CONNECTION_TEST_DATA: HassKey[dict[str, asyncio.Event]] = HassKey( + f"{DOMAIN}_connection_tests" +) class AssistSatelliteEntityFeature(IntFlag): diff --git a/homeassistant/components/assist_satellite/manifest.json b/homeassistant/components/assist_satellite/manifest.json index b4f89456351..68a3ceafd4f 100644 --- a/homeassistant/components/assist_satellite/manifest.json +++ b/homeassistant/components/assist_satellite/manifest.json @@ -2,7 +2,7 @@ "domain": "assist_satellite", "name": "Assist Satellite", "codeowners": ["@home-assistant/core", "@synesthesiam"], - "dependencies": ["assist_pipeline", "stt", "tts"], + "dependencies": ["assist_pipeline", "http", "stt", "tts"], "documentation": "https://www.home-assistant.io/integrations/assist_satellite", "integration_type": "entity", "quality_scale": "internal" diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py index ee7bef7e4e8..741f4364e7f 100644 --- a/homeassistant/components/assist_satellite/websocket_api.py +++ b/homeassistant/components/assist_satellite/websocket_api.py @@ -1,5 +1,6 @@ """Assist satellite Websocket API.""" +import asyncio from dataclasses import asdict, replace from typing import Any @@ -9,8 +10,19 @@ from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.util import uuid as uuid_util -from .const import DOMAIN, DOMAIN_DATA +from .connection_test import CONNECTION_TEST_URL_BASE +from .const import ( + CONNECTION_TEST_DATA, + DOMAIN, + DOMAIN_DATA, + AssistSatelliteEntityFeature, +) +from .entity import AssistSatelliteEntity + +CONNECTION_TEST_TIMEOUT = 30 @callback @@ -19,6 +31,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_intercept_wake_word) websocket_api.async_register_command(hass, websocket_get_configuration) websocket_api.async_register_command(hass, websocket_set_wake_words) + websocket_api.async_register_command(hass, websocket_test_connection) @callback @@ -138,3 +151,57 @@ async def websocket_set_wake_words( replace(config, active_wake_words=actual_ids) ) connection.send_result(msg["id"]) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_satellite/test_connection", + vol.Required("entity_id"): cv.entity_domain(DOMAIN), + } +) +@websocket_api.async_response +async def websocket_test_connection( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Test the connection between the device and Home Assistant. + + Send an announcement to the device with a special media id. + """ + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + satellite = component.get_entity(msg["entity_id"]) + if satellite is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found" + ) + return + if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE: + connection.send_error( + msg["id"], + websocket_api.ERR_NOT_SUPPORTED, + "Entity does not support announce", + ) + return + + # Announce and wait for event + connection_test_data = hass.data[CONNECTION_TEST_DATA] + connection_id = uuid_util.random_uuid_hex() + connection_test_event = asyncio.Event() + connection_test_data[connection_id] = connection_test_event + + hass.async_create_background_task( + satellite.async_internal_announce( + media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}" + ), + f"assist_satellite_connection_test_{msg['entity_id']}", + ) + + try: + async with asyncio.timeout(CONNECTION_TEST_TIMEOUT): + await connection_test_event.wait() + connection.send_result(msg["id"], {"status": "success"}) + except TimeoutError: + connection.send_result(msg["id"], {"status": "timeout"}) + finally: + connection_test_data.pop(connection_id, None) diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py index 489460f8e2c..9e9bfd959e6 100644 --- a/tests/components/assist_satellite/conftest.py +++ b/tests/components/assist_satellite/conftest.py @@ -44,7 +44,7 @@ class MockAssistSatellite(AssistSatelliteEntity): def __init__(self) -> None: """Initialize the mock entity.""" self.events = [] - self.announcements = [] + self.announcements: list[AssistSatelliteAnnouncement] = [] self.config = AssistSatelliteConfiguration( available_wake_words=[ AssistSatelliteWakeWord( diff --git a/tests/components/assist_satellite/test_websocket_api.py b/tests/components/assist_satellite/test_websocket_api.py index 709005e38cf..257961a5b32 100644 --- a/tests/components/assist_satellite/test_websocket_api.py +++ b/tests/components/assist_satellite/test_websocket_api.py @@ -1,11 +1,16 @@ """Test WebSocket API.""" import asyncio +from http import HTTPStatus from unittest.mock import patch +from freezegun.api import FrozenDateTimeFactory import pytest from homeassistant.components.assist_pipeline import PipelineStage +from homeassistant.components.assist_satellite.websocket_api import ( + CONNECTION_TEST_TIMEOUT, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant @@ -13,7 +18,7 @@ from . import ENTITY_ID from .conftest import MockAssistSatellite from tests.common import MockUser -from tests.typing import WebSocketGenerator +from tests.typing import ClientSessionGenerator, WebSocketGenerator async def test_intercept_wake_word( @@ -385,3 +390,129 @@ async def test_set_wake_words_bad_id( "code": "not_supported", "message": "Wake word id is not supported: abcd", } + + +async def test_connection_test( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, + hass_client: ClientSessionGenerator, +) -> None: + """Test connection test.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/test_connection", + "entity_id": ENTITY_ID, + } + ) + + for _ in range(3): + await asyncio.sleep(0) + + assert len(entity.announcements) == 1 + assert entity.announcements[0].message == "" + announcement_media_id = entity.announcements[0].media_id + hass_url = "http://10.10.10.10:8123" + assert announcement_media_id.startswith( + f"{hass_url}/api/assist_satellite/connection_test/" + ) + + # Fake satellite fetches the URL + client = await hass_client() + resp = await client.get(announcement_media_id[len(hass_url) :]) + assert resp.status == HTTPStatus.OK + + response = await ws_client.receive_json() + assert response["success"] + assert response["result"] == {"status": "success"} + + +async def test_connection_test_timeout( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, + hass_client: ClientSessionGenerator, + freezer: FrozenDateTimeFactory, +) -> None: + """Test connection test timeout.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/test_connection", + "entity_id": ENTITY_ID, + } + ) + + for _ in range(3): + await asyncio.sleep(0) + + assert len(entity.announcements) == 1 + assert entity.announcements[0].message == "" + announcement_media_id = entity.announcements[0].media_id + hass_url = "http://10.10.10.10:8123" + assert announcement_media_id.startswith( + f"{hass_url}/api/assist_satellite/connection_test/" + ) + + freezer.tick(CONNECTION_TEST_TIMEOUT + 1) + + # Timeout + response = await ws_client.receive_json() + assert response["success"] + assert response["result"] == {"status": "timeout"} + + +async def test_connection_test_invalid_satellite( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test connection test with unknown entity id.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/test_connection", + "entity_id": "assist_satellite.invalid", + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"] == { + "code": "not_found", + "message": "Entity not found", + } + + +async def test_connection_test_timeout_announcement_unsupported( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test connection test entity which does not support announce.""" + ws_client = await hass_ws_client(hass) + + # Disable announce support + entity.supported_features = 0 + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/test_connection", + "entity_id": ENTITY_ID, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"] == { + "code": "not_supported", + "message": "Entity does not support announce", + }