From d75f577b88a18e5e64b592b95d7c8930da02c4ef Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 30 Mar 2022 06:20:44 -1000 Subject: [PATCH] Additional strict typing for recorder (#68860) --- .strict-typing | 6 +- homeassistant/components/recorder/__init__.py | 80 +++++++++++-------- homeassistant/components/recorder/executor.py | 6 +- .../components/recorder/migration.py | 9 ++- homeassistant/components/recorder/purge.py | 1 + mypy.ini | 46 ++++++++++- 6 files changed, 109 insertions(+), 39 deletions(-) diff --git a/.strict-typing b/.strict-typing index 6d428fdcb2f..e808b54c85e 100644 --- a/.strict-typing +++ b/.strict-typing @@ -172,8 +172,12 @@ homeassistant.components.pure_energie.* homeassistant.components.rainmachine.* homeassistant.components.rdw.* homeassistant.components.recollect_waste.* -homeassistant.components.recorder.models +homeassistant.components.recorder +homeassistant.components.recorder.const +homeassistant.components.recorder.backup +homeassistant.components.recorder.executor homeassistant.components.recorder.history +homeassistant.components.recorder.models homeassistant.components.recorder.pool homeassistant.components.recorder.purge homeassistant.components.recorder.repack diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index b4ae03a6ef9..0381e5a4671 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -11,7 +11,7 @@ import queue import sqlite3 import threading import time -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select @@ -214,7 +214,8 @@ MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1 def get_instance(hass: HomeAssistant) -> Recorder: """Get the recorder instance.""" - return hass.data[DATA_INSTANCE] + instance: Recorder = hass.data[DATA_INSTANCE] + return instance @bind_hass @@ -225,10 +226,13 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: """ if DATA_INSTANCE not in hass.data: return False - return hass.data[DATA_INSTANCE].entity_filter(entity_id) + instance: Recorder = hass.data[DATA_INSTANCE] + return instance.entity_filter(entity_id) -def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns | None: +def run_information( + hass: HomeAssistant, point_in_time: datetime | None = None +) -> RecorderRuns | None: """Return information about current run. There is also the run that covers point_in_time. @@ -241,21 +245,20 @@ def run_information(hass, point_in_time: datetime | None = None) -> RecorderRuns def run_information_from_instance( - hass, point_in_time: datetime | None = None + hass: HomeAssistant, point_in_time: datetime | None = None ) -> RecorderRuns | None: """Return information about current run from the existing instance. Does not query the database for older runs. """ - ins = hass.data[DATA_INSTANCE] - + ins = get_instance(hass) if point_in_time is None or point_in_time > ins.recording_start: return ins.run_info return None def run_information_with_session( - session, point_in_time: datetime | None = None + session: Session, point_in_time: datetime | None = None ) -> RecorderRuns | None: """Return information about current run from the database.""" recorder_runs = RecorderRuns @@ -266,9 +269,9 @@ def run_information_with_session( (recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time) ) - res = query.first() - if res: + if (res := query.first()) is not None: session.expunge(res) + return cast(RecorderRuns, res) return res @@ -318,9 +321,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return await instance.async_db_ready -async def _process_recorder_platform(hass, domain, platform): +async def _process_recorder_platform( + hass: HomeAssistant, domain: str, platform: Any +) -> None: """Process a recorder platform.""" - hass.data[DOMAIN][domain] = platform + platforms: dict[str, Any] = hass.data[DOMAIN] + platforms[domain] = platform if hasattr(platform, "exclude_attributes"): hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass) @@ -586,11 +592,11 @@ class Recorder(threading.Thread): self.db_url = uri self.db_max_retries = db_max_retries self.db_retry_wait = db_retry_wait - self.async_db_ready: asyncio.Future = asyncio.Future() + self.async_db_ready: asyncio.Future[bool] = asyncio.Future() self.async_recorder_ready = asyncio.Event() self._queue_watch = threading.Event() self.engine: Engine | None = None - self.run_info: Any = None + self.run_info: RecorderRuns | None = None self.entity_filter = entity_filter self.exclude_t = exclude_t @@ -616,12 +622,12 @@ class Recorder(threading.Thread): self.enabled = True - def set_enable(self, enable): + def set_enable(self, enable: bool) -> None: """Enable or disable recording events and states.""" self.enabled = enable @callback - def async_start_executor(self): + def async_start_executor(self) -> None: """Start the executor.""" self._db_executor = DBInterruptibleThreadPoolExecutor( thread_name_prefix=DB_WORKER_PREFIX, @@ -629,13 +635,13 @@ class Recorder(threading.Thread): shutdown_hook=self._shutdown_pool, ) - def _shutdown_pool(self): + def _shutdown_pool(self) -> None: """Close the dbpool connections in the current thread.""" - if hasattr(self.engine.pool, "shutdown"): + if self.engine and hasattr(self.engine.pool, "shutdown"): self.engine.pool.shutdown() @callback - def async_initialize(self): + def async_initialize(self) -> None: """Initialize the recorder.""" self._event_listener = self.hass.bus.async_listen( MATCH_ALL, self.event_listener, event_filter=self._async_event_filter @@ -658,7 +664,7 @@ class Recorder(threading.Thread): self._db_executor = None @callback - def _async_check_queue(self, *_): + def _async_check_queue(self, *_: Any) -> None: """Periodic check of the queue size to ensure we do not exaust memory. The queue grows during migraton or if something really goes wrong. @@ -704,21 +710,23 @@ class Recorder(threading.Thread): # Unknown what it is. return True - def do_adhoc_purge(self, **kwargs): + def do_adhoc_purge(self, **kwargs: Any) -> None: """Trigger an adhoc purge retaining keep_days worth of data.""" keep_days = kwargs.get(ATTR_KEEP_DAYS, self.keep_days) - repack = kwargs.get(ATTR_REPACK) - apply_filter = kwargs.get(ATTR_APPLY_FILTER) + repack = cast(bool, kwargs[ATTR_REPACK]) + apply_filter = cast(bool, kwargs[ATTR_APPLY_FILTER]) purge_before = dt_util.utcnow() - timedelta(days=keep_days) self.queue.put(PurgeTask(purge_before, repack, apply_filter)) - def do_adhoc_purge_entities(self, entity_ids, domains, entity_globs): + def do_adhoc_purge_entities( + self, entity_ids: set[str], domains: list[str], entity_globs: list[str] + ) -> None: """Trigger an adhoc purge of requested entities.""" - entity_filter = generate_filter(domains, entity_ids, [], [], entity_globs) + entity_filter = generate_filter(domains, list(entity_ids), [], [], entity_globs) self.queue.put(PurgeEntitiesTask(entity_filter)) - def do_adhoc_statistics(self, **kwargs): + def do_adhoc_statistics(self, **kwargs: Any) -> None: """Trigger an adhoc statistics run.""" if not (start := kwargs.get("start")): start = statistics.get_start_time() @@ -812,22 +820,26 @@ class Recorder(threading.Thread): self.queue.put(StatisticsTask(start)) @callback - def async_adjust_statistics(self, statistic_id, start_time, sum_adjustment): + def async_adjust_statistics( + self, statistic_id: str, start_time: datetime, sum_adjustment: float + ) -> None: """Adjust statistics.""" self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment)) @callback - def async_clear_statistics(self, statistic_ids): + def async_clear_statistics(self, statistic_ids: list[str]) -> None: """Clear statistics for a list of statistic_ids.""" self.queue.put(ClearStatisticsTask(statistic_ids)) @callback - def async_update_statistics_metadata(self, statistic_id, unit_of_measurement): + def async_update_statistics_metadata( + self, statistic_id: str, unit_of_measurement: str | None + ) -> None: """Update statistics metadata for a statistic_id.""" self.queue.put(UpdateStatisticsMetadataTask(statistic_id, unit_of_measurement)) @callback - def async_external_statistics(self, metadata, stats): + def async_external_statistics(self, metadata: dict, stats: Iterable[dict]) -> None: """Schedule external statistics.""" self.queue.put(ExternalStatisticsTask(metadata, stats)) @@ -995,7 +1007,7 @@ class Recorder(threading.Thread): def _lock_database(self, task: DatabaseLockTask) -> None: @callback - def _async_set_database_locked(task: DatabaseLockTask): + def _async_set_database_locked(task: DatabaseLockTask) -> None: task.database_locked.set() with write_lock_db_sqlite(self): @@ -1285,8 +1297,11 @@ class Recorder(threading.Thread): kwargs: dict[str, Any] = {} self._completed_first_database_setup = False - def setup_recorder_connection(dbapi_connection, connection_record): + def setup_recorder_connection( + dbapi_connection: Any, connection_record: Any + ) -> None: """Dbapi specific connection settings.""" + assert self.engine is not None setup_connection_for_dialect( self, self.engine.dialect.name, @@ -1366,6 +1381,7 @@ class Recorder(threading.Thread): """End the recorder session.""" if self.event_session is None: return + assert self.run_info is not None try: self.run_info.end = dt_util.utcnow() self.event_session.add(self.run_info) diff --git a/homeassistant/components/recorder/executor.py b/homeassistant/components/recorder/executor.py index 782c3422e19..0d913310e74 100644 --- a/homeassistant/components/recorder/executor.py +++ b/homeassistant/components/recorder/executor.py @@ -10,7 +10,9 @@ import weakref from homeassistant.util.executor import InterruptibleThreadPoolExecutor -def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs): +def _worker_with_shutdown_hook( + shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any +) -> None: """Create a worker that calls a function after its finished.""" _worker(*args, **kwargs) shutdown_hook() @@ -37,7 +39,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor): # When the executor gets lost, the weakref callback will wake up # the worker threads. - def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name + def weakref_cb(_: Any, q=self._work_queue) -> None: # type: ignore[no-untyped-def] # pylint: disable=invalid-name q.put(None) num_threads = len(self._threads) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 5db43aa760f..26234be0502 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -2,6 +2,7 @@ import contextlib from datetime import timedelta import logging +from typing import Any import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text @@ -43,8 +44,9 @@ def raise_if_exception_missing_str(ex, match_substrs): raise ex -def get_schema_version(instance): +def get_schema_version(instance: Any) -> int: """Get the schema version.""" + assert instance.get_session is not None with session_scope(session=instance.get_session()) as session: res = ( session.query(SchemaChanges) @@ -62,13 +64,14 @@ def get_schema_version(instance): return current_version -def schema_is_current(current_version): +def schema_is_current(current_version: int) -> bool: """Check if the schema is current.""" return current_version == SCHEMA_VERSION -def migrate_schema(instance, current_version): +def migrate_schema(instance: Any, current_version: int) -> None: """Check if the schema needs to be upgraded.""" + assert instance.get_session is not None _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) for version in range(current_version, SCHEMA_VERSION): new_version = version + 1 diff --git a/homeassistant/components/recorder/purge.py b/homeassistant/components/recorder/purge.py index a15d22810f4..9bbe13ca5a7 100644 --- a/homeassistant/components/recorder/purge.py +++ b/homeassistant/components/recorder/purge.py @@ -291,6 +291,7 @@ def _purge_old_recorder_runs( ) -> None: """Purge all old recorder runs.""" # Recorder runs is small, no need to batch run it + assert instance.run_info is not None deleted_rows = ( session.query(RecorderRuns) .filter(RecorderRuns.start < purge_before) diff --git a/mypy.ini b/mypy.ini index cc114633787..9c98db3ca10 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1694,7 +1694,40 @@ no_implicit_optional = true warn_return_any = true warn_unreachable = true -[mypy-homeassistant.components.recorder.models] +[mypy-homeassistant.components.recorder] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + +[mypy-homeassistant.components.recorder.const] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + +[mypy-homeassistant.components.recorder.backup] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + +[mypy-homeassistant.components.recorder.executor] check_untyped_defs = true disallow_incomplete_defs = true disallow_subclassing_any = true @@ -1716,6 +1749,17 @@ no_implicit_optional = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.recorder.models] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.recorder.pool] check_untyped_defs = true disallow_incomplete_defs = true