core/homeassistant/components/assist_pipeline/pipeline.py

996 lines
32 KiB
Python

"""Classes for voice assistant pipelines."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field
import logging
from typing import Any, cast
import voluptuous as vol
from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.collection import (
CollectionError,
ItemNotFound,
SerializedStorageCollection,
StorageCollection,
StorageCollectionWebsocket,
)
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.util import (
dt as dt_util,
language as language_util,
ulid as ulid_util,
)
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DOMAIN
from .error import (
IntentRecognitionError,
PipelineError,
PipelineNotFound,
SpeechToTextError,
TextToSpeechError,
)
_LOGGER = logging.getLogger(__name__)
STORAGE_KEY = f"{DOMAIN}.pipelines"
STORAGE_VERSION = 1
ENGINE_LANGUAGE_PAIRS = (
("stt_engine", "stt_language"),
("tts_engine", "tts_language"),
)
def validate_language(data: dict[str, Any]) -> Any:
"""Validate language settings."""
for engine, language in ENGINE_LANGUAGE_PAIRS:
if data[engine] is not None and data[language] is None:
raise vol.Invalid(f"Need language {language} for {engine} {data[engine]}")
return data
PIPELINE_FIELDS = {
vol.Required("conversation_engine"): str,
vol.Required("conversation_language"): str,
vol.Required("language"): str,
vol.Required("name"): str,
vol.Required("stt_engine"): vol.Any(str, None),
vol.Required("stt_language"): vol.Any(str, None),
vol.Required("tts_engine"): vol.Any(str, None),
vol.Required("tts_language"): vol.Any(str, None),
vol.Required("tts_voice"): vol.Any(str, None),
}
STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
stt_engine_id: str | None,
tts_engine_id: str | None,
) -> dict[str, str | None]:
"""Resolve settings for a default pipeline.
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines if none are specified.
"""
conversation_language = "en"
pipeline_language = "en"
pipeline_name = "Home Assistant"
stt_engine = None
stt_language = None
tts_engine = None
tts_language = None
tts_voice = None
# Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches(
hass.config.language,
await conversation.async_get_conversation_languages(
hass, conversation.HOME_ASSISTANT_AGENT
),
country=hass.config.country,
)
if conversation_languages:
pipeline_language = hass.config.language
conversation_language = conversation_languages[0]
if stt_engine_id is None:
stt_engine_id = stt.async_default_engine(hass)
if stt_engine_id is not None:
stt_engine = stt.async_get_speech_to_text_engine(hass, stt_engine_id)
if stt_engine is None:
stt_engine_id = None
if stt_engine:
stt_languages = language_util.matches(
pipeline_language,
stt_engine.supported_languages,
country=hass.config.country,
)
if stt_languages:
stt_language = stt_languages[0]
else:
_LOGGER.debug(
"Speech-to-text engine '%s' does not support language '%s'",
stt_engine_id,
pipeline_language,
)
stt_engine_id = None
if tts_engine_id is None:
tts_engine_id = tts.async_default_engine(hass)
if tts_engine_id is not None:
tts_engine = tts.get_engine_instance(hass, tts_engine_id)
if tts_engine is None:
tts_engine_id = None
if tts_engine:
tts_languages = language_util.matches(
pipeline_language,
tts_engine.supported_languages,
country=hass.config.country,
)
if tts_languages:
tts_language = tts_languages[0]
tts_voices = tts_engine.async_get_supported_voices(tts_language)
if tts_voices:
tts_voice = tts_voices[0].voice_id
else:
_LOGGER.debug(
"Text-to-speech engine '%s' does not support language '%s'",
tts_engine_id,
pipeline_language,
)
tts_engine_id = None
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud"
return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": conversation_language,
"language": hass.config.language,
"name": pipeline_name,
"stt_engine": stt_engine_id,
"stt_language": stt_language,
"tts_engine": tts_engine_id,
"tts_language": tts_language,
"tts_voice": tts_voice,
}
async def _async_create_default_pipeline(
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
) -> Pipeline:
"""Create a default pipeline.
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines.
"""
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None)
return await pipeline_store.async_create_item(pipeline_settings)
async def async_create_default_pipeline(
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str
) -> Pipeline | None:
"""Create a pipeline with default settings.
The default pipeline will use the homeassistant conversation agent and the
specified stt / tts engines.
"""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id
)
if (
pipeline_settings["stt_engine"] != stt_engine_id
or pipeline_settings["tts_engine"] != tts_engine_id
):
return None
return await pipeline_store.async_create_item(pipeline_settings)
@callback
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
if pipeline_id is None:
# A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
# If invalid pipeline ID was specified
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
return pipeline
@callback
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]
return pipeline_data.pipeline_store.data.values()
class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run."""
RUN_START = "run-start"
RUN_END = "run-end"
STT_START = "stt-start"
STT_END = "stt-end"
INTENT_START = "intent-start"
INTENT_END = "intent-end"
TTS_START = "tts-start"
TTS_END = "tts-end"
ERROR = "error"
@dataclass(frozen=True)
class PipelineEvent:
"""Events emitted during a pipeline run."""
type: PipelineEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
PipelineEventCallback = Callable[[PipelineEvent], None]
@dataclass(frozen=True)
class Pipeline:
"""A voice assistant pipeline."""
conversation_engine: str
conversation_language: str
language: str
name: str
stt_engine: str | None
stt_language: str | None
tts_engine: str | None
tts_language: str | None
tts_voice: str | None
id: str = field(default_factory=ulid_util.ulid)
def to_json(self) -> dict[str, Any]:
"""Return a JSON serializable representation for storage."""
return {
"conversation_engine": self.conversation_engine,
"conversation_language": self.conversation_language,
"id": self.id,
"language": self.language,
"name": self.name,
"stt_engine": self.stt_engine,
"stt_language": self.stt_language,
"tts_engine": self.tts_engine,
"tts_language": self.tts_language,
"tts_voice": self.tts_voice,
}
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
STT = "stt"
INTENT = "intent"
TTS = "tts"
PIPELINE_STAGE_ORDER = [
PipelineStage.STT,
PipelineStage.INTENT,
PipelineStage.TTS,
]
class PipelineRunValidationError(Exception):
"""Error when a pipeline run is not valid."""
class InvalidPipelineStagesError(PipelineRunValidationError):
"""Error when given an invalid combination of start/end stages."""
def __init__(
self,
start_stage: PipelineStage,
end_stage: PipelineStage,
) -> None:
"""Set error message."""
super().__init__(
f"Invalid stage combination: start={start_stage}, end={end_stage}"
)
@dataclass
class PipelineRun:
"""Running context for a pipeline."""
hass: HomeAssistant
context: Context
pipeline: Pipeline
start_stage: PipelineStage
end_stage: PipelineStage
event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
intent_agent: str | None = None
tts_audio_output: str | None = None
id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None)
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
# stt -> intent -> tts
if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
self.start_stage
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.pipeline.id not in pipeline_data.pipeline_runs:
pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_RUNS
)
pipeline_data.pipeline_runs[self.pipeline.id][self.id] = PipelineRunDebug()
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
self.event_callback(event)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
# This run has been evicted from the logged pipeline runs already
return
pipeline_data.pipeline_runs[self.pipeline.id][self.id].events.append(event)
def start(self) -> None:
"""Emit run start event."""
data = {
"pipeline": self.pipeline.id,
"language": self.language,
}
if self.runner_data is not None:
data["runner_data"] = self.runner_data
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
def end(self) -> None:
"""Emit run end event."""
self.process_event(
PipelineEvent(
PipelineEventType.RUN_END,
)
)
async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
"""Prepare speech-to-text."""
# pipeline.stt_engine can't be None or this function is not called
stt_provider = stt.async_get_speech_to_text_engine(
self.hass,
self.pipeline.stt_engine, # type: ignore[arg-type]
)
if stt_provider is None:
engine = self.pipeline.stt_engine
raise SpeechToTextError(
code="stt-provider-missing",
message=f"No speech-to-text provider for: {engine}",
)
metadata.language = self.pipeline.stt_language or self.language
if not stt_provider.check_metadata(metadata):
raise SpeechToTextError(
code="stt-provider-unsupported-metadata",
message=(
f"Provider {stt_provider.name} does not support input speech "
f"to text metadata {metadata}"
),
)
self.stt_provider = stt_provider
async def speech_to_text(
self,
metadata: stt.SpeechMetadata,
stream: AsyncIterable[bytes],
) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
if isinstance(self.stt_provider, stt.Provider):
engine = self.stt_provider.name
else:
engine = self.stt_provider.entity_id
self.process_event(
PipelineEvent(
PipelineEventType.STT_START,
{
"engine": engine,
"metadata": asdict(metadata),
},
)
)
try:
# Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream(
metadata, stream
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(
code="stt-stream-failed",
message="Unexpected error during speech-to-text",
) from src_error
_LOGGER.debug("speech-to-text result %s", result)
if result.result != stt.SpeechResultState.SUCCESS:
raise SpeechToTextError(
code="stt-stream-failed",
message="speech-to-text failed",
)
if not result.text:
raise SpeechToTextError(
code="stt-no-text-recognized", message="No text recognized"
)
self.process_event(
PipelineEvent(
PipelineEventType.STT_END,
{
"stt_output": {
"text": result.text,
}
},
)
)
return result.text
async def prepare_recognize_intent(self) -> None:
"""Prepare recognizing an intent."""
agent_info = conversation.async_get_agent_info(
self.hass,
# If no conversation engine is set, use the Home Assistant agent
# (the conversation integration default is currently the last one set)
self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
)
if agent_info is None:
engine = self.pipeline.conversation_engine or "default"
raise IntentRecognitionError(
code="intent-not-supported",
message=f"Intent recognition engine {engine} is not found",
)
self.intent_agent = agent_info.id
async def recognize_intent(
self, intent_input: str, conversation_id: str | None, device_id: str | None
) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak."""
if self.intent_agent is None:
raise RuntimeError("Recognize intent was not prepared")
self.process_event(
PipelineEvent(
PipelineEventType.INTENT_START,
{
"engine": self.intent_agent,
"language": self.pipeline.conversation_language,
"intent_input": intent_input,
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
try:
conversation_result = await conversation.async_converse(
hass=self.hass,
text=intent_input,
conversation_id=conversation_id,
device_id=device_id,
context=self.context,
language=self.pipeline.conversation_language,
agent_id=self.intent_agent,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError(
code="intent-failed",
message="Unexpected error during intent recognition",
) from src_error
_LOGGER.debug("conversation result %s", conversation_result)
self.process_event(
PipelineEvent(
PipelineEventType.INTENT_END,
{"intent_output": conversation_result.as_dict()},
)
)
speech: str = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
return speech
async def prepare_text_to_speech(self) -> None:
"""Prepare text-to-speech."""
# pipeline.tts_engine can't be None or this function is not called
engine = cast(str, self.pipeline.tts_engine)
tts_options = {}
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
try:
options_supported = await tts.async_support_options(
self.hass,
engine,
self.pipeline.tts_language,
tts_options,
)
except HomeAssistantError as err:
raise TextToSpeechError(
code="tts-not-supported",
message=f"Text-to-speech engine '{engine}' not found",
) from err
if not options_supported:
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text-to-speech engine {engine} "
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
self.tts_engine = engine
self.tts_options = tts_options
async def text_to_speech(self, tts_input: str) -> str:
"""Run text-to-speech portion of pipeline. Returns URL of TTS audio."""
self.process_event(
PipelineEvent(
PipelineEventType.TTS_START,
{
"engine": self.tts_engine,
"language": self.pipeline.tts_language,
"voice": self.pipeline.tts_voice,
"tts_input": tts_input,
},
)
)
try:
# Synthesize audio and get URL
tts_media_id = tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.tts_engine,
language=self.pipeline.tts_language,
options=self.tts_options,
)
tts_media = await media_source.async_resolve_media(
self.hass,
tts_media_id,
None,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during text-to-speech")
raise TextToSpeechError(
code="tts-failed",
message="Unexpected error during text-to-speech",
) from src_error
_LOGGER.debug("TTS result %s", tts_media)
self.process_event(
PipelineEvent(
PipelineEventType.TTS_END,
{
"tts_output": {
"media_id": tts_media_id,
**asdict(tts_media),
}
},
)
)
return tts_media.url
@dataclass
class PipelineInput:
"""Input to a pipeline run."""
run: PipelineRun
stt_metadata: stt.SpeechMetadata | None = None
"""Metadata of stt input audio. Required when start_stage = stt."""
stt_stream: AsyncIterable[bytes] | None = None
"""Input audio for stt. Required when start_stage = stt."""
intent_input: str | None = None
"""Input for conversation agent. Required when start_stage = intent."""
tts_input: str | None = None
"""Input for text-to-speech. Required when start_stage = tts."""
conversation_id: str | None = None
device_id: str | None = None
async def execute(self) -> None:
"""Run pipeline."""
self.run.start()
current_stage = self.run.start_stage
try:
# speech-to-text
intent_input = self.intent_input
if current_stage == PipelineStage.STT:
assert self.stt_metadata is not None
assert self.stt_stream is not None
intent_input = await self.run.speech_to_text(
self.stt_metadata,
self.stt_stream,
)
current_stage = PipelineStage.INTENT
if self.run.end_stage != PipelineStage.STT:
tts_input = self.tts_input
if current_stage == PipelineStage.INTENT:
assert intent_input is not None
tts_input = await self.run.recognize_intent(
intent_input,
self.conversation_id,
self.device_id,
)
current_stage = PipelineStage.TTS
if self.run.end_stage != PipelineStage.INTENT:
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await self.run.text_to_speech(tts_input)
except PipelineError as err:
self.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": err.code, "message": err.message},
)
)
return
self.run.end()
async def validate(self) -> None:
"""Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT:
if self.run.pipeline.stt_engine is None:
raise PipelineRunValidationError(
"the pipeline does not support speech-to-text"
)
if self.stt_metadata is None:
raise PipelineRunValidationError(
"stt_metadata is required for speech-to-text"
)
if self.stt_stream is None:
raise PipelineRunValidationError(
"stt_stream is required for speech-to-text"
)
elif self.run.start_stage == PipelineStage.INTENT:
if self.intent_input is None:
raise PipelineRunValidationError(
"intent_input is required for intent recognition"
)
elif self.run.start_stage == PipelineStage.TTS:
if self.tts_input is None:
raise PipelineRunValidationError(
"tts_input is required for text-to-speech"
)
if self.run.end_stage == PipelineStage.TTS:
if self.run.pipeline.tts_engine is None:
raise PipelineRunValidationError(
"the pipeline does not support text-to-speech"
)
start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)
end_stage_index = PIPELINE_STAGE_ORDER.index(self.run.end_stage)
prepare_tasks = []
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.STT)
<= end_stage_index
):
# self.stt_metadata can't be None or we'd raise above
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT)
<= end_stage_index
):
prepare_tasks.append(self.run.prepare_recognize_intent())
if (
start_stage_index
<= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS)
<= end_stage_index
):
prepare_tasks.append(self.run.prepare_text_to_speech())
if prepare_tasks:
await asyncio.gather(*prepare_tasks)
class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen."""
def __init__(self, item_id: str) -> None:
"""Initialize pipeline preferred error."""
super().__init__(f"Item {item_id} preferred.")
self.item_id = item_id
class SerializedPipelineStorageCollection(SerializedStorageCollection):
"""Serialized pipeline storage collection."""
preferred_item: str
class PipelineStorageCollection(
StorageCollection[Pipeline, SerializedPipelineStorageCollection]
):
"""Pipeline storage collection."""
_preferred_item: str
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
"""Load the data."""
if not (data := await super()._async_load_data()):
pipeline = await _async_create_default_pipeline(self.hass, self)
self._preferred_item = pipeline.id
return data
self._preferred_item = data["preferred_item"]
return data
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
validated_data: dict = validate_language(data)
return validated_data
@callback
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return ulid_util.ulid()
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
"""Return a new updated item."""
update_data = validate_language(update_data)
return Pipeline(id=item.id, **update_data)
def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config."""
return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline:
"""Create an item from its serialized representation."""
return Pipeline(**data)
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
"""Return the serialized representation of an item for storing."""
return item.to_json()
async def async_delete_item(self, item_id: str) -> None:
"""Delete item."""
if self._preferred_item == item_id:
raise PipelinePreferred(item_id)
await super().async_delete_item(item_id)
@callback
def async_get_preferred_item(self) -> str:
"""Get the id of the preferred item."""
return self._preferred_item
@callback
def async_set_preferred_item(self, item_id: str) -> None:
"""Set the preferred pipeline."""
if item_id not in self.data:
raise ItemNotFound(item_id)
self._preferred_item = item_id
self._async_schedule_save()
@callback
def _data_to_save(self) -> SerializedPipelineStorageCollection:
"""Return JSON-compatible date for storing to file."""
base_data = super()._base_data_to_save()
return {
"items": base_data["items"],
"preferred_item": self._preferred_item,
}
class PipelineStorageCollectionWebsocket(
StorageCollectionWebsocket[PipelineStorageCollection]
):
"""Class to expose storage collection management over websocket."""
@callback
def async_setup(
self,
hass: HomeAssistant,
*,
create_list: bool = True,
create_create: bool = True,
) -> None:
"""Set up the websocket commands."""
super().async_setup(hass, create_list=create_list, create_create=create_create)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/get",
self.ws_get_item,
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/get",
vol.Optional(self.item_id_key): str,
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/set_preferred",
websocket_api.require_admin(
websocket_api.async_response(self.ws_set_preferred_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/set_preferred",
vol.Required(self.item_id_key): str,
}
),
)
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete an item."""
try:
await super().ws_delete_item(hass, connection, msg)
except PipelinePreferred as exc:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
)
@callback
def ws_get_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Get an item."""
item_id = msg.get(self.item_id_key)
if item_id is None:
item_id = self.storage_collection.async_get_preferred_item()
if item_id not in self.storage_collection.data:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {item_id}",
)
return
connection.send_result(msg["id"], self.storage_collection.data[item_id])
@callback
def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(
msg["id"],
{
"pipelines": self.storage_collection.async_items(),
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
},
)
async def ws_set_preferred_item(
self,
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set the preferred item."""
try:
self.storage_collection.async_set_preferred_item(msg[self.item_id_key])
except ItemNotFound:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item"
)
return
connection.send_result(msg["id"])
@dataclass
class PipelineData:
"""Store and debug data stored in hass.data."""
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
pipeline_store: PipelineStorageCollection
pipeline_devices: set[str] = field(default_factory=set, init=False)
@dataclass
class PipelineRunDebug:
"""Debug data for a pipelinerun."""
events: list[PipelineEvent] = field(default_factory=list, init=False)
timestamp: str = field(
default_factory=lambda: dt_util.utcnow().isoformat(),
init=False,
)
@singleton(DOMAIN)
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Store(hass, STORAGE_VERSION, STORAGE_KEY)
)
await pipeline_store.async_load()
PipelineStorageCollectionWebsocket(
pipeline_store,
f"{DOMAIN}/pipeline",
"pipeline",
PIPELINE_FIELDS,
PIPELINE_FIELDS,
).async_setup(hass)
return PipelineData({}, pipeline_store)