Handle announcement finished for ESPHome TTS response (#125625)

* Handle announcement finished for TTS response

* Adjust test
pull/126399/head
Michael Hansen 2024-09-13 15:31:38 -05:00 committed by GitHub
parent 970d28bce9
commit 3eed5de367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 205 additions and 1 deletions

View File

@ -14,6 +14,7 @@ import wave
from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
@ -166,6 +167,7 @@ class EsphomeAssistSatellite(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
handle_announcement_finished=self.handle_announcement_finished,
)
)
else:
@ -174,6 +176,7 @@ class EsphomeAssistSatellite(
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_announcement_finished=self.handle_announcement_finished,
)
)
@ -194,6 +197,10 @@ class EsphomeAssistSatellite(
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)
if not (feature_flags & VoiceAssistantFeature.SPEAKER):
# Will use media player for TTS/announcements
self._update_tts_format()
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
@ -382,6 +389,12 @@ class EsphomeAssistSatellite(
timer_info.is_active,
)
async def handle_announcement_finished(
self, announce_finished: VoiceAssistantAnnounceFinished
) -> None:
"""Handle announcement finished message (also sent for TTS)."""
self.tts_response_finished()
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):

View File

@ -19,6 +19,7 @@ from aioesphomeapi import (
HomeassistantServiceCall,
ReconnectLogic,
UserService,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantFeature,
)
@ -214,6 +215,13 @@ class MockESPHomeDevice:
]
| None
)
self.voice_assistant_handle_announcement_finished_callback: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
)
self.device_info = device_info
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
@ -295,11 +303,21 @@ class MockESPHomeDevice:
]
| None
) = None,
handle_announcement_finished: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
) = None,
) -> None:
"""Set the voice assistant subscription callbacks."""
self.voice_assistant_handle_start_callback = handle_start
self.voice_assistant_handle_stop_callback = handle_stop
self.voice_assistant_handle_audio_callback = handle_audio
self.voice_assistant_handle_announcement_finished_callback = (
handle_announcement_finished
)
async def mock_voice_assistant_handle_start(
self,
@ -322,6 +340,13 @@ class MockESPHomeDevice:
assert self.voice_assistant_handle_audio_callback is not None
await self.voice_assistant_handle_audio_callback(audio)
async def mock_voice_assistant_handle_announcement_finished(
self, finished: VoiceAssistantAnnounceFinished
) -> None:
"""Mock voice assistant handle announcement finished."""
assert self.voice_assistant_handle_announcement_finished_callback is not None
await self.voice_assistant_handle_announcement_finished_callback(finished)
async def _mock_generic_device_entry(
hass: HomeAssistant,
@ -402,10 +427,17 @@ async def _mock_generic_device_entry(
]
| None
) = None,
handle_announcement_finished: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
) = None,
) -> Callable[[], None]:
"""Subscribe to voice assistant."""
mock_device.set_subscribe_voice_assistant_callbacks(
handle_start, handle_stop, handle_audio
handle_start, handle_stop, handle_audio, handle_announcement_finished
)
def unsub():

View File

@ -15,6 +15,7 @@ from aioesphomeapi import (
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
@ -603,6 +604,160 @@ async def test_udp_errors() -> None:
protocol.transport.sendto.assert_not_called()
async def test_pipeline_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with the TTS response sent to a media player instead of a speaker.
This test is not as comprehensive as test_pipeline_api_audio since we're
mainly focused on tts_response_finished getting automatically called.
"""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
stt_stream = kwargs["stt_stream"]
async for _chunk in stt_stream:
break
event_callback = kwargs["event_callback"]
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}},
)
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
# Should return mock_wav audio
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
)
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
await satellite.handle_pipeline_stop(abort=False)
await pipeline_finished.wait()
assert satellite.state == AssistSatelliteState.RESPONDING
# Will trigger tts_response_finished
await mock_device.mock_voice_assistant_handle_announcement_finished(
VoiceAssistantAnnounceFinished(success=True)
)
await tts_finished.wait()
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_timer_events(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
@ -952,6 +1107,7 @@ async def test_announce_message(
async def send_voice_assistant_announcement_await_response(
media_id: str, timeout: float, text: str
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/resolved.mp3"
assert text == "test-text"
@ -983,6 +1139,7 @@ async def test_announce_message(
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_announce_media_id(
@ -1016,6 +1173,7 @@ async def test_announce_media_id(
async def send_voice_assistant_announcement_await_response(
media_id: str, timeout: float, text: str
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/resolved.mp3"
done.set()
@ -1038,6 +1196,7 @@ async def test_announce_media_id(
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_satellite_unloaded_on_disconnect(