Support marking an assist pipeline as preferred (#91418)
* Support marking an assist pipeline as preferred * Adjust * Revert unneeded change * Send preferred pipeline id in pipeline list * Don't use property functions for the preferred pipelinepull/91463/head
parent
714ec3f023
commit
8f8a398631
|
@ -149,7 +149,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
await storage_collection.async_load()
|
||||
hass.data[DOMAIN][DATA_STORAGE] = storage_collection
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -10,12 +10,15 @@ from typing import Any
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.collection import (
|
||||
CollectionError,
|
||||
ItemNotFound,
|
||||
SerializedStorageCollection,
|
||||
StorageCollection,
|
||||
StorageCollectionWebsocket,
|
||||
)
|
||||
|
@ -533,11 +536,39 @@ class PipelineInput:
|
|||
await asyncio.gather(*prepare_tasks)
|
||||
|
||||
|
||||
class PipelineStorageCollection(StorageCollection[Pipeline]):
|
||||
class PipelinePreferred(CollectionError):
|
||||
"""Raised when attempting to delete the preferred pipelen."""
|
||||
|
||||
def __init__(self, item_id: str) -> None:
|
||||
"""Initialize pipeline preferred error."""
|
||||
super().__init__(f"Item {item_id} preferred.")
|
||||
self.item_id = item_id
|
||||
|
||||
|
||||
class SerializedPipelineStorageCollection(SerializedStorageCollection):
|
||||
"""Serialized pipeline storage collection."""
|
||||
|
||||
preferred_item: str | None
|
||||
|
||||
|
||||
class PipelineStorageCollection(
|
||||
StorageCollection[Pipeline, SerializedPipelineStorageCollection]
|
||||
):
|
||||
"""Pipeline storage collection."""
|
||||
|
||||
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
||||
|
||||
_preferred_item: str | None = None
|
||||
|
||||
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
||||
"""Load the data."""
|
||||
if not (data := await super()._async_load_data()):
|
||||
return data
|
||||
|
||||
self._preferred_item = data["preferred_item"]
|
||||
|
||||
return data
|
||||
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
"""Validate the config is valid."""
|
||||
# We don't need to validate, the WS API has already validated
|
||||
|
@ -554,6 +585,8 @@ class PipelineStorageCollection(StorageCollection[Pipeline]):
|
|||
|
||||
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||
"""Create an item from validated config."""
|
||||
if self._preferred_item is None:
|
||||
self._preferred_item = item_id
|
||||
return Pipeline(id=item_id, **data)
|
||||
|
||||
def _deserialize_item(self, data: dict) -> Pipeline:
|
||||
|
@ -561,9 +594,107 @@ class PipelineStorageCollection(StorageCollection[Pipeline]):
|
|||
return Pipeline(**data)
|
||||
|
||||
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
|
||||
"""Return the serialized representation of an item."""
|
||||
"""Return the serialized representation of an item for storing."""
|
||||
return item.to_json()
|
||||
|
||||
async def async_delete_item(self, item_id: str) -> None:
|
||||
"""Delete item."""
|
||||
if self._preferred_item == item_id:
|
||||
raise PipelinePreferred(item_id)
|
||||
await super().async_delete_item(item_id)
|
||||
|
||||
@callback
|
||||
def async_get_preferred_item(self) -> str | None:
|
||||
"""Get the id of the preferred item."""
|
||||
return self._preferred_item
|
||||
|
||||
@callback
|
||||
def async_set_preferred_item(self, item_id: str) -> None:
|
||||
"""Set the preferred pipeline."""
|
||||
if item_id not in self.data:
|
||||
raise ItemNotFound(item_id)
|
||||
self._preferred_item = item_id
|
||||
self._async_schedule_save()
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> SerializedPipelineStorageCollection:
|
||||
"""Return JSON-compatible date for storing to file."""
|
||||
base_data = super()._base_data_to_save()
|
||||
return {
|
||||
"items": base_data["items"],
|
||||
"preferred_item": self._preferred_item,
|
||||
}
|
||||
|
||||
|
||||
class PipelineStorageCollectionWebsocket(
|
||||
StorageCollectionWebsocket[PipelineStorageCollection]
|
||||
):
|
||||
"""Class to expose storage collection management over websocket."""
|
||||
|
||||
@callback
|
||||
def async_setup(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
create_list: bool = True,
|
||||
create_create: bool = True,
|
||||
) -> None:
|
||||
"""Set up the websocket commands."""
|
||||
super().async_setup(hass, create_list=create_list, create_create=create_create)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/set_preferred",
|
||||
websocket_api.require_admin(
|
||||
websocket_api.async_response(self.ws_set_preferred_item)
|
||||
),
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): f"{self.api_prefix}/set_preferred",
|
||||
vol.Required(self.item_id_key): str,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def ws_list_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""List items."""
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"pipelines": self.storage_collection.async_items(),
|
||||
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
|
||||
},
|
||||
)
|
||||
|
||||
async def ws_delete_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Delete an item."""
|
||||
try:
|
||||
await super().ws_delete_item(hass, connection, msg)
|
||||
except PipelinePreferred as exc:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
||||
)
|
||||
|
||||
async def ws_set_preferred_item(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set the preferred item."""
|
||||
try:
|
||||
self.storage_collection.async_set_preferred_item(msg[self.item_id_key])
|
||||
except ItemNotFound:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item"
|
||||
)
|
||||
return
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
async def async_setup_pipeline_store(hass):
|
||||
"""Set up the pipeline storage collection."""
|
||||
|
@ -571,7 +702,7 @@ async def async_setup_pipeline_store(hass):
|
|||
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
)
|
||||
await pipeline_store.async_load()
|
||||
StorageCollectionWebsocket(
|
||||
PipelineStorageCollectionWebsocket(
|
||||
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
hass.data[DOMAIN] = pipeline_store
|
||||
|
|
|
@ -117,7 +117,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
image_dir = pathlib.Path(hass.config.path("image"))
|
||||
hass.data[DOMAIN] = storage_collection = ImageStorageCollection(hass, image_dir)
|
||||
await storage_collection.async_load()
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection,
|
||||
"image",
|
||||
"image",
|
||||
|
|
|
@ -121,7 +121,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
resource_collection = resources.ResourceStorageCollection(hass, default_config)
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
resource_collection,
|
||||
"lovelace/resources",
|
||||
"resource",
|
||||
|
@ -198,7 +198,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
dashboards_collection.async_add_listener(storage_dashboard_changed)
|
||||
await dashboards_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
dashboards_collection,
|
||||
"lovelace/dashboards",
|
||||
"dashboard",
|
||||
|
|
|
@ -354,7 +354,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
hass.data[DOMAIN] = (yaml_collection, storage_collection, entity_component)
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass, create_list=False)
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant, ServiceCall, callback
|
|||
from homeassistant.helpers.collection import (
|
||||
CollectionEntity,
|
||||
DictStorageCollection,
|
||||
DictStorageCollectionWebsocket,
|
||||
IDManager,
|
||||
SerializedStorageCollection,
|
||||
StorageCollectionWebsocket,
|
||||
YamlCollection,
|
||||
sync_entity_lifecycle,
|
||||
)
|
||||
|
@ -182,7 +182,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
StorageCollectionWebsocket(
|
||||
DictStorageCollectionWebsocket(
|
||||
storage_collection,
|
||||
DOMAIN,
|
||||
DOMAIN,
|
||||
|
|
|
@ -98,7 +98,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
id_manager,
|
||||
)
|
||||
await storage_collection.async_load()
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
)
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -209,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
await storage_collection.async_load()
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
|
|
|
@ -32,7 +32,9 @@ CHANGE_ADDED = "added"
|
|||
CHANGE_UPDATED = "updated"
|
||||
CHANGE_REMOVED = "removed"
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_ItemT = TypeVar("_ItemT")
|
||||
_StoreT = TypeVar("_StoreT", bound="SerializedStorageCollection")
|
||||
_StorageCollectionT = TypeVar("_StorageCollectionT", bound="StorageCollection")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -123,20 +125,20 @@ class CollectionEntity(Entity):
|
|||
"""Handle updated configuration."""
|
||||
|
||||
|
||||
class ObservableCollection(ABC, Generic[_T]):
|
||||
class ObservableCollection(ABC, Generic[_ItemT]):
|
||||
"""Base collection type that can be observed."""
|
||||
|
||||
def __init__(self, id_manager: IDManager | None) -> None:
|
||||
"""Initialize the base collection."""
|
||||
self.id_manager = id_manager or IDManager()
|
||||
self.data: dict[str, _T] = {}
|
||||
self.data: dict[str, _ItemT] = {}
|
||||
self.listeners: list[ChangeListener] = []
|
||||
self.change_set_listeners: list[ChangeSetListener] = []
|
||||
|
||||
self.id_manager.add_collection(self.data)
|
||||
|
||||
@callback
|
||||
def async_items(self) -> list[_T]:
|
||||
def async_items(self) -> list[_ItemT]:
|
||||
"""Return list of items in collection."""
|
||||
return list(self.data.values())
|
||||
|
||||
|
@ -226,12 +228,12 @@ class SerializedStorageCollection(TypedDict):
|
|||
items: list[dict[str, Any]]
|
||||
|
||||
|
||||
class StorageCollection(ObservableCollection[_T], ABC):
|
||||
class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
|
||||
"""Offer a CRUD interface on top of JSON storage."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: Store[SerializedStorageCollection],
|
||||
store: Store[_StoreT],
|
||||
id_manager: IDManager | None = None,
|
||||
) -> None:
|
||||
"""Initialize the storage collection."""
|
||||
|
@ -250,16 +252,14 @@ class StorageCollection(ObservableCollection[_T], ABC):
|
|||
"""Home Assistant object."""
|
||||
return self.store.hass
|
||||
|
||||
async def _async_load_data(self) -> SerializedStorageCollection | None:
|
||||
async def _async_load_data(self) -> _StoreT | None:
|
||||
"""Load the data."""
|
||||
return await self.store.async_load()
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the storage Manager."""
|
||||
raw_storage = await self._async_load_data()
|
||||
|
||||
if raw_storage is None:
|
||||
raw_storage = {"items": []}
|
||||
if not (raw_storage := await self._async_load_data()):
|
||||
return
|
||||
|
||||
for item in raw_storage["items"]:
|
||||
self.data[item[CONF_ID]] = self._deserialize_item(item)
|
||||
|
@ -281,25 +281,25 @@ class StorageCollection(ObservableCollection[_T], ABC):
|
|||
"""Suggest an ID based on the config."""
|
||||
|
||||
@abstractmethod
|
||||
async def _update_data(self, item: _T, update_data: dict) -> _T:
|
||||
async def _update_data(self, item: _ItemT, update_data: dict) -> _ItemT:
|
||||
"""Return a new updated item."""
|
||||
|
||||
@abstractmethod
|
||||
def _create_item(self, item_id: str, data: dict) -> _T:
|
||||
def _create_item(self, item_id: str, data: dict) -> _ItemT:
|
||||
"""Create an item from validated config."""
|
||||
|
||||
@abstractmethod
|
||||
def _deserialize_item(self, data: dict) -> _T:
|
||||
def _deserialize_item(self, data: dict) -> _ItemT:
|
||||
"""Create an item from its serialized representation."""
|
||||
|
||||
@abstractmethod
|
||||
def _serialize_item(self, item_id: str, item: _T) -> dict:
|
||||
"""Return the serialized representation of an item.
|
||||
def _serialize_item(self, item_id: str, item: _ItemT) -> dict:
|
||||
"""Return the serialized representation of an item for storing.
|
||||
|
||||
The serialized representation must include the item_id in the "id" key.
|
||||
"""
|
||||
|
||||
async def async_create_item(self, data: dict) -> _T:
|
||||
async def async_create_item(self, data: dict) -> _ItemT:
|
||||
"""Create a new item."""
|
||||
validated_data = await self._process_create_data(data)
|
||||
item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data))
|
||||
|
@ -309,7 +309,7 @@ class StorageCollection(ObservableCollection[_T], ABC):
|
|||
await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)])
|
||||
return item
|
||||
|
||||
async def async_update_item(self, item_id: str, updates: dict) -> _T:
|
||||
async def async_update_item(self, item_id: str, updates: dict) -> _ItemT:
|
||||
"""Update item."""
|
||||
if item_id not in self.data:
|
||||
raise ItemNotFound(item_id)
|
||||
|
@ -346,8 +346,8 @@ class StorageCollection(ObservableCollection[_T], ABC):
|
|||
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> SerializedStorageCollection:
|
||||
"""Return JSON-compatible date for storing to file."""
|
||||
def _base_data_to_save(self) -> SerializedStorageCollection:
|
||||
"""Return JSON-compatible data for storing to file."""
|
||||
return {
|
||||
"items": [
|
||||
self._serialize_item(item_id, item)
|
||||
|
@ -355,8 +355,13 @@ class StorageCollection(ObservableCollection[_T], ABC):
|
|||
]
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
@callback
|
||||
def _data_to_save(self) -> _StoreT:
|
||||
"""Return JSON-compatible date for storing to file."""
|
||||
|
||||
class DictStorageCollection(StorageCollection[dict]):
|
||||
|
||||
class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection]):
|
||||
"""A specialized StorageCollection where the items are untyped dicts."""
|
||||
|
||||
def _create_item(self, item_id: str, data: dict) -> dict:
|
||||
|
@ -368,9 +373,14 @@ class DictStorageCollection(StorageCollection[dict]):
|
|||
return data
|
||||
|
||||
def _serialize_item(self, item_id: str, item: dict) -> dict:
|
||||
"""Return the serialized representation of an item."""
|
||||
"""Return the serialized representation of an item for storing."""
|
||||
return item
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> SerializedStorageCollection:
|
||||
"""Return JSON-compatible date for storing to file."""
|
||||
return self._base_data_to_save()
|
||||
|
||||
|
||||
class IDLessCollection(YamlCollection):
|
||||
"""A collection without IDs."""
|
||||
|
@ -477,12 +487,12 @@ def sync_entity_lifecycle(
|
|||
collection.async_add_change_set_listener(_collection_changed)
|
||||
|
||||
|
||||
class StorageCollectionWebsocket:
|
||||
class StorageCollectionWebsocket(Generic[_StorageCollectionT]):
|
||||
"""Class to expose storage collection management over websocket."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_collection: StorageCollection,
|
||||
storage_collection: _StorageCollectionT,
|
||||
api_prefix: str,
|
||||
model_name: str,
|
||||
create_schema: dict,
|
||||
|
@ -635,3 +645,7 @@ class StorageCollectionWebsocket:
|
|||
)
|
||||
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
class DictStorageCollectionWebsocket(StorageCollectionWebsocket[DictStorageCollection]):
|
||||
"""Class to expose storage collection management over websocket."""
|
||||
|
|
|
@ -46,6 +46,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
|||
for pipeline in pipelines:
|
||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||
assert len(store1.data) == 3
|
||||
assert store1.async_get_preferred_item() == list(store1.data)[0]
|
||||
|
||||
await store1.async_delete_item(pipeline_ids[1])
|
||||
assert len(store1.data) == 2
|
||||
|
@ -58,6 +59,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
|||
|
||||
assert store1.data is not store2.data
|
||||
assert store1.data == store2.data
|
||||
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
|
||||
|
||||
|
||||
async def test_loading_datasets_from_storage(
|
||||
|
@ -94,7 +96,8 @@ async def test_loading_datasets_from_storage(
|
|||
"stt_engine": "stt_engine_3",
|
||||
"tts_engine": "tts_engine_3",
|
||||
},
|
||||
]
|
||||
],
|
||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -102,3 +105,4 @@ async def test_loading_datasets_from_storage(
|
|||
|
||||
store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
assert len(store.data) == 3
|
||||
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
|
|
|
@ -482,31 +482,58 @@ async def test_delete_pipeline(
|
|||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline_id_1 = msg["result"]["id"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/delete",
|
||||
"pipeline_id": pipeline_id,
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 0
|
||||
pipeline_id_2 = msg["result"]["id"]
|
||||
|
||||
assert len(pipeline_store.data) == 2
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/delete",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_id": pipeline_id_1,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_allowed",
|
||||
"message": f"Item {pipeline_id_1} preferred.",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/delete",
|
||||
"pipeline_id": pipeline_id_2,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/delete",
|
||||
"pipeline_id": pipeline_id_2,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": f"Unable to find pipeline_id {pipeline_id}",
|
||||
"message": f"Unable to find pipeline_id {pipeline_id_2}",
|
||||
}
|
||||
|
||||
|
||||
|
@ -520,7 +547,7 @@ async def test_list_pipelines(
|
|||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == []
|
||||
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
@ -539,16 +566,19 @@ async def test_list_pipelines(
|
|||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == [
|
||||
{
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": ANY,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
]
|
||||
assert msg["result"] == {
|
||||
"pipelines": [
|
||||
{
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": ANY,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
],
|
||||
"preferred_pipeline": ANY,
|
||||
}
|
||||
|
||||
|
||||
async def test_update_pipeline(
|
||||
|
@ -606,9 +636,9 @@ async def test_update_pipeline(
|
|||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "new_conversation_engine",
|
||||
"id": pipeline_id,
|
||||
"language": "new_language",
|
||||
"name": "new_name",
|
||||
"id": pipeline_id,
|
||||
"stt_engine": "new_stt_engine",
|
||||
"tts_engine": "new_tts_engine",
|
||||
}
|
||||
|
@ -623,3 +653,65 @@ async def test_update_pipeline(
|
|||
stt_engine="new_stt_engine",
|
||||
tts_engine="new_tts_engine",
|
||||
)
|
||||
|
||||
|
||||
async def test_set_preferred_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test updating the preferred pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id_1 = msg["result"]["id"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id_2 = msg["result"]["id"]
|
||||
|
||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/set_preferred",
|
||||
"pipeline_id": pipeline_id_2,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
|
||||
|
||||
|
||||
async def test_set_preferred_pipeline_wrong_id(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test updating the preferred pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{"type": "assist_pipeline/pipeline/set_preferred", "pipeline_id": "don_t_exist"}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["error"]["code"] == "not_found"
|
||||
|
|
|
@ -436,7 +436,7 @@ async def test_storage_collection_websocket(
|
|||
store = storage.Store(hass, 1, "test-data")
|
||||
coll = MockStorageCollection(store)
|
||||
changes = track_changes(coll)
|
||||
collection.StorageCollectionWebsocket(
|
||||
collection.DictStorageCollectionWebsocket(
|
||||
coll,
|
||||
"test_item/collection",
|
||||
"test_item",
|
||||
|
|
Loading…
Reference in New Issue