Time out TTS based on audio length (#92032)
* Time out TTS based on audio length * Use async mockpull/91787/head^2
parent
257944c3b7
commit
8dfecac013
|
@ -106,6 +106,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
listening_tone_enabled: bool = True,
|
listening_tone_enabled: bool = True,
|
||||||
processing_tone_enabled: bool = True,
|
processing_tone_enabled: bool = True,
|
||||||
tone_delay: float = 0.2,
|
tone_delay: float = 0.2,
|
||||||
|
tts_extra_timeout: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up pipeline RTP server."""
|
"""Set up pipeline RTP server."""
|
||||||
super().__init__(rate=RATE, width=WIDTH, channels=CHANNELS)
|
super().__init__(rate=RATE, width=WIDTH, channels=CHANNELS)
|
||||||
|
@ -120,6 +121,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
self.listening_tone_enabled = listening_tone_enabled
|
self.listening_tone_enabled = listening_tone_enabled
|
||||||
self.processing_tone_enabled = processing_tone_enabled
|
self.processing_tone_enabled = processing_tone_enabled
|
||||||
self.tone_delay = tone_delay
|
self.tone_delay = tone_delay
|
||||||
|
self.tts_extra_timeout = tts_extra_timeout
|
||||||
|
|
||||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||||
self._context = context
|
self._context = context
|
||||||
|
@ -219,8 +221,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
tts_audio_output="raw",
|
tts_audio_output="raw",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Block until TTS is done speaking
|
# Block until TTS is done speaking.
|
||||||
await self._tts_done.wait()
|
#
|
||||||
|
# This is set in _send_tts and has a timeout that's based on the
|
||||||
|
# length of the TTS audio.
|
||||||
|
await self._tts_done.wait()
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
_LOGGER.debug("Pipeline finished")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
@ -316,10 +321,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
|
|
||||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||||
|
|
||||||
# Assume TTS audio is 16Khz 16-bit mono
|
# Time out 1 second after TTS audio should be finished
|
||||||
await self.hass.async_add_executor_job(
|
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS)
|
tts_seconds = tts_samples / RATE
|
||||||
)
|
|
||||||
|
async with async_timeout.timeout(tts_seconds + self.tts_extra_timeout):
|
||||||
|
# Assume TTS audio is 16Khz 16-bit mono
|
||||||
|
await self.hass.async_add_executor_job(
|
||||||
|
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS)
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as err:
|
||||||
|
_LOGGER.warning("TTS timeout")
|
||||||
|
raise err
|
||||||
finally:
|
finally:
|
||||||
# Signal pipeline to restart
|
# Signal pipeline to restart
|
||||||
self._tts_done.set()
|
self._tts_done.set()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test VoIP protocol."""
|
"""Test VoIP protocol."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import Mock, patch
|
import time
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
|
|
||||||
|
@ -191,3 +192,96 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
||||||
# Wait for mock pipeline to time out
|
# Wait for mock pipeline to time out
|
||||||
async with async_timeout.timeout(1):
|
async with async_timeout.timeout(1):
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_timeout(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test that TTS will time out based on its length."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
def is_speech(self, chunk, sample_rate):
|
||||||
|
"""Anything non-zero is speech."""
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
done = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
async for _chunk in stt_stream:
|
||||||
|
# Stream will end when VAD detects end of "speech"
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fake intent result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
data={
|
||||||
|
"intent_output": {
|
||||||
|
"conversation_id": "fake-conversation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Proceed with media output
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_audio(*args, **kwargs):
|
||||||
|
# Block here to force a timeout in _send_tts
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
# Should time out immediately
|
||||||
|
return ("raw", bytes(0))
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"webrtcvad.Vad.is_speech",
|
||||||
|
new=is_speech,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
):
|
||||||
|
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||||
|
hass,
|
||||||
|
hass.config.language,
|
||||||
|
voip_device,
|
||||||
|
Context(),
|
||||||
|
listening_tone_enabled=False,
|
||||||
|
processing_tone_enabled=False,
|
||||||
|
)
|
||||||
|
rtp_protocol.transport = Mock()
|
||||||
|
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||||
|
|
||||||
|
async def send_tts(*args, **kwargs):
|
||||||
|
# Call original then end test successfully
|
||||||
|
rtp_protocol._send_tts(*args, **kwargs)
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts)
|
||||||
|
|
||||||
|
# silence
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# "speech"
|
||||||
|
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
|
# silence
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
|
async with async_timeout.timeout(1):
|
||||||
|
await done.wait()
|
||||||
|
|
Loading…
Reference in New Issue