277 lines
8.3 KiB
Python
277 lines
8.3 KiB
Python
"""Provide a way to label and group anything."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
import dataclasses
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any, Literal, TypedDict
|
|
|
|
from homeassistant.core import Event, HomeAssistant, callback
|
|
from homeassistant.util.dt import utc_from_timestamp, utcnow
|
|
from homeassistant.util.event_type import EventType
|
|
from homeassistant.util.hass_dict import HassKey
|
|
|
|
from .normalized_name_base_registry import (
|
|
NormalizedNameBaseRegistryEntry,
|
|
NormalizedNameBaseRegistryItems,
|
|
)
|
|
from .registry import BaseRegistry
|
|
from .singleton import singleton
|
|
from .storage import Store
|
|
from .typing import UNDEFINED, UndefinedType
|
|
|
|
DATA_REGISTRY: HassKey[LabelRegistry] = HassKey("label_registry")
|
|
EVENT_LABEL_REGISTRY_UPDATED: EventType[EventLabelRegistryUpdatedData] = EventType(
|
|
"label_registry_updated"
|
|
)
|
|
STORAGE_KEY = "core.label_registry"
|
|
STORAGE_VERSION_MAJOR = 1
|
|
STORAGE_VERSION_MINOR = 2
|
|
|
|
|
|
class _LabelStoreData(TypedDict):
|
|
"""Data type for individual label. Used in LabelRegistryStoreData."""
|
|
|
|
color: str | None
|
|
description: str | None
|
|
icon: str | None
|
|
label_id: str
|
|
name: str
|
|
created_at: str
|
|
modified_at: str
|
|
|
|
|
|
class LabelRegistryStoreData(TypedDict):
|
|
"""Store data type for LabelRegistry."""
|
|
|
|
labels: list[_LabelStoreData]
|
|
|
|
|
|
class EventLabelRegistryUpdatedData(TypedDict):
|
|
"""Event data for when the label registry is updated."""
|
|
|
|
action: Literal["create", "remove", "update"]
|
|
label_id: str
|
|
|
|
|
|
type EventLabelRegistryUpdated = Event[EventLabelRegistryUpdatedData]
|
|
|
|
|
|
@dataclass(slots=True, frozen=True, kw_only=True)
|
|
class LabelEntry(NormalizedNameBaseRegistryEntry):
|
|
"""Label Registry Entry."""
|
|
|
|
label_id: str
|
|
description: str | None = None
|
|
color: str | None = None
|
|
icon: str | None = None
|
|
|
|
|
|
class LabelRegistryStore(Store[LabelRegistryStoreData]):
|
|
"""Store label registry data."""
|
|
|
|
async def _async_migrate_func(
|
|
self,
|
|
old_major_version: int,
|
|
old_minor_version: int,
|
|
old_data: dict[str, list[dict[str, Any]]],
|
|
) -> LabelRegistryStoreData:
|
|
"""Migrate to the new version."""
|
|
if old_major_version > STORAGE_VERSION_MAJOR:
|
|
raise ValueError("Can't migrate to future version")
|
|
|
|
if old_major_version == 1:
|
|
if old_minor_version < 2:
|
|
# Version 1.2 implements migration and adds created_at and modified_at
|
|
created_at = utc_from_timestamp(0).isoformat()
|
|
for label in old_data["labels"]:
|
|
label["created_at"] = label["modified_at"] = created_at
|
|
|
|
return old_data # type: ignore[return-value]
|
|
|
|
|
|
class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
|
"""Class to hold a registry of labels."""
|
|
|
|
labels: NormalizedNameBaseRegistryItems[LabelEntry]
|
|
_label_data: dict[str, LabelEntry]
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Initialize the label registry."""
|
|
self.hass = hass
|
|
self._store = LabelRegistryStore(
|
|
hass,
|
|
STORAGE_VERSION_MAJOR,
|
|
STORAGE_KEY,
|
|
atomic_writes=True,
|
|
minor_version=STORAGE_VERSION_MINOR,
|
|
)
|
|
|
|
@callback
|
|
def async_get_label(self, label_id: str) -> LabelEntry | None:
|
|
"""Get label by ID.
|
|
|
|
We retrieve the LabelEntry from the underlying dict to avoid
|
|
the overhead of the UserDict __getitem__.
|
|
"""
|
|
return self._label_data.get(label_id)
|
|
|
|
@callback
|
|
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
|
|
"""Get label by name."""
|
|
return self.labels.get_by_name(name)
|
|
|
|
@callback
|
|
def async_list_labels(self) -> Iterable[LabelEntry]:
|
|
"""Get all labels."""
|
|
return self.labels.values()
|
|
|
|
def _generate_id(self, name: str) -> str:
|
|
"""Generate label ID."""
|
|
return self.labels.generate_id_from_name(name)
|
|
|
|
@callback
|
|
def async_create(
|
|
self,
|
|
name: str,
|
|
*,
|
|
color: str | None = None,
|
|
icon: str | None = None,
|
|
description: str | None = None,
|
|
) -> LabelEntry:
|
|
"""Create a new label."""
|
|
self.hass.verify_event_loop_thread("label_registry.async_create")
|
|
|
|
if label := self.async_get_label_by_name(name):
|
|
raise ValueError(
|
|
f"The name {name} ({label.normalized_name}) is already in use"
|
|
)
|
|
|
|
label = LabelEntry(
|
|
color=color,
|
|
description=description,
|
|
icon=icon,
|
|
label_id=self._generate_id(name),
|
|
name=name,
|
|
)
|
|
label_id = label.label_id
|
|
self.labels[label_id] = label
|
|
self.async_schedule_save()
|
|
|
|
self.hass.bus.async_fire_internal(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(action="create", label_id=label_id),
|
|
)
|
|
return label
|
|
|
|
@callback
|
|
def async_delete(self, label_id: str) -> None:
|
|
"""Delete label."""
|
|
self.hass.verify_event_loop_thread("label_registry.async_delete")
|
|
del self.labels[label_id]
|
|
self.hass.bus.async_fire_internal(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(
|
|
action="remove",
|
|
label_id=label_id,
|
|
),
|
|
)
|
|
self.async_schedule_save()
|
|
|
|
@callback
|
|
def async_update(
|
|
self,
|
|
label_id: str,
|
|
*,
|
|
color: str | None | UndefinedType = UNDEFINED,
|
|
description: str | None | UndefinedType = UNDEFINED,
|
|
icon: str | None | UndefinedType = UNDEFINED,
|
|
name: str | UndefinedType = UNDEFINED,
|
|
) -> LabelEntry:
|
|
"""Update name of label."""
|
|
old = self.labels[label_id]
|
|
changes: dict[str, Any] = {
|
|
attr_name: value
|
|
for attr_name, value in (
|
|
("color", color),
|
|
("description", description),
|
|
("icon", icon),
|
|
)
|
|
if value is not UNDEFINED and getattr(old, attr_name) != value
|
|
}
|
|
|
|
if name is not UNDEFINED and name != old.name:
|
|
changes["name"] = name
|
|
|
|
if not changes:
|
|
return old
|
|
|
|
changes["modified_at"] = utcnow()
|
|
|
|
self.hass.verify_event_loop_thread("label_registry.async_update")
|
|
new = self.labels[label_id] = dataclasses.replace(old, **changes)
|
|
|
|
self.async_schedule_save()
|
|
self.hass.bus.async_fire_internal(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(
|
|
action="update",
|
|
label_id=label_id,
|
|
),
|
|
)
|
|
|
|
return new
|
|
|
|
async def async_load(self) -> None:
|
|
"""Load the label registry."""
|
|
data = await self._store.async_load()
|
|
labels = NormalizedNameBaseRegistryItems[LabelEntry]()
|
|
|
|
if data is not None:
|
|
for label in data["labels"]:
|
|
labels[label["label_id"]] = LabelEntry(
|
|
color=label["color"],
|
|
description=label["description"],
|
|
icon=label["icon"],
|
|
label_id=label["label_id"],
|
|
name=label["name"],
|
|
created_at=datetime.fromisoformat(label["created_at"]),
|
|
modified_at=datetime.fromisoformat(label["modified_at"]),
|
|
)
|
|
|
|
self.labels = labels
|
|
self._label_data = labels.data
|
|
|
|
@callback
|
|
def _data_to_save(self) -> LabelRegistryStoreData:
|
|
"""Return data of label registry to store in a file."""
|
|
return {
|
|
"labels": [
|
|
{
|
|
"color": entry.color,
|
|
"description": entry.description,
|
|
"icon": entry.icon,
|
|
"label_id": entry.label_id,
|
|
"name": entry.name,
|
|
"created_at": entry.created_at.isoformat(),
|
|
"modified_at": entry.modified_at.isoformat(),
|
|
}
|
|
for entry in self.labels.values()
|
|
]
|
|
}
|
|
|
|
|
|
@callback
|
|
@singleton(DATA_REGISTRY)
|
|
def async_get(hass: HomeAssistant) -> LabelRegistry:
|
|
"""Get label registry."""
|
|
return LabelRegistry(hass)
|
|
|
|
|
|
async def async_load(hass: HomeAssistant) -> None:
|
|
"""Load label registry."""
|
|
assert DATA_REGISTRY not in hass.data
|
|
await async_get(hass).async_load()
|