Skip TTS when text is empty (#104741)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/105135/head
parent
34c65749e2
commit
90bcad31b5
|
@ -1024,39 +1024,38 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Synthesize audio and get URL
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_engine,
|
||||
language=self.pipeline.tts_language,
|
||||
options=self.tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_media_id,
|
||||
None,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
||||
raise TextToSpeechError(
|
||||
code="tts-failed",
|
||||
message="Unexpected error during text-to-speech",
|
||||
) from src_error
|
||||
if tts_input := tts_input.strip():
|
||||
try:
|
||||
# Synthesize audio and get URL
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_engine,
|
||||
language=self.pipeline.tts_language,
|
||||
options=self.tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_media_id,
|
||||
None,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
||||
raise TextToSpeechError(
|
||||
code="tts-failed",
|
||||
message="Unexpected error during text-to-speech",
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("TTS result %s", tts_media)
|
||||
_LOGGER.debug("TTS result %s", tts_media)
|
||||
tts_output = {
|
||||
"media_id": tts_media_id,
|
||||
**asdict(tts_media),
|
||||
}
|
||||
else:
|
||||
tts_output = {}
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_END,
|
||||
{
|
||||
"tts_output": {
|
||||
"media_id": tts_media_id,
|
||||
**asdict(tts_media),
|
||||
}
|
||||
},
|
||||
)
|
||||
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
||||
)
|
||||
|
||||
return tts_media.url
|
||||
|
|
|
@ -186,16 +186,22 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
data_to_send = {"text": event.data["tts_input"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert event.data is not None
|
||||
path = event.data["tts_output"]["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
path = tts_output["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
|
||||
if self.device_info.voice_assistant_version >= 2:
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
self._tts_task = self.hass.async_create_background_task(
|
||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||
)
|
||||
if self.device_info.voice_assistant_version >= 2:
|
||||
media_id = tts_output["media_id"]
|
||||
self._tts_task = self.hass.async_create_background_task(
|
||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||
)
|
||||
else:
|
||||
self._tts_done.set()
|
||||
else:
|
||||
# Empty TTS response
|
||||
data_to_send = {}
|
||||
self._tts_done.set()
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||
assert event.data is not None
|
||||
|
|
|
@ -389,11 +389,16 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS
|
||||
self._pipeline_error = True
|
||||
|
|
|
@ -650,6 +650,33 @@
|
|||
'message': 'Timeout running pipeline',
|
||||
})
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': '',
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.2
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.3
|
||||
None
|
||||
# ---
|
||||
# name: test_stt_provider_missing
|
||||
dict({
|
||||
'language': 'en',
|
||||
|
|
|
@ -2452,3 +2452,54 @@ async def test_device_capture_queue_full(
|
|||
assert msg["event"] == snapshot
|
||||
assert msg["event"]["type"] == "end"
|
||||
assert msg["event"]["overflow"]
|
||||
|
||||
|
||||
async def test_pipeline_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a empty text-to-speech text."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# text-to-speech
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
assert not msg["event"]["data"]["tts_output"]
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
|
|
@ -337,6 +337,28 @@ async def test_send_tts_called(
|
|||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
||||
|
||||
|
||||
async def test_send_tts_not_called_when_empty(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
|
||||
) as mock_send_tts:
|
||||
voice_assistant_udp_server_v1._event_callback(
|
||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||
)
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_send_tts(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
|
|
|
@ -528,3 +528,75 @@ async def test_tts_wrong_wav_format(
|
|||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will not stream when output is empty."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def is_speech(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.INTENT_END,
|
||||
data={
|
||||
"intent_output": {
|
||||
"conversation_id": "fake-conversation",
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Empty TTS output
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {}},
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||
) as mock_send_tts:
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to finish
|
||||
async with asyncio.timeout(1):
|
||||
await rtp_protocol._tts_done.wait()
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue