Add a pipeline store to voice_assistant (#90844)
* Add a pipeline store to voice_assistant * Improve error handling * Improve test coverage * Improve test coverage * Use StorageCollectionWebsocket * Correct rebasepull/90509/head
parent
b2bcdf7c19
commit
b3b83b7bb2
|
@ -17,6 +17,7 @@ from .pipeline import (
|
|||
PipelineRun,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_setup_pipeline_store,
|
||||
)
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
|
@ -31,7 +32,7 @@ __all__ = (
|
|||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Voice Assistant integration."""
|
||||
hass.data[DOMAIN] = {}
|
||||
await async_setup_pipeline_store(hass)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
||||
|
@ -61,7 +62,7 @@ async def async_pipeline_from_audio_stream(
|
|||
if context is None:
|
||||
context = Context()
|
||||
|
||||
pipeline = async_get_pipeline(
|
||||
pipeline = await async_get_pipeline(
|
||||
hass,
|
||||
pipeline_id=pipeline_id,
|
||||
language=language,
|
||||
|
|
|
@ -1,3 +1,2 @@
|
|||
"""Constants for the Voice Assistant integration."""
|
||||
DOMAIN = "voice_assistant"
|
||||
DEFAULT_PIPELINE = "default"
|
||||
|
|
|
@ -7,13 +7,20 @@ from dataclasses import asdict, dataclass, field
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.util.dt import utcnow
|
||||
from homeassistant.helpers.collection import (
|
||||
StorageCollection,
|
||||
StorageCollectionWebsocket,
|
||||
)
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
|
||||
from .const import DOMAIN
|
||||
from .error import (
|
||||
|
@ -25,23 +32,39 @@ from .error import (
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
||||
STORAGE_VERSION = 1
|
||||
|
||||
@callback
|
||||
def async_get_pipeline(
|
||||
STORAGE_FIELDS = {
|
||||
vol.Required("conversation_engine"): str,
|
||||
vol.Required("language"): str,
|
||||
vol.Required("name"): str,
|
||||
vol.Required("stt_engine"): str,
|
||||
vol.Required("tts_engine"): str,
|
||||
}
|
||||
|
||||
SAVE_DELAY = 10
|
||||
|
||||
|
||||
async def async_get_pipeline(
|
||||
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
||||
) -> Pipeline | None:
|
||||
"""Get a pipeline by id or create one for a language."""
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
if pipeline_id is not None:
|
||||
return hass.data[DOMAIN].get(pipeline_id)
|
||||
return pipeline_store.data.get(pipeline_id)
|
||||
|
||||
# Construct a pipeline for the required/configured language
|
||||
language = language or hass.config.language
|
||||
return Pipeline(
|
||||
name=language,
|
||||
language=language,
|
||||
stt_engine=None, # first engine
|
||||
conversation_engine=None, # first agent
|
||||
tts_engine=None, # first engine
|
||||
return await pipeline_store.async_create_item(
|
||||
{
|
||||
"name": language,
|
||||
"language": language,
|
||||
"stt_engine": None, # first engine
|
||||
"conversation_engine": None, # first agent
|
||||
"tts_engine": None, # first engine
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
@ -65,7 +88,7 @@ class PipelineEvent:
|
|||
|
||||
type: PipelineEventType
|
||||
data: dict[str, Any] | None = None
|
||||
timestamp: str = field(default_factory=lambda: utcnow().isoformat())
|
||||
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the event."""
|
||||
|
@ -79,16 +102,29 @@ class PipelineEvent:
|
|||
PipelineEventCallback = Callable[[PipelineEvent], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class Pipeline:
|
||||
"""A voice assistant pipeline."""
|
||||
|
||||
name: str
|
||||
language: str | None
|
||||
stt_engine: str | None
|
||||
conversation_engine: str | None
|
||||
language: str | None
|
||||
name: str
|
||||
stt_engine: str | None
|
||||
tts_engine: str | None
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
|
||||
def to_json(self) -> dict[str, Any]:
|
||||
"""Return a JSON serializable representation for storage."""
|
||||
return {
|
||||
"conversation_engine": self.conversation_engine,
|
||||
"id": self.id,
|
||||
"language": self.language,
|
||||
"name": self.name,
|
||||
"stt_engine": self.stt_engine,
|
||||
"tts_engine": self.tts_engine,
|
||||
}
|
||||
|
||||
|
||||
class PipelineStage(StrEnum):
|
||||
"""Stages of a pipeline."""
|
||||
|
@ -478,3 +514,47 @@ class PipelineInput:
|
|||
|
||||
if prepare_tasks:
|
||||
await asyncio.gather(*prepare_tasks)
|
||||
|
||||
|
||||
class PipelineStorageCollection(StorageCollection[Pipeline]):
|
||||
"""Pipeline storage collection."""
|
||||
|
||||
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
||||
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
"""Validate the config is valid."""
|
||||
# We don't need to validate, the WS API has already validated
|
||||
return data
|
||||
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
return ulid_util.ulid()
|
||||
|
||||
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
|
||||
"""Return a new updated item."""
|
||||
return Pipeline(id=item.id, **update_data)
|
||||
|
||||
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||
"""Create an item from validated config."""
|
||||
return Pipeline(id=item_id, **data)
|
||||
|
||||
def _deserialize_item(self, data: dict) -> Pipeline:
|
||||
"""Create an item from its serialized representation."""
|
||||
return Pipeline(**data)
|
||||
|
||||
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
|
||||
"""Return the serialized representation of an item."""
|
||||
return item.to_json()
|
||||
|
||||
|
||||
async def async_setup_pipeline_store(hass):
|
||||
"""Set up the pipeline storage collection."""
|
||||
pipeline_store = PipelineStorageCollection(
|
||||
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
)
|
||||
await pipeline_store.async_load()
|
||||
StorageCollectionWebsocket(
|
||||
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
hass.data[DOMAIN] = pipeline_store
|
||||
|
|
|
@ -61,7 +61,7 @@ async def websocket_run(
|
|||
language = "en-US"
|
||||
|
||||
pipeline_id = msg.get("pipeline")
|
||||
pipeline = async_get_pipeline(
|
||||
pipeline = await async_get_pipeline(
|
||||
hass,
|
||||
pipeline_id=pipeline_id,
|
||||
language=language,
|
||||
|
|
|
@ -117,7 +117,7 @@ async def mock_stt_provider(hass) -> MockSttProvider:
|
|||
return MockSttProvider(hass, _TRANSCRIPT)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
|
|
|
@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant
|
|||
|
||||
|
||||
async def test_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion
|
||||
hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream."""
|
||||
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.voice_assistant.const import DOMAIN
|
||||
from homeassistant.components.voice_assistant.pipeline import (
|
||||
STORAGE_KEY,
|
||||
STORAGE_VERSION,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
|
||||
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||
"""Make sure that we can load/save data correctly."""
|
||||
|
||||
pipelines = [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"language": "language_1",
|
||||
"name": "name_1",
|
||||
"stt_engine": "stt_engine_1",
|
||||
"tts_engine": "tts_engine_1",
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_2",
|
||||
"language": "language_2",
|
||||
"name": "name_2",
|
||||
"stt_engine": "stt_engine_2",
|
||||
"tts_engine": "tts_engine_2",
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_3",
|
||||
"language": "language_3",
|
||||
"name": "name_3",
|
||||
"stt_engine": "stt_engine_3",
|
||||
"tts_engine": "tts_engine_3",
|
||||
},
|
||||
]
|
||||
pipeline_ids = []
|
||||
|
||||
store1: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
for pipeline in pipelines:
|
||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||
assert len(store1.data) == 3
|
||||
|
||||
await store1.async_delete_item(pipeline_ids[1])
|
||||
assert len(store1.data) == 2
|
||||
|
||||
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
|
||||
await flush_store(store1.store)
|
||||
await store2.async_load()
|
||||
|
||||
assert len(store2.data) == 2
|
||||
|
||||
assert store1.data is not store2.data
|
||||
assert store1.data == store2.data
|
||||
|
||||
|
||||
async def test_loading_datasets_from_storage(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test loading stored datasets on start."""
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"minor_version": 1,
|
||||
"key": "voice_assistant.pipelines",
|
||||
"data": {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"language": "language_1",
|
||||
"name": "name_1",
|
||||
"stt_engine": "stt_engine_1",
|
||||
"tts_engine": "tts_engine_1",
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_2",
|
||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||
"language": "language_2",
|
||||
"name": "name_2",
|
||||
"stt_engine": "stt_engine_2",
|
||||
"tts_engine": "tts_engine_2",
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_3",
|
||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||
"language": "language_3",
|
||||
"name": "name_3",
|
||||
"stt_engine": "stt_engine_3",
|
||||
"tts_engine": "tts_engine_3",
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
assert await async_setup_component(hass, "voice_assistant", {})
|
||||
|
||||
store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
assert len(store.data) == 3
|
|
@ -1,9 +1,14 @@
|
|||
"""Websocket tests for Voice Assistant integration."""
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.voice_assistant.const import DOMAIN
|
||||
from homeassistant.components.voice_assistant.pipeline import (
|
||||
Pipeline,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
@ -12,6 +17,7 @@ from tests.typing import WebSocketGenerator
|
|||
async def test_text_only_pipeline(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||
|
@ -51,7 +57,10 @@ async def test_text_only_pipeline(
|
|||
|
||||
|
||||
async def test_audio_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with audio input/output."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
@ -271,6 +280,7 @@ async def test_audio_pipeline_timeout(
|
|||
async def test_stt_provider_missing(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
|
@ -297,6 +307,7 @@ async def test_stt_provider_missing(
|
|||
async def test_stt_stream_failed(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||
|
@ -398,3 +409,205 @@ async def test_invalid_stage_order(
|
|||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
|
||||
|
||||
async def test_add_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test we can add a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": ANY,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
|
||||
assert len(pipeline_store.data) == 1
|
||||
pipeline = pipeline_store.data[msg["result"]["id"]]
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="test_conversation_engine",
|
||||
id=msg["result"]["id"],
|
||||
language="test_language",
|
||||
name="test_name",
|
||||
stt_engine="test_stt_engine",
|
||||
tts_engine="test_tts_engine",
|
||||
)
|
||||
|
||||
|
||||
async def test_delete_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test we can delete a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
pipeline_id = msg["result"]["id"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/delete",
|
||||
"pipeline_id": pipeline_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 0
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/delete",
|
||||
"pipeline_id": pipeline_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": f"Unable to find pipeline_id {pipeline_id}",
|
||||
}
|
||||
|
||||
|
||||
async def test_list_pipelines(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test we can list pipelines."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == [
|
||||
{
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": ANY,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_update_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test we can list pipelines."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/update",
|
||||
"conversation_engine": "new_conversation_engine",
|
||||
"language": "new_language",
|
||||
"name": "new_name",
|
||||
"pipeline_id": "no_such_pipeline",
|
||||
"stt_engine": "new_stt_engine",
|
||||
"tts_engine": "new_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Unable to find pipeline_id no_such_pipeline",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "voice_assistant/pipeline/update",
|
||||
"conversation_engine": "new_conversation_engine",
|
||||
"language": "new_language",
|
||||
"name": "new_name",
|
||||
"pipeline_id": pipeline_id,
|
||||
"stt_engine": "new_stt_engine",
|
||||
"tts_engine": "new_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "new_conversation_engine",
|
||||
"language": "new_language",
|
||||
"name": "new_name",
|
||||
"id": pipeline_id,
|
||||
"stt_engine": "new_stt_engine",
|
||||
"tts_engine": "new_tts_engine",
|
||||
}
|
||||
|
||||
assert len(pipeline_store.data) == 1
|
||||
pipeline = pipeline_store.data[pipeline_id]
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="new_conversation_engine",
|
||||
id=pipeline_id,
|
||||
language="new_language",
|
||||
name="new_name",
|
||||
stt_engine="new_stt_engine",
|
||||
tts_engine="new_tts_engine",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue