236 lines
6.9 KiB
Python
236 lines
6.9 KiB
Python
"""Provide a way to label and group anything."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
import dataclasses
|
|
from dataclasses import dataclass
|
|
from typing import Literal, TypedDict, cast
|
|
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.util import slugify
|
|
|
|
from .normalized_name_base_registry import (
|
|
NormalizedNameBaseRegistryEntry,
|
|
NormalizedNameBaseRegistryItems,
|
|
normalize_name,
|
|
)
|
|
from .registry import BaseRegistry
|
|
from .storage import Store
|
|
from .typing import UNDEFINED, EventType, UndefinedType
|
|
|
|
DATA_REGISTRY = "label_registry"
|
|
EVENT_LABEL_REGISTRY_UPDATED = "label_registry_updated"
|
|
STORAGE_KEY = "core.label_registry"
|
|
STORAGE_VERSION_MAJOR = 1
|
|
|
|
|
|
class EventLabelRegistryUpdatedData(TypedDict):
|
|
"""Event data for when the label registry is updated."""
|
|
|
|
action: Literal["create", "remove", "update"]
|
|
label_id: str
|
|
|
|
|
|
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
|
|
|
|
|
|
@dataclass(slots=True, frozen=True, kw_only=True)
|
|
class LabelEntry(NormalizedNameBaseRegistryEntry):
|
|
"""Label Registry Entry."""
|
|
|
|
label_id: str
|
|
description: str | None = None
|
|
color: str | None = None
|
|
icon: str | None = None
|
|
|
|
|
|
class LabelRegistry(BaseRegistry):
|
|
"""Class to hold a registry of labels."""
|
|
|
|
labels: NormalizedNameBaseRegistryItems[LabelEntry]
|
|
_label_data: dict[str, LabelEntry]
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Initialize the label registry."""
|
|
self.hass = hass
|
|
self._store: Store[dict[str, list[dict[str, str | None]]]] = Store(
|
|
hass,
|
|
STORAGE_VERSION_MAJOR,
|
|
STORAGE_KEY,
|
|
atomic_writes=True,
|
|
)
|
|
|
|
@callback
|
|
def async_get_label(self, label_id: str) -> LabelEntry | None:
|
|
"""Get label by ID.
|
|
|
|
We retrieve the LabelEntry from the underlying dict to avoid
|
|
the overhead of the UserDict __getitem__.
|
|
"""
|
|
return self._label_data.get(label_id)
|
|
|
|
@callback
|
|
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
|
|
"""Get label by name."""
|
|
return self.labels.get_by_name(name)
|
|
|
|
@callback
|
|
def async_list_labels(self) -> Iterable[LabelEntry]:
|
|
"""Get all labels."""
|
|
return self.labels.values()
|
|
|
|
@callback
|
|
def _generate_id(self, name: str) -> str:
|
|
"""Initialize ID."""
|
|
suggestion = suggestion_base = slugify(name)
|
|
tries = 1
|
|
while suggestion in self.labels:
|
|
tries += 1
|
|
suggestion = f"{suggestion_base}_{tries}"
|
|
return suggestion
|
|
|
|
@callback
|
|
def async_create(
|
|
self,
|
|
name: str,
|
|
*,
|
|
color: str | None = None,
|
|
icon: str | None = None,
|
|
description: str | None = None,
|
|
) -> LabelEntry:
|
|
"""Create a new label."""
|
|
if label := self.async_get_label_by_name(name):
|
|
raise ValueError(
|
|
f"The name {name} ({label.normalized_name}) is already in use"
|
|
)
|
|
|
|
normalized_name = normalize_name(name)
|
|
|
|
label = LabelEntry(
|
|
color=color,
|
|
description=description,
|
|
icon=icon,
|
|
label_id=self._generate_id(name),
|
|
name=name,
|
|
normalized_name=normalized_name,
|
|
)
|
|
label_id = label.label_id
|
|
self.labels[label_id] = label
|
|
self.async_schedule_save()
|
|
self.hass.bus.async_fire(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(
|
|
action="create",
|
|
label_id=label_id,
|
|
),
|
|
)
|
|
return label
|
|
|
|
@callback
|
|
def async_delete(self, label_id: str) -> None:
|
|
"""Delete label."""
|
|
del self.labels[label_id]
|
|
self.hass.bus.async_fire(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(
|
|
action="remove",
|
|
label_id=label_id,
|
|
),
|
|
)
|
|
self.async_schedule_save()
|
|
|
|
@callback
|
|
def async_update(
|
|
self,
|
|
label_id: str,
|
|
*,
|
|
color: str | None | UndefinedType = UNDEFINED,
|
|
description: str | None | UndefinedType = UNDEFINED,
|
|
icon: str | None | UndefinedType = UNDEFINED,
|
|
name: str | UndefinedType = UNDEFINED,
|
|
) -> LabelEntry:
|
|
"""Update name of label."""
|
|
old = self.labels[label_id]
|
|
changes = {
|
|
attr_name: value
|
|
for attr_name, value in (
|
|
("color", color),
|
|
("description", description),
|
|
("icon", icon),
|
|
)
|
|
if value is not UNDEFINED and getattr(old, attr_name) != value
|
|
}
|
|
|
|
if name is not UNDEFINED and name != old.name:
|
|
changes["name"] = name
|
|
changes["normalized_name"] = normalize_name(name)
|
|
|
|
if not changes:
|
|
return old
|
|
|
|
new = self.labels[label_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type]
|
|
|
|
self.async_schedule_save()
|
|
self.hass.bus.async_fire(
|
|
EVENT_LABEL_REGISTRY_UPDATED,
|
|
EventLabelRegistryUpdatedData(
|
|
action="update",
|
|
label_id=label_id,
|
|
),
|
|
)
|
|
|
|
return new
|
|
|
|
async def async_load(self) -> None:
|
|
"""Load the label registry."""
|
|
data = await self._store.async_load()
|
|
labels = NormalizedNameBaseRegistryItems[LabelEntry]()
|
|
|
|
if data is not None:
|
|
for label in data["labels"]:
|
|
# Check if the necessary keys are present
|
|
if label["label_id"] is None or label["name"] is None:
|
|
continue
|
|
|
|
normalized_name = normalize_name(label["name"])
|
|
labels[label["label_id"]] = LabelEntry(
|
|
color=label["color"],
|
|
description=label["description"],
|
|
icon=label["icon"],
|
|
label_id=label["label_id"],
|
|
name=label["name"],
|
|
normalized_name=normalized_name,
|
|
)
|
|
|
|
self.labels = labels
|
|
self._label_data = labels.data
|
|
|
|
@callback
|
|
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
|
|
"""Return data of label registry to store in a file."""
|
|
return {
|
|
"labels": [
|
|
{
|
|
"color": entry.color,
|
|
"description": entry.description,
|
|
"icon": entry.icon,
|
|
"label_id": entry.label_id,
|
|
"name": entry.name,
|
|
}
|
|
for entry in self.labels.values()
|
|
]
|
|
}
|
|
|
|
|
|
@callback
|
|
def async_get(hass: HomeAssistant) -> LabelRegistry:
|
|
"""Get label registry."""
|
|
return cast(LabelRegistry, hass.data[DATA_REGISTRY])
|
|
|
|
|
|
async def async_load(hass: HomeAssistant) -> None:
|
|
"""Load label registry."""
|
|
assert DATA_REGISTRY not in hass.data
|
|
hass.data[DATA_REGISTRY] = LabelRegistry(hass)
|
|
await hass.data[DATA_REGISTRY].async_load()
|