Add connection test feature to assist_satellite (#126256)

* Add connection test feature to assist_satellite

* Add http to assist_satellite dependencies

* Remove extra logging

* Incorporate feedback

* Fix tests

* ruff

* Apply suggestions from code review

Co-authored-by: Bram Kragten <mail@bramkragten.nl>

* Use asyncio.Event instead of dispatcher

* Respond asap

* Update homeassistant/components/assist_satellite/websocket_api.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Bram Kragten <mail@bramkragten.nl>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/126433/head
Erik Montnemery 2024-09-22 16:55:31 +02:00 committed by GitHub
parent bb2c2d161a
commit 8158ca7c69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 258 additions and 5 deletions

View File

@ -10,7 +10,13 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
from .connection_test import ConnectionTestView
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
@ -57,7 +63,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())
return True

View File

@ -0,0 +1,43 @@
"""Assist satellite connection test."""
import logging
from pathlib import Path
from aiohttp import web
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from .const import CONNECTION_TEST_DATA
_LOGGER = logging.getLogger(__name__)
CONNECTION_TEST_CONTENT_TYPE = "audio/mpeg"
CONNECTION_TEST_FILENAME = "connection_test.mp3"
CONNECTION_TEST_URL_BASE = "/api/assist_satellite/connection_test"
class ConnectionTestView(HomeAssistantView):
"""View to serve an audio sample for connection test."""
requires_auth = False
url = f"{CONNECTION_TEST_URL_BASE}/{{connection_id}}"
name = "api:assist_satellite_connection_test"
async def get(self, request: web.Request, connection_id: str) -> web.Response:
"""Start a get request."""
_LOGGER.debug("Request for connection test with id %s", connection_id)
hass = request.app[KEY_HASS]
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_test_event = connection_test_data.pop(connection_id, None)
if connection_test_event is None:
return web.Response(status=404)
connection_test_event.set()
audio_path = Path(__file__).parent / CONNECTION_TEST_FILENAME
audio_data = await hass.async_add_executor_job(audio_path.read_bytes)
return web.Response(body=audio_data, content_type=CONNECTION_TEST_CONTENT_TYPE)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from enum import IntFlag
from typing import TYPE_CHECKING
@ -15,6 +16,9 @@ if TYPE_CHECKING:
DOMAIN = "assist_satellite"
DOMAIN_DATA: HassKey[EntityComponent[AssistSatelliteEntity]] = HassKey(DOMAIN)
CONNECTION_TEST_DATA: HassKey[dict[str, asyncio.Event]] = HassKey(
f"{DOMAIN}_connection_tests"
)
class AssistSatelliteEntityFeature(IntFlag):

View File

@ -2,7 +2,7 @@
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"

View File

@ -1,5 +1,6 @@
"""Assist satellite Websocket API."""
import asyncio
from dataclasses import asdict, replace
from typing import Any
@ -9,8 +10,19 @@ from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.util import uuid as uuid_util
from .const import DOMAIN, DOMAIN_DATA
from .connection_test import CONNECTION_TEST_URL_BASE
from .const import (
CONNECTION_TEST_DATA,
DOMAIN,
DOMAIN_DATA,
AssistSatelliteEntityFeature,
)
from .entity import AssistSatelliteEntity
CONNECTION_TEST_TIMEOUT = 30
@callback
@ -19,6 +31,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
websocket_api.async_register_command(hass, websocket_get_configuration)
websocket_api.async_register_command(hass, websocket_set_wake_words)
websocket_api.async_register_command(hass, websocket_test_connection)
@callback
@ -138,3 +151,57 @@ async def websocket_set_wake_words(
replace(config, active_wake_words=actual_ids)
)
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/test_connection",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
async def websocket_test_connection(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Test the connection between the device and Home Assistant.
Send an announcement to the device with a special media id.
"""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
if not (satellite.supported_features or 0) & AssistSatelliteEntityFeature.ANNOUNCE:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_SUPPORTED,
"Entity does not support announce",
)
return
# Announce and wait for event
connection_test_data = hass.data[CONNECTION_TEST_DATA]
connection_id = uuid_util.random_uuid_hex()
connection_test_event = asyncio.Event()
connection_test_data[connection_id] = connection_test_event
hass.async_create_background_task(
satellite.async_internal_announce(
media_id=f"{CONNECTION_TEST_URL_BASE}/{connection_id}"
),
f"assist_satellite_connection_test_{msg['entity_id']}",
)
try:
async with asyncio.timeout(CONNECTION_TEST_TIMEOUT):
await connection_test_event.wait()
connection.send_result(msg["id"], {"status": "success"})
except TimeoutError:
connection.send_result(msg["id"], {"status": "timeout"})
finally:
connection_test_data.pop(connection_id, None)

View File

@ -44,7 +44,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []
self.announcements: list[AssistSatelliteAnnouncement] = []
self.config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord(

View File

@ -1,11 +1,16 @@
"""Test WebSocket API."""
import asyncio
from http import HTTPStatus
from unittest.mock import patch
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.components.assist_satellite.websocket_api import (
CONNECTION_TEST_TIMEOUT,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
@ -13,7 +18,7 @@ from . import ENTITY_ID
from .conftest import MockAssistSatellite
from tests.common import MockUser
from tests.typing import WebSocketGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator
async def test_intercept_wake_word(
@ -385,3 +390,129 @@ async def test_set_wake_words_bad_id(
"code": "not_supported",
"message": "Wake word id is not supported: abcd",
}
async def test_connection_test(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
) -> None:
"""Test connection test."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)
# Fake satellite fetches the URL
client = await hass_client()
resp = await client.get(announcement_media_id[len(hass_url) :])
assert resp.status == HTTPStatus.OK
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "success"}
async def test_connection_test_timeout(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_client: ClientSessionGenerator,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test connection test timeout."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
assert len(entity.announcements) == 1
assert entity.announcements[0].message == ""
announcement_media_id = entity.announcements[0].media_id
hass_url = "http://10.10.10.10:8123"
assert announcement_media_id.startswith(
f"{hass_url}/api/assist_satellite/connection_test/"
)
freezer.tick(CONNECTION_TEST_TIMEOUT + 1)
# Timeout
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"status": "timeout"}
async def test_connection_test_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test with unknown entity id."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}
async def test_connection_test_timeout_announcement_unsupported(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test connection test entity which does not support announce."""
ws_client = await hass_ws_client(hass)
# Disable announce support
entity.supported_features = 0
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/test_connection",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_supported",
"message": "Entity does not support announce",
}