core/homeassistant/helpers/label_registry.py

236 lines
6.9 KiB
Python

"""Provide a way to label and group anything."""
from __future__ import annotations
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import Literal, TypedDict, cast
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .registry import BaseRegistry
from .storage import Store
from .typing import UNDEFINED, EventType, UndefinedType
DATA_REGISTRY = "label_registry"
EVENT_LABEL_REGISTRY_UPDATED = "label_registry_updated"
STORAGE_KEY = "core.label_registry"
STORAGE_VERSION_MAJOR = 1
class EventLabelRegistryUpdatedData(TypedDict):
"""Event data for when the label registry is updated."""
action: Literal["create", "remove", "update"]
label_id: str
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
@dataclass(slots=True, frozen=True, kw_only=True)
class LabelEntry(NormalizedNameBaseRegistryEntry):
"""Label Registry Entry."""
label_id: str
description: str | None = None
color: str | None = None
icon: str | None = None
class LabelRegistry(BaseRegistry):
"""Class to hold a registry of labels."""
labels: NormalizedNameBaseRegistryItems[LabelEntry]
_label_data: dict[str, LabelEntry]
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the label registry."""
self.hass = hass
self._store: Store[dict[str, list[dict[str, str | None]]]] = Store(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
)
@callback
def async_get_label(self, label_id: str) -> LabelEntry | None:
"""Get label by ID.
We retrieve the LabelEntry from the underlying dict to avoid
the overhead of the UserDict __getitem__.
"""
return self._label_data.get(label_id)
@callback
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
"""Get label by name."""
return self.labels.get_by_name(name)
@callback
def async_list_labels(self) -> Iterable[LabelEntry]:
"""Get all labels."""
return self.labels.values()
@callback
def _generate_id(self, name: str) -> str:
"""Initialize ID."""
suggestion = suggestion_base = slugify(name)
tries = 1
while suggestion in self.labels:
tries += 1
suggestion = f"{suggestion_base}_{tries}"
return suggestion
@callback
def async_create(
self,
name: str,
*,
color: str | None = None,
icon: str | None = None,
description: str | None = None,
) -> LabelEntry:
"""Create a new label."""
if label := self.async_get_label_by_name(name):
raise ValueError(
f"The name {name} ({label.normalized_name}) is already in use"
)
normalized_name = normalize_name(name)
label = LabelEntry(
color=color,
description=description,
icon=icon,
label_id=self._generate_id(name),
name=name,
normalized_name=normalized_name,
)
label_id = label.label_id
self.labels[label_id] = label
self.async_schedule_save()
self.hass.bus.async_fire(
EVENT_LABEL_REGISTRY_UPDATED,
EventLabelRegistryUpdatedData(
action="create",
label_id=label_id,
),
)
return label
@callback
def async_delete(self, label_id: str) -> None:
"""Delete label."""
del self.labels[label_id]
self.hass.bus.async_fire(
EVENT_LABEL_REGISTRY_UPDATED,
EventLabelRegistryUpdatedData(
action="remove",
label_id=label_id,
),
)
self.async_schedule_save()
@callback
def async_update(
self,
label_id: str,
*,
color: str | None | UndefinedType = UNDEFINED,
description: str | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
name: str | UndefinedType = UNDEFINED,
) -> LabelEntry:
"""Update name of label."""
old = self.labels[label_id]
changes = {
attr_name: value
for attr_name, value in (
("color", color),
("description", description),
("icon", icon),
)
if value is not UNDEFINED and getattr(old, attr_name) != value
}
if name is not UNDEFINED and name != old.name:
changes["name"] = name
changes["normalized_name"] = normalize_name(name)
if not changes:
return old
new = self.labels[label_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type]
self.async_schedule_save()
self.hass.bus.async_fire(
EVENT_LABEL_REGISTRY_UPDATED,
EventLabelRegistryUpdatedData(
action="update",
label_id=label_id,
),
)
return new
async def async_load(self) -> None:
"""Load the label registry."""
data = await self._store.async_load()
labels = NormalizedNameBaseRegistryItems[LabelEntry]()
if data is not None:
for label in data["labels"]:
# Check if the necessary keys are present
if label["label_id"] is None or label["name"] is None:
continue
normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry(
color=label["color"],
description=label["description"],
icon=label["icon"],
label_id=label["label_id"],
name=label["name"],
normalized_name=normalized_name,
)
self.labels = labels
self._label_data = labels.data
@callback
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of label registry to store in a file."""
return {
"labels": [
{
"color": entry.color,
"description": entry.description,
"icon": entry.icon,
"label_id": entry.label_id,
"name": entry.name,
}
for entry in self.labels.values()
]
}
@callback
def async_get(hass: HomeAssistant) -> LabelRegistry:
"""Get label registry."""
return cast(LabelRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistant) -> None:
"""Load label registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = LabelRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()