core/homeassistant/components/recorder/util.py

151 lines
4.3 KiB
Python

"""SQLAlchemy util functions."""
from contextlib import contextmanager
import logging
import os
import time
from sqlalchemy.exc import OperationalError, SQLAlchemyError
import homeassistant.util.dt as dt_util
from .const import DATA_INSTANCE, SQLITE_URL_PREFIX
_LOGGER = logging.getLogger(__name__)
RETRIES = 3
QUERY_RETRY_WAIT = 0.1
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
@contextmanager
def session_scope(*, hass=None, session=None):
"""Provide a transactional scope around a series of operations."""
if session is None and hass is not None:
session = hass.data[DATA_INSTANCE].get_session()
if session is None:
raise RuntimeError("Session required")
need_rollback = False
try:
yield session
if session.transaction:
need_rollback = True
session.commit()
except Exception as err:
_LOGGER.error("Error executing query: %s", err)
if need_rollback:
session.rollback()
raise
finally:
session.close()
def commit(session, work):
"""Commit & retry work: Either a model or in a function."""
for _ in range(0, RETRIES):
try:
if callable(work):
work(session)
else:
session.add(work)
session.commit()
return True
except OperationalError as err:
_LOGGER.error("Error executing query: %s", err)
session.rollback()
time.sleep(QUERY_RETRY_WAIT)
return False
def execute(qry, to_native=False, validate_entity_ids=True):
"""Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections.
"""
for tryno in range(0, RETRIES):
try:
timer_start = time.perf_counter()
if to_native:
result = [
row
for row in (
row.to_native(validate_entity_id=validate_entity_ids)
for row in qry
)
if row is not None
]
else:
result = list(qry)
if _LOGGER.isEnabledFor(logging.DEBUG):
elapsed = time.perf_counter() - timer_start
if to_native:
_LOGGER.debug(
"converting %d rows to native objects took %fs",
len(result),
elapsed,
)
else:
_LOGGER.debug(
"querying %d rows took %fs", len(result), elapsed,
)
return result
except SQLAlchemyError as err:
_LOGGER.error("Error executing query: %s", err)
if tryno == RETRIES - 1:
raise
time.sleep(QUERY_RETRY_WAIT)
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
"""Ensure that the database is valid or move it away."""
dbpath = dburl[len(SQLITE_URL_PREFIX) :]
if not os.path.exists(dbpath):
# Database does not exist yet, this is OK
return True
if not validate_sqlite_database(dbpath):
_move_away_broken_database(dbpath)
return False
return True
def validate_sqlite_database(dbpath: str) -> bool:
"""Run a quick check on an sqlite database to see if it is corrupt."""
import sqlite3 # pylint: disable=import-outside-toplevel
try:
conn = sqlite3.connect(dbpath)
conn.cursor().execute("PRAGMA QUICK_CHECK")
conn.close()
except sqlite3.DatabaseError:
_LOGGER.exception("The database at %s is corrupt or malformed.", dbpath)
return False
return True
def _move_away_broken_database(dbfile: str) -> None:
"""Move away a broken sqlite3 database."""
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
_LOGGER.error(
"The system will rename the corrupt database file %s to %s in order to allow startup to proceed",
dbfile,
f"{dbfile}{corrupt_postfix}",
)
for postfix in SQLITE3_POSTFIXES:
path = f"{dbfile}{postfix}"
if not os.path.exists(path):
continue
os.rename(path, f"{path}{corrupt_postfix}")