Use STT/TTS languages for LLM fallback (#135533)
parent
3e9b410b7c
commit
b897e6a85f
|
@ -1021,9 +1021,18 @@ class PipelineRun:
|
||||||
raise RuntimeError("Recognize intent was not prepared")
|
raise RuntimeError("Recognize intent was not prepared")
|
||||||
|
|
||||||
if self.pipeline.conversation_language == MATCH_ALL:
|
if self.pipeline.conversation_language == MATCH_ALL:
|
||||||
# LLMs support all languages ('*') so use pipeline language for
|
# LLMs support all languages ('*') so use languages from the
|
||||||
# intent fallback.
|
# pipeline for intent fallback.
|
||||||
input_language = self.pipeline.language
|
#
|
||||||
|
# We prioritize the STT and TTS languages because they may be more
|
||||||
|
# specific, such as "zh-CN" instead of just "zh". This is necessary
|
||||||
|
# for languages whose intents are split out by region when
|
||||||
|
# preferring local intent matching.
|
||||||
|
input_language = (
|
||||||
|
self.pipeline.stt_language
|
||||||
|
or self.pipeline.tts_language
|
||||||
|
or self.pipeline.language
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_language = self.pipeline.conversation_language
|
input_language = self.pipeline.conversation_language
|
||||||
|
|
||||||
|
|
|
@ -474,6 +474,108 @@
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_stt_language_used_instead_of_conversation_language
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'test input',
|
||||||
|
'language': 'en-US',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': True,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_tts_language_used_instead_of_conversation_language
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'test input',
|
||||||
|
'language': 'en-us',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': True,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_wake_word_detection_aborted
|
# name: test_wake_word_detection_aborted
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
|
|
|
@ -1102,13 +1102,13 @@ async def test_prefer_local_intents(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_language_used_instead_of_conversation_language(
|
async def test_stt_language_used_instead_of_conversation_language(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline language is used when the conversation language is '*' (all languages)."""
|
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
@ -1165,7 +1165,155 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||||
|
|
||||||
assert intent_start is not None
|
assert intent_start is not None
|
||||||
|
|
||||||
# Pipeline language (en) should be used instead of '*'
|
# STT language (en-US) should be used instead of '*'
|
||||||
|
assert intent_start.data.get("language") == pipeline.stt_language
|
||||||
|
|
||||||
|
# Check input to async_converse
|
||||||
|
mock_async_converse.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
|
== pipeline.stt_language
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_language_used_instead_of_conversation_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": MATCH_ALL,
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": "en-us",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test input",
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
intent.IntentResponse(pipeline.language)
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Check intent start event
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||||
|
intent_start = event
|
||||||
|
break
|
||||||
|
|
||||||
|
assert intent_start is not None
|
||||||
|
|
||||||
|
# STT language (en-US) should be used instead of '*'
|
||||||
|
assert intent_start.data.get("language") == pipeline.tts_language
|
||||||
|
|
||||||
|
# Check input to async_converse
|
||||||
|
mock_async_converse.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
|
== pipeline.tts_language
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_language_used_instead_of_conversation_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": MATCH_ALL,
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test input",
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
intent.IntentResponse(pipeline.language)
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Check intent start event
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||||
|
intent_start = event
|
||||||
|
break
|
||||||
|
|
||||||
|
assert intent_start is not None
|
||||||
|
|
||||||
|
# STT language (en-US) should be used instead of '*'
|
||||||
assert intent_start.data.get("language") == pipeline.language
|
assert intent_start.data.get("language") == pipeline.language
|
||||||
|
|
||||||
# Check input to async_converse
|
# Check input to async_converse
|
||||||
|
|
Loading…
Reference in New Issue