Add normalized name registry items base class (#111666)

* Add normalized name base registry items class

* Add tests
pull/111786/head
Jan-Philipp Benecke 2024-02-29 01:31:33 +01:00 committed by GitHub
parent f1398dd127
commit f31244bac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 187 deletions

View File

@ -1,8 +1,7 @@
"""Provide a way to connect devices to one physical location."""
from __future__ import annotations
from collections import UserDict
from collections.abc import Iterable, ValuesView
from collections.abc import Iterable
import dataclasses
from typing import Any, Literal, TypedDict, cast
@ -10,6 +9,11 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from . import device_registry as dr, entity_registry as er
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store
from .typing import UNDEFINED, UndefinedType
@ -29,7 +33,7 @@ class EventAreaRegistryUpdatedData(TypedDict):
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class AreaEntry:
class AreaEntry(NormalizedNameBaseRegistryEntry):
"""Area Registry Entry."""
aliases: set[str]
@ -37,57 +41,9 @@ class AreaEntry:
icon: str | None
id: str
labels: set[str] = dataclasses.field(default_factory=set)
name: str
normalized_name: str
picture: str | None
class AreaRegistryItems(UserDict[str, AreaEntry]):
"""Container for area registry items, maps area id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, AreaEntry] = {}
def values(self) -> ValuesView[AreaEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: AreaEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_area_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_area_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
return self._normalized_names.get(normalize_area_name(name))
class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
"""Store area registry data."""
@ -133,7 +89,7 @@ class AreaRegistryStore(Store[dict[str, list[dict[str, Any]]]]):
class AreaRegistry:
"""Class to hold a registry of areas."""
areas: AreaRegistryItems
areas: NormalizedNameBaseRegistryItems[AreaEntry]
_area_data: dict[str, AreaEntry]
def __init__(self, hass: HomeAssistant) -> None:
@ -159,7 +115,7 @@ class AreaRegistry:
@callback
def async_get_area_by_name(self, name: str) -> AreaEntry | None:
"""Get area by name."""
return self.areas.get_area_by_name(name)
return self.areas.get_by_name(name)
@callback
def async_list_areas(self) -> Iterable[AreaEntry]:
@ -185,7 +141,7 @@ class AreaRegistry:
picture: str | None = None,
) -> AreaEntry:
"""Create a new area."""
normalized_name = normalize_area_name(name)
normalized_name = normalize_name(name)
if self.async_get_area_by_name(name):
raise ValueError(f"The name {name} ({normalized_name}) is already in use")
@ -281,7 +237,7 @@ class AreaRegistry:
if name is not UNDEFINED and name != old.name:
new_values["name"] = name
new_values["normalized_name"] = normalize_area_name(name)
new_values["normalized_name"] = normalize_name(name)
if not new_values:
return old
@ -297,12 +253,12 @@ class AreaRegistry:
data = await self._store.async_load()
areas = AreaRegistryItems()
areas = NormalizedNameBaseRegistryItems[AreaEntry]()
if data is not None:
for area in data["areas"]:
assert area["name"] is not None and area["id"] is not None
normalized_name = normalize_area_name(area["name"])
normalized_name = normalize_name(area["name"])
areas[area["id"]] = AreaEntry(
aliases=set(area["aliases"]),
floor_id=area["floor_id"],
@ -421,8 +377,3 @@ def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaE
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
"""Return entries that match a label."""
return [area for area in registry.areas.values() if label_id in area.labels]
def normalize_area_name(area_name: str) -> str:
"""Normalize an area name by removing whitespace and case folding."""
return area_name.casefold().replace(" ", "")

View File

@ -1,8 +1,7 @@
"""Provide a way to assign areas to floors in one's home."""
from __future__ import annotations
from collections import UserDict
from collections.abc import Iterable, ValuesView
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypedDict, cast
@ -10,6 +9,11 @@ from typing import TYPE_CHECKING, Literal, TypedDict, cast
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store
from .typing import UNDEFINED, EventType, UndefinedType
@ -31,67 +35,19 @@ EventFloorRegistryUpdated = EventType[EventFloorRegistryUpdatedData]
@dataclass(slots=True, kw_only=True, frozen=True)
class FloorEntry:
class FloorEntry(NormalizedNameBaseRegistryEntry):
"""Floor registry entry."""
aliases: set[str]
floor_id: str
icon: str | None = None
level: int = 0
name: str
normalized_name: str
class FloorRegistryItems(UserDict[str, FloorEntry]):
"""Container for floor registry items, maps floor id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, FloorEntry] = {}
def values(self) -> ValuesView[FloorEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: FloorEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = _normalize_floor_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = _normalize_floor_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_floor_by_name(self, name: str) -> FloorEntry | None:
"""Get floor by name."""
return self._normalized_names.get(_normalize_floor_name(name))
class FloorRegistry:
"""Class to hold a registry of floors."""
floors: FloorRegistryItems
floors: NormalizedNameBaseRegistryItems[FloorEntry]
_floor_data: dict[str, FloorEntry]
def __init__(self, hass: HomeAssistant) -> None:
@ -118,7 +74,7 @@ class FloorRegistry:
@callback
def async_get_floor_by_name(self, name: str) -> FloorEntry | None:
"""Get floor by name."""
return self.floors.get_floor_by_name(name)
return self.floors.get_by_name(name)
@callback
def async_list_floors(self) -> Iterable[FloorEntry]:
@ -150,7 +106,7 @@ class FloorRegistry:
f"The name {name} ({floor.normalized_name}) is already in use"
)
normalized_name = _normalize_floor_name(name)
normalized_name = normalize_name(name)
floor = FloorEntry(
aliases=aliases or set(),
@ -208,7 +164,7 @@ class FloorRegistry:
}
if name is not UNDEFINED and name != old.name:
changes["name"] = name
changes["normalized_name"] = _normalize_floor_name(name)
changes["normalized_name"] = normalize_name(name)
if not changes:
return old
@ -229,7 +185,7 @@ class FloorRegistry:
async def async_load(self) -> None:
"""Load the floor registry."""
data = await self._store.async_load()
floors = FloorRegistryItems()
floors = NormalizedNameBaseRegistryItems[FloorEntry]()
if data is not None:
for floor in data["floors"]:
@ -240,7 +196,7 @@ class FloorRegistry:
assert isinstance(floor["name"], str)
assert isinstance(floor["floor_id"], str)
normalized_name = _normalize_floor_name(floor["name"])
normalized_name = normalize_name(floor["name"])
floors[floor["floor_id"]] = FloorEntry(
aliases=set(floor["aliases"]),
icon=floor["icon"],
@ -286,8 +242,3 @@ async def async_load(hass: HomeAssistant) -> None:
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = FloorRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()
def _normalize_floor_name(floor_name: str) -> str:
"""Normalize a floor name by removing whitespace and case folding."""
return floor_name.casefold().replace(" ", "")

View File

@ -1,8 +1,7 @@
"""Provide a way to label and group anything."""
from __future__ import annotations
from collections import UserDict
from collections.abc import Iterable, ValuesView
from collections.abc import Iterable
import dataclasses
from dataclasses import dataclass
from typing import Literal, TypedDict, cast
@ -10,6 +9,11 @@ from typing import Literal, TypedDict, cast
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import slugify
from .normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
from .storage import Store
from .typing import UNDEFINED, EventType, UndefinedType
@ -30,68 +34,20 @@ class EventLabelRegistryUpdatedData(TypedDict):
EventLabelRegistryUpdated = EventType[EventLabelRegistryUpdatedData]
@dataclass(slots=True, frozen=True)
class LabelEntry:
@dataclass(slots=True, frozen=True, kw_only=True)
class LabelEntry(NormalizedNameBaseRegistryEntry):
"""Label Registry Entry."""
label_id: str
name: str
normalized_name: str
description: str | None = None
color: str | None = None
icon: str | None = None
class LabelRegistryItems(UserDict[str, LabelEntry]):
"""Container for label registry items, maps label id -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, LabelEntry] = {}
def values(self) -> ValuesView[LabelEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: LabelEntry) -> None:
"""Add an item."""
data = self.data
normalized_name = _normalize_label_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = _normalize_label_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_label_by_name(self, name: str) -> LabelEntry | None:
"""Get label by name."""
return self._normalized_names.get(_normalize_label_name(name))
class LabelRegistry:
"""Class to hold a registry of labels."""
labels: LabelRegistryItems
labels: NormalizedNameBaseRegistryItems[LabelEntry]
_label_data: dict[str, LabelEntry]
def __init__(self, hass: HomeAssistant) -> None:
@ -116,7 +72,7 @@ class LabelRegistry:
@callback
def async_get_label_by_name(self, name: str) -> LabelEntry | None:
"""Get label by name."""
return self.labels.get_label_by_name(name)
return self.labels.get_by_name(name)
@callback
def async_list_labels(self) -> Iterable[LabelEntry]:
@ -148,7 +104,7 @@ class LabelRegistry:
f"The name {name} ({label.normalized_name}) is already in use"
)
normalized_name = _normalize_label_name(name)
normalized_name = normalize_name(name)
label = LabelEntry(
color=color,
@ -207,7 +163,7 @@ class LabelRegistry:
if name is not UNDEFINED and name != old.name:
changes["name"] = name
changes["normalized_name"] = _normalize_label_name(name)
changes["normalized_name"] = normalize_name(name)
if not changes:
return old
@ -228,7 +184,7 @@ class LabelRegistry:
async def async_load(self) -> None:
"""Load the label registry."""
data = await self._store.async_load()
labels = LabelRegistryItems()
labels = NormalizedNameBaseRegistryItems[LabelEntry]()
if data is not None:
for label in data["labels"]:
@ -236,7 +192,7 @@ class LabelRegistry:
if label["label_id"] is None or label["name"] is None:
continue
normalized_name = _normalize_label_name(label["name"])
normalized_name = normalize_name(label["name"])
labels[label["label_id"]] = LabelEntry(
color=label["color"],
description=label["description"],
@ -282,8 +238,3 @@ async def async_load(hass: HomeAssistant) -> None:
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = LabelRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()
def _normalize_label_name(label_name: str) -> str:
"""Normalize a label name by removing whitespace and case folding."""
return label_name.casefold().replace(" ", "")

View File

@ -0,0 +1,67 @@
"""Provide a base class for registries that use a normalized name index."""
from collections import UserDict
from collections.abc import ValuesView
from dataclasses import dataclass
from typing import TypeVar
@dataclass(slots=True, frozen=True, kw_only=True)
class NormalizedNameBaseRegistryEntry:
"""Normalized Name Base Registry Entry."""
name: str
normalized_name: str
_VT = TypeVar("_VT", bound=NormalizedNameBaseRegistryEntry)
def normalize_name(name: str) -> str:
"""Normalize a name by removing whitespace and case folding."""
return name.casefold().replace(" ", "")
class NormalizedNameBaseRegistryItems(UserDict[str, _VT]):
"""Base container for normalized name registry items, maps key -> entry.
Maintains an additional index:
- normalized name -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._normalized_names: dict[str, _VT] = {}
def values(self) -> ValuesView[_VT]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()
def __setitem__(self, key: str, entry: _VT) -> None:
"""Add an item."""
data = self.data
normalized_name = normalize_name(entry.name)
if key in data:
old_entry = data[key]
if (
normalized_name != old_entry.normalized_name
and normalized_name in self._normalized_names
):
raise ValueError(
f"The name {entry.name} ({normalized_name}) is already in use"
)
del self._normalized_names[old_entry.normalized_name]
data[key] = entry
self._normalized_names[normalized_name] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
normalized_name = normalize_name(entry.name)
del self._normalized_names[normalized_name]
super().__delitem__(key)
def get_by_name(self, name: str) -> _VT | None:
"""Get entry by name."""
return self._normalized_names.get(normalize_name(name))

View File

@ -0,0 +1,67 @@
"""Tests for the normalized name base registry helper."""
import pytest
from homeassistant.helpers.normalized_name_base_registry import (
NormalizedNameBaseRegistryEntry,
NormalizedNameBaseRegistryItems,
normalize_name,
)
@pytest.fixture
def registry_items():
"""Fixture for registry items."""
return NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry]()
def test_normalize_name():
"""Test normalize_name."""
assert normalize_name("Hello World") == "helloworld"
assert normalize_name("HELLO WORLD") == "helloworld"
assert normalize_name(" Hello World ") == "helloworld"
def test_registry_items(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
):
"""Test registry items."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
registry_items["key"] = entry
assert registry_items["key"] == entry
assert list(registry_items.values()) == [entry]
assert registry_items.get_by_name("Hello World") == entry
# test update entry
entry2 = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
registry_items["key"] = entry2
assert registry_items["key"] == entry2
assert list(registry_items.values()) == [entry2]
assert registry_items.get_by_name("Hello World 2") == entry2
# test delete entry
del registry_items["key"]
assert "key" not in registry_items
assert list(registry_items.values()) == []
def test_key_already_in_use(
registry_items: NormalizedNameBaseRegistryItems[NormalizedNameBaseRegistryEntry],
):
"""Test key already in use."""
entry = NormalizedNameBaseRegistryEntry(
name="Hello World", normalized_name="helloworld"
)
registry_items["key"] = entry
# should raise ValueError if we update a
# key with a entry with the same normalized name
with pytest.raises(ValueError):
entry = NormalizedNameBaseRegistryEntry(
name="Hello World 2", normalized_name="helloworld2"
)
registry_items["key2"] = entry
registry_items["key"] = entry