Skip TTS when text is empty (#104741)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/105135/head
Michael Hansen 2023-11-29 18:31:27 -06:00 committed by Franck Nijhof
parent 34c65749e2
commit 90bcad31b5
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
7 changed files with 225 additions and 43 deletions

View File

@ -1024,39 +1024,38 @@ class PipelineRun:
)
)
try:
# Synthesize audio and get URL
tts_media_id = tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.tts_engine,
language=self.pipeline.tts_language,
options=self.tts_options,
)
tts_media = await media_source.async_resolve_media(
self.hass,
tts_media_id,
None,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during text-to-speech")
raise TextToSpeechError(
code="tts-failed",
message="Unexpected error during text-to-speech",
) from src_error
if tts_input := tts_input.strip():
try:
# Synthesize audio and get URL
tts_media_id = tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.tts_engine,
language=self.pipeline.tts_language,
options=self.tts_options,
)
tts_media = await media_source.async_resolve_media(
self.hass,
tts_media_id,
None,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during text-to-speech")
raise TextToSpeechError(
code="tts-failed",
message="Unexpected error during text-to-speech",
) from src_error
_LOGGER.debug("TTS result %s", tts_media)
_LOGGER.debug("TTS result %s", tts_media)
tts_output = {
"media_id": tts_media_id,
**asdict(tts_media),
}
else:
tts_output = {}
self.process_event(
PipelineEvent(
PipelineEventType.TTS_END,
{
"tts_output": {
"media_id": tts_media_id,
**asdict(tts_media),
}
},
)
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
)
return tts_media.url

View File

@ -186,16 +186,22 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
path = event.data["tts_output"]["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
tts_output = event.data["tts_output"]
if tts_output:
path = tts_output["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"]
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
if self.device_info.voice_assistant_version >= 2:
media_id = tts_output["media_id"]
self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
self._tts_done.set()
else:
# Empty TTS response
data_to_send = {}
self._tts_done.set()
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None

View File

@ -389,11 +389,16 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
tts_output = event.data["tts_output"]
if tts_output:
media_id = tts_output["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS
self._pipeline_error = True

View File

@ -650,6 +650,33 @@
'message': 'Timeout running pipeline',
})
# ---
# name: test_pipeline_empty_tts_output
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 300,
}),
})
# ---
# name: test_pipeline_empty_tts_output.1
dict({
'engine': 'test',
'language': 'en-US',
'tts_input': '',
'voice': 'james_earl_jones',
})
# ---
# name: test_pipeline_empty_tts_output.2
dict({
'tts_output': dict({
}),
})
# ---
# name: test_pipeline_empty_tts_output.3
None
# ---
# name: test_stt_provider_missing
dict({
'language': 'en',

View File

@ -2452,3 +2452,54 @@ async def test_device_capture_queue_full(
assert msg["event"] == snapshot
assert msg["event"]["type"] == "end"
assert msg["event"]["overflow"]
async def test_pipeline_empty_tts_output(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a empty text-to-speech text."""
events = []
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "tts",
"end_stage": "tts",
"input": {
"text": "",
},
}
)
# result
msg = await client.receive_json()
assert msg["success"]
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot
assert not msg["event"]["data"]["tts_output"]
events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])

View File

@ -337,6 +337,28 @@ async def test_send_tts_called(
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
async def test_send_tts_not_called_when_empty(
hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
) as mock_send_tts:
voice_assistant_udp_server_v1._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
async def test_send_tts(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,

View File

@ -528,3 +528,75 @@ async def test_tts_wrong_wav_format(
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_empty_tts_output(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)
# Empty TTS output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {}},
)
)
with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
) as mock_send_tts:
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to finish
async with asyncio.timeout(1):
await rtp_protocol._tts_done.wait()
mock_send_tts.assert_not_called()