Use HassKey for assist_pipeline singleton (#135875)
parent
24c50e0988
commit
19e5b091c5
|
@ -50,6 +50,7 @@ from homeassistant.util import (
|
|||
language as language_util,
|
||||
ulid as ulid_util,
|
||||
)
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
|
||||
|
@ -91,6 +92,8 @@ ENGINE_LANGUAGE_PAIRS = (
|
|||
("tts_engine", "tts_language"),
|
||||
)
|
||||
|
||||
KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
|
||||
|
||||
|
||||
def validate_language(data: dict[str, Any]) -> Any:
|
||||
"""Validate language settings."""
|
||||
|
@ -248,7 +251,7 @@ async def async_create_default_pipeline(
|
|||
The default pipeline will use the homeassistant conversation agent and the
|
||||
specified stt / tts engines.
|
||||
"""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||
hass,
|
||||
|
@ -283,7 +286,7 @@ def _async_get_pipeline_from_conversation_entity(
|
|||
@callback
|
||||
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||
"""Get a pipeline by id or the preferred pipeline."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
|
||||
if pipeline_id is None:
|
||||
# A pipeline was not specified, use the preferred one
|
||||
|
@ -306,7 +309,7 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
|||
@callback
|
||||
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
|
||||
"""Get all pipelines."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
|
||||
return list(pipeline_data.pipeline_store.data.values())
|
||||
|
||||
|
@ -329,7 +332,7 @@ async def async_update_pipeline(
|
|||
prefer_local_intents: bool | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Update a pipeline."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
|
||||
updates: dict[str, Any] = pipeline.to_json()
|
||||
updates.pop("id")
|
||||
|
@ -587,7 +590,7 @@ class PipelineRun:
|
|||
):
|
||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||
if self.pipeline.id not in pipeline_data.pipeline_debug:
|
||||
pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict(
|
||||
size_limit=STORED_PIPELINE_RUNS
|
||||
|
@ -615,7 +618,7 @@ class PipelineRun:
|
|||
def process_event(self, event: PipelineEvent) -> None:
|
||||
"""Log an event and call listener."""
|
||||
self.event_callback(event)
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||
if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]:
|
||||
# This run has been evicted from the logged pipeline runs already
|
||||
return
|
||||
|
@ -650,7 +653,7 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_data.pipeline_runs.remove_run(self)
|
||||
|
||||
async def prepare_wake_word_detection(self) -> None:
|
||||
|
@ -1227,7 +1230,7 @@ class PipelineRun:
|
|||
return
|
||||
|
||||
# Forward to device audio capture
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
|
||||
if audio_queue is None:
|
||||
return
|
||||
|
@ -1884,7 +1887,7 @@ class PipelineStore(Store[SerializedPipelineStorageCollection]):
|
|||
return old_data
|
||||
|
||||
|
||||
@singleton(DOMAIN)
|
||||
@singleton(KEY_ASSIST_PIPELINE, async_=True)
|
||||
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
|
||||
"""Set up the pipeline storage collection."""
|
||||
pipeline_store = PipelineStorageCollection(
|
||||
|
|
|
@ -9,8 +9,8 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .const import OPTION_PREFERRED
|
||||
from .pipeline import KEY_ASSIST_PIPELINE, AssistDevice
|
||||
from .vad import VadSensitivity
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ def get_chosen_pipeline(
|
|||
if state is None or state.state == OPTION_PREFERRED:
|
||||
return None
|
||||
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN].pipeline_store
|
||||
pipeline_store = hass.data[KEY_ASSIST_PIPELINE].pipeline_store
|
||||
return next(
|
||||
(item.id for item in pipeline_store.async_items() if item.name == state.state),
|
||||
None,
|
||||
|
@ -80,7 +80,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
|||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
self.async_on_remove(
|
||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||
|
@ -116,9 +116,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
|||
@callback
|
||||
def _update_options(self) -> None:
|
||||
"""Handle pipeline update."""
|
||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||
DOMAIN
|
||||
].pipeline_store
|
||||
pipeline_store = self.hass.data[KEY_ASSIST_PIPELINE].pipeline_store
|
||||
options = [OPTION_PREFERRED]
|
||||
options.extend(sorted(item.name for item in pipeline_store.async_items()))
|
||||
self._attr_options = options
|
||||
|
|
|
@ -21,7 +21,6 @@ from homeassistant.util import language as language_util
|
|||
from .const import (
|
||||
DEFAULT_PIPELINE_TIMEOUT,
|
||||
DEFAULT_WAKE_WORD_TIMEOUT,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
|
@ -29,9 +28,9 @@ from .const import (
|
|||
)
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
KEY_ASSIST_PIPELINE,
|
||||
AudioSettings,
|
||||
DeviceAudioQueue,
|
||||
PipelineData,
|
||||
PipelineError,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
|
@ -283,7 +282,7 @@ def websocket_list_runs(
|
|||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List pipeline runs for which debug data is available."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
|
||||
if pipeline_id not in pipeline_data.pipeline_debug:
|
||||
|
@ -319,7 +318,7 @@ def websocket_list_devices(
|
|||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""List assist devices."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
ent_reg = er.async_get(hass)
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
|
@ -350,7 +349,7 @@ def websocket_get_run(
|
|||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Get debug data for a pipeline run."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
pipeline_id = msg["pipeline_id"]
|
||||
pipeline_run_id = msg["pipeline_run_id"]
|
||||
|
||||
|
@ -455,7 +454,7 @@ async def websocket_device_capture(
|
|||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Capture raw audio from a satellite device and forward to client."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||
device_id = msg["device_id"]
|
||||
|
||||
# Number of seconds to record audio in wall clock time
|
||||
|
|
Loading…
Reference in New Issue