core/homeassistant/components/thread/dataset_store.py

385 lines
14 KiB
Python

"""Persistently store thread datasets."""
from __future__ import annotations
import dataclasses
from datetime import datetime
import logging
from typing import Any, cast
from python_otbr_api import tlv_parser
from python_otbr_api.tlv_parser import MeshcopTLVType
from homeassistant.backports.functools import cached_property
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
DATA_STORE = "thread.datasets"
STORAGE_KEY = "thread.datasets"
STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 3
SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__)
class DatasetPreferredError(HomeAssistantError):
"""Raised when attempting to delete the preferred dataset."""
@dataclasses.dataclass(frozen=True)
class DatasetEntry:
"""Dataset store entry."""
preferred_border_agent_id: str | None
source: str
tlv: str
created: datetime = dataclasses.field(default_factory=dt_util.utcnow)
id: str = dataclasses.field(default_factory=ulid_util.ulid_now)
@property
def channel(self) -> int | None:
"""Return channel as an integer."""
if (channel := self.dataset.get(MeshcopTLVType.CHANNEL)) is None:
return None
return cast(tlv_parser.Channel, channel).channel
@cached_property
def dataset(self) -> dict[MeshcopTLVType, tlv_parser.MeshcopTLVItem]:
"""Return the dataset in dict format."""
return tlv_parser.parse_tlv(self.tlv)
@property
def extended_pan_id(self) -> str:
"""Return extended PAN ID as a hex string."""
return str(self.dataset[MeshcopTLVType.EXTPANID])
@property
def network_name(self) -> str | None:
"""Return network name as a string."""
if (name := self.dataset.get(MeshcopTLVType.NETWORKNAME)) is None:
return None
return cast(tlv_parser.NetworkName, name).name
@property
def pan_id(self) -> str | None:
"""Return PAN ID as a hex string."""
return str(self.dataset.get(MeshcopTLVType.PANID))
def to_json(self) -> dict[str, Any]:
"""Return a JSON serializable representation for storage."""
return {
"created": self.created.isoformat(),
"id": self.id,
"preferred_border_agent_id": self.preferred_border_agent_id,
"source": self.source,
"tlv": self.tlv,
}
class DatasetStoreStore(Store):
"""Store Thread datasets."""
async def _async_migrate_func(
self, old_major_version: int, old_minor_version: int, old_data: dict[str, Any]
) -> dict[str, Any]:
"""Migrate to the new version."""
if old_major_version == 1:
data = old_data
if old_minor_version < 2:
# Deduplicate datasets
datasets: dict[str, DatasetEntry] = {}
preferred_dataset = old_data["preferred_dataset"]
for dataset in old_data["datasets"]:
created = cast(datetime, dt_util.parse_datetime(dataset["created"]))
entry = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=None,
source=dataset["source"],
tlv=dataset["tlv"],
)
if (
MeshcopTLVType.EXTPANID not in entry.dataset
or MeshcopTLVType.ACTIVETIMESTAMP not in entry.dataset
):
_LOGGER.warning(
"Dropped invalid Thread dataset '%s'", entry.tlv
)
if entry.id == preferred_dataset:
preferred_dataset = None
continue
if entry.extended_pan_id in datasets:
if datasets[entry.extended_pan_id].id == preferred_dataset:
_LOGGER.warning(
(
"Dropped duplicated Thread dataset '%s' "
"(duplicate of preferred dataset '%s')"
),
entry.tlv,
datasets[entry.extended_pan_id].tlv,
)
continue
new_timestamp = cast(
tlv_parser.Timestamp,
entry.dataset[MeshcopTLVType.ACTIVETIMESTAMP],
)
old_timestamp = cast(
tlv_parser.Timestamp,
datasets[entry.extended_pan_id].dataset[
MeshcopTLVType.ACTIVETIMESTAMP
],
)
if old_timestamp.seconds >= new_timestamp.seconds or (
old_timestamp.seconds == new_timestamp.seconds
and old_timestamp.ticks >= new_timestamp.ticks
):
_LOGGER.warning(
(
"Dropped duplicated Thread dataset '%s' "
"(duplicate of '%s')"
),
entry.tlv,
datasets[entry.extended_pan_id].tlv,
)
continue
_LOGGER.warning(
(
"Dropped duplicated Thread dataset '%s' "
"(duplicate of '%s')"
),
datasets[entry.extended_pan_id].tlv,
entry.tlv,
)
datasets[entry.extended_pan_id] = entry
data = {
"preferred_dataset": preferred_dataset,
"datasets": [dataset.to_json() for dataset in datasets.values()],
}
if old_minor_version < 3:
# Add border agent ID
for dataset in data["datasets"]:
dataset.setdefault("preferred_border_agent_id", None)
return data
class DatasetStore:
"""Class to hold a collection of thread datasets."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the dataset store."""
self.hass = hass
self.datasets: dict[str, DatasetEntry] = {}
self._preferred_dataset: str | None = None
self._store: Store[dict[str, Any]] = DatasetStoreStore(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
)
@callback
def async_add(
self, source: str, tlv: str, preferred_border_agent_id: str | None
) -> None:
"""Add dataset, does nothing if it already exists."""
# Make sure the tlv is valid
dataset = tlv_parser.parse_tlv(tlv)
# Don't allow adding a dataset which does not have an extended pan id or
# timestamp
if (
MeshcopTLVType.EXTPANID not in dataset
or MeshcopTLVType.ACTIVETIMESTAMP not in dataset
):
raise HomeAssistantError("Invalid dataset")
# Bail out if the dataset already exists
entry: DatasetEntry | None
for entry in self.datasets.values():
if entry.dataset == dataset:
if (
preferred_border_agent_id
and entry.preferred_border_agent_id is None
):
self.async_set_preferred_border_agent_id(
entry.id, preferred_border_agent_id
)
return
# Update if dataset with same extended pan id exists and the timestamp
# is newer
if entry := next(
(
entry
for entry in self.datasets.values()
if entry.dataset[MeshcopTLVType.EXTPANID]
== dataset[MeshcopTLVType.EXTPANID]
),
None,
):
new_timestamp = cast(
tlv_parser.Timestamp, dataset[MeshcopTLVType.ACTIVETIMESTAMP]
)
old_timestamp = cast(
tlv_parser.Timestamp,
entry.dataset[MeshcopTLVType.ACTIVETIMESTAMP],
)
if old_timestamp.seconds >= new_timestamp.seconds or (
old_timestamp.seconds == new_timestamp.seconds
and old_timestamp.ticks >= new_timestamp.ticks
):
_LOGGER.warning(
(
"Got dataset with same extended PAN ID and same or older active"
" timestamp, old dataset: '%s', new dataset: '%s'"
),
entry.tlv,
tlv,
)
return
_LOGGER.debug(
(
"Updating dataset with same extended PAN ID and newer active "
"timestamp, old dataset: '%s', new dataset: '%s'"
),
entry.tlv,
tlv,
)
self.datasets[entry.id] = dataclasses.replace(
self.datasets[entry.id], tlv=tlv
)
self.async_schedule_save()
if preferred_border_agent_id and entry.preferred_border_agent_id is None:
self.async_set_preferred_border_agent_id(
entry.id, preferred_border_agent_id
)
return
entry = DatasetEntry(
preferred_border_agent_id=preferred_border_agent_id, source=source, tlv=tlv
)
self.datasets[entry.id] = entry
# Set to preferred if there is no preferred dataset
if self._preferred_dataset is None:
self._preferred_dataset = entry.id
self.async_schedule_save()
@callback
def async_delete(self, dataset_id: str) -> None:
"""Delete dataset."""
if self._preferred_dataset == dataset_id:
raise DatasetPreferredError("attempt to remove preferred dataset")
del self.datasets[dataset_id]
self.async_schedule_save()
@callback
def async_get(self, dataset_id: str) -> DatasetEntry | None:
"""Get dataset by id."""
return self.datasets.get(dataset_id)
@callback
def async_set_preferred_border_agent_id(
self, dataset_id: str, border_agent_id: str
) -> None:
"""Set preferred border agent id of a dataset."""
self.datasets[dataset_id] = dataclasses.replace(
self.datasets[dataset_id], preferred_border_agent_id=border_agent_id
)
self.async_schedule_save()
@property
@callback
def preferred_dataset(self) -> str | None:
"""Get the id of the preferred dataset."""
return self._preferred_dataset
@preferred_dataset.setter
@callback
def preferred_dataset(self, dataset_id: str) -> None:
"""Set the preferred dataset."""
if dataset_id not in self.datasets:
raise KeyError("unknown dataset")
self._preferred_dataset = dataset_id
self.async_schedule_save()
async def async_load(self) -> None:
"""Load the datasets."""
data = await self._store.async_load()
datasets: dict[str, DatasetEntry] = {}
preferred_dataset: str | None = None
if data is not None:
for dataset in data["datasets"]:
created = cast(datetime, dt_util.parse_datetime(dataset["created"]))
datasets[dataset["id"]] = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=dataset["preferred_border_agent_id"],
source=dataset["source"],
tlv=dataset["tlv"],
)
preferred_dataset = data["preferred_dataset"]
self.datasets = datasets
self._preferred_dataset = preferred_dataset
@callback
def async_schedule_save(self) -> None:
"""Schedule saving the dataset store."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of datasets to store in a file."""
data: dict[str, Any] = {}
data["datasets"] = [dataset.to_json() for dataset in self.datasets.values()]
data["preferred_dataset"] = self._preferred_dataset
return data
@singleton(DATA_STORE)
async def async_get_store(hass: HomeAssistant) -> DatasetStore:
"""Get the dataset store."""
store = DatasetStore(hass)
await store.async_load()
return store
async def async_add_dataset(
hass: HomeAssistant,
source: str,
tlv: str,
*,
preferred_border_agent_id: str | None = None,
) -> None:
"""Add a dataset."""
store = await async_get_store(hass)
store.async_add(source, tlv, preferred_border_agent_id)
async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None:
"""Get a dataset."""
store = await async_get_store(hass)
if (entry := store.async_get(dataset_id)) is None:
return None
return entry.tlv
async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None:
"""Get the preferred dataset."""
store = await async_get_store(hass)
if (preferred_dataset := store.preferred_dataset) is None or (
entry := store.async_get(preferred_dataset)
) is None:
return None
return entry.tlv