Support marking an assist pipeline as preferred (#91418)

* Support marking an assist pipeline as preferred

* Adjust

* Revert unneeded change

* Send preferred pipeline id in pipeline list

* Don't use property functions for the preferred pipeline
pull/91463/head
Erik Montnemery 2023-04-15 16:05:46 +02:00 committed by GitHub
parent 714ec3f023
commit 8f8a398631
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 308 additions and 67 deletions

View File

@ -149,7 +149,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await storage_collection.async_load()
hass.data[DOMAIN][DATA_STORAGE] = storage_collection
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass)

View File

@ -10,12 +10,15 @@ from typing import Any
import voluptuous as vol
from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.collection import (
CollectionError,
ItemNotFound,
SerializedStorageCollection,
StorageCollection,
StorageCollectionWebsocket,
)
@ -533,11 +536,39 @@ class PipelineInput:
await asyncio.gather(*prepare_tasks)
class PipelineStorageCollection(StorageCollection[Pipeline]):
class PipelinePreferred(CollectionError):
"""Raised when attempting to delete the preferred pipelen."""
def __init__(self, item_id: str) -> None:
"""Initialize pipeline preferred error."""
super().__init__(f"Item {item_id} preferred.")
self.item_id = item_id
class SerializedPipelineStorageCollection(SerializedStorageCollection):
"""Serialized pipeline storage collection."""
preferred_item: str | None
class PipelineStorageCollection(
StorageCollection[Pipeline, SerializedPipelineStorageCollection]
):
"""Pipeline storage collection."""
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
_preferred_item: str | None = None
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
"""Load the data."""
if not (data := await super()._async_load_data()):
return data
self._preferred_item = data["preferred_item"]
return data
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
# We don't need to validate, the WS API has already validated
@ -554,6 +585,8 @@ class PipelineStorageCollection(StorageCollection[Pipeline]):
def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config."""
if self._preferred_item is None:
self._preferred_item = item_id
return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline:
@ -561,9 +594,107 @@ class PipelineStorageCollection(StorageCollection[Pipeline]):
return Pipeline(**data)
def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
"""Return the serialized representation of an item."""
"""Return the serialized representation of an item for storing."""
return item.to_json()
async def async_delete_item(self, item_id: str) -> None:
"""Delete item."""
if self._preferred_item == item_id:
raise PipelinePreferred(item_id)
await super().async_delete_item(item_id)
@callback
def async_get_preferred_item(self) -> str | None:
"""Get the id of the preferred item."""
return self._preferred_item
@callback
def async_set_preferred_item(self, item_id: str) -> None:
"""Set the preferred pipeline."""
if item_id not in self.data:
raise ItemNotFound(item_id)
self._preferred_item = item_id
self._async_schedule_save()
@callback
def _data_to_save(self) -> SerializedPipelineStorageCollection:
"""Return JSON-compatible date for storing to file."""
base_data = super()._base_data_to_save()
return {
"items": base_data["items"],
"preferred_item": self._preferred_item,
}
class PipelineStorageCollectionWebsocket(
StorageCollectionWebsocket[PipelineStorageCollection]
):
"""Class to expose storage collection management over websocket."""
@callback
def async_setup(
self,
hass: HomeAssistant,
*,
create_list: bool = True,
create_create: bool = True,
) -> None:
"""Set up the websocket commands."""
super().async_setup(hass, create_list=create_list, create_create=create_create)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/set_preferred",
websocket_api.require_admin(
websocket_api.async_response(self.ws_set_preferred_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/set_preferred",
vol.Required(self.item_id_key): str,
}
),
)
def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(
msg["id"],
{
"pipelines": self.storage_collection.async_items(),
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
},
)
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete an item."""
try:
await super().ws_delete_item(hass, connection, msg)
except PipelinePreferred as exc:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
)
async def ws_set_preferred_item(
self,
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set the preferred item."""
try:
self.storage_collection.async_set_preferred_item(msg[self.item_id_key])
except ItemNotFound:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item"
)
return
connection.send_result(msg["id"])
async def async_setup_pipeline_store(hass):
"""Set up the pipeline storage collection."""
@ -571,7 +702,7 @@ async def async_setup_pipeline_store(hass):
Store(hass, STORAGE_VERSION, STORAGE_KEY)
)
await pipeline_store.async_load()
StorageCollectionWebsocket(
PipelineStorageCollectionWebsocket(
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)
hass.data[DOMAIN] = pipeline_store

View File

@ -117,7 +117,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -44,7 +44,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
image_dir = pathlib.Path(hass.config.path("image"))
hass.data[DOMAIN] = storage_collection = ImageStorageCollection(hass, image_dir)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection,
"image",
"image",

View File

@ -121,7 +121,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -106,7 +106,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -159,7 +159,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -167,7 +167,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -136,7 +136,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -119,7 +119,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
resource_collection = resources.ResourceStorageCollection(hass, default_config)
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
resource_collection,
"lovelace/resources",
"resource",
@ -198,7 +198,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
dashboards_collection.async_add_listener(storage_dashboard_changed)
await dashboards_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
dashboards_collection,
"lovelace/dashboards",
"dashboard",

View File

@ -354,7 +354,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.data[DOMAIN] = (yaml_collection, storage_collection, entity_component)
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass, create_list=False)

View File

@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.collection import (
CollectionEntity,
DictStorageCollection,
DictStorageCollectionWebsocket,
IDManager,
SerializedStorageCollection,
StorageCollectionWebsocket,
YamlCollection,
sync_entity_lifecycle,
)
@ -182,7 +182,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
StorageCollectionWebsocket(
DictStorageCollectionWebsocket(
storage_collection,
DOMAIN,
DOMAIN,

View File

@ -98,7 +98,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
id_manager,
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass)

View File

@ -130,7 +130,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass)

View File

@ -209,7 +209,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await storage_collection.async_load()
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
).async_setup(hass)

View File

@ -32,7 +32,9 @@ CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
_T = TypeVar("_T")
_ItemT = TypeVar("_ItemT")
_StoreT = TypeVar("_StoreT", bound="SerializedStorageCollection")
_StorageCollectionT = TypeVar("_StorageCollectionT", bound="StorageCollection")
@dataclass(slots=True)
@ -123,20 +125,20 @@ class CollectionEntity(Entity):
"""Handle updated configuration."""
class ObservableCollection(ABC, Generic[_T]):
class ObservableCollection(ABC, Generic[_ItemT]):
"""Base collection type that can be observed."""
def __init__(self, id_manager: IDManager | None) -> None:
"""Initialize the base collection."""
self.id_manager = id_manager or IDManager()
self.data: dict[str, _T] = {}
self.data: dict[str, _ItemT] = {}
self.listeners: list[ChangeListener] = []
self.change_set_listeners: list[ChangeSetListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> list[_T]:
def async_items(self) -> list[_ItemT]:
"""Return list of items in collection."""
return list(self.data.values())
@ -226,12 +228,12 @@ class SerializedStorageCollection(TypedDict):
items: list[dict[str, Any]]
class StorageCollection(ObservableCollection[_T], ABC):
class StorageCollection(ObservableCollection[_ItemT], Generic[_ItemT, _StoreT]):
"""Offer a CRUD interface on top of JSON storage."""
def __init__(
self,
store: Store[SerializedStorageCollection],
store: Store[_StoreT],
id_manager: IDManager | None = None,
) -> None:
"""Initialize the storage collection."""
@ -250,16 +252,14 @@ class StorageCollection(ObservableCollection[_T], ABC):
"""Home Assistant object."""
return self.store.hass
async def _async_load_data(self) -> SerializedStorageCollection | None:
async def _async_load_data(self) -> _StoreT | None:
"""Load the data."""
return await self.store.async_load()
async def async_load(self) -> None:
"""Load the storage Manager."""
raw_storage = await self._async_load_data()
if raw_storage is None:
raw_storage = {"items": []}
if not (raw_storage := await self._async_load_data()):
return
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = self._deserialize_item(item)
@ -281,25 +281,25 @@ class StorageCollection(ObservableCollection[_T], ABC):
"""Suggest an ID based on the config."""
@abstractmethod
async def _update_data(self, item: _T, update_data: dict) -> _T:
async def _update_data(self, item: _ItemT, update_data: dict) -> _ItemT:
"""Return a new updated item."""
@abstractmethod
def _create_item(self, item_id: str, data: dict) -> _T:
def _create_item(self, item_id: str, data: dict) -> _ItemT:
"""Create an item from validated config."""
@abstractmethod
def _deserialize_item(self, data: dict) -> _T:
def _deserialize_item(self, data: dict) -> _ItemT:
"""Create an item from its serialized representation."""
@abstractmethod
def _serialize_item(self, item_id: str, item: _T) -> dict:
"""Return the serialized representation of an item.
def _serialize_item(self, item_id: str, item: _ItemT) -> dict:
"""Return the serialized representation of an item for storing.
The serialized representation must include the item_id in the "id" key.
"""
async def async_create_item(self, data: dict) -> _T:
async def async_create_item(self, data: dict) -> _ItemT:
"""Create a new item."""
validated_data = await self._process_create_data(data)
item_id = self.id_manager.generate_id(self._get_suggested_id(validated_data))
@ -309,7 +309,7 @@ class StorageCollection(ObservableCollection[_T], ABC):
await self.notify_changes([CollectionChangeSet(CHANGE_ADDED, item_id, item)])
return item
async def async_update_item(self, item_id: str, updates: dict) -> _T:
async def async_update_item(self, item_id: str, updates: dict) -> _ItemT:
"""Update item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
@ -346,8 +346,8 @@ class StorageCollection(ObservableCollection[_T], ABC):
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible date for storing to file."""
def _base_data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible data for storing to file."""
return {
"items": [
self._serialize_item(item_id, item)
@ -355,8 +355,13 @@ class StorageCollection(ObservableCollection[_T], ABC):
]
}
@abstractmethod
@callback
def _data_to_save(self) -> _StoreT:
"""Return JSON-compatible date for storing to file."""
class DictStorageCollection(StorageCollection[dict]):
class DictStorageCollection(StorageCollection[dict, SerializedStorageCollection]):
"""A specialized StorageCollection where the items are untyped dicts."""
def _create_item(self, item_id: str, data: dict) -> dict:
@ -368,9 +373,14 @@ class DictStorageCollection(StorageCollection[dict]):
return data
def _serialize_item(self, item_id: str, item: dict) -> dict:
"""Return the serialized representation of an item."""
"""Return the serialized representation of an item for storing."""
return item
@callback
def _data_to_save(self) -> SerializedStorageCollection:
"""Return JSON-compatible date for storing to file."""
return self._base_data_to_save()
class IDLessCollection(YamlCollection):
"""A collection without IDs."""
@ -477,12 +487,12 @@ def sync_entity_lifecycle(
collection.async_add_change_set_listener(_collection_changed)
class StorageCollectionWebsocket:
class StorageCollectionWebsocket(Generic[_StorageCollectionT]):
"""Class to expose storage collection management over websocket."""
def __init__(
self,
storage_collection: StorageCollection,
storage_collection: _StorageCollectionT,
api_prefix: str,
model_name: str,
create_schema: dict,
@ -635,3 +645,7 @@ class StorageCollectionWebsocket:
)
connection.send_result(msg["id"])
class DictStorageCollectionWebsocket(StorageCollectionWebsocket[DictStorageCollection]):
"""Class to expose storage collection management over websocket."""

View File

@ -46,6 +46,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3
assert store1.async_get_preferred_item() == list(store1.data)[0]
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 2
@ -58,6 +59,7 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
assert store1.data is not store2.data
assert store1.data == store2.data
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
async def test_loading_datasets_from_storage(
@ -94,7 +96,8 @@ async def test_loading_datasets_from_storage(
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
},
]
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
},
}
@ -102,3 +105,4 @@ async def test_loading_datasets_from_storage(
store: PipelineStorageCollection = hass.data[DOMAIN]
assert len(store.data) == 3
assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"

View File

@ -482,31 +482,58 @@ async def test_delete_pipeline(
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
pipeline_id = msg["result"]["id"]
pipeline_id_1 = msg["result"]["id"]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/delete",
"pipeline_id": pipeline_id,
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 0
pipeline_id_2 = msg["result"]["id"]
assert len(pipeline_store.data) == 2
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/delete",
"pipeline_id": pipeline_id,
"pipeline_id": pipeline_id_1,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_allowed",
"message": f"Item {pipeline_id_1} preferred.",
}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/delete",
"pipeline_id": pipeline_id_2,
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/delete",
"pipeline_id": pipeline_id_2,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": f"Unable to find pipeline_id {pipeline_id}",
"message": f"Unable to find pipeline_id {pipeline_id_2}",
}
@ -520,7 +547,7 @@ async def test_list_pipelines(
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == []
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
await client.send_json_auto_id(
{
@ -539,16 +566,19 @@ async def test_list_pipelines(
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == [
{
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
]
assert msg["result"] == {
"pipelines": [
{
"conversation_engine": "test_conversation_engine",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
],
"preferred_pipeline": ANY,
}
async def test_update_pipeline(
@ -606,9 +636,9 @@ async def test_update_pipeline(
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "new_conversation_engine",
"id": pipeline_id,
"language": "new_language",
"name": "new_name",
"id": pipeline_id,
"stt_engine": "new_stt_engine",
"tts_engine": "new_tts_engine",
}
@ -623,3 +653,65 @@ async def test_update_pipeline(
stt_engine="new_stt_engine",
tts_engine="new_tts_engine",
)
async def test_set_preferred_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test updating the preferred pipeline."""
client = await hass_ws_client(hass)
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id_1 = msg["result"]["id"]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id_2 = msg["result"]["id"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/set_preferred",
"pipeline_id": pipeline_id_2,
}
)
msg = await client.receive_json()
assert msg["success"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
async def test_set_preferred_pipeline_wrong_id(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test updating the preferred pipeline."""
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{"type": "assist_pipeline/pipeline/set_preferred", "pipeline_id": "don_t_exist"}
)
msg = await client.receive_json()
assert msg["error"]["code"] == "not_found"

View File

@ -436,7 +436,7 @@ async def test_storage_collection_websocket(
store = storage.Store(hass, 1, "test-data")
coll = MockStorageCollection(store)
changes = track_changes(coll)
collection.StorageCollectionWebsocket(
collection.DictStorageCollectionWebsocket(
coll,
"test_item/collection",
"test_item",