Add normalized name registry items base class (#111666)
* Add normalized name base registry items class * Add testspull/111786/head
parent
f1398dd127
commit
f31244bac4
|
@ -1,8 +1,7 @@
|
||||||
"""Provide a way to connect devices to one physical location."""
|
"""Provide a way to connect devices to one physical location."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import UserDict
|
from collections.abc import Iterable
|
||||||
from collections.abc import Iterable, ValuesView
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Literal, TypedDict, cast
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
|
@ -10,6 +9,11 @@ from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
from . import device_registry as dr, entity_registry as er
|
from . import device_registry as dr, entity_registry as er
|
||||||
|
from .normalized_name_base_registry import (
|
||||||
|
NormalizedNameBaseRegistryEntry,
|
||||||
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
|
)
|
||||||
from .storage import Store
|
from .storage import Store
|
||||||
from .typing import UNDEFINED, UndefinedType
|
from .typing import UNDEFINED, UndefinedType
|
||||||
|
|
||||||
|
@ -29,7 +33,7 @@ class EventAreaRegistryUpdatedData(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
|
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
|
||||||
class AreaEntry:
|
class AreaEntry(NormalizedNameBaseRegistryEntry):
|
||||||
"""Area Registry Entry."""
|
"""Area Registry Entry."""
|
||||||
|
|
||||||
aliases: set[str]
|
aliases: set[str]
|
||||||
|
@ -37,57 +41,9 @@ class AreaEntry:
|
||||||
icon: str | None
|
icon: str | None
|
||||||
id: str
|
id: str
|
||||||
labels: set[str] = dataclasses.field(default_factory=set)
|
labels: set[str] = dataclasses.field(default_factory=set)
|
||||||
name: str
|
|
||||||
normalized_name: str
|
|
||||||
picture: str | None
|
picture: str | None
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistryItems(UserDict[str, AreaEntry]):
|
|
||||||
"""Container for area registry items, maps area id -> entry.
|
|
||||||
|
|
||||||
Maintains an additional index:
|
|
||||||
- normalized name -> entry
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize the container."""
|
|
||||||
super().__init__()
|
|
||||||
self._normalized_names: dict[str, AreaEntry] = {}
|
|
||||||
|
|
||||||
def values(self) -> ValuesView[AreaEntry]:
|
|
||||||
"""Return the underlying values to avoid __iter__ overhead."""
|
|
||||||
return self.data.values()
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, entry: AreaEntry) -> None:
|
|
||||||
"""Add an item."""
|
|
||||||
data = self.data
|
|
||||||
normalized_name = normalize_area_name(entry.name)
|
|
||||||
|
|
||||||
if key in data:
|
|
||||||
old_entry = data[key]
|
|
||||||
if (
|
|
||||||
normalized_name != old_entry.normalized_name
|
|
||||||
and normalized_name in self._normalized_names
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The name {entry.name} ({normalized_name}) is already in use"
|
|
||||||
)
|
|
||||||
del self._normalized_names[old_entry.normalized_name]
|
|
||||||
data[key] = entry
|
|
||||||
self._normalized_names[normalized_name] = entry
|
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
|
||||||
"""Remove an item."""
|
|
||||||
entry = self[key]
|
|
||||||
normalized_name = normalize_area_name(entry.name)
|
|
||||||
del self._normalized_names[normalized_name]
|
|
||||||
super().__delitem__(key)
|
|
||||||
|
|
||||||
def get_area_by_name(self, name: str) -> AreaEntry | None:
|
|
||||||
"""Get area by name."""
|
|
||||||
return self._normalized_names.get(normalize_area_name(name))
|
|
||||||
|
|
||||||
|
|
||||||
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
"""Store area registry data."""
|
"""Store area registry data."""
|
||||||
|
|
||||||
|
@ -133,7 +89,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
|
||||||
class AreaRegistry:
|
class AreaRegistry:
|
||||||
"""Class to hold a registry of areas."""
|
"""Class to hold a registry of areas."""
|
||||||
|
|
||||||
areas: AreaRegistryItems
|
areas: NormalizedNameBaseRegistryItems[AreaEntry]
|
||||||
_area_data: dict[str, AreaEntry]
|
_area_data: dict[str, AreaEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
@ -159,7 +115,7 @@ class AreaRegistry:
|
||||||
@callback
|
@callback
|
||||||
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
|
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
|
||||||
"""Get area by name."""
|
"""Get area by name."""
|
||||||
return self.areas.get_area_by_name(name)
|
return self.areas.get_by_name(name)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_areas(self) -> Iterable[AreaEntry]:
|
def async_list_areas(self) -> Iterable[AreaEntry]:
|
||||||
|
@ -185,7 +141,7 @@ class AreaRegistry:
|
||||||
picture: str | None = None,
|
picture: str | None = None,
|
||||||
) -> AreaEntry:
|
) -> AreaEntry:
|
||||||
"""Create a new area."""
|
"""Create a new area."""
|
||||||
normalized_name = normalize_area_name(name)
|
normalized_name = normalize_name(name)
|
||||||
|
|
||||||
if self.async_get_area_by_name(name):
|
if self.async_get_area_by_name(name):
|
||||||
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
|
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
|
||||||
|
@ -281,7 +237,7 @@ class AreaRegistry:
|
||||||
|
|
||||||
if name is not UNDEFINED and name != old.name:
|
if name is not UNDEFINED and name != old.name:
|
||||||
new_values["name"] = name
|
new_values["name"] = name
|
||||||
new_values["normalized_name"] = normalize_area_name(name)
|
new_values["normalized_name"] = normalize_name(name)
|
||||||
|
|
||||||
if not new_values:
|
if not new_values:
|
||||||
return old
|
return old
|
||||||
|
@ -297,12 +253,12 @@ class AreaRegistry:
|
||||||
|
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
|
|
||||||
areas = AreaRegistryItems()
|
areas = NormalizedNameBaseRegistryItems[AreaEntry]()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for area in data["areas"]:
|
for area in data["areas"]:
|
||||||
assert area["name"] is not None and area["id"] is not None
|
assert area["name"] is not None and area["id"] is not None
|
||||||
normalized_name = normalize_area_name(area["name"])
|
normalized_name = normalize_name(area["name"])
|
||||||
areas[area["id"]] = AreaEntry(
|
areas[area["id"]] = AreaEntry(
|
||||||
aliases=set(area["aliases"]),
|
aliases=set(area["aliases"]),
|
||||||
floor_id=area["floor_id"],
|
floor_id=area["floor_id"],
|
||||||
|
@ -421,8 +377,3 @@ def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaE
|
||||||
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
|
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
|
||||||
"""Return entries that match a label."""
|
"""Return entries that match a label."""
|
||||||
return [area for area in registry.areas.values() if label_id in area.labels]
|
return [area for area in registry.areas.values() if label_id in area.labels]
|
||||||
|
|
||||||
|
|
||||||
def normalize_area_name(area_name: str) -> str:
|
|
||||||
"""Normalize an area name by removing whitespace and case folding."""
|
|
||||||
return area_name.casefold().replace(" ", "")
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
"""Provide a way to assign areas to floors in one's home."""
|
"""Provide a way to assign areas to floors in one's home."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import UserDict
|
from collections.abc import Iterable
|
||||||
from collections.abc import Iterable, ValuesView
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
||||||
|
@ -10,6 +9,11 @@ from typing import TYPE_CHECKING, Literal, TypedDict, cast
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
|
from .normalized_name_base_registry import (
|
||||||
|
NormalizedNameBaseRegistryEntry,
|
||||||
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
|
)
|
||||||
from .storage import Store
|
from .storage import Store
|
||||||
from .typing import UNDEFINED, EventType, UndefinedType
|
from .typing import UNDEFINED, EventType, UndefinedType
|
||||||
|
|
||||||
|
@ -31,67 +35,19 @@ EventFloorRegistryUpdated = EventType[EventFloorRegistryUpdatedData]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True, kw_only=True, frozen=True)
|
@dataclass(slots=True, kw_only=True, frozen=True)
|
||||||
class FloorEntry:
|
class FloorEntry(NormalizedNameBaseRegistryEntry):
|
||||||
"""Floor registry entry."""
|
"""Floor registry entry."""
|
||||||
|
|
||||||
aliases: set[str]
|
aliases: set[str]
|
||||||
floor_id: str
|
floor_id: str
|
||||||
icon: str | None = None
|
icon: str | None = None
|
||||||
level: int = 0
|
level: int = 0
|
||||||
name: str
|
|
||||||
normalized_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class FloorRegistryItems(UserDict[str, FloorEntry]):
|
|
||||||
"""Container for floor registry items, maps floor id -> entry.
|
|
||||||
|
|
||||||
Maintains an additional index:
|
|
||||||
- normalized name -> entry
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize the container."""
|
|
||||||
super().__init__()
|
|
||||||
self._normalized_names: dict[str, FloorEntry] = {}
|
|
||||||
|
|
||||||
def values(self) -> ValuesView[FloorEntry]:
|
|
||||||
"""Return the underlying values to avoid __iter__ overhead."""
|
|
||||||
return self.data.values()
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, entry: FloorEntry) -> None:
|
|
||||||
"""Add an item."""
|
|
||||||
data = self.data
|
|
||||||
normalized_name = _normalize_floor_name(entry.name)
|
|
||||||
|
|
||||||
if key in data:
|
|
||||||
old_entry = data[key]
|
|
||||||
if (
|
|
||||||
normalized_name != old_entry.normalized_name
|
|
||||||
and normalized_name in self._normalized_names
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The name {entry.name} ({normalized_name}) is already in use"
|
|
||||||
)
|
|
||||||
del self._normalized_names[old_entry.normalized_name]
|
|
||||||
data[key] = entry
|
|
||||||
self._normalized_names[normalized_name] = entry
|
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
|
||||||
"""Remove an item."""
|
|
||||||
entry = self[key]
|
|
||||||
normalized_name = _normalize_floor_name(entry.name)
|
|
||||||
del self._normalized_names[normalized_name]
|
|
||||||
super().__delitem__(key)
|
|
||||||
|
|
||||||
def get_floor_by_name(self, name: str) -> FloorEntry | None:
|
|
||||||
"""Get floor by name."""
|
|
||||||
return self._normalized_names.get(_normalize_floor_name(name))
|
|
||||||
|
|
||||||
|
|
||||||
class FloorRegistry:
|
class FloorRegistry:
|
||||||
"""Class to hold a registry of floors."""
|
"""Class to hold a registry of floors."""
|
||||||
|
|
||||||
floors: FloorRegistryItems
|
floors: NormalizedNameBaseRegistryItems[FloorEntry]
|
||||||
_floor_data: dict[str, FloorEntry]
|
_floor_data: dict[str, FloorEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
@ -118,7 +74,7 @@ class FloorRegistry:
|
||||||
@callback
|
@callback
|
||||||
def async_get_floor_by_name(self, name: str) -> FloorEntry | None:
|
def async_get_floor_by_name(self, name: str) -> FloorEntry | None:
|
||||||
"""Get floor by name."""
|
"""Get floor by name."""
|
||||||
return self.floors.get_floor_by_name(name)
|
return self.floors.get_by_name(name)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_floors(self) -> Iterable[FloorEntry]:
|
def async_list_floors(self) -> Iterable[FloorEntry]:
|
||||||
|
@ -150,7 +106,7 @@ class FloorRegistry:
|
||||||
f"The name {name} ({floor.normalized_name}) is already in use"
|
f"The name {name} ({floor.normalized_name}) is already in use"
|
||||||
)
|
)
|
||||||
|
|
||||||
normalized_name = _normalize_floor_name(name)
|
normalized_name = normalize_name(name)
|
||||||
|
|
||||||
floor = FloorEntry(
|
floor = FloorEntry(
|
||||||
aliases=aliases or set(),
|
aliases=aliases or set(),
|
||||||
|
@ -208,7 +164,7 @@ class FloorRegistry:
|
||||||
}
|
}
|
||||||
if name is not UNDEFINED and name != old.name:
|
if name is not UNDEFINED and name != old.name:
|
||||||
changes["name"] = name
|
changes["name"] = name
|
||||||
changes["normalized_name"] = _normalize_floor_name(name)
|
changes["normalized_name"] = normalize_name(name)
|
||||||
|
|
||||||
if not changes:
|
if not changes:
|
||||||
return old
|
return old
|
||||||
|
@ -229,7 +185,7 @@ class FloorRegistry:
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load the floor registry."""
|
"""Load the floor registry."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
floors = FloorRegistryItems()
|
floors = NormalizedNameBaseRegistryItems[FloorEntry]()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for floor in data["floors"]:
|
for floor in data["floors"]:
|
||||||
|
@ -240,7 +196,7 @@ class FloorRegistry:
|
||||||
assert isinstance(floor["name"], str)
|
assert isinstance(floor["name"], str)
|
||||||
assert isinstance(floor["floor_id"], str)
|
assert isinstance(floor["floor_id"], str)
|
||||||
|
|
||||||
normalized_name = _normalize_floor_name(floor["name"])
|
normalized_name = normalize_name(floor["name"])
|
||||||
floors[floor["floor_id"]] = FloorEntry(
|
floors[floor["floor_id"]] = FloorEntry(
|
||||||
aliases=set(floor["aliases"]),
|
aliases=set(floor["aliases"]),
|
||||||
icon=floor["icon"],
|
icon=floor["icon"],
|
||||||
|
@ -286,8 +242,3 @@ async def async_load(hass: HomeAssistant) -> None:
|
||||||
assert DATA_REGISTRY not in hass.data
|
assert DATA_REGISTRY not in hass.data
|
||||||
hass.data[DATA_REGISTRY] = FloorRegistry(hass)
|
hass.data[DATA_REGISTRY] = FloorRegistry(hass)
|
||||||
await hass.data[DATA_REGISTRY].async_load()
|
await hass.data[DATA_REGISTRY].async_load()
|
||||||
|
|
||||||
|
|
||||||
def _normalize_floor_name(floor_name: str) -> str:
|
|
||||||
"""Normalize a floor name by removing whitespace and case folding."""
|
|
||||||
return floor_name.casefold().replace(" ", "")
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
"""Provide a way to label and group anything."""
|
"""Provide a way to label and group anything."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import UserDict
|
from collections.abc import Iterable
|
||||||
from collections.abc import Iterable, ValuesView
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, TypedDict, cast
|
from typing import Literal, TypedDict, cast
|
||||||
|
@ -10,6 +9,11 @@ from typing import Literal, TypedDict, cast
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import slugify
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
|
from .normalized_name_base_registry import (
|
||||||
|
NormalizedNameBaseRegistryEntry,
|
||||||
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
|
)
|
||||||
from .storage import Store
|
from .storage import Store
|
||||||
from .typing import UNDEFINED, EventType, UndefinedType
|
from .typing import UNDEFINED, EventType, UndefinedType
|
||||||
|
|
||||||
|
@ -30,68 +34,20 @@ class EventLabelRegistryUpdatedData(TypedDict):
|
||||||
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
|
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True, frozen=True)
|
@dataclass(slots=True, frozen=True, kw_only=True)
|
||||||
class LabelEntry:
|
class LabelEntry(NormalizedNameBaseRegistryEntry):
|
||||||
"""Label Registry Entry."""
|
"""Label Registry Entry."""
|
||||||
|
|
||||||
label_id: str
|
label_id: str
|
||||||
name: str
|
|
||||||
normalized_name: str
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
color: str | None = None
|
color: str | None = None
|
||||||
icon: str | None = None
|
icon: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LabelRegistryItems(UserDict[str, LabelEntry]):
|
|
||||||
"""Container for label registry items, maps label id -> entry.
|
|
||||||
|
|
||||||
Maintains an additional index:
|
|
||||||
- normalized name -> entry
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize the container."""
|
|
||||||
super().__init__()
|
|
||||||
self._normalized_names: dict[str, LabelEntry] = {}
|
|
||||||
|
|
||||||
def values(self) -> ValuesView[LabelEntry]:
|
|
||||||
"""Return the underlying values to avoid __iter__ overhead."""
|
|
||||||
return self.data.values()
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, entry: LabelEntry) -> None:
|
|
||||||
"""Add an item."""
|
|
||||||
data = self.data
|
|
||||||
normalized_name = _normalize_label_name(entry.name)
|
|
||||||
|
|
||||||
if key in data:
|
|
||||||
old_entry = data[key]
|
|
||||||
if (
|
|
||||||
normalized_name != old_entry.normalized_name
|
|
||||||
and normalized_name in self._normalized_names
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"The name {entry.name} ({normalized_name}) is already in use"
|
|
||||||
)
|
|
||||||
del self._normalized_names[old_entry.normalized_name]
|
|
||||||
data[key] = entry
|
|
||||||
self._normalized_names[normalized_name] = entry
|
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
|
||||||
"""Remove an item."""
|
|
||||||
entry = self[key]
|
|
||||||
normalized_name = _normalize_label_name(entry.name)
|
|
||||||
del self._normalized_names[normalized_name]
|
|
||||||
super().__delitem__(key)
|
|
||||||
|
|
||||||
def get_label_by_name(self, name: str) -> LabelEntry | None:
|
|
||||||
"""Get label by name."""
|
|
||||||
return self._normalized_names.get(_normalize_label_name(name))
|
|
||||||
|
|
||||||
|
|
||||||
class LabelRegistry:
|
class LabelRegistry:
|
||||||
"""Class to hold a registry of labels."""
|
"""Class to hold a registry of labels."""
|
||||||
|
|
||||||
labels: LabelRegistryItems
|
labels: NormalizedNameBaseRegistryItems[LabelEntry]
|
||||||
_label_data: dict[str, LabelEntry]
|
_label_data: dict[str, LabelEntry]
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
@ -116,7 +72,7 @@ class LabelRegistry:
|
||||||
@callback
|
@callback
|
||||||
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
|
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
|
||||||
"""Get label by name."""
|
"""Get label by name."""
|
||||||
return self.labels.get_label_by_name(name)
|
return self.labels.get_by_name(name)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_list_labels(self) -> Iterable[LabelEntry]:
|
def async_list_labels(self) -> Iterable[LabelEntry]:
|
||||||
|
@ -148,7 +104,7 @@ class LabelRegistry:
|
||||||
f"The name {name} ({label.normalized_name}) is already in use"
|
f"The name {name} ({label.normalized_name}) is already in use"
|
||||||
)
|
)
|
||||||
|
|
||||||
normalized_name = _normalize_label_name(name)
|
normalized_name = normalize_name(name)
|
||||||
|
|
||||||
label = LabelEntry(
|
label = LabelEntry(
|
||||||
color=color,
|
color=color,
|
||||||
|
@ -207,7 +163,7 @@ class LabelRegistry:
|
||||||
|
|
||||||
if name is not UNDEFINED and name != old.name:
|
if name is not UNDEFINED and name != old.name:
|
||||||
changes["name"] = name
|
changes["name"] = name
|
||||||
changes["normalized_name"] = _normalize_label_name(name)
|
changes["normalized_name"] = normalize_name(name)
|
||||||
|
|
||||||
if not changes:
|
if not changes:
|
||||||
return old
|
return old
|
||||||
|
@ -228,7 +184,7 @@ class LabelRegistry:
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load the label registry."""
|
"""Load the label registry."""
|
||||||
data = await self._store.async_load()
|
data = await self._store.async_load()
|
||||||
labels = LabelRegistryItems()
|
labels = NormalizedNameBaseRegistryItems[LabelEntry]()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
for label in data["labels"]:
|
for label in data["labels"]:
|
||||||
|
@ -236,7 +192,7 @@ class LabelRegistry:
|
||||||
if label["label_id"] is None or label["name"] is None:
|
if label["label_id"] is None or label["name"] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
normalized_name = _normalize_label_name(label["name"])
|
normalized_name = normalize_name(label["name"])
|
||||||
labels[label["label_id"]] = LabelEntry(
|
labels[label["label_id"]] = LabelEntry(
|
||||||
color=label["color"],
|
color=label["color"],
|
||||||
description=label["description"],
|
description=label["description"],
|
||||||
|
@ -282,8 +238,3 @@ async def async_load(hass: HomeAssistant) -> None:
|
||||||
assert DATA_REGISTRY not in hass.data
|
assert DATA_REGISTRY not in hass.data
|
||||||
hass.data[DATA_REGISTRY] = LabelRegistry(hass)
|
hass.data[DATA_REGISTRY] = LabelRegistry(hass)
|
||||||
await hass.data[DATA_REGISTRY].async_load()
|
await hass.data[DATA_REGISTRY].async_load()
|
||||||
|
|
||||||
|
|
||||||
def _normalize_label_name(label_name: str) -> str:
|
|
||||||
"""Normalize a label name by removing whitespace and case folding."""
|
|
||||||
return label_name.casefold().replace(" ", "")
|
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Provide a base class for registries that use a normalized name index."""
|
||||||
|
from collections import UserDict
|
||||||
|
from collections.abc import ValuesView
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True, frozen=True, kw_only=True)
|
||||||
|
class NormalizedNameBaseRegistryEntry:
|
||||||
|
"""Normalized Name Base Registry Entry."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
normalized_name: str
|
||||||
|
|
||||||
|
|
||||||
|
_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_name(name: str) -> str:
|
||||||
|
"""Normalize a name by removing whitespace and case folding."""
|
||||||
|
return name.casefold().replace(" ", "")
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
|
||||||
|
"""Base container for normalized name registry items, maps key -> entry.
|
||||||
|
|
||||||
|
Maintains an additional index:
|
||||||
|
- normalized name -> entry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the container."""
|
||||||
|
super().__init__()
|
||||||
|
self._normalized_names: dict[str, _VT] = {}
|
||||||
|
|
||||||
|
def values(self) -> ValuesView[_VT]:
|
||||||
|
"""Return the underlying values to avoid __iter__ overhead."""
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, entry: _VT) -> None:
|
||||||
|
"""Add an item."""
|
||||||
|
data = self.data
|
||||||
|
normalized_name = normalize_name(entry.name)
|
||||||
|
|
||||||
|
if key in data:
|
||||||
|
old_entry = data[key]
|
||||||
|
if (
|
||||||
|
normalized_name != old_entry.normalized_name
|
||||||
|
and normalized_name in self._normalized_names
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The name {entry.name} ({normalized_name}) is already in use"
|
||||||
|
)
|
||||||
|
del self._normalized_names[old_entry.normalized_name]
|
||||||
|
data[key] = entry
|
||||||
|
self._normalized_names[normalized_name] = entry
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""Remove an item."""
|
||||||
|
entry = self[key]
|
||||||
|
normalized_name = normalize_name(entry.name)
|
||||||
|
del self._normalized_names[normalized_name]
|
||||||
|
super().__delitem__(key)
|
||||||
|
|
||||||
|
def get_by_name(self, name: str) -> _VT | None:
|
||||||
|
"""Get entry by name."""
|
||||||
|
return self._normalized_names.get(normalize_name(name))
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Tests for the normalized name base registry helper."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.helpers.normalized_name_base_registry import (
|
||||||
|
NormalizedNameBaseRegistryEntry,
|
||||||
|
NormalizedNameBaseRegistryItems,
|
||||||
|
normalize_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry_items():
|
||||||
|
"""Fixture for registry items."""
|
||||||
|
return NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry]()
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_name():
|
||||||
|
"""Test normalize_name."""
|
||||||
|
assert normalize_name("Hello World") == "helloworld"
|
||||||
|
assert normalize_name("HELLO WORLD") == "helloworld"
|
||||||
|
assert normalize_name(" Hello World ") == "helloworld"
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_items(
|
||||||
|
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
|
||||||
|
):
|
||||||
|
"""Test registry items."""
|
||||||
|
entry = NormalizedNameBaseRegistryEntry(
|
||||||
|
name="Hello World", normalized_name="helloworld"
|
||||||
|
)
|
||||||
|
registry_items["key"] = entry
|
||||||
|
assert registry_items["key"] == entry
|
||||||
|
assert list(registry_items.values()) == [entry]
|
||||||
|
assert registry_items.get_by_name("Hello World") == entry
|
||||||
|
|
||||||
|
# test update entry
|
||||||
|
entry2 = NormalizedNameBaseRegistryEntry(
|
||||||
|
name="Hello World 2", normalized_name="helloworld2"
|
||||||
|
)
|
||||||
|
registry_items["key"] = entry2
|
||||||
|
assert registry_items["key"] == entry2
|
||||||
|
assert list(registry_items.values()) == [entry2]
|
||||||
|
assert registry_items.get_by_name("Hello World 2") == entry2
|
||||||
|
|
||||||
|
# test delete entry
|
||||||
|
del registry_items["key"]
|
||||||
|
assert "key" not in registry_items
|
||||||
|
assert list(registry_items.values()) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_key_already_in_use(
|
||||||
|
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
|
||||||
|
):
|
||||||
|
"""Test key already in use."""
|
||||||
|
entry = NormalizedNameBaseRegistryEntry(
|
||||||
|
name="Hello World", normalized_name="helloworld"
|
||||||
|
)
|
||||||
|
registry_items["key"] = entry
|
||||||
|
|
||||||
|
# should raise ValueError if we update a
|
||||||
|
# key with a entry with the same normalized name
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
entry = NormalizedNameBaseRegistryEntry(
|
||||||
|
name="Hello World 2", normalized_name="helloworld2"
|
||||||
|
)
|
||||||
|
registry_items["key2"] = entry
|
||||||
|
registry_items["key"] = entry
|
Loading…
Reference in New Issue