Migrate VoIP to use Assist Pipeline TTS tokens (#139671)

* Migrate VoIP to use pipeline token

* migrate announcements to use TTS token
pull/142401/head^2
Paulus Schoutsen 2025-04-22 10:24:24 -04:00 committed by GitHub
parent 871a7c87bf
commit 8aa30b0ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 92 deletions

View File

@ -408,10 +408,18 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
"""Play an announcement once."""
_LOGGER.debug("Playing announcement")
try:
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
if announcement.tts_token is None:
_LOGGER.error("Only TTS announcements are supported")
return
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
stream = tts.async_get_stream(self.hass, announcement.tts_token)
if stream is None:
_LOGGER.error("TTS stream no longer available")
return
try:
await self._send_tts(stream, wait_for_tone=False)
if not self._run_pipeline_after_announce:
# Delay before looping announcement
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
@ -442,11 +450,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
)
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
if (
event.data
and (tts_output := event.data["tts_output"])
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
):
self.config_entry.async_create_background_task(
self.hass,
self._send_tts(media_id),
self._send_tts(tts_stream=stream),
"voip_pipeline_tts",
)
else:
@ -457,19 +468,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self._pipeline_had_error = True
_LOGGER.warning(event)
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
async def _send_tts(
self,
tts_stream: tts.ResultStream,
wait_for_tone: bool = True,
) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return # not connected
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
data = b"".join([chunk async for chunk in tts_stream.async_stream_result()])
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
if tts_stream.extension != "wav":
raise ValueError(
f"Only TTS WAV audio can be streamed, got {tts_stream.extension}"
)
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
# Don't overlap TTS and processing beep

View File

@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
def _empty_wav() -> bytes:
def _empty_wav(framerate=16000) -> bytes:
"""Return bytes of an empty WAV file."""
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(16000)
wav_file.setframerate(framerate)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
@ -307,10 +307,11 @@ async def test_pipeline(
assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
data={"tts_output": {"token": mock_tts_result_stream.token}},
)
)
@ -326,22 +327,11 @@ async def test_pipeline(
original_tts_response_finished()
done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
assert media_source_id == _MEDIA_ID
return ("wav", _empty_wav())
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
satellite._tones = Tones(0)
@ -457,10 +447,11 @@ async def test_tts_timeout(
)
# Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
data={"tts_output": {"token": mock_tts_result_stream.token}},
)
)
@ -474,22 +465,9 @@ async def test_tts_timeout(
# Block here to force a timeout in _send_tts
await asyncio.sleep(2)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should time out immediately
return ("wav", _empty_wav())
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
satellite._tts_extra_timeout = 0.001
for tone in Tones:
@ -568,29 +546,18 @@ async def test_tts_wrong_extension(
)
# Proceed with media output
# Should fail because it's not "wav"
mock_tts_result_stream = MockResultStream(hass, "mp3", b"")
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
data={"tts_output": {"token": mock_tts_result_stream.token}},
)
)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not "wav"
return ("mp3", b"")
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
satellite.transport = Mock()
@ -663,36 +630,18 @@ async def test_tts_wrong_wav_format(
)
# Proceed with media output
# Should fail because it's not 16Khz
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050))
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
data={"tts_output": {"token": mock_tts_result_stream.token}},
)
)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not 16Khz, 16-bit mono
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(2)
return ("wav", wav_io.getvalue())
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
satellite.transport = Mock()
@ -878,10 +827,11 @@ async def test_announce(
assert err.value.translation_domain == "voip"
assert err.value.translation_key == "non_tts_announcement"
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -907,7 +857,9 @@ async def test_announce(
async with asyncio.timeout(1):
await announce_task
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
mock_send_tts.assert_called_once_with(
mock_tts_result_stream, wait_for_tone=False
)
@pytest.mark.usefixtures("socket_enabled")
@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
async with asyncio.timeout(1):
await announce_task
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
mock_send_tts.assert_called_once_with(
mock_tts_result_stream, wait_for_tone=False
)
@pytest.mark.usefixtures("socket_enabled")
@ -979,10 +934,11 @@ async def test_announce_timeout(
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -1020,10 +976,11 @@ async def test_start_conversation(
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
)
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -1061,10 +1018,11 @@ async def test_start_conversation(
)
# Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
data={"tts_output": {"token": mock_tts_result_stream.token}},
)
)