Migrate VoIP to use Assist Pipeline TTS tokens (#139671)
* Migrate VoIP to use pipeline token * migrate announcements to use TTS tokenpull/142401/head^2
parent
871a7c87bf
commit
8aa30b0ccb
|
@ -408,10 +408,18 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
"""Play an announcement once."""
|
||||
_LOGGER.debug("Playing announcement")
|
||||
|
||||
try:
|
||||
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
||||
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
|
||||
if announcement.tts_token is None:
|
||||
_LOGGER.error("Only TTS announcements are supported")
|
||||
return
|
||||
|
||||
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
||||
stream = tts.async_get_stream(self.hass, announcement.tts_token)
|
||||
if stream is None:
|
||||
_LOGGER.error("TTS stream no longer available")
|
||||
return
|
||||
|
||||
try:
|
||||
await self._send_tts(stream, wait_for_tone=False)
|
||||
if not self._run_pipeline_after_announce:
|
||||
# Delay before looping announcement
|
||||
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
||||
|
@ -442,11 +450,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
)
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
if (
|
||||
event.data
|
||||
and (tts_output := event.data["tts_output"])
|
||||
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
||||
):
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._send_tts(media_id),
|
||||
self._send_tts(tts_stream=stream),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
|
@ -457,19 +468,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||
self._pipeline_had_error = True
|
||||
_LOGGER.warning(event)
|
||||
|
||||
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
|
||||
async def _send_tts(
|
||||
self,
|
||||
tts_stream: tts.ResultStream,
|
||||
wait_for_tone: bool = True,
|
||||
) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return # not connected
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
data = b"".join([chunk async for chunk in tts_stream.async_stream_result()])
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
if tts_stream.extension != "wav":
|
||||
raise ValueError(
|
||||
f"Only TTS WAV audio can be streamed, got {tts_stream.extension}"
|
||||
)
|
||||
|
||||
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
|
||||
# Don't overlap TTS and processing beep
|
||||
|
|
|
@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
|||
"""Mock the TTS cache dir with empty dir."""
|
||||
|
||||
|
||||
def _empty_wav() -> bytes:
|
||||
def _empty_wav(framerate=16000) -> bytes:
|
||||
"""Return bytes of an empty WAV file."""
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||
with wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setframerate(framerate)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
|
||||
|
@ -307,10 +307,11 @@ async def test_pipeline(
|
|||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Proceed with media output
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -326,22 +327,11 @@ async def test_pipeline(
|
|||
original_tts_response_finished()
|
||||
done.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
assert media_source_id == _MEDIA_ID
|
||||
return ("wav", _empty_wav())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
satellite._tones = Tones(0)
|
||||
|
@ -457,10 +447,11 @@ async def test_tts_timeout(
|
|||
)
|
||||
|
||||
# Proceed with media output
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -474,22 +465,9 @@ async def test_tts_timeout(
|
|||
# Block here to force a timeout in _send_tts
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
# Should time out immediately
|
||||
return ("wav", _empty_wav())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
satellite._tts_extra_timeout = 0.001
|
||||
for tone in Tones:
|
||||
|
@ -568,29 +546,18 @@ async def test_tts_wrong_extension(
|
|||
)
|
||||
|
||||
# Proceed with media output
|
||||
# Should fail because it's not "wav"
|
||||
mock_tts_result_stream = MockResultStream(hass, "mp3", b"")
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
# Should fail because it's not "wav"
|
||||
return ("mp3", b"")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
satellite.transport = Mock()
|
||||
|
||||
|
@ -663,36 +630,18 @@ async def test_tts_wrong_wav_format(
|
|||
)
|
||||
|
||||
# Proceed with media output
|
||||
# Should fail because it's not 16Khz
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050))
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
# Should fail because it's not 16Khz, 16-bit mono
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||
with wav_file:
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(2)
|
||||
|
||||
return ("wav", wav_io.getvalue())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
satellite.transport = Mock()
|
||||
|
||||
|
@ -878,10 +827,11 @@ async def test_announce(
|
|||
assert err.value.translation_domain == "voip"
|
||||
assert err.value.translation_key == "non_tts_announcement"
|
||||
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
tts_token=mock_tts_result_stream.token,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
@ -907,7 +857,9 @@ async def test_announce(
|
|||
async with asyncio.timeout(1):
|
||||
await announce_task
|
||||
|
||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||
mock_send_tts.assert_called_once_with(
|
||||
mock_tts_result_stream, wait_for_tone=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
|
@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
|
|||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
tts_token=mock_tts_result_stream.token,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
|
|||
async with asyncio.timeout(1):
|
||||
await announce_task
|
||||
|
||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||
mock_send_tts.assert_called_once_with(
|
||||
mock_tts_result_stream, wait_for_tone=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
|
@ -979,10 +934,11 @@ async def test_announce_timeout(
|
|||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
tts_token=mock_tts_result_stream.token,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
@ -1020,10 +976,11 @@ async def test_start_conversation(
|
|||
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
)
|
||||
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
tts_token=mock_tts_result_stream.token,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
@ -1061,10 +1018,11 @@ async def test_start_conversation(
|
|||
)
|
||||
|
||||
# Proceed with media output
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue