diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index eb4a008a168..2eedfcdcf9b 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -106,6 +106,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): listening_tone_enabled: bool = True, processing_tone_enabled: bool = True, tone_delay: float = 0.2, + tts_extra_timeout: float = 1.0, ) -> None: """Set up pipeline RTP server.""" super().__init__(rate=RATE, width=WIDTH, channels=CHANNELS) @@ -120,6 +121,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): self.listening_tone_enabled = listening_tone_enabled self.processing_tone_enabled = processing_tone_enabled self.tone_delay = tone_delay + self.tts_extra_timeout = tts_extra_timeout self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() self._context = context @@ -219,8 +221,11 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): tts_audio_output="raw", ) - # Block until TTS is done speaking - await self._tts_done.wait() + # Block until TTS is done speaking. + # + # This is set in _send_tts and has a timeout that's based on the + # length of the TTS audio. + await self._tts_done.wait() _LOGGER.debug("Pipeline finished") except asyncio.TimeoutError: @@ -316,10 +321,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes)) - # Assume TTS audio is 16Khz 16-bit mono - await self.hass.async_add_executor_job( - partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS) - ) + # Time out 1 second after TTS audio should be finished + tts_samples = len(audio_bytes) / (WIDTH * CHANNELS) + tts_seconds = tts_samples / RATE + + async with async_timeout.timeout(tts_seconds + self.tts_extra_timeout): + # Assume TTS audio is 16Khz 16-bit mono + await self.hass.async_add_executor_job( + partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS) + ) + except asyncio.TimeoutError as err: + _LOGGER.warning("TTS timeout") + raise err finally: # Signal pipeline to restart self._tts_done.set() diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index c26d9a7a294..19b9806e41e 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -1,6 +1,7 @@ """Test VoIP protocol.""" import asyncio -from unittest.mock import Mock, patch +import time +from unittest.mock import AsyncMock, Mock, patch import async_timeout @@ -191,3 +192,96 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) # Wait for mock pipeline to time out async with async_timeout.timeout(1): await done.wait() + + +async def test_tts_timeout( + hass: HomeAssistant, + voip_device: VoIPDevice, +) -> None: + """Test that TTS will time out based on its length.""" + assert await async_setup_component(hass, "voip", {}) + + def is_speech(self, chunk, sample_rate): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, **kwargs): + stt_stream = kwargs["stt_stream"] + event_callback = kwargs["event_callback"] + async for _chunk in stt_stream: + # Stream will end when VAD detects end of "speech" + pass + + # Fake intent result + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.INTENT_END, + data={ + "intent_output": { + "conversation_id": "fake-conversation", + } + }, + ) + ) + + # Proceed with media output + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.TTS_END, + data={"tts_output": {"media_id": _MEDIA_ID}}, + ) + ) + + def send_audio(*args, **kwargs): + # Block here to force a timeout in _send_tts + time.sleep(1) + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + # Should time out immediately + return ("raw", bytes(0)) + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ), patch( + "homeassistant.components.voip.voip.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), patch( + "homeassistant.components.voip.voip.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ): + rtp_protocol = voip.voip.PipelineRtpDatagramProtocol( + hass, + hass.config.language, + voip_device, + Context(), + listening_tone_enabled=False, + processing_tone_enabled=False, + ) + rtp_protocol.transport = Mock() + rtp_protocol.send_audio = Mock(side_effect=send_audio) + + async def send_tts(*args, **kwargs): + # Call original then end test successfully + rtp_protocol._send_tts(*args, **kwargs) + done.set() + + rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) + + # silence + rtp_protocol.on_chunk(bytes(_ONE_SECOND)) + + # "speech" + rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2)) + + # silence + rtp_protocol.on_chunk(bytes(_ONE_SECOND)) + + # Wait for mock pipeline to exhaust the audio stream + async with async_timeout.timeout(1): + await done.wait()