Add a pipeline store to voice_assistant (#90844)

* Add a pipeline store to voice_assistant

* Improve error handling

* Improve test coverage

* Improve test coverage

* Use StorageCollectionWebsocket

* Correct rebase
pull/90509/head
Erik Montnemery 2023-04-06 18:55:16 +02:00 committed by GitHub
parent b2bcdf7c19
commit b3b83b7bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 420 additions and 23 deletions

View File

@ -17,6 +17,7 @@ from .pipeline import (
PipelineRun, PipelineRun,
PipelineStage, PipelineStage,
async_get_pipeline, async_get_pipeline,
async_setup_pipeline_store,
) )
from .websocket_api import async_register_websocket_api from .websocket_api import async_register_websocket_api
@ -31,7 +32,7 @@ __all__ = (
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Voice Assistant integration.""" """Set up Voice Assistant integration."""
hass.data[DOMAIN] = {} await async_setup_pipeline_store(hass)
async_register_websocket_api(hass) async_register_websocket_api(hass)
return True return True
@ -61,7 +62,7 @@ async def async_pipeline_from_audio_stream(
if context is None: if context is None:
context = Context() context = Context()
pipeline = async_get_pipeline( pipeline = await async_get_pipeline(
hass, hass,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
language=language, language=language,

View File

@ -1,3 +1,2 @@
"""Constants for the Voice Assistant integration.""" """Constants for the Voice Assistant integration."""
DOMAIN = "voice_assistant" DOMAIN = "voice_assistant"
DEFAULT_PIPELINE = "default"

View File

@ -7,13 +7,20 @@ from dataclasses import asdict, dataclass, field
import logging import logging
from typing import Any from typing import Any
import voluptuous as vol
from homeassistant.backports.enum import StrEnum from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation, media_source, stt, tts from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.tts.media_source import ( from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id, generate_media_source_id as tts_generate_media_source_id,
) )
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.util.dt import utcnow from homeassistant.helpers.collection import (
StorageCollection,
StorageCollectionWebsocket,
)
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from .const import DOMAIN from .const import DOMAIN
from .error import ( from .error import (
@ -25,23 +32,39 @@ from .error import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STORAGE_KEY = f"{DOMAIN}.pipelines"
STORAGE_VERSION = 1
@callback STORAGE_FIELDS = {
def async_get_pipeline( vol.Required("conversation_engine"): str,
vol.Required("language"): str,
vol.Required("name"): str,
vol.Required("stt_engine"): str,
vol.Required("tts_engine"): str,
}
SAVE_DELAY = 10
async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
) -> Pipeline | None: ) -> Pipeline | None:
"""Get a pipeline by id or create one for a language.""" """Get a pipeline by id or create one for a language."""
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
if pipeline_id is not None: if pipeline_id is not None:
return hass.data[DOMAIN].get(pipeline_id) return pipeline_store.data.get(pipeline_id)
# Construct a pipeline for the required/configured language # Construct a pipeline for the required/configured language
language = language or hass.config.language language = language or hass.config.language
return Pipeline( return await pipeline_store.async_create_item(
name=language, {
language=language, "name": language,
stt_engine=None, # first engine "language": language,
conversation_engine=None, # first agent "stt_engine": None, # first engine
tts_engine=None, # first engine "conversation_engine": None, # first agent
"tts_engine": None, # first engine
}
) )
@ -65,7 +88,7 @@ class PipelineEvent:
type: PipelineEventType type: PipelineEventType
data: dict[str, Any] | None = None data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: utcnow().isoformat()) timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the event.""" """Return a dict representation of the event."""
@ -79,16 +102,29 @@ class PipelineEvent:
PipelineEventCallback = Callable[[PipelineEvent], None] PipelineEventCallback = Callable[[PipelineEvent], None]
@dataclass @dataclass(frozen=True)
class Pipeline: class Pipeline:
"""A voice assistant pipeline.""" """A voice assistant pipeline."""
name: str
language: str | None
stt_engine: str | None
conversation_engine: str | None conversation_engine: str | None
language: str | None
name: str
stt_engine: str | None
tts_engine: str | None tts_engine: str | None
id: str = field(default_factory=ulid_util.ulid)
def to_json(self) -> dict[str, Any]:
"""Return a JSON serializable representation for storage."""
return {
"conversation_engine": self.conversation_engine,
"id": self.id,
"language": self.language,
"name": self.name,
"stt_engine": self.stt_engine,
"tts_engine": self.tts_engine,
}
class PipelineStage(StrEnum): class PipelineStage(StrEnum):
"""Stages of a pipeline.""" """Stages of a pipeline."""
@ -478,3 +514,47 @@ class PipelineInput:
if prepare_tasks: if prepare_tasks:
await asyncio.gather(*prepare_tasks) await asyncio.gather(*prepare_tasks)
class PipelineStorageCollection(StorageCollection[Pipeline]):
"""Pipeline storage collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
# We don't need to validate, the WS API has already validated
return data
@callback
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
return ulid_util.ulid()
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
"""Return a new updated item."""
return Pipeline(id=item.id, **update_data)
def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config."""
return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline:
"""Create an item from its serialized representation."""
return Pipeline(**data)
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
"""Return the serialized representation of an item."""
return item.to_json()
async def async_setup_pipeline_store(hass):
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Store(hass, STORAGE_VERSION, STORAGE_KEY)
)
await pipeline_store.async_load()
StorageCollectionWebsocket(
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)
hass.data[DOMAIN] = pipeline_store

