Additional strict typing for recorder (#68860)

pull/68914/head
J. Nick Koston 2022-03-30 06:20:44 -10:00 committed by GitHub
parent fa33ac73f3
commit d75f577b88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 39 deletions

View File

@ -172,8 +172,12 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.*
homeassistant.components.rdw.*
homeassistant.components.recollect_waste.*
homeassistant.components.recorder.models
homeassistant.components.recorder
homeassistant.components.recorder.const
homeassistant.components.recorder.backup
homeassistant.components.recorder.executor
homeassistant.components.recorder.history
homeassistant.components.recorder.models
homeassistant.components.recorder.pool
homeassistant.components.recorder.purge
homeassistant.components.recorder.repack

View File

@ -11,7 +11,7 @@ import queue
import sqlite3
import threading
import time
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
@ -214,7 +214,8 @@ MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
instance: Recorder = hass.data[DATA_INSTANCE]
return instance
@bind_hass
@ -225,10 +226,13 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
"""
if DATA_INSTANCE not in hass.data:
return False
return hass.data[DATA_INSTANCE].entity_filter(entity_id)
instance: Recorder = hass.data[DATA_INSTANCE]
return instance.entity_filter(entity_id)
def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns | None:
def run_information(
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run.
There is also the run that covers point_in_time.
@ -241,21 +245,20 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns
def run_information_from_instance(
hass, point_in_time: datetime | None = None
hass: HomeAssistant, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the existing instance.
Does not query the database for older runs.
"""
ins = hass.data[DATA_INSTANCE]
ins = get_instance(hass)
if point_in_time is None or point_in_time > ins.recording_start:
return ins.run_info
return None
def run_information_with_session(
session, point_in_time: datetime | None = None
session: Session, point_in_time: datetime | None = None
) -> RecorderRuns | None:
"""Return information about current run from the database."""
recorder_runs = RecorderRuns
@ -266,9 +269,9 @@ def run_information_with_session(
(recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)
)
res = query.first()
if res:
if (res := query.first()) is not None:
session.expunge(res)
return cast(RecorderRuns, res)
return res
@ -318,9 +321,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return await instance.async_db_ready
async def _process_recorder_platform(hass, domain, platform):
async def _process_recorder_platform(
hass: HomeAssistant, domain: str, platform: Any
) -> None:
"""Process a recorder platform."""
hass.data[DOMAIN][domain] = platform
platforms: dict[str, Any] = hass.data[DOMAIN]
platforms[domain] = platform
if hasattr(platform, "exclude_attributes"):
hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass)
@ -586,11 +592,11 @@ class Recorder(threading.Thread):
self.db_url = uri
self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait
self.async_db_ready: asyncio.Future = asyncio.Future()
self.async_db_ready: asyncio.Future[bool] = asyncio.Future()
self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event()
self.engine: Engine | None = None
self.run_info: Any = None
self.run_info: RecorderRuns | None = None
self.entity_filter = entity_filter
self.exclude_t = exclude_t
@ -616,12 +622,12 @@ class Recorder(threading.Thread):
self.enabled = True
def set_enable(self, enable):
def set_enable(self, enable: bool) -> None:
"""Enable or disable recording events and states."""
self.enabled = enable
@callback
def async_start_executor(self):
def async_start_executor(self) -> None:
"""Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor(
thread_name_prefix=DB_WORKER_PREFIX,
@ -629,13 +635,13 @@ class Recorder(threading.Thread):
shutdown_hook=self._shutdown_pool,
)
def _shutdown_pool(self):
def _shutdown_pool(self) -> None:
"""Close the dbpool connections in the current thread."""
if hasattr(self.engine.pool, "shutdown"):
if self.engine and hasattr(self.engine.pool, "shutdown"):
self.engine.pool.shutdown()
@callback
def async_initialize(self):
def async_initialize(self) -> None:
"""Initialize the recorder."""
self._event_listener = self.hass.bus.async_listen(
MATCH_ALL, self.event_listener, event_filter=self._async_event_filter
@ -658,7 +664,7 @@ class Recorder(threading.Thread):
self._db_executor = None
@callback
def _async_check_queue(self, *_):
def _async_check_queue(self, *_: Any) -> None:
"""Periodic check of the queue size to ensure we do not exaust memory.
The queue grows during migraton or if something really goes wrong.
@ -704,21 +710,23 @@ class Recorder(threading.Thread):
# Unknown what it is.
return True
def do_adhoc_purge(self, **kwargs):
def do_adhoc_purge(self, **kwargs: Any) -> None:
"""Trigger an adhoc purge retaining keep_days worth of data."""
keep_days = kwargs.get(ATTR_KEEP_DAYS, self.keep_days)
repack = kwargs.get(ATTR_REPACK)
apply_filter = kwargs.get(ATTR_APPLY_FILTER)
repack = cast(bool, kwargs[ATTR_REPACK])
apply_filter = cast(bool, kwargs[ATTR_APPLY_FILTER])
purge_before = dt_util.utcnow() - timedelta(days=keep_days)
self.queue.put(PurgeTask(purge_before, repack, apply_filter))
def do_adhoc_purge_entities(self, entity_ids, domains, entity_globs):
def do_adhoc_purge_entities(
self, entity_ids: set[str], domains: list[str], entity_globs: list[str]
) -> None:
"""Trigger an adhoc purge of requested entities."""
entity_filter = generate_filter(domains, entity_ids, [], [], entity_globs)
entity_filter = generate_filter(domains, list(entity_ids), [], [], entity_globs)
self.queue.put(PurgeEntitiesTask(entity_filter))
def do_adhoc_statistics(self, **kwargs):
def do_adhoc_statistics(self, **kwargs: Any) -> None:
"""Trigger an adhoc statistics run."""
if not (start := kwargs.get("start")):
start = statistics.get_start_time()
@ -812,22 +820,26 @@ class Recorder(threading.Thread):
self.queue.put(StatisticsTask(start))
@callback
def async_adjust_statistics(self, statistic_id, start_time, sum_adjustment):
def async_adjust_statistics(
self, statistic_id: str, start_time: datetime, sum_adjustment: float
) -> None:
"""Adjust statistics."""
self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment))
@callback
def async_clear_statistics(self, statistic_ids):
def async_clear_statistics(self, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
self.queue.put(ClearStatisticsTask(statistic_ids))
@callback
def async_update_statistics_metadata(self, statistic_id, unit_of_measurement):
def async_update_statistics_metadata(
self, statistic_id: str, unit_of_measurement: str | None
) -> None:
"""Update statistics metadata for a statistic_id."""
self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement))
@callback
def async_external_statistics(self, metadata, stats):
def async_external_statistics(self, metadata: dict, stats: Iterable[dict]) -> None:
"""Schedule external statistics."""
self.queue.put(ExternalStatisticsTask(metadata, stats))
@ -995,7 +1007,7 @@ class Recorder(threading.Thread):
def _lock_database(self, task: DatabaseLockTask) -> None:
@callback
def _async_set_database_locked(task: DatabaseLockTask):
def _async_set_database_locked(task: DatabaseLockTask) -> None:
task.database_locked.set()
with write_lock_db_sqlite(self):
@ -1285,8 +1297,11 @@ class Recorder(threading.Thread):
kwargs: dict[str, Any] = {}
self._completed_first_database_setup = False
def setup_recorder_connection(dbapi_connection, connection_record):
def setup_recorder_connection(
dbapi_connection: Any, connection_record: Any
) -> None:
"""Dbapi specific connection settings."""
assert self.engine is not None
setup_connection_for_dialect(
self,
self.engine.dialect.name,
@ -1366,6 +1381,7 @@ class Recorder(threading.Thread):
"""End the recorder session."""
if self.event_session is None:
return
assert self.run_info is not None
try:
self.run_info.end = dt_util.utcnow()
self.event_session.add(self.run_info)

View File

@ -10,7 +10,9 @@ import weakref
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
def _worker_with_shutdown_hook(
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
) -> None:
"""Create a worker that calls a function after its finished."""
_worker(*args, **kwargs)
shutdown_hook()
@ -37,7 +39,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
def weakref_cb(_: Any, q=self._work_queue) -> None: # type: ignore[no-untyped-def] # pylint: disable=invalid-name
q.put(None)
num_threads = len(self._threads)

View File

@ -2,6 +2,7 @@
import contextlib
from datetime import timedelta
import logging
from typing import Any
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -43,8 +44,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
raise ex
def get_schema_version(instance):
def get_schema_version(instance: Any) -> int:
"""Get the schema version."""
assert instance.get_session is not None
with session_scope(session=instance.get_session()) as session:
res = (
session.query(SchemaChanges)
@ -62,13 +64,14 @@ def get_schema_version(instance):
return current_version
def schema_is_current(current_version):
def schema_is_current(current_version: int) -> bool:
"""Check if the schema is current."""
return current_version == SCHEMA_VERSION
def migrate_schema(instance, current_version):
def migrate_schema(instance: Any, current_version: int) -> None:
"""Check if the schema needs to be upgraded."""
assert instance.get_session is not None
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1

View File

@ -291,6 +291,7 @@ def _purge_old_recorder_runs(
) -> None:
"""Purge all old recorder runs."""
# Recorder runs is small, no need to batch run it
assert instance.run_info is not None
deleted_rows = (
session.query(RecorderRuns)
.filter(RecorderRuns.start < purge_before)

View File

@ -1694,7 +1694,40 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
[mypy-homeassistant.components.recorder]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.const]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.backup]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.executor]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
@ -1716,6 +1749,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true
disallow_incomplete_defs = true