Add normalized name registry items base class (#111666)

* Add normalized name base registry items class

* Add tests
pull/111786/head
Jan-Philipp Benecke 2024-02-29 01:31:33 +01:00 committed by GitHub
parent f1398dd127
commit f31244bac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 187 deletions

View File

@ -1,8 +1,7 @@
"""Provide a way to connect devices to one physical location.""" """Provide a way to connect devices to one physical location."""
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections.abc import Iterable
from collections.abc import Iterable, ValuesView
import dataclasses import dataclasses
from typing import Any, Literal, TypedDict, cast from typing import Any, Literal, TypedDict, cast
@ -10,6 +9,11 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify from homeassistant.util import slugify
from . import device_registry as dr, entity_registry as er from . import device_registry as dr, entity_registry as er
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store from .storage import Store
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -29,7 +33,7 @@ class EventAreaRegistryUpdatedData(TypedDict):
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) @dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class AreaEntry: class AreaEntry(NormalizedNameBaseRegistryEntry):
"""Area Registry Entry.""" """Area Registry Entry."""
aliases: set[str] aliases: set[str]
@ -37,57 +41,9 @@ class AreaEntry:
icon: str | None icon: str | None
id: str id: str
labels: set[str] = dataclasses.field(default_factory=set) labels: set[str] = dataclasses.field(default_factory=set)
name: str
normalized_name: str
picture: str | None picture: str | None
class AreaRegistryItems(UserDict[str, AreaEntry]):
"""Container for area registry items, maps area id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, AreaEntry] = {}
def values(self) -> ValuesView[AreaEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: AreaEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_area_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_area_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
return self._normalized_names.get(normalize_area_name(name))
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]): class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
"""Store area registry data.""" """Store area registry data."""
@ -133,7 +89,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
class AreaRegistry: class AreaRegistry:
"""Class to hold a registry of areas.""" """Class to hold a registry of areas."""
areas: AreaRegistryItems areas: NormalizedNameBaseRegistryItems[AreaEntry]
_area_data: dict[str, AreaEntry] _area_data: dict[str, AreaEntry]
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -159,7 +115,7 @@ class AreaRegistry:
@callback @callback
def async_get_area_by_name(self, name: str) -> AreaEntry | None: def async_get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name.""" """Get area by name."""
return self.areas.get_area_by_name(name) return self.areas.get_by_name(name)
@callback @callback
def async_list_areas(self) -> Iterable[AreaEntry]: def async_list_areas(self) -> Iterable[AreaEntry]:
@ -185,7 +141,7 @@ class AreaRegistry:
picture: str | None = None, picture: str | None = None,
) -> AreaEntry: ) -> AreaEntry:
"""Create a new area.""" """Create a new area."""
normalized_name = normalize_area_name(name) normalized_name = normalize_name(name)
if self.async_get_area_by_name(name): if self.async_get_area_by_name(name):
raise ValueError(f"The name {name} ({normalized_name}) is already in use") raise ValueError(f"The name {name} ({normalized_name}) is already in use")
@ -281,7 +237,7 @@ class AreaRegistry:
if name is not UNDEFINED and name != old.name: if name is not UNDEFINED and name != old.name:
new_values["name"] = name new_values["name"] = name
new_values["normalized_name"] = normalize_area_name(name) new_values["normalized_name"] = normalize_name(name)
if not new_values: if not new_values:
return old return old
@ -297,12 +253,12 @@ class AreaRegistry:
data = await self._store.async_load() data = await self._store.async_load()
areas = AreaRegistryItems() areas = NormalizedNameBaseRegistryItems[AreaEntry]()
if data is not None: if data is not None:
for area in data["areas"]: for area in data["areas"]:
assert area["name"] is not None and area["id"] is not None assert area["name"] is not None and area["id"] is not None
normalized_name = normalize_area_name(area["name"]) normalized_name = normalize_name(area["name"])
areas[area["id"]] = AreaEntry( areas[area["id"]] = AreaEntry(
aliases=set(area["aliases"]), aliases=set(area["aliases"]),
floor_id=area["floor_id"], floor_id=area["floor_id"],
@ -421,8 +377,3 @@ def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaE
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]: def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
"""Return entries that match a label.""" """Return entries that match a label."""
return [area for area in registry.areas.values() if label_id in area.labels] return [area for area in registry.areas.values() if label_id in area.labels]
def normalize_area_name(area_name: str) -> str:
"""Normalize an area name by removing whitespace and case folding."""
return area_name.casefold().replace(" ", "")

View File

@ -1,8 +1,7 @@
"""Provide a way to assign areas to floors in one's home.""" """Provide a way to assign areas to floors in one's home."""
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections.abc import Iterable
from collections.abc import Iterable, ValuesView
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypedDict, cast from typing import TYPE_CHECKING, Literal, TypedDict, cast
@ -10,6 +9,11 @@ from typing import TYPE_CHECKING, Literal, TypedDict, cast
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify from homeassistant.util import slugify
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store from .storage import Store
from .typing import UNDEFINED, EventType, UndefinedType from .typing import UNDEFINED, EventType, UndefinedType
@ -31,67 +35,19 @@ EventFloorRegistryUpdated = EventType[EventFloorRegistryUpdatedData]
@dataclass(slots=True, kw_only=True, frozen=True) @dataclass(slots=True, kw_only=True, frozen=True)
class FloorEntry: class FloorEntry(NormalizedNameBaseRegistryEntry):
"""Floor registry entry.""" """Floor registry entry."""
aliases: set[str] aliases: set[str]
floor_id: str floor_id: str
icon: str | None = None icon: str | None = None
level: int = 0 level: int = 0
name: str
normalized_name: str
class FloorRegistryItems(UserDict[str, FloorEntry]):
"""Container for floor registry items, maps floor id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, FloorEntry] = {}
def values(self) -> ValuesView[FloorEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: FloorEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = _normalize_floor_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = _normalize_floor_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_floor_by_name(self, name: str) -> FloorEntry | None:
"""Get floor by name."""
return self._normalized_names.get(_normalize_floor_name(name))
class FloorRegistry: class FloorRegistry:
"""Class to hold a registry of floors.""" """Class to hold a registry of floors."""
floors: FloorRegistryItems floors: NormalizedNameBaseRegistryItems[FloorEntry]
_floor_data: dict[str, FloorEntry] _floor_data: dict[str, FloorEntry]
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -118,7 +74,7 @@ class FloorRegistry:
@callback @callback
def async_get_floor_by_name(self, name: str) -> FloorEntry | None: def async_get_floor_by_name(self, name: str) -> FloorEntry | None:
"""Get floor by name.""" """Get floor by name."""
return self.floors.get_floor_by_name(name) return self.floors.get_by_name(name)
@callback @callback
def async_list_floors(self) -> Iterable[FloorEntry]: def async_list_floors(self) -> Iterable[FloorEntry]:
@ -150,7 +106,7 @@ class FloorRegistry:
f"The name {name} ({floor.normalized_name}) is already in use" f"The name {name} ({floor.normalized_name}) is already in use"
) )
normalized_name = _normalize_floor_name(name) normalized_name = normalize_name(name)
floor = FloorEntry( floor = FloorEntry(
aliases=aliases or set(), aliases=aliases or set(),
@ -208,7 +164,7 @@ class FloorRegistry:
} }
if name is not UNDEFINED and name != old.name: if name is not UNDEFINED and name != old.name:
changes["name"] = name changes["name"] = name
changes["normalized_name"] = _normalize_floor_name(name) changes["normalized_name"] = normalize_name(name)
if not changes: if not changes:
return old return old
@ -229,7 +185,7 @@ class FloorRegistry:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the floor registry.""" """Load the floor registry."""
data = await self._store.async_load() data = await self._store.async_load()
floors = FloorRegistryItems() floors = NormalizedNameBaseRegistryItems[FloorEntry]()
if data is not None: if data is not None:
for floor in data["floors"]: for floor in data["floors"]:
@ -240,7 +196,7 @@ class FloorRegistry:
assert isinstance(floor["name"], str) assert isinstance(floor["name"], str)
assert isinstance(floor["floor_id"], str) assert isinstance(floor["floor_id"], str)
normalized_name = _normalize_floor_name(floor["name"]) normalized_name = normalize_name(floor["name"])
floors[floor["floor_id"]] = FloorEntry( floors[floor["floor_id"]] = FloorEntry(
aliases=set(floor["aliases"]), aliases=set(floor["aliases"]),
icon=floor["icon"], icon=floor["icon"],
@ -286,8 +242,3 @@ async def async_load(hass: HomeAssistant) -> None:
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = FloorRegistry(hass) hass.data[DATA_REGISTRY] = FloorRegistry(hass)
await hass.data[DATA_REGISTRY].async_load() await hass.data[DATA_REGISTRY].async_load()
def _normalize_floor_name(floor_name: str) -> str:
"""Normalize a floor name by removing whitespace and case folding."""
return floor_name.casefold().replace(" ", "")

View File

@ -1,8 +1,7 @@
"""Provide a way to label and group anything.""" """Provide a way to label and group anything."""
from __future__ import annotations from __future__ import annotations
from collections import UserDict from collections.abc import Iterable
from collections.abc import Iterable, ValuesView
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, TypedDict, cast from typing import Literal, TypedDict, cast
@ -10,6 +9,11 @@ from typing import Literal, TypedDict, cast
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify from homeassistant.util import slugify
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store from .storage import Store
from .typing import UNDEFINED, EventType, UndefinedType from .typing import UNDEFINED, EventType, UndefinedType
@ -30,68 +34,20 @@ class EventLabelRegistryUpdatedData(TypedDict):
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData] EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
@dataclass(slots=True, frozen=True) @dataclass(slots=True, frozen=True, kw_only=True)
class LabelEntry: class LabelEntry(NormalizedNameBaseRegistryEntry):
"""Label Registry Entry.""" """Label Registry Entry."""
label_id: str label_id: str
name: str
normalized_name: str
description: str | None = None description: str | None = None
color: str | None = None color: str | None = None
icon: str | None = None icon: str | None = None
class LabelRegistryItems(UserDict[str, LabelEntry]):
"""Container for label registry items, maps label id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, LabelEntry] = {}
def values(self) -> ValuesView[LabelEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: LabelEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = _normalize_label_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = _normalize_label_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_label_by_name(self, name: str) -> LabelEntry | None:
"""Get label by name."""
return self._normalized_names.get(_normalize_label_name(name))
class LabelRegistry: class LabelRegistry:
"""Class to hold a registry of labels.""" """Class to hold a registry of labels."""
labels: LabelRegistryItems labels: NormalizedNameBaseRegistryItems[LabelEntry]
_label_data: dict[str, LabelEntry] _label_data: dict[str, LabelEntry]
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
@ -116,7 +72,7 @@ class LabelRegistry:
@callback @callback
def async_get_label_by_name(self, name: str) -> LabelEntry | None: def async_get_label_by_name(self, name: str) -> LabelEntry | None:
"""Get label by name.""" """Get label by name."""
return self.labels.get_label_by_name(name) return self.labels.get_by_name(name)
@callback @callback
def async_list_labels(self) -> Iterable[LabelEntry]: def async_list_labels(self) -> Iterable[LabelEntry]:
@ -148,7 +104,7 @@ class LabelRegistry:
f"The name {name} ({label.normalized_name}) is already in use" f"The name {name} ({label.normalized_name}) is already in use"
) )
normalized_name = _normalize_label_name(name) normalized_name = normalize_name(name)
label = LabelEntry( label = LabelEntry(
color=color, color=color,
@ -207,7 +163,7 @@ class LabelRegistry:
if name is not UNDEFINED and name != old.name: if name is not UNDEFINED and name != old.name:
changes["name"] = name changes["name"] = name
changes["normalized_name"] = _normalize_label_name(name) changes["normalized_name"] = normalize_name(name)
if not changes: if not changes:
return old return old
@ -228,7 +184,7 @@ class LabelRegistry:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the label registry.""" """Load the label registry."""
data = await self._store.async_load() data = await self._store.async_load()
labels = LabelRegistryItems() labels = NormalizedNameBaseRegistryItems[LabelEntry]()
if data is not None: if data is not None:
for label in data["labels"]: for label in data["labels"]:
@ -236,7 +192,7 @@ class LabelRegistry:
if label["label_id"] is None or label["name"] is None: if label["label_id"] is None or label["name"] is None:
continue continue
normalized_name = _normalize_label_name(label["name"]) normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry( labels[label["label_id"]] = LabelEntry(
color=label["color"], color=label["color"],
description=label["description"], description=label["description"],
@ -282,8 +238,3 @@ async def async_load(hass: HomeAssistant) -> None:
assert DATA_REGISTRY not in hass.data assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = LabelRegistry(hass) hass.data[DATA_REGISTRY] = LabelRegistry(hass)
await hass.data[DATA_REGISTRY].async_load() await hass.data[DATA_REGISTRY].async_load()
def _normalize_label_name(label_name: str) -> str:
"""Normalize a label name by removing whitespace and case folding."""
return label_name.casefold().replace(" ", "")

View File

@ -0,0 +1,67 @@
"""Provide a base class for registries that use a normalized name index."""
from collections import UserDict
from collections.abc import ValuesView
from dataclasses import dataclass
from typing import TypeVar
@dataclass(slots=True, frozen=True, kw_only=True)
class NormalizedNameBaseRegistryEntry:
"""Normalized Name Base Registry Entry."""
name: str
normalized_name: str
_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry)
def normalize_name(name: str) -> str:
"""Normalize a name by removing whitespace and case folding."""
return name.casefold().replace(" ", "")
class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
"""Base container for normalized name registry items, maps key -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, _VT] = {}
def values(self) -> ValuesView[_VT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: _VT) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_by_name(self, name: str) -> _VT | None:
"""Get entry by name."""
return self._normalized_names.get(normalize_name(name))

View File

@ -0,0 +1,67 @@
"""Tests for the normalized name base registry helper."""
import pytest
from homeassistant.helpers.normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
@pytest.fixture
def registry_items():
"""Fixture for registry items."""
return NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry]()
def test_normalize_name():
"""Test normalize_name."""
assert normalize_name("Hello World") == "helloworld"
assert normalize_name("HELLO WORLD") == "helloworld"
assert normalize_name(" Hello World ") == "helloworld"
def test_registry_items(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
):
"""Test registry items."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
registry_items["key"] = entry
assert registry_items["key"] == entry
assert list(registry_items.values()) == [entry]
assert registry_items.get_by_name("Hello World") == entry
# test update entry
entry2 = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
registry_items["key"] = entry2
assert registry_items["key"] == entry2
assert list(registry_items.values()) == [entry2]
assert registry_items.get_by_name("Hello World 2") == entry2
# test delete entry
del registry_items["key"]
assert "key" not in registry_items
assert list(registry_items.values()) == []
def test_key_already_in_use(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
):
"""Test key already in use."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
registry_items["key"] = entry
# should raise ValueError if we update a
# key with a entry with the same normalized name
with pytest.raises(ValueError):
entry = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
registry_items["key2"] = entry
registry_items["key"] = entry