core/homeassistant/helpers/storage.py

383 lines
13 KiB
Python
Raw Normal View History

"""Helper to help store data."""
2021-03-17 17:34:19 +00:00
from __future__ import annotations
import asyncio
2022-07-09 20:32:57 +00:00
from collections.abc import Callable, Mapping, Sequence
from contextlib import suppress
2021-11-18 23:56:22 +00:00
from copy import deepcopy
import inspect
from json import JSONDecodeError, JSONEncoder
import logging
import os
from typing import Any, Generic, TypeVar
from homeassistant.const import EVENT_HOMEASSISTANT_FINAL_WRITE
from homeassistant.core import (
CALLBACK_TYPE,
DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState,
Event,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import MAX_LOAD_CONCURRENTLY, bind_hass
from homeassistant.util import json as json_util
import homeassistant.util.dt as dt_util
from homeassistant.util.file import WriteError
from . import json as json_helper
# mypy: allow-untyped-calls, allow-untyped-defs, no-warn-return-any
# mypy: no-check-untyped-defs
2019-07-31 19:25:30 +00:00
STORAGE_DIR = ".storage"
_LOGGER = logging.getLogger(__name__)
STORAGE_SEMAPHORE = "storage_semaphore"
_T = TypeVar("_T", bound=Mapping[str, Any] | Sequence[Any])
2022-07-09 20:32:57 +00:00
@bind_hass
2019-07-31 19:25:30 +00:00
async def async_migrator(
hass: HomeAssistant,
old_path: str,
2023-06-20 20:50:10 +00:00
store: Store[_T],
2020-08-27 11:56:20 +00:00
*,
old_conf_load_func: Callable | None = None,
old_conf_migrate_func: Callable | None = None,
2023-06-20 20:50:10 +00:00
) -> _T | None:
"""Migrate old data to a store and then load data.
async def old_conf_migrate_func(old_data)
"""
# If we already have store data we have already migrated in the past.
2021-10-30 14:29:07 +00:00
if (store_data := await store.async_load()) is not None:
return store_data
2019-07-31 19:25:30 +00:00
def load_old_config():
"""Load old config."""
if not os.path.isfile(old_path):
return None
if old_conf_load_func is not None:
return old_conf_load_func(old_path)
return json_util.load_json(old_path)
config = await hass.async_add_executor_job(load_old_config)
if config is None:
return None
if old_conf_migrate_func is not None:
config = await old_conf_migrate_func(config)
await store.async_save(config)
await hass.async_add_executor_job(os.remove, old_path)
return config
@bind_hass
2022-07-09 20:32:57 +00:00
class Store(Generic[_T]):
"""Class to help storing data."""
2019-07-31 19:25:30 +00:00
def __init__(
self,
hass: HomeAssistant,
2019-07-31 19:25:30 +00:00
version: int,
key: str,
private: bool = False,
*,
atomic_writes: bool = False,
encoder: type[JSONEncoder] | None = None,
minor_version: int = 1,
read_only: bool = False,
) -> None:
"""Initialize storage class."""
self.version = version
self.minor_version = minor_version
self.key = key
self.hass = hass
self._private = private
2021-03-17 17:34:19 +00:00
self._data: dict[str, Any] | None = None
self._unsub_delay_listener: CALLBACK_TYPE | None = None
self._unsub_final_write_listener: CALLBACK_TYPE | None = None
self._write_lock = asyncio.Lock()
2022-07-09 20:32:57 +00:00
self._load_task: asyncio.Future[_T | None] | None = None
self._encoder = encoder
self._atomic_writes = atomic_writes
self._read_only = read_only
@property
def path(self):
"""Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key)
2022-07-09 20:32:57 +00:00
async def async_load(self) -> _T | None:
"""Load data.
If the expected version and minor version do not match the given
versions, the migrate function will be invoked with
migrate_func(version, minor_version, config).
Will ensure that when a call comes in while another one is in progress,
the second call will wait and return the result of the first call.
"""
if self._load_task is None:
self._load_task = self.hass.async_create_task(
self._async_load(), f"Storage load {self.key}"
)
return await self._load_task
2022-07-09 20:32:57 +00:00
async def _async_load(self) -> _T | None:
"""Load the data and ensure the task is removed."""
if STORAGE_SEMAPHORE not in self.hass.data:
self.hass.data[STORAGE_SEMAPHORE] = asyncio.Semaphore(MAX_LOAD_CONCURRENTLY)
try:
async with self.hass.data[STORAGE_SEMAPHORE]:
return await self._async_load_data()
finally:
self._load_task = None
async def _async_load_data(self):
"""Load the data."""
# Check if we have a pending write
if self._data is not None:
data = self._data
# If we didn't generate data yet, do it now.
2019-07-31 19:25:30 +00:00
if "data_func" in data:
data["data"] = data.pop("data_func")()
2021-11-18 23:56:22 +00:00
# We make a copy because code might assume it's safe to mutate loaded data
# and we don't want that to mess with what we're trying to store.
data = deepcopy(data)
else:
try:
data = await self.hass.async_add_executor_job(
json_util.load_json, self.path
)
except HomeAssistantError as err:
if isinstance(err.__cause__, JSONDecodeError):
# If we have a JSONDecodeError, it means the file is corrupt.
# We can't recover from this, so we'll log an error, rename the file and
# return None so that we can start with a clean slate which will
# allow startup to continue so they can restore from a backup.
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
corrupt_path = f"{self.path}{corrupt_postfix}"
await self.hass.async_add_executor_job(
os.rename, self.path, corrupt_path
)
storage_key = self.key
_LOGGER.error(
"Unrecoverable error decoding storage %s at %s; "
"This may indicate an unclean shutdown, invalid syntax "
"from manual edits, or disk corruption; "
"The corrupt file has been saved as %s; "
"It is recommended to restore from backup: %s",
storage_key,
self.path,
corrupt_path,
err,
)
from .issue_registry import ( # pylint: disable=import-outside-toplevel
IssueSeverity,
async_create_issue,
)
issue_domain = HOMEASSISTANT_DOMAIN
if (
domain := (storage_key.partition(".")[0])
) and domain in self.hass.config.components:
issue_domain = domain
async_create_issue(
self.hass,
HOMEASSISTANT_DOMAIN,
f"storage_corruption_{storage_key}_{isotime}",
is_fixable=True,
issue_domain=issue_domain,
translation_key="storage_corruption",
is_persistent=True,
severity=IssueSeverity.CRITICAL,
translation_placeholders={
"storage_key": storage_key,
"original_path": self.path,
"corrupt_path": corrupt_path,
"error": str(err),
},
)
return None
raise
if data == {}:
return None
# Add minor_version if not set
if "minor_version" not in data:
data["minor_version"] = 1
if (
data["version"] == self.version
and data["minor_version"] == self.minor_version
):
2019-07-31 19:25:30 +00:00
stored = data["data"]
else:
2019-07-31 19:25:30 +00:00
_LOGGER.info(
"Migrating %s storage from %s.%s to %s.%s",
2019-07-31 19:25:30 +00:00
self.key,
data["version"],
data["minor_version"],
2019-07-31 19:25:30 +00:00
self.version,
self.minor_version,
2019-07-31 19:25:30 +00:00
)
if len(inspect.signature(self._async_migrate_func).parameters) == 2:
stored = await self._async_migrate_func(data["version"], data["data"])
else:
try:
stored = await self._async_migrate_func(
data["version"], data["minor_version"], data["data"]
)
except NotImplementedError:
if data["version"] != self.version:
raise
stored = data["data"]
2022-11-28 14:42:08 +00:00
await self.async_save(stored)
return stored
2022-07-09 20:32:57 +00:00
async def async_save(self, data: _T) -> None:
"""Save data."""
self._data = {
"version": self.version,
"minor_version": self.minor_version,
"key": self.key,
"data": data,
}
if self.hass.state == CoreState.stopping:
self._async_ensure_final_write_listener()
return
await self._async_handle_write_data()
@callback
def async_delay_save(
self,
2022-07-09 20:32:57 +00:00
data_func: Callable[[], _T],
delay: float = 0,
) -> None:
"""Save data with an optional delay."""
# pylint: disable-next=import-outside-toplevel
from .event import async_call_later
self._data = {
"version": self.version,
"minor_version": self.minor_version,
"key": self.key,
"data_func": data_func,
}
self._async_cleanup_delay_listener()
self._async_ensure_final_write_listener()
if self.hass.state == CoreState.stopping:
return
self._unsub_delay_listener = async_call_later(
2019-07-31 19:25:30 +00:00
self.hass, delay, self._async_callback_delayed_write
)
@callback
2021-04-17 06:35:21 +00:00
def _async_ensure_final_write_listener(self) -> None:
"""Ensure that we write if we quit before delay has passed."""
if self._unsub_final_write_listener is None:
self._unsub_final_write_listener = self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_FINAL_WRITE, self._async_callback_final_write
2019-07-31 19:25:30 +00:00
)
@callback
2021-04-17 06:35:21 +00:00
def _async_cleanup_final_write_listener(self) -> None:
"""Clean up a stop listener."""
if self._unsub_final_write_listener is not None:
self._unsub_final_write_listener()
self._unsub_final_write_listener = None
@callback
2021-04-17 06:35:21 +00:00
def _async_cleanup_delay_listener(self) -> None:
"""Clean up a delay listener."""
if self._unsub_delay_listener is not None:
self._unsub_delay_listener()
self._unsub_delay_listener = None
async def _async_callback_delayed_write(self, _now):
"""Handle a delayed write callback."""
# catch the case where a call is scheduled and then we stop Home Assistant
if self.hass.state == CoreState.stopping:
self._async_ensure_final_write_listener()
return
await self._async_handle_write_data()
2021-04-17 06:35:21 +00:00
async def _async_callback_final_write(self, _event: Event) -> None:
"""Handle a write because Home Assistant is in final write state."""
self._unsub_final_write_listener = None
await self._async_handle_write_data()
async def _async_handle_write_data(self, *_args):
"""Handle writing the config."""
async with self._write_lock:
self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener()
if self._data is None:
# Another write already consumed the data
return
data = self._data
if "data_func" in data:
data["data"] = data.pop("data_func")()
self._data = None
if self._read_only:
return
try:
await self._async_write_data(self.path, data)
except (json_util.SerializationError, WriteError) as err:
2019-07-31 19:25:30 +00:00
_LOGGER.error("Error writing config for %s: %s", self.key, err)
async def _async_write_data(self, path: str, data: dict) -> None:
await self.hass.async_add_executor_job(self._write_data, self.path, data)
2021-03-17 17:34:19 +00:00
def _write_data(self, path: str, data: dict) -> None:
"""Write the data."""
os.makedirs(os.path.dirname(path), exist_ok=True)
_LOGGER.debug("Writing data for %s to %s", self.key, path)
json_helper.save_json(
path,
data,
self._private,
encoder=self._encoder,
atomic_writes=self._atomic_writes,
)
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
"""Migrate to the new version."""
raise NotImplementedError
2021-04-17 06:35:21 +00:00
async def async_remove(self) -> None:
"""Remove all data."""
self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener()
with suppress(FileNotFoundError):
await self.hass.async_add_executor_job(os.unlink, self.path)