Additional strict typing for recorder (#68860)
parent
fa33ac73f3
commit
d75f577b88
|
@ -172,8 +172,12 @@ homeassistant.components.pure_energie.*
|
|||
homeassistant.components.rainmachine.*
|
||||
homeassistant.components.rdw.*
|
||||
homeassistant.components.recollect_waste.*
|
||||
homeassistant.components.recorder.models
|
||||
homeassistant.components.recorder
|
||||
homeassistant.components.recorder.const
|
||||
homeassistant.components.recorder.backup
|
||||
homeassistant.components.recorder.executor
|
||||
homeassistant.components.recorder.history
|
||||
homeassistant.components.recorder.models
|
||||
homeassistant.components.recorder.pool
|
||||
homeassistant.components.recorder.purge
|
||||
homeassistant.components.recorder.repack
|
||||
|
|
|
@ -11,7 +11,7 @@ import queue
|
|||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from lru import LRU # pylint: disable=no-name-in-module
|
||||
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
|
||||
|
@ -214,7 +214,8 @@ MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
|
|||
|
||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||
"""Get the recorder instance."""
|
||||
return hass.data[DATA_INSTANCE]
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
return instance
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -225,10 +226,13 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
|
|||
"""
|
||||
if DATA_INSTANCE not in hass.data:
|
||||
return False
|
||||
return hass.data[DATA_INSTANCE].entity_filter(entity_id)
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
return instance.entity_filter(entity_id)
|
||||
|
||||
|
||||
def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns | None:
|
||||
def run_information(
|
||||
hass: HomeAssistant, point_in_time: datetime | None = None
|
||||
) -> RecorderRuns | None:
|
||||
"""Return information about current run.
|
||||
|
||||
There is also the run that covers point_in_time.
|
||||
|
@ -241,21 +245,20 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
|
|||
|
||||
|
||||
def run_information_from_instance(
|
||||
hass, point_in_time: datetime | None = None
|
||||
hass: HomeAssistant, point_in_time: datetime | None = None
|
||||
) -> RecorderRuns | None:
|
||||
"""Return information about current run from the existing instance.
|
||||
|
||||
Does not query the database for older runs.
|
||||
"""
|
||||
ins = hass.data[DATA_INSTANCE]
|
||||
|
||||
ins = get_instance(hass)
|
||||
if point_in_time is None or point_in_time > ins.recording_start:
|
||||
return ins.run_info
|
||||
return None
|
||||
|
||||
|
||||
def run_information_with_session(
|
||||
session, point_in_time: datetime | None = None
|
||||
session: Session, point_in_time: datetime | None = None
|
||||
) -> RecorderRuns | None:
|
||||
"""Return information about current run from the database."""
|
||||
recorder_runs = RecorderRuns
|
||||
|
@ -266,9 +269,9 @@ def run_information_with_session(
|
|||
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
|
||||
)
|
||||
|
||||
res = query.first()
|
||||
if res:
|
||||
if (res := query.first()) is not None:
|
||||
session.expunge(res)
|
||||
return cast(RecorderRuns, res)
|
||||
return res
|
||||
|
||||
|
||||
|
@ -318,9 +321,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return await instance.async_db_ready
|
||||
|
||||
|
||||
async def _process_recorder_platform(hass, domain, platform):
|
||||
async def _process_recorder_platform(
|
||||
hass: HomeAssistant, domain: str, platform: Any
|
||||
) -> None:
|
||||
"""Process a recorder platform."""
|
||||
hass.data[DOMAIN][domain] = platform
|
||||
platforms: dict[str, Any] = hass.data[DOMAIN]
|
||||
platforms[domain] = platform
|
||||
if hasattr(platform, "exclude_attributes"):
|
||||
hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass)
|
||||
|
||||
|
@ -586,11 +592,11 @@ class Recorder(threading.Thread):
|
|||
self.db_url = uri
|
||||
self.db_max_retries = db_max_retries
|
||||
self.db_retry_wait = db_retry_wait
|
||||
self.async_db_ready: asyncio.Future = asyncio.Future()
|
||||
self.async_db_ready: asyncio.Future[bool] = asyncio.Future()
|
||||
self.async_recorder_ready = asyncio.Event()
|
||||
self._queue_watch = threading.Event()
|
||||
self.engine: Engine | None = None
|
||||
self.run_info: Any = None
|
||||
self.run_info: RecorderRuns | None = None
|
||||
|
||||
self.entity_filter = entity_filter
|
||||
self.exclude_t = exclude_t
|
||||
|
@ -616,12 +622,12 @@ class Recorder(threading.Thread):
|
|||
|
||||
self.enabled = True
|
||||
|
||||
def set_enable(self, enable):
|
||||
def set_enable(self, enable: bool) -> None:
|
||||
"""Enable or disable recording events and states."""
|
||||
self.enabled = enable
|
||||
|
||||
@callback
|
||||
def async_start_executor(self):
|
||||
def async_start_executor(self) -> None:
|
||||
"""Start the executor."""
|
||||
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
||||
thread_name_prefix=DB_WORKER_PREFIX,
|
||||
|
@ -629,13 +635,13 @@ class Recorder(threading.Thread):
|
|||
shutdown_hook=self._shutdown_pool,
|
||||
)
|
||||
|
||||
def _shutdown_pool(self):
|
||||
def _shutdown_pool(self) -> None:
|
||||
"""Close the dbpool connections in the current thread."""
|
||||
if hasattr(self.engine.pool, "shutdown"):
|
||||
if self.engine and hasattr(self.engine.pool, "shutdown"):
|
||||
self.engine.pool.shutdown()
|
||||
|
||||
@callback
|
||||
def async_initialize(self):
|
||||
def async_initialize(self) -> None:
|
||||
"""Initialize the recorder."""
|
||||
self._event_listener = self.hass.bus.async_listen(
|
||||
MATCH_ALL, self.event_listener, event_filter=self._async_event_filter
|
||||
|
@ -658,7 +664,7 @@ class Recorder(threading.Thread):
|
|||
self._db_executor = None
|
||||
|
||||
@callback
|
||||
def _async_check_queue(self, *_):
|
||||
def _async_check_queue(self, *_: Any) -> None:
|
||||
"""Periodic check of the queue size to ensure we do not exaust memory.
|
||||
|
||||
The queue grows during migraton or if something really goes wrong.
|
||||
|
@ -704,21 +710,23 @@ class Recorder(threading.Thread):
|
|||
# Unknown what it is.
|
||||
return True
|
||||
|
||||
def do_adhoc_purge(self, **kwargs):
|
||||
def do_adhoc_purge(self, **kwargs: Any) -> None:
|
||||
"""Trigger an adhoc purge retaining keep_days worth of data."""
|
||||
keep_days = kwargs.get(ATTR_KEEP_DAYS, self.keep_days)
|
||||
repack = kwargs.get(ATTR_REPACK)
|
||||
apply_filter = kwargs.get(ATTR_APPLY_FILTER)
|
||||
repack = cast(bool, kwargs[ATTR_REPACK])
|
||||
apply_filter = cast(bool, kwargs[ATTR_APPLY_FILTER])
|
||||
|
||||
purge_before = dt_util.utcnow() - timedelta(days=keep_days)
|
||||
self.queue.put(PurgeTask(purge_before, repack, apply_filter))
|
||||
|
||||
def do_adhoc_purge_entities(self, entity_ids, domains, entity_globs):
|
||||
def do_adhoc_purge_entities(
|
||||
self, entity_ids: set[str], domains: list[str], entity_globs: list[str]
|
||||
) -> None:
|
||||
"""Trigger an adhoc purge of requested entities."""
|
||||
entity_filter = generate_filter(domains, entity_ids, [], [], entity_globs)
|
||||
entity_filter = generate_filter(domains, list(entity_ids), [], [], entity_globs)
|
||||
self.queue.put(PurgeEntitiesTask(entity_filter))
|
||||
|
||||
def do_adhoc_statistics(self, **kwargs):
|
||||
def do_adhoc_statistics(self, **kwargs: Any) -> None:
|
||||
"""Trigger an adhoc statistics run."""
|
||||
if not (start := kwargs.get("start")):
|
||||
start = statistics.get_start_time()
|
||||
|
@ -812,22 +820,26 @@ class Recorder(threading.Thread):
|
|||
self.queue.put(StatisticsTask(start))
|
||||
|
||||
@callback
|
||||
def async_adjust_statistics(self, statistic_id, start_time, sum_adjustment):
|
||||
def async_adjust_statistics(
|
||||
self, statistic_id: str, start_time: datetime, sum_adjustment: float
|
||||
) -> None:
|
||||
"""Adjust statistics."""
|
||||
self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment))
|
||||
|
||||
@callback
|
||||
def async_clear_statistics(self, statistic_ids):
|
||||
def async_clear_statistics(self, statistic_ids: list[str]) -> None:
|
||||
"""Clear statistics for a list of statistic_ids."""
|
||||
self.queue.put(ClearStatisticsTask(statistic_ids))
|
||||
|
||||
@callback
|
||||
def async_update_statistics_metadata(self, statistic_id, unit_of_measurement):
|
||||
def async_update_statistics_metadata(
|
||||
self, statistic_id: str, unit_of_measurement: str | None
|
||||
) -> None:
|
||||
"""Update statistics metadata for a statistic_id."""
|
||||
self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement))
|
||||
|
||||
@callback
|
||||
def async_external_statistics(self, metadata, stats):
|
||||
def async_external_statistics(self, metadata: dict, stats: Iterable[dict]) -> None:
|
||||
"""Schedule external statistics."""
|
||||
self.queue.put(ExternalStatisticsTask(metadata, stats))
|
||||
|
||||
|
@ -995,7 +1007,7 @@ class Recorder(threading.Thread):
|
|||
|
||||
def _lock_database(self, task: DatabaseLockTask) -> None:
|
||||
@callback
|
||||
def _async_set_database_locked(task: DatabaseLockTask):
|
||||
def _async_set_database_locked(task: DatabaseLockTask) -> None:
|
||||
task.database_locked.set()
|
||||
|
||||
with write_lock_db_sqlite(self):
|
||||
|
@ -1285,8 +1297,11 @@ class Recorder(threading.Thread):
|
|||
kwargs: dict[str, Any] = {}
|
||||
self._completed_first_database_setup = False
|
||||
|
||||
def setup_recorder_connection(dbapi_connection, connection_record):
|
||||
def setup_recorder_connection(
|
||||
dbapi_connection: Any, connection_record: Any
|
||||
) -> None:
|
||||
"""Dbapi specific connection settings."""
|
||||
assert self.engine is not None
|
||||
setup_connection_for_dialect(
|
||||
self,
|
||||
self.engine.dialect.name,
|
||||
|
@ -1366,6 +1381,7 @@ class Recorder(threading.Thread):
|
|||
"""End the recorder session."""
|
||||
if self.event_session is None:
|
||||
return
|
||||
assert self.run_info is not None
|
||||
try:
|
||||
self.run_info.end = dt_util.utcnow()
|
||||
self.event_session.add(self.run_info)
|
||||
|
|
|
@ -10,7 +10,9 @@ import weakref
|
|||
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
||||
|
||||
|
||||
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
|
||||
def _worker_with_shutdown_hook(
|
||||
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a worker that calls a function after its finished."""
|
||||
_worker(*args, **kwargs)
|
||||
shutdown_hook()
|
||||
|
@ -37,7 +39,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
|||
|
||||
# When the executor gets lost, the weakref callback will wake up
|
||||
# the worker threads.
|
||||
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
|
||||
def weakref_cb(_: Any, q=self._work_queue) -> None: # type: ignore[no-untyped-def] # pylint: disable=invalid-name
|
||||
q.put(None)
|
||||
|
||||
num_threads = len(self._threads)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import contextlib
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
|
||||
|
@ -43,8 +44,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
|
|||
raise ex
|
||||
|
||||
|
||||
def get_schema_version(instance):
|
||||
def get_schema_version(instance: Any) -> int:
|
||||
"""Get the schema version."""
|
||||
assert instance.get_session is not None
|
||||
with session_scope(session=instance.get_session()) as session:
|
||||
res = (
|
||||
session.query(SchemaChanges)
|
||||
|
@ -62,13 +64,14 @@ def get_schema_version(instance):
|
|||
return current_version
|
||||
|
||||
|
||||
def schema_is_current(current_version):
|
||||
def schema_is_current(current_version: int) -> bool:
|
||||
"""Check if the schema is current."""
|
||||
return current_version == SCHEMA_VERSION
|
||||
|
||||
|
||||
def migrate_schema(instance, current_version):
|
||||
def migrate_schema(instance: Any, current_version: int) -> None:
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
assert instance.get_session is not None
|
||||
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
new_version = version + 1
|
||||
|
|
|
@ -291,6 +291,7 @@ def _purge_old_recorder_runs(
|
|||
) -> None:
|
||||
"""Purge all old recorder runs."""
|
||||
# Recorder runs is small, no need to batch run it
|
||||
assert instance.run_info is not None
|
||||
deleted_rows = (
|
||||
session.query(RecorderRuns)
|
||||
.filter(RecorderRuns.start < purge_before)
|
||||
|
|
46
mypy.ini
46
mypy.ini
|
@ -1694,7 +1694,40 @@ no_implicit_optional = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.models]
|
||||
[mypy-homeassistant.components.recorder]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.const]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.backup]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.executor]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
|
@ -1716,6 +1749,17 @@ no_implicit_optional = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.models]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.pool]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
|
Loading…
Reference in New Issue