Use a dedicated executor pool for database operations (#68105)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
Co-authored-by: Franck Nijhof <git@frenck.dev>
pull/68224/head^2
J. Nick Koston 2022-03-17 23:09:01 -10:00 committed by GitHub
parent 0655ebbd84
commit bc862e97ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 230 additions and 28 deletions

View File

@ -14,7 +14,11 @@ import voluptuous as vol
from homeassistant.components import frontend, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import history, models as history_models
from homeassistant.components.recorder import (
get_instance,
history,
models as history_models,
)
from homeassistant.components.recorder.statistics import (
list_statistic_ids,
statistics_during_period,
@ -142,7 +146,7 @@ async def ws_get_statistics_during_period(
else:
end_time = None
statistics = await hass.async_add_executor_job(
statistics = await get_instance(hass).async_add_executor_job(
statistics_during_period,
hass,
start_time,
@ -164,7 +168,7 @@ async def ws_get_list_statistic_ids(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Fetch a list of available statistic_id."""
statistic_ids = await hass.async_add_executor_job(
statistic_ids = await get_instance(hass).async_add_executor_job(
list_statistic_ids,
hass,
msg.get("statistic_type"),
@ -232,7 +236,7 @@ class HistoryPeriodView(HomeAssistantView):
return cast(
web.Response,
await hass.async_add_executor_job(
await get_instance(hass).async_add_executor_job(
self._sorted_significant_states_json,
hass,
start_time,

View File

@ -7,7 +7,7 @@ import math
import voluptuous as vol
from homeassistant.components.recorder import history
from homeassistant.components.recorder import get_instance, history
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
from homeassistant.const import (
CONF_ENTITY_ID,
@ -225,7 +225,7 @@ class HistoryStatsSensor(SensorEntity):
# Don't compute anything as the value cannot have changed
return
await self.hass.async_add_executor_job(
await get_instance(self.hass).async_add_executor_job(
self._update, start, end, now_timestamp, start_timestamp, end_timestamp
)

View File

@ -15,6 +15,7 @@ from homeassistant.components import frontend
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import (
Events,
States,
@ -254,7 +255,7 @@ class LogbookView(HomeAssistantView):
)
)
return await hass.async_add_executor_job(json_events)
return await get_instance(hass).async_add_executor_job(json_events)
def humanify(hass, events, entity_attr_cache, context_lookup):

View File

@ -6,6 +6,7 @@ import logging
import voluptuous as vol
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import States
from homeassistant.components.recorder.util import execute, session_scope
from homeassistant.const import (
@ -283,7 +284,9 @@ class Plant(Entity):
"""After being added to hass, load from history."""
if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components:
# only use the database if it's configured
await self.hass.async_add_executor_job(self._load_history_from_db)
await get_instance(self.hass).async_add_executor_job(
self._load_history_from_db
)
self.async_write_ha_state()
async_track_state_change_event(

View File

@ -12,7 +12,7 @@ import queue
import sqlite3
import threading
import time
from typing import Any
from typing import Any, TypeVar
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.exc import SQLAlchemyError
@ -57,10 +57,12 @@ from . import history, migration, purge, statistics, websocket_api
from .const import (
CONF_DB_INTEGRITY_CHECK,
DATA_INSTANCE,
DB_WORKER_PREFIX,
DOMAIN,
MAX_QUEUE_BACKLOG,
SQLITE_URL_PREFIX,
)
from .executor import DBInterruptibleThreadPoolExecutor
from .models import (
Base,
Events,
@ -69,7 +71,7 @@ from .models import (
StatisticsRuns,
process_timestamp,
)
from .pool import RecorderPool
from .pool import POOL_SIZE, RecorderPool
from .util import (
dburl_to_path,
end_incomplete_runs,
@ -83,6 +85,9 @@ from .util import (
_LOGGER = logging.getLogger(__name__)
T = TypeVar("T")
SERVICE_PURGE = "purge"
SERVICE_PURGE_ENTITIES = "purge_entities"
SERVICE_ENABLE = "enable"
@ -182,6 +187,15 @@ CONFIG_SCHEMA = vol.Schema(
)
# Pool size must accommodate Recorder thread + All db executors
MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
def get_instance(hass: HomeAssistant) -> Recorder:
"""Get the recorder instance."""
return hass.data[DATA_INSTANCE]
@bind_hass
def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
"""Check if an entity is being recorded.
@ -537,6 +551,7 @@ class Recorder(threading.Thread):
self._queue_watcher = None
self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self.enabled = True
@ -544,6 +559,20 @@ class Recorder(threading.Thread):
"""Enable or disable recording events and states."""
self.enabled = enable
@callback
def async_start_executor(self):
"""Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor(
thread_name_prefix=DB_WORKER_PREFIX,
max_workers=MAX_DB_EXECUTOR_WORKERS,
shutdown_hook=self._shutdown_pool,
)
def _shutdown_pool(self):
"""Close the dbpool connections in the current thread."""
if hasattr(self.engine.pool, "shutdown"):
self.engine.pool.shutdown()
@callback
def async_initialize(self):
"""Initialize the recorder."""
@ -554,6 +583,19 @@ class Recorder(threading.Thread):
self.hass, self._async_check_queue, timedelta(minutes=10)
)
@callback
def async_add_executor_job(
self, target: Callable[..., T], *args: Any
) -> asyncio.Future[T]:
"""Add an executor job from within the event loop."""
return self.hass.loop.run_in_executor(self._db_executor, target, *args)
def _stop_executor(self) -> None:
"""Stop the executor."""
assert self._db_executor is not None
self._db_executor.shutdown()
self._db_executor = None
@callback
def _async_check_queue(self, *_):
"""Periodic check of the queue size to ensure we do not exaust memory.
@ -680,6 +722,7 @@ class Recorder(threading.Thread):
def async_connection_success(self):
"""Connect success tasks."""
self.async_db_ready.set_result(True)
self.async_start_executor()
@callback
def _async_recorder_ready(self):
@ -1212,6 +1255,7 @@ class Recorder(threading.Thread):
def _shutdown(self):
"""Save end time for current run."""
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
self._stop_executor()
self._end_session()
self._close_connection()

View File

@ -15,3 +15,5 @@ MAX_QUEUE_BACKLOG = 30000
# We can increase this back to 1000 once most
# have upgraded their sqlite version
MAX_ROWS_TO_PURGE = 998
DB_WORKER_PREFIX = "DbWorker"

View File

@ -0,0 +1,59 @@
"""Database executor helpers."""
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures.thread import _threads_queues, _worker
import threading
from typing import Any
import weakref
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
"""Create a worker that calls a function after its finished."""
_worker(*args, **kwargs)
shutdown_hook()
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
"""A database instance that will not deadlock on shutdown."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init the executor with a shutdown hook support."""
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
super().__init__(*args, **kwargs)
def _adjust_thread_count(self) -> None:
"""Overridden to add support for shutdown hook.
Based on the CPython 3.10 implementation.
"""
# if idle threads are available, don't spin new threads
if self._idle_semaphore.acquire( # pylint: disable=consider-using-with
timeout=0
):
return
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
q.put(None)
num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads)
executor_thread = threading.Thread(
name=thread_name,
target=_worker_with_shutdown_hook,
args=(
self._shutdown_hook,
weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs,
),
)
executor_thread.start()
self._threads.add(executor_thread) # type: ignore[attr-defined]
_threads_queues[executor_thread] = self._work_queue # type: ignore[index]

View File

@ -1,34 +1,60 @@
"""A pool for sqlite connections."""
import threading
from sqlalchemy.pool import NullPool, StaticPool
from sqlalchemy.pool import NullPool, SingletonThreadPool
from homeassistant.helpers.frame import report
from .const import DB_WORKER_PREFIX
POOL_SIZE = 5
class RecorderPool(StaticPool, NullPool):
"""A hybrid of NullPool and StaticPool.
class RecorderPool(SingletonThreadPool, NullPool):
"""A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread acts like StaticPool
When called from the creating thread or db executor acts like SingletonThreadPool
When called from any other thread, acts like NullPool
"""
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
"""Create the pool."""
self._tid = threading.current_thread().ident
StaticPool.__init__(self, *args, **kw)
kw["pool_size"] = POOL_SIZE
SingletonThreadPool.__init__(self, *args, **kw)
@property
def recorder_or_dbworker(self) -> bool:
"""Check if the thread is a recorder or dbworker thread."""
thread_name = threading.current_thread().name
return bool(
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
)
def _do_return_conn(self, conn):
if threading.current_thread().ident == self._tid:
if self.recorder_or_dbworker:
return super()._do_return_conn(conn)
conn.close()
def shutdown(self):
"""Close the connection."""
if self.recorder_or_dbworker and (conn := self._conn.current()):
conn.close()
def dispose(self):
"""Dispose of the connection."""
if threading.current_thread().ident == self._tid:
if self.recorder_or_dbworker:
return super().dispose()
def _do_get(self):
if threading.current_thread().ident == self._tid:
if self.recorder_or_dbworker:
return super()._do_get()
report(
"accesses the database without the database executor; "
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() "
"for faster database operations",
exclude_integrations={"recorder"},
error_if_core=False,
)
return super( # pylint: disable=bad-super-call
NullPool, self
)._create_connection()

View File

@ -40,7 +40,8 @@ async def ws_validate_statistics(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Fetch a list of available statistic_id."""
statistic_ids = await hass.async_add_executor_job(
instance: Recorder = hass.data[DATA_INSTANCE]
statistic_ids = await instance.async_add_executor_job(
validate_statistics,
hass,
)

View File

@ -4,7 +4,7 @@ import asyncio
from datetime import datetime, timedelta
import sqlite3
import threading
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
@ -22,6 +22,7 @@ from homeassistant.components.recorder import (
SERVICE_PURGE_ENTITIES,
SQLITE_URL_PREFIX,
Recorder,
get_instance,
run_information,
run_information_from_instance,
run_information_with_session,
@ -101,6 +102,30 @@ async def test_shutdown_before_startup_finishes(hass):
assert run_info.end is not None
async def test_shutdown_closes_connections(hass):
"""Test shutdown closes connections."""
hass.state = CoreState.not_running
await async_init_recorder_component(hass)
instance = get_instance(hass)
await instance.async_db_ready
await hass.async_block_till_done()
pool = instance.engine.pool
pool.shutdown = Mock()
def _ensure_connected():
with session_scope(hass=hass) as session:
list(session.query(States))
await instance.async_add_executor_job(_ensure_connected)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
assert len(pool.shutdown.mock_calls) == 1
async def test_state_gets_saved_when_set_before_start_event(
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT
):

View File

@ -4,15 +4,16 @@ import threading
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from homeassistant.components.recorder.const import DB_WORKER_PREFIX
from homeassistant.components.recorder.pool import RecorderPool
def test_recorder_pool():
def test_recorder_pool(caplog):
"""Test RecorderPool gives the same connection in the creating thread."""
engine = create_engine("sqlite://", poolclass=RecorderPool)
get_session = sessionmaker(bind=engine)
shutdown = False
connections = []
def _get_connection_twice():
@ -20,15 +21,42 @@ def test_recorder_pool():
connections.append(session.connection().connection.connection)
session.close()
if shutdown:
engine.pool.shutdown()
session = get_session()
connections.append(session.connection().connection.connection)
session.close()
_get_connection_twice()
assert connections[0] == connections[1]
assert "accesses the database without the database executor" in caplog.text
assert connections[0] != connections[1]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice)
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" in caplog.text
assert connections[2] != connections[3]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[4] == connections[5]
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[6] == connections[7]
shutdown = True
caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start()
new_thread.join()
assert "accesses the database without the database executor" not in caplog.text
assert connections[8] != connections[9]

View File

@ -570,8 +570,17 @@ async def test_write_lock_db(hass, tmp_path):
instance = hass.data[DATA_INSTANCE]
def _drop_table():
with instance.engine.connect() as connection:
connection.execute(text("DROP TABLE events;"))
with util.write_lock_db_sqlite(instance):
# Database should be locked now, try writing SQL command
with instance.engine.connect() as connection:
with pytest.raises(OperationalError):
connection.execute(text("DROP TABLE events;"))
with pytest.raises(OperationalError):
# This needs to be called in another thread since
# the lock method is BEGIN IMMEDIATE and since we have
# a connection per thread with sqlite now, we cannot do it
# in the same thread as the one holding the lock since it
# would be allowed to proceed as the goal is to prevent
# all the other threads from accessing the database
await hass.async_add_executor_job(_drop_table)