"""SQLAlchemy util functions.""" from contextlib import contextmanager import logging import time from sqlalchemy.exc import OperationalError, SQLAlchemyError from .const import DATA_INSTANCE _LOGGER = logging.getLogger(__name__) RETRIES = 3 QUERY_RETRY_WAIT = 0.1 @contextmanager def session_scope(*, hass=None, session=None): """Provide a transactional scope around a series of operations.""" if session is None and hass is not None: session = hass.data[DATA_INSTANCE].get_session() if session is None: raise RuntimeError("Session required") need_rollback = False try: yield session if session.transaction: need_rollback = True session.commit() except Exception as err: _LOGGER.error("Error executing query: %s", err) if need_rollback: session.rollback() raise finally: session.close() def commit(session, work): """Commit & retry work: Either a model or in a function.""" for _ in range(0, RETRIES): try: if callable(work): work(session) else: session.add(work) session.commit() return True except OperationalError as err: _LOGGER.error("Error executing query: %s", err) session.rollback() time.sleep(QUERY_RETRY_WAIT) return False def execute(qry, to_native=True): """Query the database and convert the objects to HA native form. This method also retries a few times in the case of stale connections. """ for tryno in range(0, RETRIES): try: timer_start = time.perf_counter() if to_native: result = [ row for row in (row.to_native() for row in qry) if row is not None ] else: result = list(qry) if _LOGGER.isEnabledFor(logging.DEBUG): elapsed = time.perf_counter() - timer_start if to_native: _LOGGER.debug( "converting %d rows to native objects took %fs", len(result), elapsed, ) else: _LOGGER.debug( "querying %d rows took %fs", len(result), elapsed, ) return result except SQLAlchemyError as err: _LOGGER.error("Error executing query: %s", err) if tryno == RETRIES - 1: raise time.sleep(QUERY_RETRY_WAIT)