233 lines
6.9 KiB
Python
233 lines
6.9 KiB
Python
"""SQLAlchemy util functions."""
|
|
from contextlib import contextmanager
|
|
from datetime import timedelta
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
|
|
|
import homeassistant.util.dt as dt_util
|
|
|
|
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, SQLITE_URL_PREFIX
|
|
from .models import ALL_TABLES, process_timestamp
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
RETRIES = 3
|
|
QUERY_RETRY_WAIT = 0.1
|
|
SQLITE3_POSTFIXES = ["", "-wal", "-shm"]
|
|
|
|
# This is the maximum time after the recorder ends the session
|
|
# before we no longer consider startup to be a "restart" and we
|
|
# should do a check on the sqlite3 database.
|
|
MAX_RESTART_TIME = timedelta(minutes=10)
|
|
|
|
|
|
@contextmanager
|
|
def session_scope(*, hass=None, session=None):
|
|
"""Provide a transactional scope around a series of operations."""
|
|
if session is None and hass is not None:
|
|
session = hass.data[DATA_INSTANCE].get_session()
|
|
|
|
if session is None:
|
|
raise RuntimeError("Session required")
|
|
|
|
need_rollback = False
|
|
try:
|
|
yield session
|
|
if session.transaction:
|
|
need_rollback = True
|
|
session.commit()
|
|
except Exception as err:
|
|
_LOGGER.error("Error executing query: %s", err)
|
|
if need_rollback:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def commit(session, work):
|
|
"""Commit & retry work: Either a model or in a function."""
|
|
for _ in range(0, RETRIES):
|
|
try:
|
|
if callable(work):
|
|
work(session)
|
|
else:
|
|
session.add(work)
|
|
session.commit()
|
|
return True
|
|
except OperationalError as err:
|
|
_LOGGER.error("Error executing query: %s", err)
|
|
session.rollback()
|
|
time.sleep(QUERY_RETRY_WAIT)
|
|
return False
|
|
|
|
|
|
def execute(qry, to_native=False, validate_entity_ids=True):
|
|
"""Query the database and convert the objects to HA native form.
|
|
|
|
This method also retries a few times in the case of stale connections.
|
|
"""
|
|
|
|
for tryno in range(0, RETRIES):
|
|
try:
|
|
timer_start = time.perf_counter()
|
|
if to_native:
|
|
result = [
|
|
row
|
|
for row in (
|
|
row.to_native(validate_entity_id=validate_entity_ids)
|
|
for row in qry
|
|
)
|
|
if row is not None
|
|
]
|
|
else:
|
|
result = list(qry)
|
|
|
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
|
elapsed = time.perf_counter() - timer_start
|
|
if to_native:
|
|
_LOGGER.debug(
|
|
"converting %d rows to native objects took %fs",
|
|
len(result),
|
|
elapsed,
|
|
)
|
|
else:
|
|
_LOGGER.debug(
|
|
"querying %d rows took %fs",
|
|
len(result),
|
|
elapsed,
|
|
)
|
|
|
|
return result
|
|
except SQLAlchemyError as err:
|
|
_LOGGER.error("Error executing query: %s", err)
|
|
|
|
if tryno == RETRIES - 1:
|
|
raise
|
|
time.sleep(QUERY_RETRY_WAIT)
|
|
|
|
|
|
def validate_or_move_away_sqlite_database(dburl: str, db_integrity_check: bool) -> bool:
|
|
"""Ensure that the database is valid or move it away."""
|
|
dbpath = dburl_to_path(dburl)
|
|
|
|
if not os.path.exists(dbpath):
|
|
# Database does not exist yet, this is OK
|
|
return True
|
|
|
|
if not validate_sqlite_database(dbpath, db_integrity_check):
|
|
move_away_broken_database(dbpath)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def dburl_to_path(dburl):
|
|
"""Convert the db url into a filesystem path."""
|
|
return dburl[len(SQLITE_URL_PREFIX) :]
|
|
|
|
|
|
def last_run_was_recently_clean(cursor):
|
|
"""Verify the last recorder run was recently clean."""
|
|
|
|
cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;")
|
|
end_time = cursor.fetchone()
|
|
|
|
if not end_time or not end_time[0]:
|
|
return False
|
|
|
|
last_run_end_time = process_timestamp(dt_util.parse_datetime(end_time[0]))
|
|
now = dt_util.utcnow()
|
|
|
|
_LOGGER.debug("The last run ended at: %s (now: %s)", last_run_end_time, now)
|
|
|
|
if last_run_end_time + MAX_RESTART_TIME < now:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def basic_sanity_check(cursor):
|
|
"""Check tables to make sure select does not fail."""
|
|
|
|
for table in ALL_TABLES:
|
|
cursor.execute(f"SELECT * FROM {table} LIMIT 1;") # nosec # not injection
|
|
|
|
return True
|
|
|
|
|
|
def validate_sqlite_database(dbpath: str, db_integrity_check: bool) -> bool:
|
|
"""Run a quick check on an sqlite database to see if it is corrupt."""
|
|
import sqlite3 # pylint: disable=import-outside-toplevel
|
|
|
|
try:
|
|
conn = sqlite3.connect(dbpath)
|
|
run_checks_on_open_db(dbpath, conn.cursor(), db_integrity_check)
|
|
conn.close()
|
|
except sqlite3.DatabaseError:
|
|
_LOGGER.exception("The database at %s is corrupt or malformed.", dbpath)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def run_checks_on_open_db(dbpath, cursor, db_integrity_check):
|
|
"""Run checks that will generate a sqlite3 exception if there is corruption."""
|
|
sanity_check_passed = basic_sanity_check(cursor)
|
|
last_run_was_clean = last_run_was_recently_clean(cursor)
|
|
|
|
if sanity_check_passed and last_run_was_clean:
|
|
_LOGGER.debug(
|
|
"The quick_check will be skipped as the system was restarted cleanly and passed the basic sanity check"
|
|
)
|
|
return
|
|
|
|
if not db_integrity_check:
|
|
# Always warn so when it does fail they remember it has
|
|
# been manually disabled
|
|
_LOGGER.warning(
|
|
"The quick_check on the sqlite3 database at %s was skipped because %s was disabled",
|
|
dbpath,
|
|
CONF_DB_INTEGRITY_CHECK,
|
|
)
|
|
return
|
|
|
|
if not sanity_check_passed:
|
|
_LOGGER.warning(
|
|
"The database sanity check failed to validate the sqlite3 database at %s",
|
|
dbpath,
|
|
)
|
|
|
|
if not last_run_was_clean:
|
|
_LOGGER.warning(
|
|
"The system could not validate that the sqlite3 database at %s was shutdown cleanly.",
|
|
dbpath,
|
|
)
|
|
|
|
_LOGGER.info(
|
|
"A quick_check is being performed on the sqlite3 database at %s", dbpath
|
|
)
|
|
cursor.execute("PRAGMA QUICK_CHECK")
|
|
|
|
|
|
def move_away_broken_database(dbfile: str) -> None:
|
|
"""Move away a broken sqlite3 database."""
|
|
|
|
isotime = dt_util.utcnow().isoformat()
|
|
corrupt_postfix = f".corrupt.{isotime}"
|
|
|
|
_LOGGER.error(
|
|
"The system will rename the corrupt database file %s to %s in order to allow startup to proceed",
|
|
dbfile,
|
|
f"{dbfile}{corrupt_postfix}",
|
|
)
|
|
|
|
for postfix in SQLITE3_POSTFIXES:
|
|
path = f"{dbfile}{postfix}"
|
|
if not os.path.exists(path):
|
|
continue
|
|
os.rename(path, f"{path}{corrupt_postfix}")
|