View File

@ -61,7 +61,7 @@ async def websocket_run(
language = "en-US" language = "en-US"
pipeline_id = msg.get("pipeline") pipeline_id = msg.get("pipeline")
pipeline = async_get_pipeline( pipeline = await async_get_pipeline(
hass, hass,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
language=language, language=language,

View File

@ -117,7 +117,7 @@ async def mock_stt_provider(hass) -> MockSttProvider:
return MockSttProvider(hass, _TRANSCRIPT) return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture(autouse=True) @pytest.fixture
async def init_components( async def init_components(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSttProvider,

View File

@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant
async def test_pipeline_from_audio_stream( async def test_pipeline_from_audio_stream(
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion
) -> None: ) -> None:
"""Test creating a pipeline from an audio stream.""" """Test creating a pipeline from an audio stream."""

View File

@ -0,0 +1,104 @@
"""Websocket tests for Voice Assistant integration."""
from typing import Any
from homeassistant.components.voice_assistant.const import DOMAIN
from homeassistant.components.voice_assistant.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
PipelineStorageCollection,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
from homeassistant.setup import async_setup_component
from tests.common import flush_store
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
"""Make sure that we can load/save data correctly."""
pipelines = [
{
"conversation_engine": "conversation_engine_1",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"tts_engine": "tts_engine_1",
},
{
"conversation_engine": "conversation_engine_2",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
},
]
pipeline_ids = []
store1: PipelineStorageCollection = hass.data[DOMAIN]
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 2
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
await flush_store(store1.store)
await store2.async_load()
assert len(store2.data) == 2
assert store1.data is not store2.data
assert store1.data == store2.data
async def test_loading_datasets_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test loading stored datasets on start."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "voice_assistant.pipelines",
"data": {
"items": [
{
"conversation_engine": "conversation_engine_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "name_1",
"stt_engine": "stt_engine_1",
"tts_engine": "tts_engine_1",
},
{
"conversation_engine": "conversation_engine_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
},
]
},
}
assert await async_setup_component(hass, "voice_assistant", {})
store: PipelineStorageCollection = hass.data[DOMAIN]
assert len(store.data) == 3

View File

@ -1,9 +1,14 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
import asyncio import asyncio
from unittest.mock import MagicMock, patch from unittest.mock import ANY, MagicMock, patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.voice_assistant.const import DOMAIN
from homeassistant.components.voice_assistant.pipeline import (
Pipeline,
PipelineStorageCollection,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -12,6 +17,7 @@ from tests.typing import WebSocketGenerator
async def test_text_only_pipeline( async def test_text_only_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with text input (no STT/TTS).""" """Test events from a pipeline run with text input (no STT/TTS)."""
@ -51,7 +57,10 @@ async def test_text_only_pipeline(
async def test_audio_pipeline( async def test_audio_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with audio input/output.""" """Test events from a pipeline run with audio input/output."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -271,6 +280,7 @@ async def test_audio_pipeline_timeout(
async def test_stt_provider_missing( async def test_stt_provider_missing(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with a non-existent STT provider.""" """Test events from a pipeline run with a non-existent STT provider."""
@ -297,6 +307,7 @@ async def test_stt_provider_missing(
async def test_stt_stream_failed( async def test_stt_stream_failed(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with a non-existent STT provider.""" """Test events from a pipeline run with a non-existent STT provider."""
@ -398,3 +409,205 @@ async def test_invalid_stage_order(
# result # result
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert not msg["success"]
async def test_add_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can add a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
assert len(pipeline_store.data) == 1
pipeline = pipeline_store.data[msg["result"]["id"]]
assert pipeline == Pipeline(
conversation_engine="test_conversation_engine",
id=msg["result"]["id"],
language="test_language",
name="test_name",
stt_engine="test_stt_engine",
tts_engine="test_tts_engine",
)
async def test_delete_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can delete a pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
pipeline_id = msg["result"]["id"]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/delete",
"pipeline_id": pipeline_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 0
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/delete",
"pipeline_id": pipeline_id,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": f"Unable to find pipeline_id {pipeline_id}",
}
async def test_list_pipelines(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == []
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == [
{
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
]
async def test_update_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/update",
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"pipeline_id": "no_such_pipeline",
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Unable to find pipeline_id no_such_pipeline",
}
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
"type": "voice_assistant/pipeline/update",
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"pipeline_id": pipeline_id,
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "new_conversation_engine",
"language": "new_language",
"name": "new_name",
"id": pipeline_id,
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
assert len(pipeline_store.data) == 1
pipeline = pipeline_store.data[pipeline_id]
assert pipeline == Pipeline(
conversation_engine="new_conversation_engine",
id=pipeline_id,
language="new_language",
name="new_name",
stt_engine="new_stt_engine",
tts_engine="new_tts_engine",
)