diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index 1cbd18f44a7..5d870ffa8ee 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -14,7 +14,11 @@ import voluptuous as vol from homeassistant.components import frontend, websocket_api from homeassistant.components.http import HomeAssistantView -from homeassistant.components.recorder import history, models as history_models +from homeassistant.components.recorder import ( + get_instance, + history, + models as history_models, +) from homeassistant.components.recorder.statistics import ( list_statistic_ids, statistics_during_period, @@ -142,7 +146,7 @@ async def ws_get_statistics_during_period( else: end_time = None - statistics = await hass.async_add_executor_job( + statistics = await get_instance(hass).async_add_executor_job( statistics_during_period, hass, start_time, @@ -164,7 +168,7 @@ async def ws_get_list_statistic_ids( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Fetch a list of available statistic_id.""" - statistic_ids = await hass.async_add_executor_job( + statistic_ids = await get_instance(hass).async_add_executor_job( list_statistic_ids, hass, msg.get("statistic_type"), @@ -232,7 +236,7 @@ class HistoryPeriodView(HomeAssistantView): return cast( web.Response, - await hass.async_add_executor_job( + await get_instance(hass).async_add_executor_job( self._sorted_significant_states_json, hass, start_time, diff --git a/homeassistant/components/history_stats/sensor.py b/homeassistant/components/history_stats/sensor.py index 2af3706e4e8..b33b7ca4db9 100644 --- a/homeassistant/components/history_stats/sensor.py +++ b/homeassistant/components/history_stats/sensor.py @@ -7,7 +7,7 @@ import math import voluptuous as vol -from homeassistant.components.recorder import history +from homeassistant.components.recorder import get_instance, history from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity from homeassistant.const import ( CONF_ENTITY_ID, @@ -225,7 +225,7 @@ class HistoryStatsSensor(SensorEntity): # Don't compute anything as the value cannot have changed return - await self.hass.async_add_executor_job( + await get_instance(self.hass).async_add_executor_job( self._update, start, end, now_timestamp, start_timestamp, end_timestamp ) diff --git a/homeassistant/components/logbook/__init__.py b/homeassistant/components/logbook/__init__.py index 28b0460ac7a..66c78e30eab 100644 --- a/homeassistant/components/logbook/__init__.py +++ b/homeassistant/components/logbook/__init__.py @@ -15,6 +15,7 @@ from homeassistant.components import frontend from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf from homeassistant.components.http import HomeAssistantView +from homeassistant.components.recorder import get_instance from homeassistant.components.recorder.models import ( Events, States, @@ -254,7 +255,7 @@ class LogbookView(HomeAssistantView): ) ) - return await hass.async_add_executor_job(json_events) + return await get_instance(hass).async_add_executor_job(json_events) def humanify(hass, events, entity_attr_cache, context_lookup): diff --git a/homeassistant/components/plant/__init__.py b/homeassistant/components/plant/__init__.py index 37ae8422af0..8b46ad7801e 100644 --- a/homeassistant/components/plant/__init__.py +++ b/homeassistant/components/plant/__init__.py @@ -6,6 +6,7 @@ import logging import voluptuous as vol +from homeassistant.components.recorder import get_instance from homeassistant.components.recorder.models import States from homeassistant.components.recorder.util import execute, session_scope from homeassistant.const import ( @@ -283,7 +284,9 @@ class Plant(Entity): """After being added to hass, load from history.""" if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components: # only use the database if it's configured - await self.hass.async_add_executor_job(self._load_history_from_db) + await get_instance(self.hass).async_add_executor_job( + self._load_history_from_db + ) self.async_write_ha_state() async_track_state_change_event( diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 18acfabffaa..6894d1367e6 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -12,7 +12,7 @@ import queue import sqlite3 import threading import time -from typing import Any +from typing import Any, TypeVar from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy.exc import SQLAlchemyError @@ -57,10 +57,12 @@ from . import history, migration, purge, statistics, websocket_api from .const import ( CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, + DB_WORKER_PREFIX, DOMAIN, MAX_QUEUE_BACKLOG, SQLITE_URL_PREFIX, ) +from .executor import DBInterruptibleThreadPoolExecutor from .models import ( Base, Events, @@ -69,7 +71,7 @@ from .models import ( StatisticsRuns, process_timestamp, ) -from .pool import RecorderPool +from .pool import POOL_SIZE, RecorderPool from .util import ( dburl_to_path, end_incomplete_runs, @@ -83,6 +85,9 @@ from .util import ( _LOGGER = logging.getLogger(__name__) +T = TypeVar("T") + + SERVICE_PURGE = "purge" SERVICE_PURGE_ENTITIES = "purge_entities" SERVICE_ENABLE = "enable" @@ -182,6 +187,15 @@ CONFIG_SCHEMA = vol.Schema( ) +# Pool size must accommodate Recorder thread + All db executors +MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1 + + +def get_instance(hass: HomeAssistant) -> Recorder: + """Get the recorder instance.""" + return hass.data[DATA_INSTANCE] + + @bind_hass def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: """Check if an entity is being recorded. @@ -537,6 +551,7 @@ class Recorder(threading.Thread): self._queue_watcher = None self._db_supports_row_number = True self._database_lock_task: DatabaseLockTask | None = None + self._db_executor: DBInterruptibleThreadPoolExecutor | None = None self.enabled = True @@ -544,6 +559,20 @@ class Recorder(threading.Thread): """Enable or disable recording events and states.""" self.enabled = enable + @callback + def async_start_executor(self): + """Start the executor.""" + self._db_executor = DBInterruptibleThreadPoolExecutor( + thread_name_prefix=DB_WORKER_PREFIX, + max_workers=MAX_DB_EXECUTOR_WORKERS, + shutdown_hook=self._shutdown_pool, + ) + + def _shutdown_pool(self): + """Close the dbpool connections in the current thread.""" + if hasattr(self.engine.pool, "shutdown"): + self.engine.pool.shutdown() + @callback def async_initialize(self): """Initialize the recorder.""" @@ -554,6 +583,19 @@ class Recorder(threading.Thread): self.hass, self._async_check_queue, timedelta(minutes=10) ) + @callback + def async_add_executor_job( + self, target: Callable[..., T], *args: Any + ) -> asyncio.Future[T]: + """Add an executor job from within the event loop.""" + return self.hass.loop.run_in_executor(self._db_executor, target, *args) + + def _stop_executor(self) -> None: + """Stop the executor.""" + assert self._db_executor is not None + self._db_executor.shutdown() + self._db_executor = None + @callback def _async_check_queue(self, *_): """Periodic check of the queue size to ensure we do not exaust memory. @@ -680,6 +722,7 @@ class Recorder(threading.Thread): def async_connection_success(self): """Connect success tasks.""" self.async_db_ready.set_result(True) + self.async_start_executor() @callback def _async_recorder_ready(self): @@ -1212,6 +1255,7 @@ class Recorder(threading.Thread): def _shutdown(self): """Save end time for current run.""" self.hass.add_job(self._async_stop_queue_watcher_and_event_listener) + self._stop_executor() self._end_session() self._close_connection() diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py index a04218264ee..ae0b37e211a 100644 --- a/homeassistant/components/recorder/const.py +++ b/homeassistant/components/recorder/const.py @@ -15,3 +15,5 @@ MAX_QUEUE_BACKLOG = 30000 # We can increase this back to 1000 once most # have upgraded their sqlite version MAX_ROWS_TO_PURGE = 998 + +DB_WORKER_PREFIX = "DbWorker" diff --git a/homeassistant/components/recorder/executor.py b/homeassistant/components/recorder/executor.py new file mode 100644 index 00000000000..782c3422e19 --- /dev/null +++ b/homeassistant/components/recorder/executor.py @@ -0,0 +1,59 @@ +"""Database executor helpers.""" +from __future__ import annotations + +from collections.abc import Callable +from concurrent.futures.thread import _threads_queues, _worker +import threading +from typing import Any +import weakref + +from homeassistant.util.executor import InterruptibleThreadPoolExecutor + + +def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs): + """Create a worker that calls a function after its finished.""" + _worker(*args, **kwargs) + shutdown_hook() + + +class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor): + """A database instance that will not deadlock on shutdown.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Init the executor with a shutdown hook support.""" + self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook") + super().__init__(*args, **kwargs) + + def _adjust_thread_count(self) -> None: + """Overridden to add support for shutdown hook. + + Based on the CPython 3.10 implementation. + """ + # if idle threads are available, don't spin new threads + if self._idle_semaphore.acquire( # pylint: disable=consider-using-with + timeout=0 + ): + return + + # When the executor gets lost, the weakref callback will wake up + # the worker threads. + def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name + q.put(None) + + num_threads = len(self._threads) + if num_threads < self._max_workers: + thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads) + executor_thread = threading.Thread( + name=thread_name, + target=_worker_with_shutdown_hook, + args=( + self._shutdown_hook, + weakref.ref(self, weakref_cb), + self._work_queue, + self._initializer, + self._initargs, + ), + ) + executor_thread.start() + self._threads.add(executor_thread) # type: ignore[attr-defined] + _threads_queues[executor_thread] = self._work_queue # type: ignore[index] diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index b30237f98da..76b8aceb30f 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -1,34 +1,60 @@ """A pool for sqlite connections.""" import threading -from sqlalchemy.pool import NullPool, StaticPool +from sqlalchemy.pool import NullPool, SingletonThreadPool + +from homeassistant.helpers.frame import report + +from .const import DB_WORKER_PREFIX + +POOL_SIZE = 5 -class RecorderPool(StaticPool, NullPool): - """A hybrid of NullPool and StaticPool. +class RecorderPool(SingletonThreadPool, NullPool): + """A hybrid of NullPool and SingletonThreadPool. - When called from the creating thread acts like StaticPool + When called from the creating thread or db executor acts like SingletonThreadPool When called from any other thread, acts like NullPool """ def __init__(self, *args, **kw): # pylint: disable=super-init-not-called """Create the pool.""" - self._tid = threading.current_thread().ident - StaticPool.__init__(self, *args, **kw) + kw["pool_size"] = POOL_SIZE + SingletonThreadPool.__init__(self, *args, **kw) + + @property + def recorder_or_dbworker(self) -> bool: + """Check if the thread is a recorder or dbworker thread.""" + thread_name = threading.current_thread().name + return bool( + thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) + ) def _do_return_conn(self, conn): - if threading.current_thread().ident == self._tid: + if self.recorder_or_dbworker: return super()._do_return_conn(conn) conn.close() + def shutdown(self): + """Close the connection.""" + if self.recorder_or_dbworker and (conn := self._conn.current()): + conn.close() + def dispose(self): """Dispose of the connection.""" - if threading.current_thread().ident == self._tid: + if self.recorder_or_dbworker: return super().dispose() def _do_get(self): - if threading.current_thread().ident == self._tid: + if self.recorder_or_dbworker: return super()._do_get() + report( + "accesses the database without the database executor; " + "Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() " + "for faster database operations", + exclude_integrations={"recorder"}, + error_if_core=False, + ) return super( # pylint: disable=bad-super-call NullPool, self )._create_connection() diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index aec7905615f..ee5081b3c1c 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -40,7 +40,8 @@ async def ws_validate_statistics( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: """Fetch a list of available statistic_id.""" - statistic_ids = await hass.async_add_executor_job( + instance: Recorder = hass.data[DATA_INSTANCE] + statistic_ids = await instance.async_add_executor_job( validate_statistics, hass, ) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 6bdc8250afc..ab05c6ec05b 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -4,7 +4,7 @@ import asyncio from datetime import datetime, timedelta import sqlite3 import threading -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError @@ -22,6 +22,7 @@ from homeassistant.components.recorder import ( SERVICE_PURGE_ENTITIES, SQLITE_URL_PREFIX, Recorder, + get_instance, run_information, run_information_from_instance, run_information_with_session, @@ -101,6 +102,30 @@ async def test_shutdown_before_startup_finishes(hass): assert run_info.end is not None +async def test_shutdown_closes_connections(hass): + """Test shutdown closes connections.""" + + hass.state = CoreState.not_running + + await async_init_recorder_component(hass) + instance = get_instance(hass) + await instance.async_db_ready + await hass.async_block_till_done() + pool = instance.engine.pool + pool.shutdown = Mock() + + def _ensure_connected(): + with session_scope(hass=hass) as session: + list(session.query(States)) + + await instance.async_add_executor_job(_ensure_connected) + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + + assert len(pool.shutdown.mock_calls) == 1 + + async def test_state_gets_saved_when_set_before_start_event( hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT ): diff --git a/tests/components/recorder/test_pool.py b/tests/components/recorder/test_pool.py index e59dc18fc8b..ca6a88d84a7 100644 --- a/tests/components/recorder/test_pool.py +++ b/tests/components/recorder/test_pool.py @@ -4,15 +4,16 @@ import threading from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from homeassistant.components.recorder.const import DB_WORKER_PREFIX from homeassistant.components.recorder.pool import RecorderPool -def test_recorder_pool(): +def test_recorder_pool(caplog): """Test RecorderPool gives the same connection in the creating thread.""" engine = create_engine("sqlite://", poolclass=RecorderPool) get_session = sessionmaker(bind=engine) - + shutdown = False connections = [] def _get_connection_twice(): @@ -20,15 +21,42 @@ def test_recorder_pool(): connections.append(session.connection().connection.connection) session.close() + if shutdown: + engine.pool.shutdown() + session = get_session() connections.append(session.connection().connection.connection) session.close() _get_connection_twice() - assert connections[0] == connections[1] + assert "accesses the database without the database executor" in caplog.text + assert connections[0] != connections[1] + caplog.clear() new_thread = threading.Thread(target=_get_connection_twice) new_thread.start() new_thread.join() - + assert "accesses the database without the database executor" in caplog.text assert connections[2] != connections[3] + + caplog.clear() + new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX) + new_thread.start() + new_thread.join() + assert "accesses the database without the database executor" not in caplog.text + assert connections[4] == connections[5] + + caplog.clear() + new_thread = threading.Thread(target=_get_connection_twice, name="Recorder") + new_thread.start() + new_thread.join() + assert "accesses the database without the database executor" not in caplog.text + assert connections[6] == connections[7] + + shutdown = True + caplog.clear() + new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX) + new_thread.start() + new_thread.join() + assert "accesses the database without the database executor" not in caplog.text + assert connections[8] != connections[9] diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index ec74ea73975..fe38aa2ab4f 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -570,8 +570,17 @@ async def test_write_lock_db(hass, tmp_path): instance = hass.data[DATA_INSTANCE] + def _drop_table(): + with instance.engine.connect() as connection: + connection.execute(text("DROP TABLE events;")) + with util.write_lock_db_sqlite(instance): # Database should be locked now, try writing SQL command - with instance.engine.connect() as connection: - with pytest.raises(OperationalError): - connection.execute(text("DROP TABLE events;")) + with pytest.raises(OperationalError): + # This needs to be called in another thread since + # the lock method is BEGIN IMMEDIATE and since we have + # a connection per thread with sqlite now, we cannot do it + # in the same thread as the one holding the lock since it + # would be allowed to proceed as the goal is to prevent + # all the other threads from accessing the database + await hass.async_add_executor_job(_drop_table)