457 lines
14 KiB
Python
457 lines
14 KiB
Python
"""Helper to deal with YAML + storage."""
|
|
from abc import ABC, abstractmethod
|
|
import logging
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
|
|
|
|
import voluptuous as vol
|
|
from voluptuous.humanize import humanize_error
|
|
|
|
from homeassistant.components import websocket_api
|
|
from homeassistant.const import CONF_ID
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import entity_registry
|
|
from homeassistant.helpers.entity import Entity
|
|
from homeassistant.helpers.entity_component import EntityComponent
|
|
from homeassistant.helpers.storage import Store
|
|
from homeassistant.helpers.typing import HomeAssistantType
|
|
from homeassistant.util import slugify
|
|
|
|
STORAGE_VERSION = 1
|
|
SAVE_DELAY = 10
|
|
|
|
CHANGE_ADDED = "added"
|
|
CHANGE_UPDATED = "updated"
|
|
CHANGE_REMOVED = "removed"
|
|
|
|
|
|
ChangeListener = Callable[
|
|
[
|
|
# Change type
|
|
str,
|
|
# Item ID
|
|
str,
|
|
# New or removed config
|
|
dict,
|
|
],
|
|
Awaitable[None],
|
|
]
|
|
|
|
|
|
class CollectionError(HomeAssistantError):
|
|
"""Base class for collection related errors."""
|
|
|
|
|
|
class ItemNotFound(CollectionError):
|
|
"""Raised when an item is not found."""
|
|
|
|
def __init__(self, item_id: str):
|
|
"""Initialize item not found error."""
|
|
super().__init__(f"Item {item_id} not found.")
|
|
self.item_id = item_id
|
|
|
|
|
|
class IDManager:
|
|
"""Keep track of IDs across different collections."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initiate the ID manager."""
|
|
self.collections: List[Dict[str, Any]] = []
|
|
|
|
def add_collection(self, collection: Dict[str, Any]) -> None:
|
|
"""Add a collection to check for ID usage."""
|
|
self.collections.append(collection)
|
|
|
|
def has_id(self, item_id: str) -> bool:
|
|
"""Test if the ID exists."""
|
|
return any(item_id in collection for collection in self.collections)
|
|
|
|
def generate_id(self, suggestion: str) -> str:
|
|
"""Generate an ID."""
|
|
base = slugify(suggestion)
|
|
proposal = base
|
|
attempt = 1
|
|
|
|
while self.has_id(proposal):
|
|
attempt += 1
|
|
proposal = f"{base}_{attempt}"
|
|
|
|
return proposal
|
|
|
|
|
|
class ObservableCollection(ABC):
|
|
"""Base collection type that can be observed."""
|
|
|
|
def __init__(self, logger: logging.Logger, id_manager: Optional[IDManager] = None):
|
|
"""Initialize the base collection."""
|
|
self.logger = logger
|
|
self.id_manager = id_manager or IDManager()
|
|
self.data: Dict[str, dict] = {}
|
|
self.listeners: List[ChangeListener] = []
|
|
|
|
self.id_manager.add_collection(self.data)
|
|
|
|
@callback
|
|
def async_items(self) -> List[dict]:
|
|
"""Return list of items in collection."""
|
|
return list(self.data.values())
|
|
|
|
@callback
|
|
def async_add_listener(self, listener: ChangeListener) -> None:
|
|
"""Add a listener.
|
|
|
|
Will be called with (change_type, item_id, updated_config).
|
|
"""
|
|
self.listeners.append(listener)
|
|
|
|
async def notify_change(self, change_type: str, item_id: str, item: dict) -> None:
|
|
"""Notify listeners of a change."""
|
|
self.logger.debug("%s %s: %s", change_type, item_id, item)
|
|
for listener in self.listeners:
|
|
await listener(change_type, item_id, item)
|
|
|
|
|
|
class YamlCollection(ObservableCollection):
|
|
"""Offer a collection based on static data."""
|
|
|
|
async def async_load(self, data: List[dict]) -> None:
|
|
"""Load the YAML collection. Overrides existing data."""
|
|
old_ids = set(self.data)
|
|
|
|
for item in data:
|
|
item_id = item[CONF_ID]
|
|
|
|
if item_id in old_ids:
|
|
old_ids.remove(item_id)
|
|
event = CHANGE_UPDATED
|
|
elif self.id_manager.has_id(item_id):
|
|
self.logger.warning("Duplicate ID '%s' detected, skipping", item_id)
|
|
continue
|
|
else:
|
|
event = CHANGE_ADDED
|
|
|
|
self.data[item_id] = item
|
|
await self.notify_change(event, item_id, item)
|
|
|
|
for item_id in old_ids:
|
|
|
|
await self.notify_change(CHANGE_REMOVED, item_id, self.data.pop(item_id))
|
|
|
|
|
|
class StorageCollection(ObservableCollection):
|
|
"""Offer a CRUD interface on top of JSON storage."""
|
|
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
logger: logging.Logger,
|
|
id_manager: Optional[IDManager] = None,
|
|
):
|
|
"""Initialize the storage collection."""
|
|
super().__init__(logger, id_manager)
|
|
self.store = store
|
|
|
|
@property
|
|
def hass(self) -> HomeAssistant:
|
|
"""Home Assistant object."""
|
|
return self.store.hass
|
|
|
|
async def _async_load_data(self) -> Optional[dict]:
|
|
"""Load the data."""
|
|
return cast(Optional[dict], await self.store.async_load())
|
|
|
|
async def async_load(self) -> None:
|
|
"""Load the storage Manager."""
|
|
raw_storage = await self._async_load_data()
|
|
|
|
if raw_storage is None:
|
|
raw_storage = {"items": []}
|
|
|
|
for item in raw_storage["items"]:
|
|
self.data[item[CONF_ID]] = item
|
|
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
|
|
|
|
@abstractmethod
|
|
async def _process_create_data(self, data: dict) -> dict:
|
|
"""Validate the config is valid."""
|
|
|
|
@callback
|
|
@abstractmethod
|
|
def _get_suggested_id(self, info: dict) -> str:
|
|
"""Suggest an ID based on the config."""
|
|
|
|
@abstractmethod
|
|
async def _update_data(self, data: dict, update_data: dict) -> dict:
|
|
"""Return a new updated data object."""
|
|
|
|
async def async_create_item(self, data: dict) -> dict:
|
|
"""Create a new item."""
|
|
item = await self._process_create_data(data)
|
|
item[CONF_ID] = self.id_manager.generate_id(self._get_suggested_id(item))
|
|
self.data[item[CONF_ID]] = item
|
|
self._async_schedule_save()
|
|
await self.notify_change(CHANGE_ADDED, item[CONF_ID], item)
|
|
return item
|
|
|
|
async def async_update_item(self, item_id: str, updates: dict) -> dict:
|
|
"""Update item."""
|
|
if item_id not in self.data:
|
|
raise ItemNotFound(item_id)
|
|
|
|
if CONF_ID in updates:
|
|
raise ValueError("Cannot update ID")
|
|
|
|
current = self.data[item_id]
|
|
|
|
updated = await self._update_data(current, updates)
|
|
|
|
self.data[item_id] = updated
|
|
self._async_schedule_save()
|
|
|
|
await self.notify_change(CHANGE_UPDATED, item_id, updated)
|
|
|
|
return self.data[item_id]
|
|
|
|
async def async_delete_item(self, item_id: str) -> None:
|
|
"""Delete item."""
|
|
if item_id not in self.data:
|
|
raise ItemNotFound(item_id)
|
|
|
|
item = self.data.pop(item_id)
|
|
self._async_schedule_save()
|
|
|
|
await self.notify_change(CHANGE_REMOVED, item_id, item)
|
|
|
|
@callback
|
|
def _async_schedule_save(self) -> None:
|
|
"""Schedule saving the area registry."""
|
|
self.store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
|
|
|
@callback
|
|
def _data_to_save(self) -> dict:
|
|
"""Return data of area registry to store in a file."""
|
|
return {"items": list(self.data.values())}
|
|
|
|
|
|
class IDLessCollection(ObservableCollection):
|
|
"""A collection without IDs."""
|
|
|
|
counter = 0
|
|
|
|
async def async_load(self, data: List[dict]) -> None:
|
|
"""Load the collection. Overrides existing data."""
|
|
for item_id, item in list(self.data.items()):
|
|
await self.notify_change(CHANGE_REMOVED, item_id, item)
|
|
|
|
self.data.clear()
|
|
|
|
for item in data:
|
|
self.counter += 1
|
|
item_id = f"fakeid-{self.counter}"
|
|
|
|
self.data[item_id] = item
|
|
await self.notify_change(CHANGE_ADDED, item_id, item)
|
|
|
|
|
|
@callback
|
|
def attach_entity_component_collection(
|
|
entity_component: EntityComponent,
|
|
collection: ObservableCollection,
|
|
create_entity: Callable[[dict], Entity],
|
|
) -> None:
|
|
"""Map a collection to an entity component."""
|
|
entities = {}
|
|
|
|
async def _collection_changed(change_type: str, item_id: str, config: dict) -> None:
|
|
"""Handle a collection change."""
|
|
if change_type == CHANGE_ADDED:
|
|
entity = create_entity(config)
|
|
await entity_component.async_add_entities([entity])
|
|
entities[item_id] = entity
|
|
return
|
|
|
|
if change_type == CHANGE_REMOVED:
|
|
entity = entities.pop(item_id)
|
|
await entity.async_remove()
|
|
return
|
|
|
|
# CHANGE_UPDATED
|
|
await entities[item_id].async_update_config(config) # type: ignore
|
|
|
|
collection.async_add_listener(_collection_changed)
|
|
|
|
|
|
@callback
|
|
def attach_entity_registry_cleaner(
|
|
hass: HomeAssistantType,
|
|
domain: str,
|
|
platform: str,
|
|
collection: ObservableCollection,
|
|
) -> None:
|
|
"""Attach a listener to clean up entity registry on collection changes."""
|
|
|
|
async def _collection_changed(change_type: str, item_id: str, config: Dict) -> None:
|
|
"""Handle a collection change: clean up entity registry on removals."""
|
|
if change_type != CHANGE_REMOVED:
|
|
return
|
|
|
|
ent_reg = await entity_registry.async_get_registry(hass)
|
|
ent_to_remove = ent_reg.async_get_entity_id(domain, platform, item_id)
|
|
if ent_to_remove is not None:
|
|
ent_reg.async_remove(ent_to_remove)
|
|
|
|
collection.async_add_listener(_collection_changed)
|
|
|
|
|
|
class StorageCollectionWebsocket:
|
|
"""Class to expose storage collection management over websocket."""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_collection: StorageCollection,
|
|
api_prefix: str,
|
|
model_name: str,
|
|
create_schema: dict,
|
|
update_schema: dict,
|
|
):
|
|
"""Initialize a websocket CRUD."""
|
|
self.storage_collection = storage_collection
|
|
self.api_prefix = api_prefix
|
|
self.model_name = model_name
|
|
self.create_schema = create_schema
|
|
self.update_schema = update_schema
|
|
|
|
assert self.api_prefix[-1] != "/", "API prefix should not end in /"
|
|
|
|
@property
|
|
def item_id_key(self) -> str:
|
|
"""Return item ID key."""
|
|
return f"{self.model_name}_id"
|
|
|
|
@callback
|
|
def async_setup(self, hass: HomeAssistant, *, create_list: bool = True) -> None:
|
|
"""Set up the websocket commands."""
|
|
if create_list:
|
|
websocket_api.async_register_command(
|
|
hass,
|
|
f"{self.api_prefix}/list",
|
|
self.ws_list_item,
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
{vol.Required("type"): f"{self.api_prefix}/list"}
|
|
),
|
|
)
|
|
|
|
websocket_api.async_register_command(
|
|
hass,
|
|
f"{self.api_prefix}/create",
|
|
websocket_api.require_admin(
|
|
websocket_api.async_response(self.ws_create_item)
|
|
),
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
{
|
|
**self.create_schema,
|
|
vol.Required("type"): f"{self.api_prefix}/create",
|
|
}
|
|
),
|
|
)
|
|
|
|
websocket_api.async_register_command(
|
|
hass,
|
|
f"{self.api_prefix}/update",
|
|
websocket_api.require_admin(
|
|
websocket_api.async_response(self.ws_update_item)
|
|
),
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
{
|
|
**self.update_schema,
|
|
vol.Required("type"): f"{self.api_prefix}/update",
|
|
vol.Required(self.item_id_key): str,
|
|
}
|
|
),
|
|
)
|
|
|
|
websocket_api.async_register_command(
|
|
hass,
|
|
f"{self.api_prefix}/delete",
|
|
websocket_api.require_admin(
|
|
websocket_api.async_response(self.ws_delete_item)
|
|
),
|
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
{
|
|
vol.Required("type"): f"{self.api_prefix}/delete",
|
|
vol.Required(self.item_id_key): str,
|
|
}
|
|
),
|
|
)
|
|
|
|
def ws_list_item(
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
) -> None:
|
|
"""List items."""
|
|
connection.send_result(msg["id"], self.storage_collection.async_items())
|
|
|
|
async def ws_create_item(
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
) -> None:
|
|
"""Create a item."""
|
|
try:
|
|
data = dict(msg)
|
|
data.pop("id")
|
|
data.pop("type")
|
|
item = await self.storage_collection.async_create_item(data)
|
|
connection.send_result(msg["id"], item)
|
|
except vol.Invalid as err:
|
|
connection.send_error(
|
|
msg["id"],
|
|
websocket_api.const.ERR_INVALID_FORMAT,
|
|
humanize_error(data, err),
|
|
)
|
|
except ValueError as err:
|
|
connection.send_error(
|
|
msg["id"], websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
|
)
|
|
|
|
async def ws_update_item(
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
) -> None:
|
|
"""Update a item."""
|
|
data = dict(msg)
|
|
msg_id = data.pop("id")
|
|
item_id = data.pop(self.item_id_key)
|
|
data.pop("type")
|
|
|
|
try:
|
|
item = await self.storage_collection.async_update_item(item_id, data)
|
|
connection.send_result(msg_id, item)
|
|
except ItemNotFound:
|
|
connection.send_error(
|
|
msg["id"],
|
|
websocket_api.const.ERR_NOT_FOUND,
|
|
f"Unable to find {self.item_id_key} {item_id}",
|
|
)
|
|
except vol.Invalid as err:
|
|
connection.send_error(
|
|
msg["id"],
|
|
websocket_api.const.ERR_INVALID_FORMAT,
|
|
humanize_error(data, err),
|
|
)
|
|
except ValueError as err:
|
|
connection.send_error(
|
|
msg_id, websocket_api.const.ERR_INVALID_FORMAT, str(err)
|
|
)
|
|
|
|
async def ws_delete_item(
|
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
) -> None:
|
|
"""Delete a item."""
|
|
try:
|
|
await self.storage_collection.async_delete_item(msg[self.item_id_key])
|
|
except ItemNotFound:
|
|
connection.send_error(
|
|
msg["id"],
|
|
websocket_api.const.ERR_NOT_FOUND,
|
|
f"Unable to find {self.item_id_key} {msg[self.item_id_key]}",
|
|
)
|
|
|
|
connection.send_result(msg["id"])
|