Assist Pipeline minor cleanup (#121187)
parent
2b9bddc3fc
commit
22718ca32a
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
import array
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
|
@ -118,8 +118,10 @@ AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
|
|||
@callback
|
||||
def _async_resolve_default_pipeline_settings(
|
||||
hass: HomeAssistant,
|
||||
stt_engine_id: str | None,
|
||||
tts_engine_id: str | None,
|
||||
*,
|
||||
conversation_engine_id: str | None = None,
|
||||
stt_engine_id: str | None = None,
|
||||
tts_engine_id: str | None = None,
|
||||
pipeline_name: str,
|
||||
) -> dict[str, str | None]:
|
||||
"""Resolve settings for a default pipeline.
|
||||
|
@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings(
|
|||
wake_word_entity = None
|
||||
wake_word_id = None
|
||||
|
||||
if conversation_engine_id is None:
|
||||
conversation_engine_id = conversation.HOME_ASSISTANT_AGENT
|
||||
|
||||
# Find a matching language supported by the Home Assistant conversation agent
|
||||
conversation_languages = language_util.matches(
|
||||
hass.config.language,
|
||||
conversation.async_get_conversation_languages(
|
||||
hass, conversation.HOME_ASSISTANT_AGENT
|
||||
),
|
||||
conversation.async_get_conversation_languages(hass, conversation_engine_id),
|
||||
country=hass.config.country,
|
||||
)
|
||||
if conversation_languages:
|
||||
|
@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings(
|
|||
tts_engine_id = None
|
||||
|
||||
return {
|
||||
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
||||
"conversation_engine": conversation_engine_id,
|
||||
"conversation_language": conversation_language,
|
||||
"language": hass.config.language,
|
||||
"name": pipeline_name,
|
||||
|
@ -224,7 +227,7 @@ async def _async_create_default_pipeline(
|
|||
default stt / tts engines.
|
||||
"""
|
||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
|
||||
hass, pipeline_name="Home Assistant"
|
||||
)
|
||||
return await pipeline_store.async_create_item(pipeline_settings)
|
||||
|
||||
|
@ -243,7 +246,10 @@ async def async_create_default_pipeline(
|
|||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
|
||||
hass,
|
||||
stt_engine_id=stt_engine_id,
|
||||
tts_engine_id=tts_engine_id,
|
||||
pipeline_name=pipeline_name,
|
||||
)
|
||||
if (
|
||||
pipeline_settings["stt_engine"] != stt_engine_id
|
||||
|
@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
|||
|
||||
|
||||
@callback
|
||||
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
|
||||
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
|
||||
"""Get all pipelines."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
||||
return pipeline_data.pipeline_store.data.values()
|
||||
return list(pipeline_data.pipeline_store.data.values())
|
||||
|
||||
|
||||
async def async_update_pipeline(
|
||||
|
@ -1675,7 +1681,7 @@ class PipelineStorageCollectionWebsocket(
|
|||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"pipelines": self.storage_collection.async_items(),
|
||||
"pipelines": async_get_pipelines(hass),
|
||||
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue