Add a pipeline store to voice_assistant (#90844)
* Add a pipeline store to voice_assistant * Improve error handling * Improve test coverage * Improve test coverage * Use StorageCollectionWebsocket * Correct rebasepull/90509/head
parent
b2bcdf7c19
commit
b3b83b7bb2
|
@ -17,6 +17,7 @@ from .pipeline import (
|
||||||
PipelineRun,
|
PipelineRun,
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
|
async_setup_pipeline_store,
|
||||||
)
|
)
|
||||||
from .websocket_api import async_register_websocket_api
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ __all__ = (
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up Voice Assistant integration."""
|
"""Set up Voice Assistant integration."""
|
||||||
hass.data[DOMAIN] = {}
|
await async_setup_pipeline_store(hass)
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -61,7 +62,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
if context is None:
|
if context is None:
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
pipeline = async_get_pipeline(
|
pipeline = await async_get_pipeline(
|
||||||
hass,
|
hass,
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
language=language,
|
language=language,
|
||||||
|
|
|
@ -1,3 +1,2 @@
|
||||||
"""Constants for the Voice Assistant integration."""
|
"""Constants for the Voice Assistant integration."""
|
||||||
DOMAIN = "voice_assistant"
|
DOMAIN = "voice_assistant"
|
||||||
DEFAULT_PIPELINE = "default"
|
|
||||||
|
|
|
@ -7,13 +7,20 @@ from dataclasses import asdict, dataclass, field
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.backports.enum import StrEnum
|
from homeassistant.backports.enum import StrEnum
|
||||||
from homeassistant.components import conversation, media_source, stt, tts
|
from homeassistant.components import conversation, media_source, stt, tts
|
||||||
from homeassistant.components.tts.media_source import (
|
from homeassistant.components.tts.media_source import (
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.helpers.collection import (
|
||||||
|
StorageCollection,
|
||||||
|
StorageCollectionWebsocket,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers.storage import Store
|
||||||
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .error import (
|
from .error import (
|
||||||
|
@ -25,23 +32,39 @@ from .error import (
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
||||||
|
STORAGE_VERSION = 1
|
||||||
|
|
||||||
@callback
|
STORAGE_FIELDS = {
|
||||||
def async_get_pipeline(
|
vol.Required("conversation_engine"): str,
|
||||||
|
vol.Required("language"): str,
|
||||||
|
vol.Required("name"): str,
|
||||||
|
vol.Required("stt_engine"): str,
|
||||||
|
vol.Required("tts_engine"): str,
|
||||||
|
}
|
||||||
|
|
||||||
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_pipeline(
|
||||||
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
||||||
) -> Pipeline | None:
|
) -> Pipeline | None:
|
||||||
"""Get a pipeline by id or create one for a language."""
|
"""Get a pipeline by id or create one for a language."""
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
|
||||||
if pipeline_id is not None:
|
if pipeline_id is not None:
|
||||||
return hass.data[DOMAIN].get(pipeline_id)
|
return pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
# Construct a pipeline for the required/configured language
|
# Construct a pipeline for the required/configured language
|
||||||
language = language or hass.config.language
|
language = language or hass.config.language
|
||||||
return Pipeline(
|
return await pipeline_store.async_create_item(
|
||||||
name=language,
|
{
|
||||||
language=language,
|
"name": language,
|
||||||
stt_engine=None, # first engine
|
"language": language,
|
||||||
conversation_engine=None, # first agent
|
"stt_engine": None, # first engine
|
||||||
tts_engine=None, # first engine
|
"conversation_engine": None, # first agent
|
||||||
|
"tts_engine": None, # first engine
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,7 +88,7 @@ class PipelineEvent:
|
||||||
|
|
||||||
type: PipelineEventType
|
type: PipelineEventType
|
||||||
data: dict[str, Any] | None = None
|
data: dict[str, Any] | None = None
|
||||||
timestamp: str = field(default_factory=lambda: utcnow().isoformat())
|
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
"""Return a dict representation of the event."""
|
"""Return a dict representation of the event."""
|
||||||
|
@ -79,16 +102,29 @@ class PipelineEvent:
|
||||||
PipelineEventCallback = Callable[[PipelineEvent], None]
|
PipelineEventCallback = Callable[[PipelineEvent], None]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""A voice assistant pipeline."""
|
"""A voice assistant pipeline."""
|
||||||
|
|
||||||
name: str
|
|
||||||
language: str | None
|
|
||||||
stt_engine: str | None
|
|
||||||
conversation_engine: str | None
|
conversation_engine: str | None
|
||||||
|
language: str | None
|
||||||
|
name: str
|
||||||
|
stt_engine: str | None
|
||||||
tts_engine: str | None
|
tts_engine: str | None
|
||||||
|
|
||||||
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
|
|
||||||
|
def to_json(self) -> dict[str, Any]:
|
||||||
|
"""Return a JSON serializable representation for storage."""
|
||||||
|
return {
|
||||||
|
"conversation_engine": self.conversation_engine,
|
||||||
|
"id": self.id,
|
||||||
|
"language": self.language,
|
||||||
|
"name": self.name,
|
||||||
|
"stt_engine": self.stt_engine,
|
||||||
|
"tts_engine": self.tts_engine,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PipelineStage(StrEnum):
|
class PipelineStage(StrEnum):
|
||||||
"""Stages of a pipeline."""
|
"""Stages of a pipeline."""
|
||||||
|
@ -478,3 +514,47 @@ class PipelineInput:
|
||||||
|
|
||||||
if prepare_tasks:
|
if prepare_tasks:
|
||||||
await asyncio.gather(*prepare_tasks)
|
await asyncio.gather(*prepare_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStorageCollection(StorageCollection[Pipeline]):
|
||||||
|
"""Pipeline storage collection."""
|
||||||
|
|
||||||
|
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
||||||
|
|
||||||
|
async def _process_create_data(self, data: dict) -> dict:
|
||||||
|
"""Validate the config is valid."""
|
||||||
|
# We don't need to validate, the WS API has already validated
|
||||||
|
return data
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _get_suggested_id(self, info: dict) -> str:
|
||||||
|
"""Suggest an ID based on the config."""
|
||||||
|
return ulid_util.ulid()
|
||||||
|
|
||||||
|
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
|
||||||
|
"""Return a new updated item."""
|
||||||
|
return Pipeline(id=item.id, **update_data)
|
||||||
|
|
||||||
|
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||||
|
"""Create an item from validated config."""
|
||||||
|
return Pipeline(id=item_id, **data)
|
||||||
|
|
||||||
|
def _deserialize_item(self, data: dict) -> Pipeline:
|
||||||
|
"""Create an item from its serialized representation."""
|
||||||
|
return Pipeline(**data)
|
||||||
|
|
||||||
|
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
|
||||||
|
"""Return the serialized representation of an item."""
|
||||||
|
return item.to_json()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_pipeline_store(hass):
|
||||||
|
"""Set up the pipeline storage collection."""
|
||||||
|
pipeline_store = PipelineStorageCollection(
|
||||||
|
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
)
|
||||||
|
await pipeline_store.async_load()
|
||||||
|
StorageCollectionWebsocket(
|
||||||
|
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
||||||
|
).async_setup(hass)
|
||||||
|
hass.data[DOMAIN] = pipeline_store
|
||||||
|
|
|
@ -61,7 +61,7 @@ async def websocket_run(
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
pipeline_id = msg.get("pipeline")
|
pipeline_id = msg.get("pipeline")
|
||||||
pipeline = async_get_pipeline(
|
pipeline = await async_get_pipeline(
|
||||||
hass,
|
hass,
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
language=language,
|
language=language,
|
||||||
|
|
|
@ -117,7 +117,7 @@ async def mock_stt_provider(hass) -> MockSttProvider:
|
||||||
return MockSttProvider(hass, _TRANSCRIPT)
|
return MockSttProvider(hass, _TRANSCRIPT)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture
|
||||||
async def init_components(
|
async def init_components(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSttProvider,
|
||||||
|
|
|
@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream(
|
async def test_pipeline_from_audio_stream(
|
||||||
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion
|
hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating a pipeline from an audio stream."""
|
"""Test creating a pipeline from an audio stream."""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.voice_assistant.const import DOMAIN
|
||||||
|
from homeassistant.components.voice_assistant.pipeline import (
|
||||||
|
STORAGE_KEY,
|
||||||
|
STORAGE_VERSION,
|
||||||
|
PipelineStorageCollection,
|
||||||
|
)
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.storage import Store
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import flush_store
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||||
|
"""Make sure that we can load/save data correctly."""
|
||||||
|
|
||||||
|
pipelines = [
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_1",
|
||||||
|
"language": "language_1",
|
||||||
|
"name": "name_1",
|
||||||
|
"stt_engine": "stt_engine_1",
|
||||||
|
"tts_engine": "tts_engine_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
"language": "language_2",
|
||||||
|
"name": "name_2",
|
||||||
|
"stt_engine": "stt_engine_2",
|
||||||
|
"tts_engine": "tts_engine_2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
"language": "language_3",
|
||||||
|
"name": "name_3",
|
||||||
|
"stt_engine": "stt_engine_3",
|
||||||
|
"tts_engine": "tts_engine_3",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
pipeline_ids = []
|
||||||
|
|
||||||
|
store1: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
for pipeline in pipelines:
|
||||||
|
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||||
|
assert len(store1.data) == 3
|
||||||
|
|
||||||
|
await store1.async_delete_item(pipeline_ids[1])
|
||||||
|
assert len(store1.data) == 2
|
||||||
|
|
||||||
|
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
|
||||||
|
await flush_store(store1.store)
|
||||||
|
await store2.async_load()
|
||||||
|
|
||||||
|
assert len(store2.data) == 2
|
||||||
|
|
||||||
|
assert store1.data is not store2.data
|
||||||
|
assert store1.data == store2.data
|
||||||
|
|
||||||
|
|
||||||
|
async def test_loading_datasets_from_storage(
|
||||||
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test loading stored datasets on start."""
|
||||||
|
hass_storage[STORAGE_KEY] = {
|
||||||
|
"version": 1,
|
||||||
|
"minor_version": 1,
|
||||||
|
"key": "voice_assistant.pipelines",
|
||||||
|
"data": {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_1",
|
||||||
|
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
|
"language": "language_1",
|
||||||
|
"name": "name_1",
|
||||||
|
"stt_engine": "stt_engine_1",
|
||||||
|
"tts_engine": "tts_engine_1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||||
|
"language": "language_2",
|
||||||
|
"name": "name_2",
|
||||||
|
"stt_engine": "stt_engine_2",
|
||||||
|
"tts_engine": "tts_engine_2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||||
|
"language": "language_3",
|
||||||
|
"name": "name_3",
|
||||||
|
"stt_engine": "stt_engine_3",
|
||||||
|
"tts_engine": "tts_engine_3",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "voice_assistant", {})
|
||||||
|
|
||||||
|
store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
assert len(store.data) == 3
|
|
@ -1,9 +1,14 @@
|
||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.components.voice_assistant.const import DOMAIN
|
||||||
|
from homeassistant.components.voice_assistant.pipeline import (
|
||||||
|
Pipeline,
|
||||||
|
PipelineStorageCollection,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
@ -12,6 +17,7 @@ from tests.typing import WebSocketGenerator
|
||||||
async def test_text_only_pipeline(
|
async def test_text_only_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||||
|
@ -51,7 +57,10 @@ async def test_text_only_pipeline(
|
||||||
|
|
||||||
|
|
||||||
async def test_audio_pipeline(
|
async def test_audio_pipeline(
|
||||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with audio input/output."""
|
"""Test events from a pipeline run with audio input/output."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
@ -271,6 +280,7 @@ async def test_audio_pipeline_timeout(
|
||||||
async def test_stt_provider_missing(
|
async def test_stt_provider_missing(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
|
@ -297,6 +307,7 @@ async def test_stt_provider_missing(
|
||||||
async def test_stt_stream_failed(
|
async def test_stt_stream_failed(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with a non-existent STT provider."""
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
||||||
|
@ -398,3 +409,205 @@ async def test_invalid_stage_order(
|
||||||
# result
|
# result
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert not msg["success"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_pipeline(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test we can add a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"id": ANY,
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
pipeline = pipeline_store.data[msg["result"]["id"]]
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="test_conversation_engine",
|
||||||
|
id=msg["result"]["id"],
|
||||||
|
language="test_language",
|
||||||
|
name="test_name",
|
||||||
|
stt_engine="test_stt_engine",
|
||||||
|
tts_engine="test_tts_engine",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_pipeline(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test we can delete a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/delete",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(pipeline_store.data) == 0
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/delete",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": f"Unable to find pipeline_id {pipeline_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_pipelines(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test we can list pipelines."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
|
await client.send_json_auto_id({"type": "voice_assistant/pipeline/list"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == [
|
||||||
|
{
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"id": ANY,
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_pipeline(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test we can list pipelines."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/update",
|
||||||
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"language": "new_language",
|
||||||
|
"name": "new_name",
|
||||||
|
"pipeline_id": "no_such_pipeline",
|
||||||
|
"stt_engine": "new_stt_engine",
|
||||||
|
"tts_engine": "new_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "Unable to find pipeline_id no_such_pipeline",
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "voice_assistant/pipeline/update",
|
||||||
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"language": "new_language",
|
||||||
|
"name": "new_name",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"stt_engine": "new_stt_engine",
|
||||||
|
"tts_engine": "new_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"language": "new_language",
|
||||||
|
"name": "new_name",
|
||||||
|
"id": pipeline_id,
|
||||||
|
"stt_engine": "new_stt_engine",
|
||||||
|
"tts_engine": "new_tts_engine",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
pipeline = pipeline_store.data[pipeline_id]
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="new_conversation_engine",
|
||||||
|
id=pipeline_id,
|
||||||
|
language="new_language",
|
||||||
|
name="new_name",
|
||||||
|
stt_engine="new_stt_engine",
|
||||||
|
tts_engine="new_tts_engine",
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue