Assist Pipeline minor cleanup (#121187)

pull/121283/head
Paulus Schoutsen 2024-07-05 09:26:32 +02:00 committed by GitHub
parent 2b9bddc3fc
commit 22718ca32a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 12 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import array import array
import asyncio import asyncio
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable from collections.abc import AsyncGenerator, AsyncIterable, Callable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import StrEnum from enum import StrEnum
import logging import logging
@ -118,8 +118,10 @@ AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
@callback @callback
def _async_resolve_default_pipeline_settings( def _async_resolve_default_pipeline_settings(
hass: HomeAssistant, hass: HomeAssistant,
stt_engine_id: str | None, *,
tts_engine_id: str | None, conversation_engine_id: str | None = None,
stt_engine_id: str | None = None,
tts_engine_id: str | None = None,
pipeline_name: str, pipeline_name: str,
) -> dict[str, str | None]: ) -> dict[str, str | None]:
"""Resolve settings for a default pipeline. """Resolve settings for a default pipeline.
@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings(
wake_word_entity = None wake_word_entity = None
wake_word_id = None wake_word_id = None
if conversation_engine_id is None:
conversation_engine_id = conversation.HOME_ASSISTANT_AGENT
# Find a matching language supported by the Home Assistant conversation agent # Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches( conversation_languages = language_util.matches(
hass.config.language, hass.config.language,
conversation.async_get_conversation_languages( conversation.async_get_conversation_languages(hass, conversation_engine_id),
hass, conversation.HOME_ASSISTANT_AGENT
),
country=hass.config.country, country=hass.config.country,
) )
if conversation_languages: if conversation_languages:
@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings(
tts_engine_id = None tts_engine_id = None
return { return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT, "conversation_engine": conversation_engine_id,
"conversation_language": conversation_language, "conversation_language": conversation_language,
"language": hass.config.language, "language": hass.config.language,
"name": pipeline_name, "name": pipeline_name,
@ -224,7 +227,7 @@ async def _async_create_default_pipeline(
default stt / tts engines. default stt / tts engines.
""" """
pipeline_settings = _async_resolve_default_pipeline_settings( pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant" hass, pipeline_name="Home Assistant"
) )
return await pipeline_store.async_create_item(pipeline_settings) return await pipeline_store.async_create_item(pipeline_settings)
@ -243,7 +246,10 @@ async def async_create_default_pipeline(
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store pipeline_store = pipeline_data.pipeline_store
pipeline_settings = _async_resolve_default_pipeline_settings( pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name hass,
stt_engine_id=stt_engine_id,
tts_engine_id=tts_engine_id,
pipeline_name=pipeline_name,
) )
if ( if (
pipeline_settings["stt_engine"] != stt_engine_id pipeline_settings["stt_engine"] != stt_engine_id
@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
@callback @callback
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]: def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
"""Get all pipelines.""" """Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
return pipeline_data.pipeline_store.data.values() return list(pipeline_data.pipeline_store.data.values())
async def async_update_pipeline( async def async_update_pipeline(
@ -1675,7 +1681,7 @@ class PipelineStorageCollectionWebsocket(
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {
"pipelines": self.storage_collection.async_items(), "pipelines": async_get_pipelines(hass),
"preferred_pipeline": self.storage_collection.async_get_preferred_item(), "preferred_pipeline": self.storage_collection.async_get_preferred_item(),
}, },
) )