core/homeassistant/helpers/collection.py

457 lines
14 KiB
Python

"""Helper to deal with YAML + storage."""
from abc import ABC, abstractmethod
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import websocket_api
from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util import slugify
STORAGE_VERSION = 1
SAVE_DELAY = 10
CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
ChangeListener = Callable[
[
# Change type
str,
# Item ID
str,
# New or removed config
dict,
],
Awaitable[None],
]
class CollectionError(HomeAssistantError):
"""Base class for collection related errors."""
class ItemNotFound(CollectionError):
"""Raised when an item is not found."""
def __init__(self, item_id: str):
"""Initialize item not found error."""
super().__init__(f"Item {item_id} not found.")
self.item_id = item_id
class IDManager:
"""Keep track of IDs across different collections."""
def __init__(self) -> None:
"""Initiate the ID manager."""
self.collections: List[Dict[str, Any]] = []
def add_collection(self, collection: Dict[str, Any]) -> None:
"""Add a collection to check for ID usage."""
self.collections.append(collection)
def has_id(self, item_id: str) -> bool:
"""Test if the ID exists."""
return any(item_id in collection for collection in self.collections)
def generate_id(self, suggestion: str) -> str:
"""Generate an ID."""
base = slugify(suggestion)
proposal = base
attempt = 1
while self.has_id(proposal):
attempt += 1
proposal = f"{base}_{attempt}"
return proposal
class ObservableCollection(ABC):
"""Base collection type that can be observed."""
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
"""Initialize the base collection."""
self.logger = logger
self.id_manager = id_manager or IDManager()
self.data: Dict[str, dict] = {}
self.listeners: List[ChangeListener] = []
self.id_manager.add_collection(self.data)
@callback
def async_items(self) -> List[dict]:
"""Return list of items in collection."""
return list(self.data.values())
@callback
def async_add_listener(self, listener: ChangeListener) -> None:
"""Add a listener.
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
async def notify_change(self, change_type: str, item_id: str, item: dict) -> None:
"""Notify listeners of a change."""
self.logger.debug("%s %s: %s", change_type, item_id, item)
for listener in self.listeners:
await listener(change_type, item_id, item)
class YamlCollection(ObservableCollection):
"""Offer a collection based on static data."""
async def async_load(self, data: List[dict]) -> None:
"""Load the YAML collection. Overrides existing data."""
old_ids = set(self.data)
for item in data:
item_id = item[CONF_ID]
if item_id in old_ids:
old_ids.remove(item_id)
event = CHANGE_UPDATED
elif self.id_manager.has_id(item_id):
self.logger.warning("Duplicate ID '%s' detected, skipping", item_id)
continue
else:
event = CHANGE_ADDED
self.data[item_id] = item
await self.notify_change(event, item_id, item)
for item_id in old_ids:
await self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id))
class StorageCollection(ObservableCollection):
"""Offer a CRUD interface on top of JSON storage."""
def __init__(
self,
store: Store,
logger: logging.Logger,
id_manager: Optional[IDManager] = None,
):
"""Initialize the storage collection."""
super().__init__(logger, id_manager)
self.store = store
@property
def hass(self) -> HomeAssistant:
"""Home Assistant object."""
return self.store.hass
async def _async_load_data(self) -> Optional[dict]:
"""Load the data."""
return cast(Optional[dict], await self.store.async_load())
async def async_load(self) -> None:
"""Load the storage Manager."""
raw_storage = await self._async_load_data()
if raw_storage is None:
raw_storage = {"items": []}
for item in raw_storage["items"]:
self.data[item[CONF_ID]] = item
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
@abstractmethod
async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid."""
@callback
@abstractmethod
def _get_suggested_id(self, info: dict) -> str:
"""Suggest an ID based on the config."""
@abstractmethod
async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object."""
async def async_create_item(self, data: dict) -> dict:
"""Create a new item."""
item = await self._process_create_data(data)
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
self.data[item[CONF_ID]] = item
self._async_schedule_save()
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
return item
async def async_update_item(self, item_id: str, updates: dict) -> dict:
"""Update item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
if CONF_ID in updates:
raise ValueError("Cannot update ID")
current = self.data[item_id]
updated = await self._update_data(current, updates)
self.data[item_id] = updated
self._async_schedule_save()
await self.notify_change(CHANGE_UPDATED, item_id, updated)
return self.data[item_id]
async def async_delete_item(self, item_id: str) -> None:
"""Delete item."""
if item_id not in self.data:
raise ItemNotFound(item_id)
item = self.data.pop(item_id)
self._async_schedule_save()
await self.notify_change(CHANGE_REMOVED, item_id, item)
@callback
def _async_schedule_save(self) -> None:
"""Schedule saving the area registry."""
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict:
"""Return data of area registry to store in a file."""
return {"items": list(self.data.values())}
class IDLessCollection(ObservableCollection):
"""A collection without IDs."""
counter = 0
async def async_load(self, data: List[dict]) -> None:
"""Load the collection. Overrides existing data."""
for item_id, item in list(self.data.items()):
await self.notify_change(CHANGE_REMOVED, item_id, item)
self.data.clear()
for item in data:
self.counter += 1
item_id = f"fakeid-{self.counter}"
self.data[item_id] = item
await self.notify_change(CHANGE_ADDED, item_id, item)
@callback
def attach_entity_component_collection(
entity_component: EntityComponent,
collection: ObservableCollection,
create_entity: Callable[[dict], Entity],
) -> None:
"""Map a collection to an entity component."""
entities = {}
async def _collection_changed(change_type: str, item_id: str, config: dict) -> None:
"""Handle a collection change."""
if change_type == CHANGE_ADDED:
entity = create_entity(config)
await entity_component.async_add_entities([entity])
entities[item_id] = entity
return
if change_type == CHANGE_REMOVED:
entity = entities.pop(item_id)
await entity.async_remove()
return
# CHANGE_UPDATED
await entities[item_id].async_update_config(config) # type: ignore
collection.async_add_listener(_collection_changed)
@callback
def attach_entity_registry_cleaner(
hass: HomeAssistantType,
domain: str,
platform: str,
collection: ObservableCollection,
) -> None:
"""Attach a listener to clean up entity registry on collection changes."""
async def _collection_changed(change_type: str, item_id: str, config: Dict) -> None:
"""Handle a collection change: clean up entity registry on removals."""
if change_type != CHANGE_REMOVED:
return
ent_reg = await entity_registry.async_get_registry(hass)
ent_to_remove = ent_reg.async_get_entity_id(domain, platform, item_id)
if ent_to_remove is not None:
ent_reg.async_remove(ent_to_remove)
collection.async_add_listener(_collection_changed)
class StorageCollectionWebsocket:
"""Class to expose storage collection management over websocket."""
def __init__(
self,
storage_collection: StorageCollection,
api_prefix: str,
model_name: str,
create_schema: dict,
update_schema: dict,
):
"""Initialize a websocket CRUD."""
self.storage_collection = storage_collection
self.api_prefix = api_prefix
self.model_name = model_name
self.create_schema = create_schema
self.update_schema = update_schema
assert self.api_prefix[-1] != "/", "API prefix should not end in /"
@property
def item_id_key(self) -> str:
"""Return item ID key."""
return f"{self.model_name}_id"
@callback
def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None:
"""Set up the websocket commands."""
if create_list:
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/list",
self.ws_list_item,
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): f"{self.api_prefix}/list"}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/create",
websocket_api.require_admin(
websocket_api.async_response(self.ws_create_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.create_schema,
vol.Required("type"): f"{self.api_prefix}/create",
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/update",
websocket_api.require_admin(
websocket_api.async_response(self.ws_update_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
**self.update_schema,
vol.Required("type"): f"{self.api_prefix}/update",
vol.Required(self.item_id_key): str,
}
),
)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/delete",
websocket_api.require_admin(
websocket_api.async_response(self.ws_delete_item)
),
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/delete",
vol.Required(self.item_id_key): str,
}
),
)
def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List items."""
connection.send_result(msg["id"], self.storage_collection.async_items())
async def ws_create_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Create a item."""
try:
data = dict(msg)
data.pop("id")
data.pop("type")
item = await self.storage_collection.async_create_item(data)
connection.send_result(msg["id"], item)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_update_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Update a item."""
data = dict(msg)
msg_id = data.pop("id")
item_id = data.pop(self.item_id_key)
data.pop("type")
try:
item = await self.storage_collection.async_update_item(item_id, data)
connection.send_result(msg_id, item)
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {item_id}",
)
except vol.Invalid as err:
connection.send_error(
msg["id"],
websocket_api.const.ERR_INVALID_FORMAT,
humanize_error(data, err),
)
except ValueError as err:
connection.send_error(
msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err)
)
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete a item."""
try:
await self.storage_collection.async_delete_item(msg[self.item_id_key])
except ItemNotFound:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {msg[self.item_id_key]}",
)
connection.send_result(msg["id"])