Use HassKey for assist_pipeline singleton (#135875)

pull/135939/head
Marc Mueller 2025-01-18 20:52:13 +01:00 committed by GitHub
parent 24c50e0988
commit 19e5b091c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 22 deletions

View File

@ -50,6 +50,7 @@ from homeassistant.util import (
language as language_util,
ulid as ulid_util,
)
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
@ -91,6 +92,8 @@ ENGINE_LANGUAGE_PAIRS = (
("tts_engine", "tts_language"),
)
KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
def validate_language(data: dict[str, Any]) -> Any:
"""Validate language settings."""
@ -248,7 +251,7 @@ async def async_create_default_pipeline(
The default pipeline will use the homeassistant conversation agent and the
specified stt / tts engines.
"""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = _async_resolve_default_pipeline_settings(
hass,
@ -283,7 +286,7 @@ def _async_get_pipeline_from_conversation_entity(
@callback
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
if pipeline_id is None:
# A pipeline was not specified, use the preferred one
@ -306,7 +309,7 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
@callback
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
return list(pipeline_data.pipeline_store.data.values())
@ -329,7 +332,7 @@ async def async_update_pipeline(
prefer_local_intents: bool | UndefinedType = UNDEFINED,
) -> None:
"""Update a pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
updates: dict[str, Any] = pipeline.to_json()
updates.pop("id")
@ -587,7 +590,7 @@ class PipelineRun:
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
if self.pipeline.id not in pipeline_data.pipeline_debug:
pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict(
size_limit=STORED_PIPELINE_RUNS
@ -615,7 +618,7 @@ class PipelineRun:
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
self.event_callback(event)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]:
# This run has been evicted from the logged pipeline runs already
return
@ -650,7 +653,7 @@ class PipelineRun:
)
)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
pipeline_data.pipeline_runs.remove_run(self)
async def prepare_wake_word_detection(self) -> None:
@ -1227,7 +1230,7 @@ class PipelineRun:
return
# Forward to device audio capture
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
if audio_queue is None:
return
@ -1884,7 +1887,7 @@ class PipelineStore(Store[SerializedPipelineStorageCollection]):
return old_data
@singleton(DOMAIN)
@singleton(KEY_ASSIST_PIPELINE, async_=True)
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(

View File

@ -9,8 +9,8 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .const import OPTION_PREFERRED
from .pipeline import KEY_ASSIST_PIPELINE, AssistDevice
from .vad import VadSensitivity
@ -30,7 +30,7 @@ def get_chosen_pipeline(
if state is None or state.state == OPTION_PREFERRED:
return None
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN].pipeline_store
pipeline_store = hass.data[KEY_ASSIST_PIPELINE].pipeline_store
return next(
(item.id for item in pipeline_store.async_items() if item.name == state.state),
None,
@ -80,7 +80,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
pipeline_store = pipeline_data.pipeline_store
self.async_on_remove(
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
@ -116,9 +116,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
@callback
def _update_options(self) -> None:
"""Handle pipeline update."""
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
pipeline_store = self.hass.data[KEY_ASSIST_PIPELINE].pipeline_store
options = [OPTION_PREFERRED]
options.extend(sorted(item.name for item in pipeline_store.async_items()))
self._attr_options = options

View File

@ -21,7 +21,6 @@ from homeassistant.util import language as language_util
from .const import (
DEFAULT_PIPELINE_TIMEOUT,
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
EVENT_RECORDING,
SAMPLE_CHANNELS,
SAMPLE_RATE,
@ -29,9 +28,9 @@ from .const import (
)
from .error import PipelineNotFound
from .pipeline import (
KEY_ASSIST_PIPELINE,
AudioSettings,
DeviceAudioQueue,
PipelineData,
PipelineError,
PipelineEvent,
PipelineEventType,
@ -283,7 +282,7 @@ def websocket_list_runs(
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_id = msg["pipeline_id"]
if pipeline_id not in pipeline_data.pipeline_debug:
@ -319,7 +318,7 @@ def websocket_list_devices(
msg: dict[str, Any],
) -> None:
"""List assist devices."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
ent_reg = er.async_get(hass)
connection.send_result(
msg["id"],
@ -350,7 +349,7 @@ def websocket_get_run(
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]
@ -455,7 +454,7 @@ async def websocket_device_capture(
msg: dict[str, Any],
) -> None:
"""Capture raw audio from a satellite device and forward to client."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
device_id = msg["device_id"]
# Number of seconds to record audio in wall clock time