Allow to lock SQLite database during backup (#60874)
* Allow to set CONF_DB_URL This is useful for test which need a custom DB path. * Introduce write_lock_db helper to lock SQLite database * Introduce Websocket API which allows to lock database during backup * Fix isort * Avoid mutable default arguments * Address pylint issues * Avoid holding executor thread * Set unlock event in case timeout occures This makes sure the database is left unlocked even in case of a race condition. * Add more unit tests * Address new pylint errors * Lower timeout to speedup tests * Introduce queue overflow test * Unlock database if necessary This makes sure that the test runs through in case locking actually succeeds (and the test fails). * Make DB_LOCK_TIMEOUT a global There is no good reason for this to be an argument. The recorder needs to pick a sensible value. * Add Websocket Timeout test * Test lock_database() return * Update homeassistant/components/recorder/__init__.py Co-authored-by: Erik Montnemery <erik@montnemery.com> * Fix format Co-authored-by: J. Nick Koston <nick@koston.org> Co-authored-by: Erik Montnemery <erik@montnemery.com>pull/61160/head
parent
4eeee79517
commit
f0006b92be
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections.abc import Callable, Iterable
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import queue
|
||||
|
@ -76,6 +77,7 @@ from .util import (
|
|||
session_scope,
|
||||
setup_connection_for_dialect,
|
||||
validate_or_move_away_sqlite_database,
|
||||
write_lock_db,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -123,6 +125,9 @@ KEEPALIVE_TIME = 30
|
|||
# States and Events objects
|
||||
EXPIRE_AFTER_COMMITS = 120
|
||||
|
||||
DB_LOCK_TIMEOUT = 30
|
||||
DB_LOCK_QUEUE_CHECK_TIMEOUT = 1
|
||||
|
||||
CONF_AUTO_PURGE = "auto_purge"
|
||||
CONF_DB_URL = "db_url"
|
||||
CONF_DB_MAX_RETRIES = "db_max_retries"
|
||||
|
@ -370,6 +375,15 @@ class WaitTask:
|
|||
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseLockTask:
|
||||
"""An object to insert into the recorder queue to prevent writes to the database."""
|
||||
|
||||
database_locked: asyncio.Event
|
||||
database_unlock: threading.Event
|
||||
queue_overflow: bool
|
||||
|
||||
|
||||
class Recorder(threading.Thread):
|
||||
"""A threaded recorder class."""
|
||||
|
||||
|
@ -419,6 +433,7 @@ class Recorder(threading.Thread):
|
|||
self.migration_in_progress = False
|
||||
self._queue_watcher = None
|
||||
self._db_supports_row_number = True
|
||||
self._database_lock_task: DatabaseLockTask | None = None
|
||||
|
||||
self.enabled = True
|
||||
|
||||
|
@ -687,6 +702,8 @@ class Recorder(threading.Thread):
|
|||
def _process_one_event_or_recover(self, event):
|
||||
"""Process an event, reconnect, or recover a malformed database."""
|
||||
try:
|
||||
if self._process_one_task(event):
|
||||
return
|
||||
self._process_one_event(event)
|
||||
return
|
||||
except exc.DatabaseError as err:
|
||||
|
@ -788,34 +805,63 @@ class Recorder(threading.Thread):
|
|||
# Schedule a new statistics task if this one didn't finish
|
||||
self.queue.put(ExternalStatisticsTask(metadata, stats))
|
||||
|
||||
def _process_one_event(self, event):
|
||||
def _lock_database(self, task: DatabaseLockTask):
|
||||
@callback
|
||||
def _async_set_database_locked(task: DatabaseLockTask):
|
||||
task.database_locked.set()
|
||||
|
||||
with write_lock_db(self):
|
||||
# Notify that lock is being held, wait until database can be used again.
|
||||
self.hass.add_job(_async_set_database_locked, task)
|
||||
while not task.database_unlock.wait(timeout=DB_LOCK_QUEUE_CHECK_TIMEOUT):
|
||||
if self.queue.qsize() > MAX_QUEUE_BACKLOG * 0.9:
|
||||
_LOGGER.warning(
|
||||
"Database queue backlog reached more than 90% of maximum queue "
|
||||
"length while waiting for backup to finish; recorder will now "
|
||||
"resume writing to database. The backup can not be trusted and "
|
||||
"must be restarted"
|
||||
)
|
||||
task.queue_overflow = True
|
||||
break
|
||||
_LOGGER.info(
|
||||
"Database queue backlog reached %d entries during backup",
|
||||
self.queue.qsize(),
|
||||
)
|
||||
|
||||
def _process_one_task(self, event) -> bool:
|
||||
"""Process one event."""
|
||||
if isinstance(event, PurgeTask):
|
||||
self._run_purge(event.purge_before, event.repack, event.apply_filter)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, PurgeEntitiesTask):
|
||||
self._run_purge_entities(event.entity_filter)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, PerodicCleanupTask):
|
||||
perodic_db_cleanups(self)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, StatisticsTask):
|
||||
self._run_statistics(event.start)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, ClearStatisticsTask):
|
||||
statistics.clear_statistics(self, event.statistic_ids)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, UpdateStatisticsMetadataTask):
|
||||
statistics.update_statistics_metadata(
|
||||
self, event.statistic_id, event.unit_of_measurement
|
||||
)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, ExternalStatisticsTask):
|
||||
self._run_external_statistics(event.metadata, event.statistics)
|
||||
return
|
||||
return True
|
||||
if isinstance(event, WaitTask):
|
||||
self._queue_watch.set()
|
||||
return
|
||||
return True
|
||||
if isinstance(event, DatabaseLockTask):
|
||||
self._lock_database(event)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_one_event(self, event):
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
self._keepalive_count += 1
|
||||
if self._keepalive_count >= KEEPALIVE_TIME:
|
||||
|
@ -982,6 +1028,42 @@ class Recorder(threading.Thread):
|
|||
self.queue.put(WaitTask())
|
||||
self._queue_watch.wait()
|
||||
|
||||
async def lock_database(self) -> bool:
|
||||
"""Lock database so it can be backed up safely."""
|
||||
if self._database_lock_task:
|
||||
_LOGGER.warning("Database already locked")
|
||||
return False
|
||||
|
||||
database_locked = asyncio.Event()
|
||||
task = DatabaseLockTask(database_locked, threading.Event(), False)
|
||||
self.queue.put(task)
|
||||
try:
|
||||
await asyncio.wait_for(database_locked.wait(), timeout=DB_LOCK_TIMEOUT)
|
||||
except asyncio.TimeoutError as err:
|
||||
task.database_unlock.set()
|
||||
raise TimeoutError(
|
||||
f"Could not lock database within {DB_LOCK_TIMEOUT} seconds."
|
||||
) from err
|
||||
self._database_lock_task = task
|
||||
return True
|
||||
|
||||
@callback
|
||||
def unlock_database(self) -> bool:
|
||||
"""Unlock database.
|
||||
|
||||
Returns true if database lock has been held throughout the process.
|
||||
"""
|
||||
if not self._database_lock_task:
|
||||
_LOGGER.warning("Database currently not locked")
|
||||
return False
|
||||
|
||||
self._database_lock_task.database_unlock.set()
|
||||
success = not self._database_lock_task.queue_overflow
|
||||
|
||||
self._database_lock_task = None
|
||||
|
||||
return success
|
||||
|
||||
def _setup_connection(self):
|
||||
"""Ensure database is ready to fly."""
|
||||
kwargs = {}
|
||||
|
|
|
@ -457,6 +457,25 @@ def perodic_db_cleanups(instance: Recorder):
|
|||
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE);"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def write_lock_db(instance: Recorder):
|
||||
"""Lock database for writes."""
|
||||
|
||||
if instance.engine.dialect.name == "sqlite":
|
||||
with instance.engine.connect() as connection:
|
||||
# Execute sqlite to create a wal checkpoint
|
||||
# This is optional but makes sure the backup is going to be minimal
|
||||
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
# Create write lock
|
||||
_LOGGER.debug("Lock database")
|
||||
connection.execute(text("BEGIN IMMEDIATE;"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_LOGGER.debug("Unlock database")
|
||||
connection.execute(text("END;"))
|
||||
|
||||
|
||||
def async_migration_in_progress(hass: HomeAssistant) -> bool:
|
||||
"""Determine is a migration is in progress.
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""The Energy websocket API."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -15,6 +16,8 @@ from .util import async_migration_in_progress
|
|||
if TYPE_CHECKING:
|
||||
from . import Recorder
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__package__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
|
@ -23,6 +26,8 @@ def async_setup(hass: HomeAssistant) -> None:
|
|||
websocket_api.async_register_command(hass, ws_clear_statistics)
|
||||
websocket_api.async_register_command(hass, ws_update_statistics_metadata)
|
||||
websocket_api.async_register_command(hass, ws_info)
|
||||
websocket_api.async_register_command(hass, ws_backup_start)
|
||||
websocket_api.async_register_command(hass, ws_backup_end)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
|
@ -106,3 +111,38 @@ def ws_info(
|
|||
"thread_running": thread_alive,
|
||||
}
|
||||
connection.send_result(msg["id"], recorder_info)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command({vol.Required("type"): "backup/start"})
|
||||
@websocket_api.async_response
|
||||
async def ws_backup_start(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Backup start notification."""
|
||||
|
||||
_LOGGER.info("Backup start notification, locking database for writes")
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
try:
|
||||
await instance.lock_database()
|
||||
except TimeoutError as err:
|
||||
connection.send_error(msg["id"], "timeout_error", str(err))
|
||||
return
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command({vol.Required("type"): "backup/end"})
|
||||
@websocket_api.async_response
|
||||
async def ws_backup_end(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Backup end notification."""
|
||||
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
_LOGGER.info("Backup end notification, releasing write lock")
|
||||
if not instance.unlock_database():
|
||||
connection.send_error(
|
||||
msg["id"], "database_unlock_failed", "Failed to unlock database."
|
||||
)
|
||||
connection.send_result(msg["id"])
|
||||
|
|
|
@ -902,8 +902,9 @@ def init_recorder_component(hass, add_config=None):
|
|||
|
||||
async def async_init_recorder_component(hass, add_config=None):
|
||||
"""Initialize the recorder asynchronously."""
|
||||
config = dict(add_config) if add_config else {}
|
||||
config[recorder.CONF_DB_URL] = "sqlite://"
|
||||
config = add_config or {}
|
||||
if recorder.CONF_DB_URL not in config:
|
||||
config[recorder.CONF_DB_URL] = "sqlite://"
|
||||
|
||||
with patch("homeassistant.components.recorder.migration.migrate_schema"):
|
||||
assert await async_setup_component(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""The tests for the Recorder component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import sqlite3
|
||||
from unittest.mock import patch
|
||||
|
@ -1134,3 +1135,81 @@ def test_entity_id_filter(hass_recorder):
|
|||
db_events = list(session.query(Events).filter_by(event_type="hello"))
|
||||
# Keep referring idx + 1, as no new events are being added
|
||||
assert len(db_events) == idx + 1, data
|
||||
|
||||
|
||||
async def test_database_lock_and_unlock(hass: HomeAssistant, tmp_path):
|
||||
"""Test writing events during lock getting written after unlocking."""
|
||||
# Use file DB, in memory DB cannot do write locks.
|
||||
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
|
||||
await async_init_recorder_component(hass, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
|
||||
assert await instance.lock_database()
|
||||
|
||||
assert not await instance.lock_database()
|
||||
|
||||
event_type = "EVENT_TEST"
|
||||
event_data = {"test_attr": 5, "test_attr_10": "nice"}
|
||||
hass.bus.fire(event_type, event_data)
|
||||
task = asyncio.create_task(async_wait_recording_done(hass, instance))
|
||||
|
||||
# Recording can't be finished while lock is held
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=1)
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
||||
assert len(db_events) == 0
|
||||
|
||||
assert instance.unlock_database()
|
||||
|
||||
await task
|
||||
with session_scope(hass=hass) as session:
|
||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
||||
assert len(db_events) == 1
|
||||
|
||||
|
||||
async def test_database_lock_and_overflow(hass: HomeAssistant, tmp_path):
|
||||
"""Test writing events during lock leading to overflow the queue causes the database to unlock."""
|
||||
# Use file DB, in memory DB cannot do write locks.
|
||||
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
|
||||
await async_init_recorder_component(hass, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
|
||||
with patch.object(recorder, "MAX_QUEUE_BACKLOG", 1), patch.object(
|
||||
recorder, "DB_LOCK_QUEUE_CHECK_TIMEOUT", 0.1
|
||||
):
|
||||
await instance.lock_database()
|
||||
|
||||
event_type = "EVENT_TEST"
|
||||
event_data = {"test_attr": 5, "test_attr_10": "nice"}
|
||||
hass.bus.fire(event_type, event_data)
|
||||
|
||||
# Check that this causes the queue to overflow and write succeeds
|
||||
# even before unlocking.
|
||||
await async_wait_recording_done(hass, instance)
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
db_events = list(session.query(Events).filter_by(event_type=event_type))
|
||||
assert len(db_events) == 1
|
||||
|
||||
assert not instance.unlock_database()
|
||||
|
||||
|
||||
async def test_database_lock_timeout(hass):
|
||||
"""Test locking database timeout when recorder stopped."""
|
||||
await async_init_recorder_component(hass)
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance: Recorder = hass.data[DATA_INSTANCE]
|
||||
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0.1):
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
await instance.lock_database()
|
||||
finally:
|
||||
instance.unlock_database()
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from sqlalchemy import text
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.recorder import run_information_with_session, util
|
||||
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
|
||||
from homeassistant.components.recorder.models import RecorderRuns
|
||||
|
@ -556,3 +557,21 @@ def test_perodic_db_cleanups(hass_recorder):
|
|||
][0]
|
||||
assert isinstance(text_obj, TextClause)
|
||||
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
|
||||
|
||||
|
||||
async def test_write_lock_db(hass, tmp_path):
|
||||
"""Test database write lock."""
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
# Use file DB, in memory DB cannot do write locks.
|
||||
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
|
||||
await async_init_recorder_component(hass, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance = hass.data[DATA_INSTANCE]
|
||||
|
||||
with util.write_lock_db(instance):
|
||||
# Database should be locked now, try writing SQL command
|
||||
with instance.engine.connect() as connection:
|
||||
with pytest.raises(OperationalError):
|
||||
connection.execute(text("DROP TABLE events;"))
|
||||
|
|
|
@ -358,3 +358,62 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
|
|||
assert response["result"]["migration_in_progress"] is False
|
||||
assert response["result"]["recording"] is True
|
||||
assert response["result"]["thread_running"] is True
|
||||
|
||||
|
||||
async def test_backup_start_no_recorder(hass, hass_ws_client):
|
||||
"""Test getting backup start when recorder is not present."""
|
||||
client = await hass_ws_client()
|
||||
|
||||
await client.send_json({"id": 1, "type": "backup/start"})
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "unknown_command"
|
||||
|
||||
|
||||
async def test_backup_start_timeout(hass, hass_ws_client):
|
||||
"""Test getting backup start when recorder is not present."""
|
||||
client = await hass_ws_client()
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
|
||||
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0):
|
||||
try:
|
||||
await client.send_json({"id": 1, "type": "backup/start"})
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "timeout_error"
|
||||
finally:
|
||||
await client.send_json({"id": 2, "type": "backup/end"})
|
||||
|
||||
|
||||
async def test_backup_end(hass, hass_ws_client):
|
||||
"""Test backup start."""
|
||||
client = await hass_ws_client()
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
|
||||
await client.send_json({"id": 1, "type": "backup/start"})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
await client.send_json({"id": 2, "type": "backup/end"})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
|
||||
async def test_backup_end_without_start(hass, hass_ws_client):
|
||||
"""Test backup start."""
|
||||
client = await hass_ws_client()
|
||||
await async_init_recorder_component(hass)
|
||||
|
||||
# Ensure there are no queued events
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
|
||||
await client.send_json({"id": 1, "type": "backup/end"})
|
||||
response = await client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"]["code"] == "database_unlock_failed"
|
||||
|
|
Loading…
Reference in New Issue