Time out TTS based on audio length (#92032)

* Time out TTS based on audio length

* Use async mock
pull/91787/head^2
Michael Hansen 2023-04-25 23:35:14 -05:00 committed by GitHub
parent 257944c3b7
commit 8dfecac013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 7 deletions

View File

@ -106,6 +106,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
listening_tone_enabled: bool = True, listening_tone_enabled: bool = True,
processing_tone_enabled: bool = True, processing_tone_enabled: bool = True,
tone_delay: float = 0.2, tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
) -> None: ) -> None:
"""Set up pipeline RTP server.""" """Set up pipeline RTP server."""
super().__init__(rate=RATE, width=WIDTH, channels=CHANNELS) super().__init__(rate=RATE, width=WIDTH, channels=CHANNELS)
@ -120,6 +121,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.listening_tone_enabled = listening_tone_enabled self.listening_tone_enabled = listening_tone_enabled
self.processing_tone_enabled = processing_tone_enabled self.processing_tone_enabled = processing_tone_enabled
self.tone_delay = tone_delay self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context self._context = context
@ -219,8 +221,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
tts_audio_output="raw", tts_audio_output="raw",
) )
# Block until TTS is done speaking # Block until TTS is done speaking.
await self._tts_done.wait() #
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished") _LOGGER.debug("Pipeline finished")
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -316,10 +321,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Assume TTS audio is 16Khz 16-bit mono # Time out 1 second after TTS audio should be finished
await self.hass.async_add_executor_job( tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS) tts_seconds = tts_samples / RATE
)
async with async_timeout.timeout(tts_seconds + self.tts_extra_timeout):
# Assume TTS audio is 16Khz 16-bit mono
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS)
)
except asyncio.TimeoutError as err:
_LOGGER.warning("TTS timeout")
raise err
finally: finally:
# Signal pipeline to restart # Signal pipeline to restart
self._tts_done.set() self._tts_done.set()

View File

@ -1,6 +1,7 @@
"""Test VoIP protocol.""" """Test VoIP protocol."""
import asyncio import asyncio
from unittest.mock import Mock, patch import time
from unittest.mock import AsyncMock, Mock, patch
import async_timeout import async_timeout
@ -191,3 +192,96 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
# Wait for mock pipeline to time out # Wait for mock pipeline to time out
async with async_timeout.timeout(1): async with async_timeout.timeout(1):
await done.wait() await done.wait()
async def test_tts_timeout(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk, sample_rate):
"""Anything non-zero is speech."""
return sum(chunk) > 0
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
)
)
def send_audio(*args, **kwargs):
# Block here to force a timeout in _send_tts
time.sleep(1)
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should time out immediately
return ("raw", bytes(0))
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
listening_tone_enabled=False,
processing_tone_enabled=False,
)
rtp_protocol.transport = Mock()
rtp_protocol.send_audio = Mock(side_effect=send_audio)
async def send_tts(*args, **kwargs):
# Call original then end test successfully
rtp_protocol._send_tts(*args, **kwargs)
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts)
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with async_timeout.timeout(1):
await done.wait()