diff --git a/homeassistant/components/application_credentials/__init__.py b/homeassistant/components/application_credentials/__init__.py index f57d6c82b7f..f1471f29666 100644 --- a/homeassistant/components/application_credentials/__init__.py +++ b/homeassistant/components/application_credentials/__init__.py @@ -149,7 +149,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await storage_collection.async_load() hass.data[DOMAIN][DATA_STORAGE] = storage_collection - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index c2eff90530b..5e194b2a31f 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -10,12 +10,15 @@ from typing import Any import voluptuous as vol from homeassistant.backports.enum import StrEnum -from homeassistant.components import conversation, media_source, stt, tts +from homeassistant.components import conversation, media_source, stt, tts, websocket_api from homeassistant.components.tts.media_source import ( generate_media_source_id as tts_generate_media_source_id, ) from homeassistant.core import Context, HomeAssistant, callback from homeassistant.helpers.collection import ( + CollectionError, + ItemNotFound, + SerializedStorageCollection, StorageCollection, StorageCollectionWebsocket, ) @@ -533,11 +536,39 @@ class PipelineInput: await asyncio.gather(*prepare_tasks) -class PipelineStorageCollection(StorageCollection[Pipeline]): +class PipelinePreferred(CollectionError): + """Raised when attempting to delete the preferred pipelen.""" + + def __init__(self, item_id: str) -> None: + """Initialize pipeline preferred error.""" + super().__init__(f"Item {item_id} preferred.") + self.item_id = item_id + + +class SerializedPipelineStorageCollection(SerializedStorageCollection): + """Serialized pipeline storage collection.""" + + preferred_item: str | None + + +class PipelineStorageCollection( + StorageCollection[Pipeline, SerializedPipelineStorageCollection] +): """Pipeline storage collection.""" CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) + _preferred_item: str | None = None + + async def _async_load_data(self) -> SerializedPipelineStorageCollection | None: + """Load the data.""" + if not (data := await super()._async_load_data()): + return data + + self._preferred_item = data["preferred_item"] + + return data + async def _process_create_data(self, data: dict) -> dict: """Validate the config is valid.""" # We don't need to validate, the WS API has already validated @@ -554,6 +585,8 @@ class PipelineStorageCollection(StorageCollection[Pipeline]): def _create_item(self, item_id: str, data: dict) -> Pipeline: """Create an item from validated config.""" + if self._preferred_item is None: + self._preferred_item = item_id return Pipeline(id=item_id, **data) def _deserialize_item(self, data: dict) -> Pipeline: @@ -561,9 +594,107 @@ class PipelineStorageCollection(StorageCollection[Pipeline]): return Pipeline(**data) def _serialize_item(self, item_id: str, item: Pipeline) -> dict: - """Return the serialized representation of an item.""" + """Return the serialized representation of an item for storing.""" return item.to_json() + async def async_delete_item(self, item_id: str) -> None: + """Delete item.""" + if self._preferred_item == item_id: + raise PipelinePreferred(item_id) + await super().async_delete_item(item_id) + + @callback + def async_get_preferred_item(self) -> str | None: + """Get the id of the preferred item.""" + return self._preferred_item + + @callback + def async_set_preferred_item(self, item_id: str) -> None: + """Set the preferred pipeline.""" + if item_id not in self.data: + raise ItemNotFound(item_id) + self._preferred_item = item_id + self._async_schedule_save() + + @callback + def _data_to_save(self) -> SerializedPipelineStorageCollection: + """Return JSON-compatible date for storing to file.""" + base_data = super()._base_data_to_save() + return { + "items": base_data["items"], + "preferred_item": self._preferred_item, + } + + +class PipelineStorageCollectionWebsocket( + StorageCollectionWebsocket[PipelineStorageCollection] +): + """Class to expose storage collection management over websocket.""" + + @callback + def async_setup( + self, + hass: HomeAssistant, + *, + create_list: bool = True, + create_create: bool = True, + ) -> None: + """Set up the websocket commands.""" + super().async_setup(hass, create_list=create_list, create_create=create_create) + + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/set_preferred", + websocket_api.require_admin( + websocket_api.async_response(self.ws_set_preferred_item) + ), + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + vol.Required("type"): f"{self.api_prefix}/set_preferred", + vol.Required(self.item_id_key): str, + } + ), + ) + + def ws_list_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """List items.""" + connection.send_result( + msg["id"], + { + "pipelines": self.storage_collection.async_items(), + "preferred_pipeline": self.storage_collection.async_get_preferred_item(), + }, + ) + + async def ws_delete_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Delete an item.""" + try: + await super().ws_delete_item(hass, connection, msg) + except PipelinePreferred as exc: + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc) + ) + + async def ws_set_preferred_item( + self, + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], + ) -> None: + """Set the preferred item.""" + try: + self.storage_collection.async_set_preferred_item(msg[self.item_id_key]) + except ItemNotFound: + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item" + ) + return + connection.send_result(msg["id"]) + async def async_setup_pipeline_store(hass): """Set up the pipeline storage collection.""" @@ -571,7 +702,7 @@ async def async_setup_pipeline_store(hass): Store(hass, STORAGE_VERSION, STORAGE_KEY) ) await pipeline_store.async_load() - StorageCollectionWebsocket( + PipelineStorageCollectionWebsocket( pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) hass.data[DOMAIN] = pipeline_store diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index db739f3f0db..768491f6085 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -117,7 +117,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/image_upload/__init__.py b/homeassistant/components/image_upload/__init__.py index 452b23d27be..17c40cfc875 100644 --- a/homeassistant/components/image_upload/__init__.py +++ b/homeassistant/components/image_upload/__init__.py @@ -44,7 +44,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: image_dir = pathlib.Path(hass.config.path("image")) hass.data[DOMAIN] = storage_collection = ImageStorageCollection(hass, image_dir) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, "image", "image", diff --git a/homeassistant/components/input_boolean/__init__.py b/homeassistant/components/input_boolean/__init__.py index 49dcf731f7b..33cb4b9e576 100644 --- a/homeassistant/components/input_boolean/__init__.py +++ b/homeassistant/components/input_boolean/__init__.py @@ -121,7 +121,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/input_button/__init__.py b/homeassistant/components/input_button/__init__.py index d9693a208c1..8a1f0785435 100644 --- a/homeassistant/components/input_button/__init__.py +++ b/homeassistant/components/input_button/__init__.py @@ -106,7 +106,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/input_datetime/__init__.py b/homeassistant/components/input_datetime/__init__.py index c927b71c77e..c51c0fdd67c 100644 --- a/homeassistant/components/input_datetime/__init__.py +++ b/homeassistant/components/input_datetime/__init__.py @@ -159,7 +159,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index 9f77bb0a828..061b388ace5 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/input_select/__init__.py b/homeassistant/components/input_select/__init__.py index b7a026352d0..186ab84fb81 100644 --- a/homeassistant/components/input_select/__init__.py +++ b/homeassistant/components/input_select/__init__.py @@ -167,7 +167,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/input_text/__init__.py b/homeassistant/components/input_text/__init__.py index f246779b64c..efd58e38e72 100644 --- a/homeassistant/components/input_text/__init__.py +++ b/homeassistant/components/input_text/__init__.py @@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/lovelace/__init__.py b/homeassistant/components/lovelace/__init__.py index f880f83d766..1412aa085c8 100644 --- a/homeassistant/components/lovelace/__init__.py +++ b/homeassistant/components/lovelace/__init__.py @@ -119,7 +119,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: resource_collection = resources.ResourceStorageCollection(hass, default_config) - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( resource_collection, "lovelace/resources", "resource", @@ -198,7 +198,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: dashboards_collection.async_add_listener(storage_dashboard_changed) await dashboards_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( dashboards_collection, "lovelace/dashboards", "dashboard", diff --git a/homeassistant/components/person/__init__.py b/homeassistant/components/person/__init__.py index ba11250f83e..c1373ce1df9 100644 --- a/homeassistant/components/person/__init__.py +++ b/homeassistant/components/person/__init__.py @@ -354,7 +354,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.data[DOMAIN] = (yaml_collection, storage_collection, entity_component) - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS ).async_setup(hass, create_list=False) diff --git a/homeassistant/components/schedule/__init__.py b/homeassistant/components/schedule/__init__.py index 3e91e8ab86d..2e5fcc27715 100644 --- a/homeassistant/components/schedule/__init__.py +++ b/homeassistant/components/schedule/__init__.py @@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.helpers.collection import ( CollectionEntity, DictStorageCollection, + DictStorageCollectionWebsocket, IDManager, SerializedStorageCollection, - StorageCollectionWebsocket, YamlCollection, sync_entity_lifecycle, ) @@ -182,7 +182,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - StorageCollectionWebsocket( + DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, diff --git a/homeassistant/components/tag/__init__.py b/homeassistant/components/tag/__init__.py index 363c28cc3f8..cd0dd00afe5 100644 --- a/homeassistant/components/tag/__init__.py +++ b/homeassistant/components/tag/__init__.py @@ -98,7 +98,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: id_manager, ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/timer/__init__.py b/homeassistant/components/timer/__init__.py index 214d95c72e5..7cb2c10425e 100644 --- a/homeassistant/components/timer/__init__.py +++ b/homeassistant/components/timer/__init__.py @@ -130,7 +130,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) diff --git a/homeassistant/components/zone/__init__.py b/homeassistant/components/zone/__init__.py index cad92a2978c..2133c8550da 100644 --- a/homeassistant/components/zone/__init__.py +++ b/homeassistant/components/zone/__init__.py @@ -209,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await storage_collection.async_load() - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS ).async_setup(hass) diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 29151221a89..2526a210d70 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -32,7 +32,9 @@ CHANGE_ADDED = "added" CHANGE_UPDATED = "updated" CHANGE_REMOVED = "removed" -_T = TypeVar("_T") +_ItemT = TypeVar("_ItemT") +_StoreT = TypeVar("_StoreT", bound="SerializedStorageCollection") +_StorageCollectionT = TypeVar("_StorageCollectionT", bound="StorageCollection") @dataclass(slots=True) @@ -123,20 +125,20 @@ class CollectionEntity(Entity): """Handle updated configuration.""" -class ObservableCollection(ABC, Generic[_T]): +class ObservableCollection(ABC, Generic[_ItemT]): """Base collection type that can be observed.""" def __init__(self, id_manager: IDManager | None) -> None: """Initialize the base collection.""" self.id_manager = id_manager or IDManager() - self.data: dict[str, _T] = {} + self.data: dict[str, _ItemT] = {} self.listeners: list[ChangeListener] = [] self.change_set_listeners: list[ChangeSetListener] = [] self.id_manager.add_collection(self.data) @callback - def async_items(self) -> list[_T]: + def async_items(self) -> list[_ItemT]: """Return list of items in collection.""" return list(self.data.values()) @@ -226,12 +228,12 @@ class SerializedStorageCollection(TypedDict): items: list[dict[str, Any]] -class StorageCollection(ObservableCollection[_T], ABC): +class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]): """Offer a CRUD interface on top of JSON storage.""" def __init__( self, - store: Store[SerializedStorageCollection], + store: Store[_StoreT], id_manager: IDManager | None = None, ) -> None: """Initialize the storage collection.""" @@ -250,16 +252,14 @@ class StorageCollection(ObservableCollection[_T], ABC): """Home Assistant object.""" return self.store.hass - async def _async_load_data(self) -> SerializedStorageCollection | None: + async def _async_load_data(self) -> _StoreT | None: """Load the data.""" return await self.store.async_load() async def async_load(self) -> None: """Load the storage Manager.""" - raw_storage = await self._async_load_data() - - if raw_storage is None: - raw_storage = {"items": []} + if not (raw_storage := await self._async_load_data()): + return for item in raw_storage["items"]: self.data[item[CONF_ID]] = self._deserialize_item(item) @@ -281,25 +281,25 @@ class StorageCollection(ObservableCollection[_T], ABC): """Suggest an ID based on the config.""" @abstractmethod - async def _update_data(self, item: _T, update_data: dict) -> _T: + async def _update_data(self, item: _ItemT, update_data: dict) -> _ItemT: """Return a new updated item.""" @abstractmethod - def _create_item(self, item_id: str, data: dict) -> _T: + def _create_item(self, item_id: str, data: dict) -> _ItemT: """Create an item from validated config.""" @abstractmethod - def _deserialize_item(self, data: dict) -> _T: + def _deserialize_item(self, data: dict) -> _ItemT: """Create an item from its serialized representation.""" @abstractmethod - def _serialize_item(self, item_id: str, item: _T) -> dict: - """Return the serialized representation of an item. + def _serialize_item(self, item_id: str, item: _ItemT) -> dict: + """Return the serialized representation of an item for storing. The serialized representation must include the item_id in the "id" key. """ - async def async_create_item(self, data: dict) -> _T: + async def async_create_item(self, data: dict) -> _ItemT: """Create a new item.""" validated_data = await self._process_create_data(data) item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data)) @@ -309,7 +309,7 @@ class StorageCollection(ObservableCollection[_T], ABC): await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)]) return item - async def async_update_item(self, item_id: str, updates: dict) -> _T: + async def async_update_item(self, item_id: str, updates: dict) -> _ItemT: """Update item.""" if item_id not in self.data: raise ItemNotFound(item_id) @@ -346,8 +346,8 @@ class StorageCollection(ObservableCollection[_T], ABC): self.store.async_delay_save(self._data_to_save, SAVE_DELAY) @callback - def _data_to_save(self) -> SerializedStorageCollection: - """Return JSON-compatible date for storing to file.""" + def _base_data_to_save(self) -> SerializedStorageCollection: + """Return JSON-compatible data for storing to file.""" return { "items": [ self._serialize_item(item_id, item) @@ -355,8 +355,13 @@ class StorageCollection(ObservableCollection[_T], ABC): ] } + @abstractmethod + @callback + def _data_to_save(self) -> _StoreT: + """Return JSON-compatible date for storing to file.""" -class DictStorageCollection(StorageCollection[dict]): + +class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection]): """A specialized StorageCollection where the items are untyped dicts.""" def _create_item(self, item_id: str, data: dict) -> dict: @@ -368,9 +373,14 @@ class DictStorageCollection(StorageCollection[dict]): return data def _serialize_item(self, item_id: str, item: dict) -> dict: - """Return the serialized representation of an item.""" + """Return the serialized representation of an item for storing.""" return item + @callback + def _data_to_save(self) -> SerializedStorageCollection: + """Return JSON-compatible date for storing to file.""" + return self._base_data_to_save() + class IDLessCollection(YamlCollection): """A collection without IDs.""" @@ -477,12 +487,12 @@ def sync_entity_lifecycle( collection.async_add_change_set_listener(_collection_changed) -class StorageCollectionWebsocket: +class StorageCollectionWebsocket(Generic[_StorageCollectionT]): """Class to expose storage collection management over websocket.""" def __init__( self, - storage_collection: StorageCollection, + storage_collection: _StorageCollectionT, api_prefix: str, model_name: str, create_schema: dict, @@ -635,3 +645,7 @@ class StorageCollectionWebsocket: ) connection.send_result(msg["id"]) + + +class DictStorageCollectionWebsocket(StorageCollectionWebsocket[DictStorageCollection]): + """Class to expose storage collection management over websocket.""" diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 317d395a1e1..1898e3d2237 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -46,6 +46,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: for pipeline in pipelines: pipeline_ids.append((await store1.async_create_item(pipeline)).id) assert len(store1.data) == 3 + assert store1.async_get_preferred_item() == list(store1.data)[0] await store1.async_delete_item(pipeline_ids[1]) assert len(store1.data) == 2 @@ -58,6 +59,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: assert store1.data is not store2.data assert store1.data == store2.data + assert store1.async_get_preferred_item() == store2.async_get_preferred_item() async def test_loading_datasets_from_storage( @@ -94,7 +96,8 @@ async def test_loading_datasets_from_storage( "stt_engine": "stt_engine_3", "tts_engine": "tts_engine_3", }, - ] + ], + "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", }, } @@ -102,3 +105,4 @@ async def test_loading_datasets_from_storage( store: PipelineStorageCollection = hass.data[DOMAIN] assert len(store.data) == 3 + assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 0a6872cf4ec..fad2221fe99 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -482,31 +482,58 @@ async def test_delete_pipeline( ) msg = await client.receive_json() assert msg["success"] - assert len(pipeline_store.data) == 1 - - pipeline_id = msg["result"]["id"] + pipeline_id_1 = msg["result"]["id"] await client.send_json_auto_id( { - "type": "assist_pipeline/pipeline/delete", - "pipeline_id": pipeline_id, + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", } ) msg = await client.receive_json() assert msg["success"] - assert len(pipeline_store.data) == 0 + pipeline_id_2 = msg["result"]["id"] + + assert len(pipeline_store.data) == 2 await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/delete", - "pipeline_id": pipeline_id, + "pipeline_id": pipeline_id_1, + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_allowed", + "message": f"Item {pipeline_id_1} preferred.", + } + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/delete", + "pipeline_id": pipeline_id_2, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert len(pipeline_store.data) == 1 + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/delete", + "pipeline_id": pipeline_id_2, } ) msg = await client.receive_json() assert not msg["success"] assert msg["error"] == { "code": "not_found", - "message": f"Unable to find pipeline_id {pipeline_id}", + "message": f"Unable to find pipeline_id {pipeline_id_2}", } @@ -520,7 +547,7 @@ async def test_list_pipelines( await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == [] + assert msg["result"] == {"pipelines": [], "preferred_pipeline": None} await client.send_json_auto_id( { @@ -539,16 +566,19 @@ async def test_list_pipelines( await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) msg = await client.receive_json() assert msg["success"] - assert msg["result"] == [ - { - "conversation_engine": "test_conversation_engine", - "id": ANY, - "language": "test_language", - "name": "test_name", - "stt_engine": "test_stt_engine", - "tts_engine": "test_tts_engine", - } - ] + assert msg["result"] == { + "pipelines": [ + { + "conversation_engine": "test_conversation_engine", + "id": ANY, + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ], + "preferred_pipeline": ANY, + } async def test_update_pipeline( @@ -606,9 +636,9 @@ async def test_update_pipeline( assert msg["success"] assert msg["result"] == { "conversation_engine": "new_conversation_engine", + "id": pipeline_id, "language": "new_language", "name": "new_name", - "id": pipeline_id, "stt_engine": "new_stt_engine", "tts_engine": "new_tts_engine", } @@ -623,3 +653,65 @@ async def test_update_pipeline( stt_engine="new_stt_engine", tts_engine="new_tts_engine", ) + + +async def test_set_preferred_pipeline( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test updating the preferred pipeline.""" + client = await hass_ws_client(hass) + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id_1 = msg["result"]["id"] + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id_2 = msg["result"]["id"] + + assert pipeline_store.async_get_preferred_item() == pipeline_id_1 + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/set_preferred", + "pipeline_id": pipeline_id_2, + } + ) + msg = await client.receive_json() + assert msg["success"] + + assert pipeline_store.async_get_preferred_item() == pipeline_id_2 + + +async def test_set_preferred_pipeline_wrong_id( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test updating the preferred pipeline.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + {"type": "assist_pipeline/pipeline/set_preferred", "pipeline_id": "don_t_exist"} + ) + msg = await client.receive_json() + assert msg["error"]["code"] == "not_found" diff --git a/tests/helpers/test_collection.py b/tests/helpers/test_collection.py index 52c7f899a6a..7969e02ab2f 100644 --- a/tests/helpers/test_collection.py +++ b/tests/helpers/test_collection.py @@ -436,7 +436,7 @@ async def test_storage_collection_websocket( store = storage.Store(hass, 1, "test-data") coll = MockStorageCollection(store) changes = track_changes(coll) - collection.StorageCollectionWebsocket( + collection.DictStorageCollectionWebsocket( coll, "test_item/collection", "test_item",