Assist Pipeline minor cleanup (#121187)
parent
2b9bddc3fc
commit
22718ca32a
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||||
import array
|
import array
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
from collections.abc import AsyncGenerator, AsyncIterable, Callable
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
|
@ -118,8 +118,10 @@ AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
|
||||||
@callback
|
@callback
|
||||||
def _async_resolve_default_pipeline_settings(
|
def _async_resolve_default_pipeline_settings(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
stt_engine_id: str | None,
|
*,
|
||||||
tts_engine_id: str | None,
|
conversation_engine_id: str | None = None,
|
||||||
|
stt_engine_id: str | None = None,
|
||||||
|
tts_engine_id: str | None = None,
|
||||||
pipeline_name: str,
|
pipeline_name: str,
|
||||||
) -> dict[str, str | None]:
|
) -> dict[str, str | None]:
|
||||||
"""Resolve settings for a default pipeline.
|
"""Resolve settings for a default pipeline.
|
||||||
|
@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings(
|
||||||
wake_word_entity = None
|
wake_word_entity = None
|
||||||
wake_word_id = None
|
wake_word_id = None
|
||||||
|
|
||||||
|
if conversation_engine_id is None:
|
||||||
|
conversation_engine_id = conversation.HOME_ASSISTANT_AGENT
|
||||||
|
|
||||||
# Find a matching language supported by the Home Assistant conversation agent
|
# Find a matching language supported by the Home Assistant conversation agent
|
||||||
conversation_languages = language_util.matches(
|
conversation_languages = language_util.matches(
|
||||||
hass.config.language,
|
hass.config.language,
|
||||||
conversation.async_get_conversation_languages(
|
conversation.async_get_conversation_languages(hass, conversation_engine_id),
|
||||||
hass, conversation.HOME_ASSISTANT_AGENT
|
|
||||||
),
|
|
||||||
country=hass.config.country,
|
country=hass.config.country,
|
||||||
)
|
)
|
||||||
if conversation_languages:
|
if conversation_languages:
|
||||||
|
@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings(
|
||||||
tts_engine_id = None
|
tts_engine_id = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
"conversation_engine": conversation_engine_id,
|
||||||
"conversation_language": conversation_language,
|
"conversation_language": conversation_language,
|
||||||
"language": hass.config.language,
|
"language": hass.config.language,
|
||||||
"name": pipeline_name,
|
"name": pipeline_name,
|
||||||
|
@ -224,7 +227,7 @@ async def _async_create_default_pipeline(
|
||||||
default stt / tts engines.
|
default stt / tts engines.
|
||||||
"""
|
"""
|
||||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||||
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
|
hass, pipeline_name="Home Assistant"
|
||||||
)
|
)
|
||||||
return await pipeline_store.async_create_item(pipeline_settings)
|
return await pipeline_store.async_create_item(pipeline_settings)
|
||||||
|
|
||||||
|
@ -243,7 +246,10 @@ async def async_create_default_pipeline(
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||||
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
|
hass,
|
||||||
|
stt_engine_id=stt_engine_id,
|
||||||
|
tts_engine_id=tts_engine_id,
|
||||||
|
pipeline_name=pipeline_name,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
pipeline_settings["stt_engine"] != stt_engine_id
|
pipeline_settings["stt_engine"] != stt_engine_id
|
||||||
|
@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
|
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
|
||||||
"""Get all pipelines."""
|
"""Get all pipelines."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
|
||||||
return pipeline_data.pipeline_store.data.values()
|
return list(pipeline_data.pipeline_store.data.values())
|
||||||
|
|
||||||
|
|
||||||
async def async_update_pipeline(
|
async def async_update_pipeline(
|
||||||
|
@ -1675,7 +1681,7 @@ class PipelineStorageCollectionWebsocket(
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
{
|
{
|
||||||
"pipelines": self.storage_collection.async_items(),
|
"pipelines": async_get_pipelines(hass),
|
||||||
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
|
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue