Use aliases when listing pipeline languages (#99672)
parent
98ff3e233d
commit
e69c88a0d2
|
@ -332,7 +332,7 @@ async def websocket_list_languages(
|
||||||
dialect = language_util.Dialect.parse(language_tag)
|
dialect = language_util.Dialect.parse(language_tag)
|
||||||
languages.add(dialect.language)
|
languages.add(dialect.language)
|
||||||
if pipeline_languages is not None:
|
if pipeline_languages is not None:
|
||||||
pipeline_languages &= languages
|
pipeline_languages = language_util.intersect(pipeline_languages, languages)
|
||||||
else:
|
else:
|
||||||
pipeline_languages = languages
|
pipeline_languages = languages
|
||||||
|
|
||||||
|
@ -342,11 +342,15 @@ async def websocket_list_languages(
|
||||||
dialect = language_util.Dialect.parse(language_tag)
|
dialect = language_util.Dialect.parse(language_tag)
|
||||||
languages.add(dialect.language)
|
languages.add(dialect.language)
|
||||||
if pipeline_languages is not None:
|
if pipeline_languages is not None:
|
||||||
pipeline_languages &= languages
|
pipeline_languages = language_util.intersect(pipeline_languages, languages)
|
||||||
else:
|
else:
|
||||||
pipeline_languages = languages
|
pipeline_languages = languages
|
||||||
|
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
{"languages": pipeline_languages},
|
{
|
||||||
|
"languages": sorted(pipeline_languages)
|
||||||
|
if pipeline_languages
|
||||||
|
else pipeline_languages
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -199,3 +199,14 @@ def matches(
|
||||||
|
|
||||||
# Score < 0 is not a match
|
# Score < 0 is not a match
|
||||||
return [tag for _dialect, score, tag in scored if score[0] >= 0]
|
return [tag for _dialect, score, tag in scored if score[0] >= 0]
|
||||||
|
|
||||||
|
|
||||||
|
def intersect(languages_1: set[str], languages_2: set[str]) -> set[str]:
|
||||||
|
"""Intersect two sets of languages using is_match for aliases."""
|
||||||
|
languages = set()
|
||||||
|
for lang_1 in languages_1:
|
||||||
|
for lang_2 in languages_2:
|
||||||
|
if is_language_match(lang_1, lang_2):
|
||||||
|
languages.add(lang_1)
|
||||||
|
|
||||||
|
return languages
|
||||||
|
|
|
@ -1633,3 +1633,29 @@ async def test_list_pipeline_languages(
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {"languages": ["en"]}
|
assert msg["result"] == {"languages": ["en"]}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_pipeline_languages_with_aliases(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing pipeline languages using aliases."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.conversation.async_get_conversation_languages",
|
||||||
|
return_value={"he", "nb"},
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.stt.async_get_speech_to_text_languages",
|
||||||
|
return_value={"he", "no"},
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.tts.async_get_text_to_speech_languages",
|
||||||
|
return_value={"iw", "nb"},
|
||||||
|
):
|
||||||
|
await client.send_json_auto_id({"type": "assist_pipeline/language/list"})
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {"languages": ["he", "nb"]}
|
||||||
|
|
Loading…
Reference in New Issue