Use a dedicated executor pool for database operations (#68105)
Co-authored-by: Erik Montnemery <erik@montnemery.com> Co-authored-by: Franck Nijhof <git@frenck.dev>pull/68224/head^2
parent
0655ebbd84
commit
bc862e97ed
|
@ -14,7 +14,11 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import frontend, websocket_api
|
from homeassistant.components import frontend, websocket_api
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from homeassistant.components.recorder import history, models as history_models
|
from homeassistant.components.recorder import (
|
||||||
|
get_instance,
|
||||||
|
history,
|
||||||
|
models as history_models,
|
||||||
|
)
|
||||||
from homeassistant.components.recorder.statistics import (
|
from homeassistant.components.recorder.statistics import (
|
||||||
list_statistic_ids,
|
list_statistic_ids,
|
||||||
statistics_during_period,
|
statistics_during_period,
|
||||||
|
@ -142,7 +146,7 @@ async def ws_get_statistics_during_period(
|
||||||
else:
|
else:
|
||||||
end_time = None
|
end_time = None
|
||||||
|
|
||||||
statistics = await hass.async_add_executor_job(
|
statistics = await get_instance(hass).async_add_executor_job(
|
||||||
statistics_during_period,
|
statistics_during_period,
|
||||||
hass,
|
hass,
|
||||||
start_time,
|
start_time,
|
||||||
|
@ -164,7 +168,7 @@ async def ws_get_list_statistic_ids(
|
||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fetch a list of available statistic_id."""
|
"""Fetch a list of available statistic_id."""
|
||||||
statistic_ids = await hass.async_add_executor_job(
|
statistic_ids = await get_instance(hass).async_add_executor_job(
|
||||||
list_statistic_ids,
|
list_statistic_ids,
|
||||||
hass,
|
hass,
|
||||||
msg.get("statistic_type"),
|
msg.get("statistic_type"),
|
||||||
|
@ -232,7 +236,7 @@ class HistoryPeriodView(HomeAssistantView):
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
web.Response,
|
web.Response,
|
||||||
await hass.async_add_executor_job(
|
await get_instance(hass).async_add_executor_job(
|
||||||
self._sorted_significant_states_json,
|
self._sorted_significant_states_json,
|
||||||
hass,
|
hass,
|
||||||
start_time,
|
start_time,
|
||||||
|
|
|
@ -7,7 +7,7 @@ import math
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.recorder import history
|
from homeassistant.components.recorder import get_instance, history
|
||||||
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
|
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_ENTITY_ID,
|
CONF_ENTITY_ID,
|
||||||
|
@ -225,7 +225,7 @@ class HistoryStatsSensor(SensorEntity):
|
||||||
# Don't compute anything as the value cannot have changed
|
# Don't compute anything as the value cannot have changed
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.hass.async_add_executor_job(
|
await get_instance(self.hass).async_add_executor_job(
|
||||||
self._update, start, end, now_timestamp, start_timestamp, end_timestamp
|
self._update, start, end, now_timestamp, start_timestamp, end_timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.components import frontend
|
||||||
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
|
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
|
||||||
from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf
|
from homeassistant.components.history import sqlalchemy_filter_from_include_exclude_conf
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.components.recorder import get_instance
|
||||||
from homeassistant.components.recorder.models import (
|
from homeassistant.components.recorder.models import (
|
||||||
Events,
|
Events,
|
||||||
States,
|
States,
|
||||||
|
@ -254,7 +255,7 @@ class LogbookView(HomeAssistantView):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return await hass.async_add_executor_job(json_events)
|
return await get_instance(hass).async_add_executor_job(json_events)
|
||||||
|
|
||||||
|
|
||||||
def humanify(hass, events, entity_attr_cache, context_lookup):
|
def humanify(hass, events, entity_attr_cache, context_lookup):
|
||||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.recorder import get_instance
|
||||||
from homeassistant.components.recorder.models import States
|
from homeassistant.components.recorder.models import States
|
||||||
from homeassistant.components.recorder.util import execute, session_scope
|
from homeassistant.components.recorder.util import execute, session_scope
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -283,7 +284,9 @@ class Plant(Entity):
|
||||||
"""After being added to hass, load from history."""
|
"""After being added to hass, load from history."""
|
||||||
if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components:
|
if ENABLE_LOAD_HISTORY and "recorder" in self.hass.config.components:
|
||||||
# only use the database if it's configured
|
# only use the database if it's configured
|
||||||
await self.hass.async_add_executor_job(self._load_history_from_db)
|
await get_instance(self.hass).async_add_executor_job(
|
||||||
|
self._load_history_from_db
|
||||||
|
)
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
async_track_state_change_event(
|
async_track_state_change_event(
|
||||||
|
|
|
@ -12,7 +12,7 @@ import queue
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
|
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
@ -57,10 +57,12 @@ from . import history, migration, purge, statistics, websocket_api
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_DB_INTEGRITY_CHECK,
|
CONF_DB_INTEGRITY_CHECK,
|
||||||
DATA_INSTANCE,
|
DATA_INSTANCE,
|
||||||
|
DB_WORKER_PREFIX,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
MAX_QUEUE_BACKLOG,
|
MAX_QUEUE_BACKLOG,
|
||||||
SQLITE_URL_PREFIX,
|
SQLITE_URL_PREFIX,
|
||||||
)
|
)
|
||||||
|
from .executor import DBInterruptibleThreadPoolExecutor
|
||||||
from .models import (
|
from .models import (
|
||||||
Base,
|
Base,
|
||||||
Events,
|
Events,
|
||||||
|
@ -69,7 +71,7 @@ from .models import (
|
||||||
StatisticsRuns,
|
StatisticsRuns,
|
||||||
process_timestamp,
|
process_timestamp,
|
||||||
)
|
)
|
||||||
from .pool import RecorderPool
|
from .pool import POOL_SIZE, RecorderPool
|
||||||
from .util import (
|
from .util import (
|
||||||
dburl_to_path,
|
dburl_to_path,
|
||||||
end_incomplete_runs,
|
end_incomplete_runs,
|
||||||
|
@ -83,6 +85,9 @@ from .util import (
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
SERVICE_PURGE = "purge"
|
SERVICE_PURGE = "purge"
|
||||||
SERVICE_PURGE_ENTITIES = "purge_entities"
|
SERVICE_PURGE_ENTITIES = "purge_entities"
|
||||||
SERVICE_ENABLE = "enable"
|
SERVICE_ENABLE = "enable"
|
||||||
|
@ -182,6 +187,15 @@ CONFIG_SCHEMA = vol.Schema(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Pool size must accommodate Recorder thread + All db executors
|
||||||
|
MAX_DB_EXECUTOR_WORKERS = POOL_SIZE - 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_instance(hass: HomeAssistant) -> Recorder:
|
||||||
|
"""Get the recorder instance."""
|
||||||
|
return hass.data[DATA_INSTANCE]
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
|
def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
|
||||||
"""Check if an entity is being recorded.
|
"""Check if an entity is being recorded.
|
||||||
|
@ -537,6 +551,7 @@ class Recorder(threading.Thread):
|
||||||
self._queue_watcher = None
|
self._queue_watcher = None
|
||||||
self._db_supports_row_number = True
|
self._db_supports_row_number = True
|
||||||
self._database_lock_task: DatabaseLockTask | None = None
|
self._database_lock_task: DatabaseLockTask | None = None
|
||||||
|
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
|
||||||
|
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
|
|
||||||
|
@ -544,6 +559,20 @@ class Recorder(threading.Thread):
|
||||||
"""Enable or disable recording events and states."""
|
"""Enable or disable recording events and states."""
|
||||||
self.enabled = enable
|
self.enabled = enable
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_start_executor(self):
|
||||||
|
"""Start the executor."""
|
||||||
|
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
||||||
|
thread_name_prefix=DB_WORKER_PREFIX,
|
||||||
|
max_workers=MAX_DB_EXECUTOR_WORKERS,
|
||||||
|
shutdown_hook=self._shutdown_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shutdown_pool(self):
|
||||||
|
"""Close the dbpool connections in the current thread."""
|
||||||
|
if hasattr(self.engine.pool, "shutdown"):
|
||||||
|
self.engine.pool.shutdown()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_initialize(self):
|
def async_initialize(self):
|
||||||
"""Initialize the recorder."""
|
"""Initialize the recorder."""
|
||||||
|
@ -554,6 +583,19 @@ class Recorder(threading.Thread):
|
||||||
self.hass, self._async_check_queue, timedelta(minutes=10)
|
self.hass, self._async_check_queue, timedelta(minutes=10)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_executor_job(
|
||||||
|
self, target: Callable[..., T], *args: Any
|
||||||
|
) -> asyncio.Future[T]:
|
||||||
|
"""Add an executor job from within the event loop."""
|
||||||
|
return self.hass.loop.run_in_executor(self._db_executor, target, *args)
|
||||||
|
|
||||||
|
def _stop_executor(self) -> None:
|
||||||
|
"""Stop the executor."""
|
||||||
|
assert self._db_executor is not None
|
||||||
|
self._db_executor.shutdown()
|
||||||
|
self._db_executor = None
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_check_queue(self, *_):
|
def _async_check_queue(self, *_):
|
||||||
"""Periodic check of the queue size to ensure we do not exaust memory.
|
"""Periodic check of the queue size to ensure we do not exaust memory.
|
||||||
|
@ -680,6 +722,7 @@ class Recorder(threading.Thread):
|
||||||
def async_connection_success(self):
|
def async_connection_success(self):
|
||||||
"""Connect success tasks."""
|
"""Connect success tasks."""
|
||||||
self.async_db_ready.set_result(True)
|
self.async_db_ready.set_result(True)
|
||||||
|
self.async_start_executor()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_recorder_ready(self):
|
def _async_recorder_ready(self):
|
||||||
|
@ -1212,6 +1255,7 @@ class Recorder(threading.Thread):
|
||||||
def _shutdown(self):
|
def _shutdown(self):
|
||||||
"""Save end time for current run."""
|
"""Save end time for current run."""
|
||||||
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
|
self.hass.add_job(self._async_stop_queue_watcher_and_event_listener)
|
||||||
|
self._stop_executor()
|
||||||
self._end_session()
|
self._end_session()
|
||||||
self._close_connection()
|
self._close_connection()
|
||||||
|
|
||||||
|
|
|
@ -15,3 +15,5 @@ MAX_QUEUE_BACKLOG = 30000
|
||||||
# We can increase this back to 1000 once most
|
# We can increase this back to 1000 once most
|
||||||
# have upgraded their sqlite version
|
# have upgraded their sqlite version
|
||||||
MAX_ROWS_TO_PURGE = 998
|
MAX_ROWS_TO_PURGE = 998
|
||||||
|
|
||||||
|
DB_WORKER_PREFIX = "DbWorker"
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
"""Database executor helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from concurrent.futures.thread import _threads_queues, _worker
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_with_shutdown_hook(shutdown_hook, *args, **kwargs):
|
||||||
|
"""Create a worker that calls a function after its finished."""
|
||||||
|
_worker(*args, **kwargs)
|
||||||
|
shutdown_hook()
|
||||||
|
|
||||||
|
|
||||||
|
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
||||||
|
"""A database instance that will not deadlock on shutdown."""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Init the executor with a shutdown hook support."""
|
||||||
|
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _adjust_thread_count(self) -> None:
|
||||||
|
"""Overridden to add support for shutdown hook.
|
||||||
|
|
||||||
|
Based on the CPython 3.10 implementation.
|
||||||
|
"""
|
||||||
|
# if idle threads are available, don't spin new threads
|
||||||
|
if self._idle_semaphore.acquire( # pylint: disable=consider-using-with
|
||||||
|
timeout=0
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# When the executor gets lost, the weakref callback will wake up
|
||||||
|
# the worker threads.
|
||||||
|
def weakref_cb(_, q=self._work_queue): # pylint: disable=invalid-name
|
||||||
|
q.put(None)
|
||||||
|
|
||||||
|
num_threads = len(self._threads)
|
||||||
|
if num_threads < self._max_workers:
|
||||||
|
thread_name = "%s_%d" % (self._thread_name_prefix or self, num_threads)
|
||||||
|
executor_thread = threading.Thread(
|
||||||
|
name=thread_name,
|
||||||
|
target=_worker_with_shutdown_hook,
|
||||||
|
args=(
|
||||||
|
self._shutdown_hook,
|
||||||
|
weakref.ref(self, weakref_cb),
|
||||||
|
self._work_queue,
|
||||||
|
self._initializer,
|
||||||
|
self._initargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
executor_thread.start()
|
||||||
|
self._threads.add(executor_thread) # type: ignore[attr-defined]
|
||||||
|
_threads_queues[executor_thread] = self._work_queue # type: ignore[index]
|
|
@ -1,34 +1,60 @@
|
||||||
"""A pool for sqlite connections."""
|
"""A pool for sqlite connections."""
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from sqlalchemy.pool import NullPool, StaticPool
|
from sqlalchemy.pool import NullPool, SingletonThreadPool
|
||||||
|
|
||||||
|
from homeassistant.helpers.frame import report
|
||||||
|
|
||||||
|
from .const import DB_WORKER_PREFIX
|
||||||
|
|
||||||
|
POOL_SIZE = 5
|
||||||
|
|
||||||
|
|
||||||
class RecorderPool(StaticPool, NullPool):
|
class RecorderPool(SingletonThreadPool, NullPool):
|
||||||
"""A hybrid of NullPool and StaticPool.
|
"""A hybrid of NullPool and SingletonThreadPool.
|
||||||
|
|
||||||
When called from the creating thread acts like StaticPool
|
When called from the creating thread or db executor acts like SingletonThreadPool
|
||||||
When called from any other thread, acts like NullPool
|
When called from any other thread, acts like NullPool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
||||||
"""Create the pool."""
|
"""Create the pool."""
|
||||||
self._tid = threading.current_thread().ident
|
kw["pool_size"] = POOL_SIZE
|
||||||
StaticPool.__init__(self, *args, **kw)
|
SingletonThreadPool.__init__(self, *args, **kw)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recorder_or_dbworker(self) -> bool:
|
||||||
|
"""Check if the thread is a recorder or dbworker thread."""
|
||||||
|
thread_name = threading.current_thread().name
|
||||||
|
return bool(
|
||||||
|
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
||||||
|
)
|
||||||
|
|
||||||
def _do_return_conn(self, conn):
|
def _do_return_conn(self, conn):
|
||||||
if threading.current_thread().ident == self._tid:
|
if self.recorder_or_dbworker:
|
||||||
return super()._do_return_conn(conn)
|
return super()._do_return_conn(conn)
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Close the connection."""
|
||||||
|
if self.recorder_or_dbworker and (conn := self._conn.current()):
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def dispose(self):
|
def dispose(self):
|
||||||
"""Dispose of the connection."""
|
"""Dispose of the connection."""
|
||||||
if threading.current_thread().ident == self._tid:
|
if self.recorder_or_dbworker:
|
||||||
return super().dispose()
|
return super().dispose()
|
||||||
|
|
||||||
def _do_get(self):
|
def _do_get(self):
|
||||||
if threading.current_thread().ident == self._tid:
|
if self.recorder_or_dbworker:
|
||||||
return super()._do_get()
|
return super()._do_get()
|
||||||
|
report(
|
||||||
|
"accesses the database without the database executor; "
|
||||||
|
"Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job() "
|
||||||
|
"for faster database operations",
|
||||||
|
exclude_integrations={"recorder"},
|
||||||
|
error_if_core=False,
|
||||||
|
)
|
||||||
return super( # pylint: disable=bad-super-call
|
return super( # pylint: disable=bad-super-call
|
||||||
NullPool, self
|
NullPool, self
|
||||||
)._create_connection()
|
)._create_connection()
|
||||||
|
|
|
@ -40,7 +40,8 @@ async def ws_validate_statistics(
|
||||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fetch a list of available statistic_id."""
|
"""Fetch a list of available statistic_id."""
|
||||||
statistic_ids = await hass.async_add_executor_job(
|
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||||
|
statistic_ids = await instance.async_add_executor_job(
|
||||||
validate_statistics,
|
validate_statistics,
|
||||||
hass,
|
hass,
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||||
|
@ -22,6 +22,7 @@ from homeassistant.components.recorder import (
|
||||||
SERVICE_PURGE_ENTITIES,
|
SERVICE_PURGE_ENTITIES,
|
||||||
SQLITE_URL_PREFIX,
|
SQLITE_URL_PREFIX,
|
||||||
Recorder,
|
Recorder,
|
||||||
|
get_instance,
|
||||||
run_information,
|
run_information,
|
||||||
run_information_from_instance,
|
run_information_from_instance,
|
||||||
run_information_with_session,
|
run_information_with_session,
|
||||||
|
@ -101,6 +102,30 @@ async def test_shutdown_before_startup_finishes(hass):
|
||||||
assert run_info.end is not None
|
assert run_info.end is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_shutdown_closes_connections(hass):
|
||||||
|
"""Test shutdown closes connections."""
|
||||||
|
|
||||||
|
hass.state = CoreState.not_running
|
||||||
|
|
||||||
|
await async_init_recorder_component(hass)
|
||||||
|
instance = get_instance(hass)
|
||||||
|
await instance.async_db_ready
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
pool = instance.engine.pool
|
||||||
|
pool.shutdown = Mock()
|
||||||
|
|
||||||
|
def _ensure_connected():
|
||||||
|
with session_scope(hass=hass) as session:
|
||||||
|
list(session.query(States))
|
||||||
|
|
||||||
|
await instance.async_add_executor_job(_ensure_connected)
|
||||||
|
|
||||||
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(pool.shutdown.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_state_gets_saved_when_set_before_start_event(
|
async def test_state_gets_saved_when_set_before_start_event(
|
||||||
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT
|
hass: HomeAssistant, async_setup_recorder_instance: SetupRecorderInstanceT
|
||||||
):
|
):
|
||||||
|
|
|
@ -4,15 +4,16 @@ import threading
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from homeassistant.components.recorder.const import DB_WORKER_PREFIX
|
||||||
from homeassistant.components.recorder.pool import RecorderPool
|
from homeassistant.components.recorder.pool import RecorderPool
|
||||||
|
|
||||||
|
|
||||||
def test_recorder_pool():
|
def test_recorder_pool(caplog):
|
||||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||||
|
|
||||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||||
get_session = sessionmaker(bind=engine)
|
get_session = sessionmaker(bind=engine)
|
||||||
|
shutdown = False
|
||||||
connections = []
|
connections = []
|
||||||
|
|
||||||
def _get_connection_twice():
|
def _get_connection_twice():
|
||||||
|
@ -20,15 +21,42 @@ def test_recorder_pool():
|
||||||
connections.append(session.connection().connection.connection)
|
connections.append(session.connection().connection.connection)
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
if shutdown:
|
||||||
|
engine.pool.shutdown()
|
||||||
|
|
||||||
session = get_session()
|
session = get_session()
|
||||||
connections.append(session.connection().connection.connection)
|
connections.append(session.connection().connection.connection)
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
_get_connection_twice()
|
_get_connection_twice()
|
||||||
assert connections[0] == connections[1]
|
assert "accesses the database without the database executor" in caplog.text
|
||||||
|
assert connections[0] != connections[1]
|
||||||
|
|
||||||
|
caplog.clear()
|
||||||
new_thread = threading.Thread(target=_get_connection_twice)
|
new_thread = threading.Thread(target=_get_connection_twice)
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
new_thread.join()
|
new_thread.join()
|
||||||
|
assert "accesses the database without the database executor" in caplog.text
|
||||||
assert connections[2] != connections[3]
|
assert connections[2] != connections[3]
|
||||||
|
|
||||||
|
caplog.clear()
|
||||||
|
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||||
|
new_thread.start()
|
||||||
|
new_thread.join()
|
||||||
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
|
assert connections[4] == connections[5]
|
||||||
|
|
||||||
|
caplog.clear()
|
||||||
|
new_thread = threading.Thread(target=_get_connection_twice, name="Recorder")
|
||||||
|
new_thread.start()
|
||||||
|
new_thread.join()
|
||||||
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
|
assert connections[6] == connections[7]
|
||||||
|
|
||||||
|
shutdown = True
|
||||||
|
caplog.clear()
|
||||||
|
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||||
|
new_thread.start()
|
||||||
|
new_thread.join()
|
||||||
|
assert "accesses the database without the database executor" not in caplog.text
|
||||||
|
assert connections[8] != connections[9]
|
||||||
|
|
|
@ -570,8 +570,17 @@ async def test_write_lock_db(hass, tmp_path):
|
||||||
|
|
||||||
instance = hass.data[DATA_INSTANCE]
|
instance = hass.data[DATA_INSTANCE]
|
||||||
|
|
||||||
|
def _drop_table():
|
||||||
|
with instance.engine.connect() as connection:
|
||||||
|
connection.execute(text("DROP TABLE events;"))
|
||||||
|
|
||||||
with util.write_lock_db_sqlite(instance):
|
with util.write_lock_db_sqlite(instance):
|
||||||
# Database should be locked now, try writing SQL command
|
# Database should be locked now, try writing SQL command
|
||||||
with instance.engine.connect() as connection:
|
with pytest.raises(OperationalError):
|
||||||
with pytest.raises(OperationalError):
|
# This needs to be called in another thread since
|
||||||
connection.execute(text("DROP TABLE events;"))
|
# the lock method is BEGIN IMMEDIATE and since we have
|
||||||
|
# a connection per thread with sqlite now, we cannot do it
|
||||||
|
# in the same thread as the one holding the lock since it
|
||||||
|
# would be allowed to proceed as the goal is to prevent
|
||||||
|
# all the other threads from accessing the database
|
||||||
|
await hass.async_add_executor_job(_drop_table)
|
||||||
|
|
Loading…
Reference in New Issue