From 604c848dec7d2ac272ddbb9c841a4d8aac4e073b Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 20 Sep 2024 09:09:37 -0400 Subject: [PATCH] Change assist satellite announce method signature (#126299) --- .../components/assist_satellite/__init__.py | 2 ++ .../components/assist_satellite/entity.py | 29 +++++++++++++++++-- .../components/esphome/assist_satellite.py | 10 ++++--- tests/components/assist_satellite/conftest.py | 5 ++-- .../assist_satellite/test_entity.py | 23 ++++++++++----- 5 files changed, 52 insertions(+), 17 deletions(-) diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py index 77c9d8e678a..3f322beef29 100644 --- a/homeassistant/components/assist_satellite/__init__.py +++ b/homeassistant/components/assist_satellite/__init__.py @@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature from .entity import ( + AssistSatelliteAnnouncement, AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityDescription, @@ -22,6 +23,7 @@ from .websocket_api import async_register_websocket_api __all__ = [ "DOMAIN", + "AssistSatelliteAnnouncement", "AssistSatelliteEntity", "AssistSatelliteConfiguration", "AssistSatelliteEntityDescription", diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 079d3ae2948..23b588b569e 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from enum import StrEnum import logging import time -from typing import Any, Final, final +from typing import Any, Final, Literal, final from homeassistant.components import media_source, stt, tts from homeassistant.components.assist_pipeline import ( @@ -86,6 +86,19 @@ class AssistSatelliteConfiguration: """Maximum number of simultaneous wake words allowed (0 for no limit).""" +@dataclass +class AssistSatelliteAnnouncement: + """Announcement to be made.""" + + message: str + """Message to be spoken.""" + + media_id: str + """Media ID to be played.""" + + media_id_source: Literal["url", "media_id", "tts"] + + class AssistSatelliteEntity(entity.Entity): """Entity encapsulating the state and functionality of an Assist satellite.""" @@ -174,10 +187,13 @@ class AssistSatelliteEntity(entity.Entity): """ await self._cancel_running_pipeline() + media_id_source: Literal["url", "media_id", "tts"] | None = None + if message is None: message = "" if not media_id: + media_id_source = "tts" # Synthesize audio and get URL pipeline_id = self._resolve_pipeline() pipeline = async_get_pipeline(self.hass, pipeline_id) @@ -198,6 +214,8 @@ class AssistSatelliteEntity(entity.Entity): ) if media_source.is_media_source_id(media_id): + if not media_id_source: + media_id_source = "media_id" media = await media_source.async_resolve_media( self.hass, media_id, @@ -205,6 +223,9 @@ class AssistSatelliteEntity(entity.Entity): ) media_id = media.url + if not media_id_source: + media_id_source = "url" + # Resolve to full URL media_id = async_process_play_media_url(self.hass, media_id) @@ -216,12 +237,14 @@ class AssistSatelliteEntity(entity.Entity): try: # Block until announcement is finished - await self.async_announce(message, media_id) + await self.async_announce( + AssistSatelliteAnnouncement(message, media_id, media_id_source) + ) finally: self._is_announcing = False self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) - async def async_announce(self, message: str, media_id: str) -> None: + async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None: """Announce media on the satellite. Should block until the announcement is done playing. diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index f8ed4c48651..a0e05a6c565 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -313,18 +313,20 @@ class EsphomeAssistSatellite( self.cli.send_voice_assistant_event(event_type, data_to_send) - async def async_announce(self, message: str, media_id: str) -> None: + async def async_announce( + self, announcement: assist_satellite.AssistSatelliteAnnouncement + ) -> None: """Announce media on the satellite. Should block until the announcement is done playing. """ _LOGGER.debug( "Waiting for announcement to finished (message=%s, media_id=%s)", - message, - media_id, + announcement.message, + announcement.media_id, ) await self.cli.send_voice_assistant_announcement_await_response( - media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message + announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message ) async def handle_pipeline_start( diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py index 3a374b312cc..489460f8e2c 100644 --- a/tests/components/assist_satellite/conftest.py +++ b/tests/components/assist_satellite/conftest.py @@ -8,6 +8,7 @@ import pytest from homeassistant.components.assist_pipeline import PipelineEvent from homeassistant.components.assist_satellite import ( DOMAIN as AS_DOMAIN, + AssistSatelliteAnnouncement, AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityFeature, @@ -63,9 +64,9 @@ class MockAssistSatellite(AssistSatelliteEntity): """Handle pipeline events.""" self.events.append(event) - async def async_announce(self, message: str, media_id: str) -> None: + async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None: """Announce media on a device.""" - self.announcements.append((message, media_id)) + self.announcements.append(announcement) @callback def async_get_configuration(self) -> AssistSatelliteConfiguration: diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index 2af3af89681..b2347184bec 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -17,7 +17,10 @@ from homeassistant.components.assist_pipeline import ( async_update_pipeline, vad, ) -from homeassistant.components.assist_satellite import SatelliteBusyError +from homeassistant.components.assist_satellite import ( + AssistSatelliteAnnouncement, + SatelliteBusyError, +) from homeassistant.components.assist_satellite.entity import AssistSatelliteState from homeassistant.components.media_source import PlayMedia from homeassistant.config_entries import ConfigEntry @@ -159,18 +162,22 @@ async def test_new_pipeline_cancels_pipeline( [ ( {"message": "Hello"}, - ("Hello", "https://www.home-assistant.io/resolved.mp3"), + AssistSatelliteAnnouncement( + "Hello", "https://www.home-assistant.io/resolved.mp3", "tts" + ), ), ( { "message": "Hello", - "media_id": "http://example.com/bla.mp3", + "media_id": "media-source://bla", }, - ("Hello", "http://example.com/bla.mp3"), + AssistSatelliteAnnouncement( + "Hello", "https://www.home-assistant.io/resolved.mp3", "media_id" + ), ), ( {"media_id": "http://example.com/bla.mp3"}, - ("", "http://example.com/bla.mp3"), + AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"), ), ], ) @@ -195,10 +202,10 @@ async def test_announce( original_announce = entity.async_announce announce_started = asyncio.Event() - async def async_announce(message, media_id): + async def async_announce(announcement): # Verify state change assert entity.state == AssistSatelliteState.RESPONDING - await original_announce(message, media_id) + await original_announce(announcement) announce_started.set() def tts_generate_media_source_id( @@ -249,7 +256,7 @@ async def test_announce_busy( announce_started = asyncio.Event() got_error = asyncio.Event() - async def async_announce(message, media_id): + async def async_announce(announcement): announce_started.set() # Block so we can do another announcement