Add a pipeline store to voice_assistant (#90844)

* Add a pipeline store to voice_assistant

* Improve error handling

* Improve test coverage

* Improve test coverage

* Use StorageCollectionWebsocket

* Correct rebase
pull/90509/head
Erik Montnemery 2023-04-06 18:55:16 +02:00 committed by GitHub
parent b2bcdf7c19
commit b3b83b7bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 420 additions and 23 deletions

View File

@ -17,6 +17,7 @@ from .pipeline import (
PipelineRun,
PipelineStage,
async_get_pipeline,
async_setup_pipeline_store,
)
from .websocket_api import async_register_websocket_api
@ -31,7 +32,7 @@ __all__ = (
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Voice Assistant integration."""
hass.data[DOMAIN] = {}
await async_setup_pipeline_store(hass)
async_register_websocket_api(hass)
return True
@ -61,7 +62,7 @@ async def async_pipeline_from_audio_stream(
if context is None:
context = Context()
pipeline = async_get_pipeline(
pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,

View File

@ -1,3 +1,2 @@
"""Constants for the Voice Assistant integration."""
DOMAIN = "voice_assistant"
DEFAULT_PIPELINE = "default"

View File

@ -7,13 +7,20 @@ from dataclasses import asdict, dataclass, field
import logging
from typing import Any
import voluptuous as vol
from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.util.dt import utcnow
from homeassistant.helpers.collection import (
StorageCollection,
StorageCollectionWebsocket,
)
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from .const import DOMAIN
from .error import (
@ -25,23 +32,39 @@ from .error import (
_LOGGER = logging.getLogger(__name__)
STORAGE_KEY = f"{DOMAIN}.pipelines"
STORAGE_VERSION = 1
@callback
def async_get_pipeline(
STORAGE_FIELDS = {
vol.Required("conversation_engine"): str,
vol.Required("language"): str,
vol.Required("name"): str,
vol.Required("stt_engine"): str,
vol.Required("tts_engine"): str,
}
SAVE_DELAY = 10
async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
if pipeline_id is not None:
return hass.data[DOMAIN].get(pipeline_id)
return pipeline_store.data.get(pipeline_id)
# Construct a pipeline for the required/configured language
language = language or hass.config.language
return Pipeline(
name=language,
language=language,
stt_engine=None, # first engine
conversation_engine=None, # first agent
tts_engine=None, # first engine
return await pipeline_store.async_create_item(
{
"name": language,
"language": language,
"stt_engine": None, # first engine
"conversation_engine": None, # first agent
"tts_engine": None, # first engine
}
)
@ -65,7 +88,7 @@ class PipelineEvent:
type: PipelineEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: utcnow().isoformat())
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the event."""
@ -79,16 +102,29 @@ class PipelineEvent:
PipelineEventCallback = Callable[[PipelineEvent], None]
@dataclass
@dataclass(frozen=True)
class Pipeline:
"""A voice assistant pipeline."""
name: str
language: str | None
stt_engine: str | None
conversation_engine: str | None
language: str | None
name: str
stt_engine: str | None
tts_engine: str | None
id: str = field(default_factory=ulid_util.ulid)
def to_json(self) -> dict[str, Any]:
"""Return a JSON serializable representation for storage."""
return {
"conversation_engine": self.conversation_engine,
"id": self.id,
"language": self.language,
"name": self.name,
"stt_engine": self.stt_engine,
"tts_engine": self.tts_engine,
}
class PipelineStage(StrEnum):
"""Stages of a pipeline."""
@ -478,3 +514,47 @@ class PipelineInput:
if prepare_tasks:
await asyncio.gather(*prepare_tasks)
class PipelineStorageCollection(StorageCollection[Pipeline]):
"""Pipeline storage collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
# We don't need to validate, the WS API has already validated
return data
@callback
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return ulid_util.ulid()
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
"""Return a new updated item."""
return Pipeline(id=item.id, **update_data)
def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config."""
return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline:
"""Create an item from its serialized representation."""
return Pipeline(**data)
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
"""Return the serialized representation of an item."""
return item.to_json()
async def async_setup_pipeline_store(hass):
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Store(hass, STORAGE_VERSION, STORAGE_KEY)
)
await pipeline_store.async_load()
StorageCollectionWebsocket(
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)
hass.data[DOMAIN] = pipeline_store

View File

@ -61,7 +61,7 @@ async def websocket_run(
language = "en-US"
pipeline_id = msg.get("pipeline")
pipeline = async_get_pipeline(
pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,

View File

@ -117,7 +117,7 @@ async def mock_stt_provider(hass) -> MockSttProvider:
return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture(autouse=True)
@pytest.fixture
async def init_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,

View File

@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant
async def test_pipeline_from_audio_stream(
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion
hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion
) -> None:
"""Test creating a pipeline from an audio stream."""

View File

@ -0,0 +1,104 @@
"""Websocket tests for Voice Assistant integration."""
from typing import Any
from homeassistant.components.voice_assistant.const import DOMAIN
from homeassistant.components.voice_assistant.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
PipelineStorageCollection,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
from homeassistant.setup import async_setup_component
from tests.common import flush_store
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
"""Make sure that we can load/save data correctly."""
pipelines = [
{
"conversation_engine": "conversation_engine_1",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"tts_engine": "tts_engine_1",
},
{
"conversation_engine": "conversation_engine_2",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
},
]
pipeline_ids = []
store1: PipelineStorageCollection = hass.data[DOMAIN]
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 2
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
await flush_store(store1.store)
await store2.async_load()
assert len(store2.data) == 2
assert store1.data is not store2.data
assert store1.data == store2.data
async def test_loading_datasets_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored datasets on start."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "voice_assistant.pipelines",
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"tts_engine": "tts_engine_1",
},
{
"conversation_engine": "conversation_engine_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
},
]
},
}
assert await async_setup_component(hass, "voice_assistant", {})
store: PipelineStorageCollection = hass.data[DOMAIN]
assert len(store.data) == 3

