Handle announcement finished for ESPHome TTS response (#125625)
* Handle announcement finished for TTS response * Adjust testpull/126399/head
parent
970d28bce9
commit
3eed5de367
|
@ -14,6 +14,7 @@ import wave
|
|||
|
||||
from aioesphomeapi import (
|
||||
MediaPlayerFormatPurpose,
|
||||
VoiceAssistantAnnounceFinished,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
|
@ -166,6 +167,7 @@ class EsphomeAssistSatellite(
|
|||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
handle_audio=self.handle_audio,
|
||||
handle_announcement_finished=self.handle_announcement_finished,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -174,6 +176,7 @@ class EsphomeAssistSatellite(
|
|||
self.cli.subscribe_voice_assistant(
|
||||
handle_start=self.handle_pipeline_start,
|
||||
handle_stop=self.handle_pipeline_stop,
|
||||
handle_announcement_finished=self.handle_announcement_finished,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -194,6 +197,10 @@ class EsphomeAssistSatellite(
|
|||
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
if not (feature_flags & VoiceAssistantFeature.SPEAKER):
|
||||
# Will use media player for TTS/announcements
|
||||
self._update_tts_format()
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
await super().async_will_remove_from_hass()
|
||||
|
@ -382,6 +389,12 @@ class EsphomeAssistSatellite(
|
|||
timer_info.is_active,
|
||||
)
|
||||
|
||||
async def handle_announcement_finished(
|
||||
self, announce_finished: VoiceAssistantAnnounceFinished
|
||||
) -> None:
|
||||
"""Handle announcement finished message (also sent for TTS)."""
|
||||
self.tts_response_finished()
|
||||
|
||||
def _update_tts_format(self) -> None:
|
||||
"""Update the TTS format from the first media player."""
|
||||
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
||||
|
|
|
@ -19,6 +19,7 @@ from aioesphomeapi import (
|
|||
HomeassistantServiceCall,
|
||||
ReconnectLogic,
|
||||
UserService,
|
||||
VoiceAssistantAnnounceFinished,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantFeature,
|
||||
)
|
||||
|
@ -214,6 +215,13 @@ class MockESPHomeDevice:
|
|||
]
|
||||
| None
|
||||
)
|
||||
self.voice_assistant_handle_announcement_finished_callback: (
|
||||
Callable[
|
||||
[VoiceAssistantAnnounceFinished],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None
|
||||
)
|
||||
self.device_info = device_info
|
||||
|
||||
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
|
||||
|
@ -295,11 +303,21 @@ class MockESPHomeDevice:
|
|||
]
|
||||
| None
|
||||
) = None,
|
||||
handle_announcement_finished: (
|
||||
Callable[
|
||||
[VoiceAssistantAnnounceFinished],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
) -> None:
|
||||
"""Set the voice assistant subscription callbacks."""
|
||||
self.voice_assistant_handle_start_callback = handle_start
|
||||
self.voice_assistant_handle_stop_callback = handle_stop
|
||||
self.voice_assistant_handle_audio_callback = handle_audio
|
||||
self.voice_assistant_handle_announcement_finished_callback = (
|
||||
handle_announcement_finished
|
||||
)
|
||||
|
||||
async def mock_voice_assistant_handle_start(
|
||||
self,
|
||||
|
@ -322,6 +340,13 @@ class MockESPHomeDevice:
|
|||
assert self.voice_assistant_handle_audio_callback is not None
|
||||
await self.voice_assistant_handle_audio_callback(audio)
|
||||
|
||||
async def mock_voice_assistant_handle_announcement_finished(
|
||||
self, finished: VoiceAssistantAnnounceFinished
|
||||
) -> None:
|
||||
"""Mock voice assistant handle announcement finished."""
|
||||
assert self.voice_assistant_handle_announcement_finished_callback is not None
|
||||
await self.voice_assistant_handle_announcement_finished_callback(finished)
|
||||
|
||||
|
||||
async def _mock_generic_device_entry(
|
||||
hass: HomeAssistant,
|
||||
|
@ -402,10 +427,17 @@ async def _mock_generic_device_entry(
|
|||
]
|
||||
| None
|
||||
) = None,
|
||||
handle_announcement_finished: (
|
||||
Callable[
|
||||
[VoiceAssistantAnnounceFinished],
|
||||
Coroutine[Any, Any, None],
|
||||
]
|
||||
| None
|
||||
) = None,
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to voice assistant."""
|
||||
mock_device.set_subscribe_voice_assistant_callbacks(
|
||||
handle_start, handle_stop, handle_audio
|
||||
handle_start, handle_stop, handle_audio, handle_announcement_finished
|
||||
)
|
||||
|
||||
def unsub():
|
||||
|
|
|
@ -15,6 +15,7 @@ from aioesphomeapi import (
|
|||
MediaPlayerInfo,
|
||||
MediaPlayerSupportedFormat,
|
||||
UserService,
|
||||
VoiceAssistantAnnounceFinished,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
|
@ -603,6 +604,160 @@ async def test_udp_errors() -> None:
|
|||
protocol.transport.sendto.assert_not_called()
|
||||
|
||||
|
||||
async def test_pipeline_media_player(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
mock_wav: bytes,
|
||||
) -> None:
|
||||
"""Test a complete pipeline run with the TTS response sent to a media player instead of a speaker.
|
||||
|
||||
This test is not as comprehensive as test_pipeline_api_audio since we're
|
||||
mainly focused on tts_response_finished getting automatically called.
|
||||
"""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
|
||||
async for _chunk in stt_stream:
|
||||
break
|
||||
|
||||
event_callback = kwargs["event_callback"]
|
||||
|
||||
# STT
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_START,
|
||||
data={"engine": "test-stt-engine", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "test-stt-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Intent
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test-intent-engine",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "test-intent-text",
|
||||
"conversation_id": conversation_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.INTENT_END,
|
||||
data={"intent_output": {"conversation_id": conversation_id}},
|
||||
)
|
||||
)
|
||||
|
||||
# TTS
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test-stt-engine",
|
||||
"language": hass.config.language,
|
||||
"voice": "test-voice",
|
||||
"tts_input": "test-tts-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should return mock_wav audio
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||
|
||||
pipeline_finished = asyncio.Event()
|
||||
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||
|
||||
def handle_pipeline_finished():
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
tts_finished.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await satellite.handle_pipeline_start(
|
||||
conversation_id=conversation_id,
|
||||
flags=VoiceAssistantCommandFlag(0), # stt
|
||||
audio_settings=VoiceAssistantAudioSettings(),
|
||||
wake_word_phrase="",
|
||||
)
|
||||
|
||||
await satellite.handle_pipeline_stop(abort=False)
|
||||
await pipeline_finished.wait()
|
||||
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Will trigger tts_response_finished
|
||||
await mock_device.mock_voice_assistant_handle_announcement_finished(
|
||||
VoiceAssistantAnnounceFinished(success=True)
|
||||
)
|
||||
await tts_finished.wait()
|
||||
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_timer_events(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
|
@ -952,6 +1107,7 @@ async def test_announce_message(
|
|||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||
assert text == "test-text"
|
||||
|
||||
|
@ -983,6 +1139,7 @@ async def test_announce_message(
|
|||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_announce_media_id(
|
||||
|
@ -1016,6 +1173,7 @@ async def test_announce_media_id(
|
|||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||
|
||||
done.set()
|
||||
|
@ -1038,6 +1196,7 @@ async def test_announce_media_id(
|
|||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_satellite_unloaded_on_disconnect(
|
||||
|
|
Loading…
Reference in New Issue