Use a dedicated executor pool for database operations (#68105)
Co-authored-by: Erik Montnemery <erik@montnemery.com> Co-authored-by: Franck Nijhof <git@frenck.dev>pull/68224/head^2
parent
0655ebbd84
commit
bc862e97ed
|
@ -14,7 +14,11 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.components import frontend, websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.recorder import history, models as history_models
|
||||
from homeassistant.components.recorder import (
|
||||
get_instance,
|
||||
history,
|
||||
models as history_models,
|
||||
)
|
||||
from homeassistant.components.recorder.statistics import (
|
||||
list_statistic_ids,
|
||||
statistics_during_period,
|
||||
|
@ -142,7 +146,7 @@ async def ws_get_statistics_during_period(
|
|||
else:
|
||||
end_time = None
|
||||
|
||||
statistics = await hass.async_add_executor_job(
|
||||
statistics = await get_instance(hass).async_add_executor_job(
|
||||
statistics_during_period,
|
||||
hass,
|
||||
start_time,
|
||||
|
@ -164,7 +168,7 @@ async def ws_get_list_statistic_ids(
|
|||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Fetch a list of available statistic_id."""
|
||||
statistic_ids = await hass.async_add_executor_job(
|
||||
statistic_ids = await get_instance(hass).async_add_executor_job(
|
||||
list_statistic_ids,
|
||||
hass,
|
||||
msg.get("statistic_type"),
|
||||
|
@ -232,7 +236,7 @@ class HistoryPeriodView(HomeAssistantView):
|
|||
|
||||
return cast(
|
||||
web.Response,
|
||||
await hass.async_add_executor_job(
|
||||
await get_instance(hass).async_add_executor_job(
|
||||
self._sorted_significant_states_json,
|
||||
hass,
|
||||
start_time,
|
||||
|
|
|
@ -7,7 +7,7 @@ import math
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import history
|
||||
from homeassistant.components.recorder import get_instance, history
|
||||
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
|
||||
from homeassistant.const import (
|
||||
CONF_ENTITY_ID,
|
||||
|
@ -225,7 +225,7 @@ class HistoryStatsSensor(SensorEntity):
|
|||
# Don't compute anything as the value cannot have changed
|
||||
return
|
||||
|
||||
await self.hass.async_add_executor_job(
|
||||
await get_instance(self.hass).async_add_executor_job(
|
||||
self._update, start, end, now_timestamp, start_timestamp, end_timestamp
|
||||
)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.components import frontend
|
|||
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
|
||||
from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.components.recorder.models import (
|
||||
Events,
|
||||
States,
|
||||
|
@ -254,7 +255,7 @@ class LogbookView(HomeAssistantView):
|
|||
)
|
||||
)
|
||||
|
||||
return await hass.async_add_executor_job(json_events)
|
||||
return await get_instance(hass).async_add_executor_job(json_events)
|
||||
|
||||
|
||||
def humanify(hass, events, entity_attr_cache, context_lookup):
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.components.recorder.models import States
|
||||
from homeassistant.components.recorder.util import execute, session_scope
|
||||
from homeassistant.const import (
|
||||
|
@ -283,7 +284,9 @@ class Plant(Entity):
|
|||
"""After being added to hass, load from history."""
|
||||
if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components:
|
||||
# only use the database if it's configured
|
||||
await self.hass.async_add_executor_job(self._load_history_from_db)
|
||||
await get_instance(self.hass).async_add_executor_job(
|
||||
self._load_history_from_db
|
||||
)
|
||||
self.async_write_ha_state()
|
||||
|
||||
async_track_state_change_event(
|
||||
|
|
|
@ -12,7 +12,7 @@ import queue
|
|||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
@ -57,10 +57,12 @@ from . import history, migration, purge, statistics, websocket_api
|
|||
from .const import (
|
||||
CONF_DB_INTEGRITY_CHECK,
|
||||
DATA_INSTANCE,
|
||||
DB_WORKER_PREFIX,
|
||||
DOMAIN,
|
||||
MAX_QUEUE_BACKLOG,
|
||||
SQLITE_URL_PREFIX,
|
||||
)
|
||||
from .executor import DBInterruptibleThreadPoolExecutor
|
||||
from .models import (
|
||||
Base,
|
||||
Events,
|
||||
|
@ -69,7 +71,7 @@ from .models import (
|
|||
StatisticsRuns,
|
||||
process_timestamp,
|
||||
)
|
||||
from .pool import RecorderPool
|
||||
from .pool import POOL_SIZE, RecorderPool
|
||||
from .util import (
|
||||
dburl_to_path,
|
||||
end_incomplete_runs,
|
||||
|
@ -83,6 +85,9 @@ from .util import (
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
SERVICE_PURGE = "purge"
|
||||
SERVICE_PURGE_ENTITIES = "purge_entities"
|
||||
SERVICE_ENABLE = "enable"
|
||||
|
@ -182,6 +187,15 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
)
|
||||
|
||||
|
||||
# Pool size must accommodate Recorder thread + All db executors
|
||||
MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
|
||||
|
||||
|
||||
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||
"""Get the recorder instance."""
|
||||
return hass.data[DATA_INSTANCE]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
|
||||
"""Check if an entity is being recorded.
|
||||
|
@ -537,6 +551,7 @@ class Recorder(threading.Thread):
|
|||
self._queue_watcher = None
|
||||
self._db_supports_row_number = True
|
||||
self._database_lock_task: DatabaseLockTask | None = None
|
||||
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
|
||||
|
||||
self.enabled = True
|
||||
|
||||
|
@ -544,6 +559,20 @@ class Recorder(threading.Thread):
|
|||
"""Enable or disable recording events and states."""
|
||||
self.enabled = enable
|
||||
|
||||
@callback
|
||||
def async_start_executor(self):
|
||||
"""Start the executor."""
|
||||
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
||||
thread_name_prefix=DB_WORKER_PREFIX,
|
||||
max_workers=MAX_DB_EXECUTOR_WORKERS,
|
||||
shutdown_hook=self._shutdown_pool,
|
||||
)
|
||||
|
||||
def _shutdown_pool(self):
|
||||
"""Close the dbpool connections in the current thread."""
|
||||
if hasattr(self.engine.pool, "shutdown"):
|
||||
self.engine.pool.shutdown()
|
||||
|
||||
@callback
|
||||
def async_initialize(self):
|
||||
"""Initialize the recorder."""
|
||||
|
@ -554,6 +583,19 @@ class Recorder(threading.Thread):
|
|||
self.hass, self._async_check_queue, timedelta(minutes=10)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_add_executor_job(
|
||||
self, target: Callable[..., T], *args: Any
|
||||
) -> asyncio.Future[T]:
|
||||
"""Add an executor job from within the event loop."""
|
||||
return self.hass.loop.run_in_executor(self._db_executor, target, *args)
|
||||
|
||||
def _stop_executor(self) -> None:
|
||||
"""Stop the executor."""
|
||||
assert self._db_executor is not None
|
||||
self._db_executor.shutdown()
|
||||
self._db_executor = None
|
||||
|
||||
@callback
|
||||
def _async_check_queue(self, *_):
|
||||
"""Periodic check of the queue size to ensure we do not exaust memory.
|
||||
|
@ -680,6 +722,7 @@ class Recorder(threading.Thread):
|
|||
def async_connection_success(self):
|
||||
"""Connect success tasks."""
|
||||
self.async_db_ready.set_result(True)
|
||||
self.async_start_executor()
|
||||
|
||||
@callback
|
||||
def _async_recorder_ready(self):
|
||||
|
@ -1212,6 +1255,7 @@ class Recorder(threading.Thread):
|
|||
def _shutdown(self):
|
||||
"""Save end time for current run."""
|
||||
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
|
||||
self._stop_executor()
|
||||
self._end_session()
|
||||
self._close_connection()
|
||||
|
||||
|
|
|
@ -15,3 +15,5 @@ MAX_QUEUE_BACKLOG = 30000
|
|||
# We can increase this back to 1000 once most
|
||||
# have upgraded their sqlite version
|
||||
MAX_ROWS_TO_PURGE = 998
|
||||
|
||||
DB_WORKER_PREFIX = "DbWorker"
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
"""Database executor helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures.thread import _threads_queues, _worker
|
||||
import threading
|
||||
from typing import Any
|
||||
import weakref
|
||||
|
||||
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
||||
|
||||
|
||||
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
|
||||
"""Create a worker that calls a function after its finished."""
|
||||
_worker(*args, **kwargs)
|
||||
shutdown_hook()
|
||||
|
||||
|
||||
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
||||
"""A database instance that will not deadlock on shutdown."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init the executor with a shutdown hook support."""
|
||||
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _adjust_thread_count(self) -> None:
|
||||
"""Overridden to add support for shutdown hook.
|
||||
|
||||
Based on the CPython 3.10 implementation.
|
||||
"""
|
||||
# if idle threads are available, don't spin new threads
|
||||
if self._idle_semaphore.acquire( # pylint: disable=consider-using-with
|
||||
timeout=0
|
||||
):
|
||||
return
|
||||
|
||||
# When the executor gets lost, the weakref callback will wake up
|
||||
# the worker threads.
|
||||
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
|
||||
q.put(None)
|
||||
|
||||
num_threads = len(self._threads)
|
||||
if num_threads < self._max_workers:
|
||||
thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads)
|
||||
executor_thread = threading.Thread(
|
||||
name=thread_name,
|
||||
target=_worker_with_shutdown_hook,
|
||||
args=(
|
||||
self._shutdown_hook,
|
||||
weakref.ref(self, weakref_cb),
|
||||
self._work_queue,
|
||||
self._initializer,
|
||||
self._initargs,
|
||||
),
|
||||
)
|
||||
executor_thread.start()
|
||||
self._threads.add(executor_thread) # type: ignore[attr-defined]
|
||||
_threads_queues[executor_thread] = self._work_queue # type: ignore[index]
|
|
@ -1,34 +1,60 @@
|
|||
"""A pool for sqlite connections."""
|
||||
import threading
|
||||
|
||||
from sqlalchemy.pool import NullPool, StaticPool
|
||||
from sqlalchemy.pool import NullPool, SingletonThreadPool
|
||||
|
||||
from homeassistant.helpers.frame import report
|
||||
|
||||
from .const import DB_WORKER_PREFIX
|
||||
|
||||
POOL_SIZE = 5
|
||||
|
||||
|
||||
class RecorderPool(StaticPool, NullPool):
|
||||
"""A hybrid of NullPool and StaticPool.
|
||||
class RecorderPool(SingletonThreadPool, NullPool):
|
||||
"""A hybrid of NullPool and SingletonThreadPool.
|
||||
|
||||
When called from the creating thread acts like StaticPool
|
||||
When called from the creating thread or db executor acts like SingletonThreadPool
|
||||
When called from any other thread, acts like NullPool
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
||||
"""Create the pool."""
|
||||
self._tid = threading.current_thread().ident
|
||||
StaticPool.__init__(self, *args, **kw)
|
||||
kw["pool_size"] = POOL_SIZE
|
||||
SingletonThreadPool.__init__(self, *args, **kw)
|
||||
|
||||
@property
|
||||
def recorder_or_dbworker(self) -> bool:
|
||||
"""Check if the thread is a recorder or dbworker thread."""
|
||||
thread_name = threading.current_thread().name
|
||||
return bool(
|
||||
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
||||
)
|
||||
|
||||
def _do_return_conn(self, conn):
|
||||
if threading.current_thread().ident == self._tid:
|
||||
if self.recorder_or_dbworker:
|
||||
return super()._do_return_conn(conn)
|
||||
conn.close()
|
||||
|
||||
def shutdown(self):
|
||||
"""Close the connection."""
|
||||
if self.recorder_or_dbworker and (conn := self._conn.current()):
|
||||
conn.close()
|
||||
|
||||
def dispose(self):
|
||||
"""Dispose of the connection."""
|
||||
if threading.current_thread().ident == self._tid:
|
||||
if self.recorder_or_dbworker:
|
||||
return super().dispose()
|
||||
|
||||
def _do_get(self):
|
||||
if threading.current_thread().ident == self._tid:
|
||||
if self.recorder_or_dbworker:
|
||||
return super()._do_get()
|
||||
report(
|
||||
"accesses the database without the database executor; "
|
||||
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() "
|
||||
"for faster database operations",
|
||||
exclude_integrations={"recorder"},
|
||||
error_if_core=False,
|
||||
)
|
||||
return super( # pylint: disable=bad-super-call
|
||||
NullPool, self
|
||||
)._create_connection()
|
||||
|
|
|
@ -40,7 +40,8 @@ async def ws_validate_statistics(
|
|||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Fetch a list of available statistic_id."""
|
||||
statistic_ids = await hass.async_add_executor_job(
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
statistic_ids = await instance.async_add_executor_job(
|
||||
validate_statistics,
|
||||
hass,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
from datetime import datetime, timedelta
|
||||
import sqlite3
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||
|
@ -22,6 +22,7 @@ from homeassistant.components.recorder import (
|
|||
SERVICE_PURGE_ENTITIES,
|
||||
SQLITE_URL_PREFIX,
|
||||
Recorder,
|
||||
get_instance,
|
||||
run_information,
|
||||
run_information_from_instance,
|
||||
run_information_with_session,
|
||||
|
@ -101,6 +102,30 @@ async def test_shutdown_before_startup_finishes(hass):
|
|||
assert run_info.end is not None
|
||||
|
||||
|
||||
async def test_shutdown_closes_connections(hass):
|
||||
"""Test shutdown closes connections."""
|
||||
|
||||
hass.state = CoreState.not_running
|
||||
|
||||
await async_init_recorder_component(hass)
|
||||
instance = get_instance(hass)
|
||||
await instance.async_db_ready
|
||||
await hass.async_block_till_done()
|
||||
pool = instance.engine.pool
|
||||
pool.shutdown = Mock()
|
||||
|
||||
def _ensure_connected():
|
||||
with session_scope(hass=hass) as session:
|
||||
list(session.query(States))
|
||||
|
||||
await instance.async_add_executor_job(_ensure_connected)
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(pool.shutdown.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_state_gets_saved_when_set_before_start_event(
|
||||
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT
|
||||
):
|
||||
|
|
|
@ -4,15 +4,16 @@ import threading
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from homeassistant.components.recorder.const import DB_WORKER_PREFIX
|
||||
from homeassistant.components.recorder.pool import RecorderPool
|
||||
|
||||
|
||||
def test_recorder_pool():
|
||||
def test_recorder_pool(caplog):
|
||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||
|
||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||
get_session = sessionmaker(bind=engine)
|
||||
|
||||
shutdown = False
|
||||
connections = []
|
||||
|
||||
def _get_connection_twice():
|
||||
|
@ -20,15 +21,42 @@ def test_recorder_pool():
|
|||
connections.append(session.connection().connection.connection)
|
||||
session.close()
|
||||
|
||||
if shutdown:
|
||||
engine.pool.shutdown()
|
||||
|
||||
session = get_session()
|
||||
connections.append(session.connection().connection.connection)
|
||||
session.close()
|
||||
|
||||
_get_connection_twice()
|
||||
assert connections[0] == connections[1]
|
||||
assert "accesses the database without the database executor" in caplog.text
|
||||
assert connections[0] != connections[1]
|
||||
|
||||
caplog.clear()
|
||||
new_thread = threading.Thread(target=_get_connection_twice)
|
||||
new_thread.start()
|
||||
new_thread.join()
|
||||
|
||||
assert "accesses the database without the database executor" in caplog.text
|
||||
assert connections[2] != connections[3]
|
||||
|
||||
caplog.clear()
|
||||
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||
new_thread.start()
|
||||
new_thread.join()
|
||||
assert "accesses the database without the database executor" not in caplog.text
|
||||
assert connections[4] == connections[5]
|
||||
|
||||
caplog.clear()
|
||||
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
|
||||
new_thread.start()
|
||||
new_thread.join()
|
||||
assert "accesses the database without the database executor" not in caplog.text
|
||||
assert connections[6] == connections[7]
|
||||
|
||||
shutdown = True
|
||||
caplog.clear()
|
||||
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||
new_thread.start()
|
||||
new_thread.join()
|
||||
assert "accesses the database without the database executor" not in caplog.text
|
||||
assert connections[8] != connections[9]
|
||||
|
|
|
@ -570,8 +570,17 @@ async def test_write_lock_db(hass, tmp_path):
|
|||
|
||||
instance = hass.data[DATA_INSTANCE]
|
||||
|
||||
def _drop_table():
|
||||
with instance.engine.connect() as connection:
|
||||
connection.execute(text("DROP TABLE events;"))
|
||||
|
||||
with util.write_lock_db_sqlite(instance):
|
||||
# Database should be locked now, try writing SQL command
|
||||
with instance.engine.connect() as connection:
|
||||
with pytest.raises(OperationalError):
|
||||
connection.execute(text("DROP TABLE events;"))
|
||||
with pytest.raises(OperationalError):
|
||||
# This needs to be called in another thread since
|
||||
# the lock method is BEGIN IMMEDIATE and since we have
|
||||
# a connection per thread with sqlite now, we cannot do it
|
||||
# in the same thread as the one holding the lock since it
|
||||
# would be allowed to proceed as the goal is to prevent
|
||||
# all the other threads from accessing the database
|
||||
await hass.async_add_executor_job(_drop_table)
|
||||
|
|
Loading…
Reference in New Issue