View File

@ -1,9 +1,14 @@
"""Websocket tests for Voice Assistant integration."""
import asyncio
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.voice_assistant.const import DOMAIN
from homeassistant.components.voice_assistant.pipeline import (
Pipeline,
PipelineStorageCollection,
)
from homeassistant.core import HomeAssistant
from tests.typing import WebSocketGenerator
@ -12,6 +17,7 @@ from tests.typing import WebSocketGenerator
async def test_text_only_pipeline(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with text input (no STT/TTS)."""
@ -51,7 +57,10 @@ async def test_text_only_pipeline(
async def test_audio_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with audio input/output."""
client = await hass_ws_client(hass)
@ -271,6 +280,7 @@ async def test_audio_pipeline_timeout(
async def test_stt_provider_missing(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
@ -297,6 +307,7 @@ async def test_stt_provider_missing(
async def test_stt_stream_failed(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
@ -398,3 +409,205 @@ async def test_invalid_stage_order(
# result
msg = await client.receive_json()
assert not msg["success"]
async def test_add_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can add a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
assert len(pipeline_store.data) == 1
pipeline = pipeline_store.data[msg["result"]["id"]]
assert pipeline == Pipeline(
conversation_engine="test_conversation_engine",
id=msg["result"]["id"],
language="test_language",
name="test_name",
stt_engine="test_stt_engine",
tts_engine="test_tts_engine",
)
async def test_delete_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can delete a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
pipeline_id = msg["result"]["id"]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/delete",
"pipeline_id": pipeline_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 0
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/delete",
"pipeline_id": pipeline_id,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": f"Unable to find pipeline_id {pipeline_id}",
}
async def test_list_pipelines(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == []
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == [
{
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
]
async def test_update_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/update",
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"pipeline_id": "no_such_pipeline",
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Unable to find pipeline_id no_such_pipeline",
}
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/update",
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"pipeline_id": pipeline_id,
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"id": pipeline_id,
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
assert len(pipeline_store.data) == 1
pipeline = pipeline_store.data[pipeline_id]
assert pipeline == Pipeline(
conversation_engine="new_conversation_engine",
id=pipeline_id,
language="new_language",
name="new_name",
stt_engine="new_stt_engine",
tts_engine="new_tts_engine",
)