"""A pool for sqlite connections.""" import logging import threading import traceback from typing import Any from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.pool import ( ConnectionPoolEntry, NullPool, SingletonThreadPool, StaticPool, ) from homeassistant.helpers.frame import report from homeassistant.util.async_ import check_loop from .const import DB_WORKER_PREFIX _LOGGER = logging.getLogger(__name__) # For debugging the MutexPool DEBUG_MUTEX_POOL = True DEBUG_MUTEX_POOL_TRACE = False POOL_SIZE = 5 ADVISE_MSG = ( "Use homeassistant.components.recorder.get_instance(hass).async_add_executor_job()" ) class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] """A hybrid of NullPool and SingletonThreadPool. When called from the creating thread or db executor acts like SingletonThreadPool When called from any other thread, acts like NullPool """ def __init__( # pylint: disable=super-init-not-called self, *args: Any, **kw: Any ) -> None: """Create the pool.""" kw["pool_size"] = POOL_SIZE SingletonThreadPool.__init__(self, *args, **kw) @property def recorder_or_dbworker(self) -> bool: """Check if the thread is a recorder or dbworker thread.""" thread_name = threading.current_thread().name return bool( thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) ) def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if self.recorder_or_dbworker: return super()._do_return_conn(record) record.close() def shutdown(self) -> None: """Close the connection.""" if ( self.recorder_or_dbworker and self._conn and hasattr(self._conn, "current") and (conn := self._conn.current()) ): conn.close() def dispose(self) -> None: """Dispose of the connection.""" if self.recorder_or_dbworker: super().dispose() def _do_get(self) -> ConnectionPoolEntry: if self.recorder_or_dbworker: return super()._do_get() check_loop( self._do_get_db_connection_protected, strict=True, advise_msg=ADVISE_MSG, ) return self._do_get_db_connection_protected() def _do_get_db_connection_protected(self) -> ConnectionPoolEntry: report( ( "accesses the database without the database executor; " f"{ADVISE_MSG} " "for faster database operations" ), exclude_integrations={"recorder"}, error_if_core=False, ) return NullPool._create_connection(self) class MutexPool(StaticPool): """A pool which prevents concurrent accesses from multiple threads. This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite databases. """ _reference_counter = 0 pool_lock: threading.RLock def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if DEBUG_MUTEX_POOL_TRACE: trace = traceback.extract_stack() trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) else: trace_msg = "" super()._do_return_conn(record) if DEBUG_MUTEX_POOL: self._reference_counter -= 1 _LOGGER.debug( "%s return conn ctr: %s%s", threading.current_thread().name, self._reference_counter, trace_msg, ) MutexPool.pool_lock.release() def _do_get(self) -> ConnectionPoolEntry: if DEBUG_MUTEX_POOL_TRACE: trace = traceback.extract_stack() trace_msg = "".join(traceback.format_list(trace[:-1])) else: trace_msg = "" if DEBUG_MUTEX_POOL: _LOGGER.debug("%s wait conn%s", threading.current_thread().name, trace_msg) # pylint: disable-next=consider-using-with got_lock = MutexPool.pool_lock.acquire(timeout=10) if not got_lock: raise SQLAlchemyError conn = super()._do_get() if DEBUG_MUTEX_POOL: self._reference_counter += 1 _LOGGER.debug( "%s get conn: ctr: %s", threading.current_thread().name, self._reference_counter, ) return conn