Use a dedicated executor pool for database operations (#68105)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Franck Nijhof <git@frenck.dev>
pull/68224/head^2
J. Nick Koston 2022-03-17 23:09:01 -10:00 committed by GitHub
parent 0655ebbd84
commit bc862e97ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 230 additions and 28 deletions

View File

@ -14,7 +14,11 @@ import voluptuous as vol
from homeassistant.components import frontend, websocket_api from homeassistant.components import frontend, websocket_api
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import history, models as history_models from homeassistant.components.recorder import (
get_instance,
history,
models as history_models,
)
from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.statistics import (
list_statistic_ids, list_statistic_ids,
statistics_during_period, statistics_during_period,
@ -142,7 +146,7 @@ async def ws_get_statistics_during_period(
else: else:
end_time = None end_time = None
statistics = await hass.async_add_executor_job( statistics = await get_instance(hass).async_add_executor_job(
statistics_during_period, statistics_during_period,
hass, hass,
start_time, start_time,
@ -164,7 +168,7 @@ async def ws_get_list_statistic_ids(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""Fetch a list of available statistic_id.""" """Fetch a list of available statistic_id."""
statistic_ids = await hass.async_add_executor_job( statistic_ids = await get_instance(hass).async_add_executor_job(
list_statistic_ids, list_statistic_ids,
hass, hass,
msg.get("statistic_type"), msg.get("statistic_type"),
@ -232,7 +236,7 @@ class HistoryPeriodView(HomeAssistantView):
return cast( return cast(
web.Response, web.Response,
await hass.async_add_executor_job( await get_instance(hass).async_add_executor_job(
self._sorted_significant_states_json, self._sorted_significant_states_json,
hass, hass,
start_time, start_time,

View File

@ -7,7 +7,7 @@ import math
import voluptuous as vol import voluptuous as vol
from homeassistant.components.recorder import history from homeassistant.components.recorder import get_instance, history
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
from homeassistant.const import ( from homeassistant.const import (
CONF_ENTITY_ID, CONF_ENTITY_ID,
@ -225,7 +225,7 @@ class HistoryStatsSensor(SensorEntity):
# Don't compute anything as the value cannot have changed # Don't compute anything as the value cannot have changed
return return
await self.hass.async_add_executor_job( await get_instance(self.hass).async_add_executor_job(
self._update, start, end, now_timestamp, start_timestamp, end_timestamp self._update, start, end, now_timestamp, start_timestamp, end_timestamp
) )

View File

@ -15,6 +15,7 @@ from homeassistant.components import frontend
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
Events, Events,
States, States,
@ -254,7 +255,7 @@ class LogbookView(HomeAssistantView):
) )
) )
return await hass.async_add_executor_job(json_events) return await get_instance(hass).async_add_executor_job(json_events)
def humanify(hass, events, entity_attr_cache, context_lookup): def humanify(hass, events, entity_attr_cache, context_lookup):

View File

@ -6,6 +6,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import States from homeassistant.components.recorder.models import States
from homeassistant.components.recorder.util import execute, session_scope from homeassistant.components.recorder.util import execute, session_scope
from homeassistant.const import ( from homeassistant.const import (
@ -283,7 +284,9 @@ class Plant(Entity):
"""After being added to hass, load from history.""" """After being added to hass, load from history."""
if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components: if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components:
# only use the database if it's configured # only use the database if it's configured
await self.hass.async_add_executor_job(self._load_history_from_db) await get_instance(self.hass).async_add_executor_job(
self._load_history_from_db
)
self.async_write_ha_state() self.async_write_ha_state()
async_track_state_change_event( async_track_state_change_event(

View File

@ -12,7 +12,7 @@ import queue
import sqlite3 import sqlite3
import threading import threading
import time import time
from typing import Any from typing import Any, TypeVar
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -57,10 +57,12 @@ from . import history, migration, purge, statistics, websocket_api
from .const import ( from .const import (
CONF_DB_INTEGRITY_CHECK, CONF_DB_INTEGRITY_CHECK,
DATA_INSTANCE, DATA_INSTANCE,
DB_WORKER_PREFIX,
DOMAIN, DOMAIN,
MAX_QUEUE_BACKLOG, MAX_QUEUE_BACKLOG,
SQLITE_URL_PREFIX, SQLITE_URL_PREFIX,
) )
from .executor import DBInterruptibleThreadPoolExecutor
from .models import ( from .models import (
Base, Base,
Events, Events,
@ -69,7 +71,7 @@ from .models import (
StatisticsRuns, StatisticsRuns,
process_timestamp, process_timestamp,
) )
from .pool import RecorderPool from .pool import POOL_SIZE, RecorderPool
from .util import ( from .util import (
dburl_to_path, dburl_to_path,
end_incomplete_runs, end_incomplete_runs,
@ -83,6 +85,9 @@ from .util import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
T = TypeVar("T")
SERVICE_PURGE = "purge" SERVICE_PURGE = "purge"
SERVICE_PURGE_ENTITIES = "purge_entities" SERVICE_PURGE_ENTITIES = "purge_entities"
SERVICE_ENABLE = "enable" SERVICE_ENABLE = "enable"
@ -182,6 +187,15 @@ CONFIG_SCHEMA = vol.Schema(
) )
# Pool size must accommodate Recorder thread + All db executors
MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
@bind_hass @bind_hass
def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool: def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
"""Check if an entity is being recorded. """Check if an entity is being recorded.
@ -537,6 +551,7 @@ class Recorder(threading.Thread):
self._queue_watcher = None self._queue_watcher = None
self._db_supports_row_number = True self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self.enabled = True self.enabled = True
@ -544,6 +559,20 @@ class Recorder(threading.Thread):
"""Enable or disable recording events and states.""" """Enable or disable recording events and states."""
self.enabled = enable self.enabled = enable
@callback
def async_start_executor(self):
"""Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor(
thread_name_prefix=DB_WORKER_PREFIX,
max_workers=MAX_DB_EXECUTOR_WORKERS,
shutdown_hook=self._shutdown_pool,
)
def _shutdown_pool(self):
"""Close the dbpool connections in the current thread."""
if hasattr(self.engine.pool, "shutdown"):
self.engine.pool.shutdown()
@callback @callback
def async_initialize(self): def async_initialize(self):
"""Initialize the recorder.""" """Initialize the recorder."""
@ -554,6 +583,19 @@ class Recorder(threading.Thread):
self.hass, self._async_check_queue, timedelta(minutes=10) self.hass, self._async_check_queue, timedelta(minutes=10)
) )
@callback
def async_add_executor_job(
self, target: Callable[..., T], *args: Any
) -> asyncio.Future[T]:
"""Add an executor job from within the event loop."""
return self.hass.loop.run_in_executor(self._db_executor, target, *args)
def _stop_executor(self) -> None:
"""Stop the executor."""
assert self._db_executor is not None
self._db_executor.shutdown()
self._db_executor = None
@callback @callback
def _async_check_queue(self, *_): def _async_check_queue(self, *_):
"""Periodic check of the queue size to ensure we do not exaust memory. """Periodic check of the queue size to ensure we do not exaust memory.
@ -680,6 +722,7 @@ class Recorder(threading.Thread):
def async_connection_success(self): def async_connection_success(self):
"""Connect success tasks.""" """Connect success tasks."""
self.async_db_ready.set_result(True) self.async_db_ready.set_result(True)
self.async_start_executor()
@callback @callback
def _async_recorder_ready(self): def _async_recorder_ready(self):
@ -1212,6 +1255,7 @@ class Recorder(threading.Thread):
def _shutdown(self): def _shutdown(self):
"""Save end time for current run.""" """Save end time for current run."""
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener) self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
self._stop_executor()
self._end_session() self._end_session()
self._close_connection() self._close_connection()

View File

@ -15,3 +15,5 @@ MAX_QUEUE_BACKLOG = 30000
# We can increase this back to 1000 once most # We can increase this back to 1000 once most
# have upgraded their sqlite version # have upgraded their sqlite version
MAX_ROWS_TO_PURGE = 998 MAX_ROWS_TO_PURGE = 998
DB_WORKER_PREFIX = "DbWorker"

View File

@ -0,0 +1,59 @@
"""Database executor helpers."""
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures.thread import _threads_queues, _worker
import threading
from typing import Any
import weakref
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
"""Create a worker that calls a function after its finished."""
_worker(*args, **kwargs)
shutdown_hook()
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
"""A database instance that will not deadlock on shutdown."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init the executor with a shutdown hook support."""
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
super().__init__(*args, **kwargs)
def _adjust_thread_count(self) -> None:
"""Overridden to add support for shutdown hook.
Based on the CPython 3.10 implementation.
"""
# if idle threads are available, don't spin new threads
if self._idle_semaphore.acquire( # pylint: disable=consider-using-with
timeout=0
):
return
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
q.put(None)
num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads)
executor_thread = threading.Thread(
name=thread_name,
target=_worker_with_shutdown_hook,
args=(
self._shutdown_hook,
weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs,
),
)
executor_thread.start()
self._threads.add(executor_thread) # type: ignore[attr-defined]
_threads_queues[executor_thread] = self._work_queue # type: ignore[index]

View File

@ -1,34 +1,60 @@
"""A pool for sqlite connections.""" """A pool for sqlite connections."""
import threading import threading
from sqlalchemy.pool import NullPool, StaticPool from sqlalchemy.pool import NullPool, SingletonThreadPool
from homeassistant.helpers.frame import report
from .const import DB_WORKER_PREFIX
POOL_SIZE = 5
class RecorderPool(StaticPool, NullPool): class RecorderPool(SingletonThreadPool, NullPool):
"""A hybrid of NullPool and StaticPool. """A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread acts like StaticPool When called from the creating thread or db executor acts like SingletonThreadPool
When called from any other thread, acts like NullPool When called from any other thread, acts like NullPool
""" """
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
"""Create the pool.""" """Create the pool."""
self._tid = threading.current_thread().ident kw["pool_size"] = POOL_SIZE
StaticPool.__init__(self, *args, **kw) SingletonThreadPool.__init__(self, *args, **kw)
@property
def recorder_or_dbworker(self) -> bool:
"""Check if the thread is a recorder or dbworker thread."""
thread_name = threading.current_thread().name
return bool(
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
)
def _do_return_conn(self, conn): def _do_return_conn(self, conn):
if threading.current_thread().ident == self._tid: if self.recorder_or_dbworker:
return super()._do_return_conn(conn) return super()._do_return_conn(conn)
conn.close() conn.close()
def shutdown(self):
"""Close the connection."""
if self.recorder_or_dbworker and (conn := self._conn.current()):
conn.close()
def dispose(self): def dispose(self):
"""Dispose of the connection.""" """Dispose of the connection."""
if threading.current_thread().ident == self._tid: if self.recorder_or_dbworker:
return super().dispose() return super().dispose()
def _do_get(self): def _do_get(self):
if threading.current_thread().ident == self._tid: if self.recorder_or_dbworker:
return super()._do_get() return super()._do_get()
report(
"accesses the database without the database executor; "
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() "
"for faster database operations",
exclude_integrations={"recorder"},
error_if_core=False,
)
return super( # pylint: disable=bad-super-call return super( # pylint: disable=bad-super-call
NullPool, self NullPool, self
)._create_connection() )._create_connection()

View File

@ -40,7 +40,8 @@ async def ws_validate_statistics(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""Fetch a list of available statistic_id.""" """Fetch a list of available statistic_id."""
statistic_ids = await hass.async_add_executor_job( instance: Recorder = hass.data[DATA_INSTANCE]
statistic_ids = await instance.async_add_executor_job(
validate_statistics, validate_statistics,
hass, hass,
) )

View File

@ -4,7 +4,7 @@ import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import sqlite3 import sqlite3
import threading import threading
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
@ -22,6 +22,7 @@ from homeassistant.components.recorder import (
SERVICE_PURGE_ENTITIES, SERVICE_PURGE_ENTITIES,
SQLITE_URL_PREFIX, SQLITE_URL_PREFIX,
Recorder, Recorder,
get_instance,
run_information, run_information,
run_information_from_instance, run_information_from_instance,
run_information_with_session, run_information_with_session,
@ -101,6 +102,30 @@ async def test_shutdown_before_startup_finishes(hass):
assert run_info.end is not None assert run_info.end is not None
async def test_shutdown_closes_connections(hass):
"""Test shutdown closes connections."""
hass.state = CoreState.not_running
await async_init_recorder_component(hass)
instance = get_instance(hass)
await instance.async_db_ready
await hass.async_block_till_done()
pool = instance.engine.pool
pool.shutdown = Mock()
def _ensure_connected():
with session_scope(hass=hass) as session:
list(session.query(States))
await instance.async_add_executor_job(_ensure_connected)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
assert len(pool.shutdown.mock_calls) == 1
async def test_state_gets_saved_when_set_before_start_event( async def test_state_gets_saved_when_set_before_start_event(
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT
): ):

View File

@ -4,15 +4,16 @@ import threading
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from homeassistant.components.recorder.const import DB_WORKER_PREFIX
from homeassistant.components.recorder.pool import RecorderPool from homeassistant.components.recorder.pool import RecorderPool
def test_recorder_pool(): def test_recorder_pool(caplog):
"""Test RecorderPool gives the same connection in the creating thread.""" """Test RecorderPool gives the same connection in the creating thread."""
engine = create_engine("sqlite://", poolclass=RecorderPool) engine = create_engine("sqlite://", poolclass=RecorderPool)
get_session = sessionmaker(bind=engine) get_session = sessionmaker(bind=engine)
shutdown = False
connections = [] connections = []
def _get_connection_twice(): def _get_connection_twice():
@ -20,15 +21,42 @@ def test_recorder_pool():
connections.append(session.connection().connection.connection) connections.append(session.connection().connection.connection)
session.close() session.close()
if shutdown:
engine.pool.shutdown()
session = get_session() session = get_session()
connections.append(session.connection().connection.connection) connections.append(session.connection().connection.connection)
session.close() session.close()
_get_connection_twice() _get_connection_twice()
assert connections[0] == connections[1] assert "accesses the database without the database executor" in caplog.text
assert connections[0] != connections[1]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice) new_thread = threading.Thread(target=_get_connection_twice)
new_thread.start() new_thread.start()
new_thread.join() new_thread.join()
assert "accesses the database without the database executor" in caplog.text
assert connections[2] != connections[3] assert connections[2] != connections[3]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[4] == connections[5]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[6] == connections[7]
shutdown = True
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[8] != connections[9]

View File

@ -570,8 +570,17 @@ async def test_write_lock_db(hass, tmp_path):
instance = hass.data[DATA_INSTANCE] instance = hass.data[DATA_INSTANCE]
def _drop_table():
with instance.engine.connect() as connection:
connection.execute(text("DROP TABLE events;"))
with util.write_lock_db_sqlite(instance): with util.write_lock_db_sqlite(instance):
# Database should be locked now, try writing SQL command # Database should be locked now, try writing SQL command
with instance.engine.connect() as connection: with pytest.raises(OperationalError):
with pytest.raises(OperationalError): # This needs to be called in another thread since
connection.execute(text("DROP TABLE events;")) # the lock method is BEGIN IMMEDIATE and since we have
# a connection per thread with sqlite now, we cannot do it
# in the same thread as the one holding the lock since it
# would be allowed to proceed as the goal is to prevent
# all the other threads from accessing the database
await hass.async_add_executor_job(_drop_table)