Allow to lock SQLite database during backup (#60874)

* Allow to set CONF_DB_URL

This is useful for test which need a custom DB path.

* Introduce write_lock_db helper to lock SQLite database

* Introduce Websocket API which allows to lock database during backup

* Fix isort

* Avoid mutable default arguments

* Address pylint issues

* Avoid holding executor thread

* Set unlock event in case timeout occures

This makes sure the database is left unlocked even in case of a race
condition.

* Add more unit tests

* Address new pylint errors

* Lower timeout to speedup tests

* Introduce queue overflow test

* Unlock database if necessary

This makes sure that the test runs through in case locking actually
succeeds (and the test fails).

* Make DB_LOCK_TIMEOUT a global

There is no good reason for this to be an argument. The recorder needs
to pick a sensible value.

* Add Websocket Timeout test

* Test lock_database() return

* Update homeassistant/components/recorder/__init__.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Fix format

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
pull/61160/head
Stefan Agner 2021-12-07 13:16:24 +01:00 committed by GitHub
parent 4eeee79517
commit f0006b92be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 310 additions and 11 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Iterable
import concurrent.futures
from dataclasses import dataclass
from datetime import datetime, timedelta
import logging
import queue
@ -76,6 +77,7 @@ from .util import (
session_scope,
setup_connection_for_dialect,
validate_or_move_away_sqlite_database,
write_lock_db,
)
_LOGGER = logging.getLogger(__name__)
@ -123,6 +125,9 @@ KEEPALIVE_TIME = 30
# States and Events objects
EXPIRE_AFTER_COMMITS = 120
DB_LOCK_TIMEOUT = 30
DB_LOCK_QUEUE_CHECK_TIMEOUT = 1
CONF_AUTO_PURGE = "auto_purge"
CONF_DB_URL = "db_url"
CONF_DB_MAX_RETRIES = "db_max_retries"
@ -370,6 +375,15 @@ class WaitTask:
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
@dataclass
class DatabaseLockTask:
"""An object to insert into the recorder queue to prevent writes to the database."""
database_locked: asyncio.Event
database_unlock: threading.Event
queue_overflow: bool
class Recorder(threading.Thread):
"""A threaded recorder class."""
@ -419,6 +433,7 @@ class Recorder(threading.Thread):
self.migration_in_progress = False
self._queue_watcher = None
self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None
self.enabled = True
@ -687,6 +702,8 @@ class Recorder(threading.Thread):
def _process_one_event_or_recover(self, event):
"""Process an event, reconnect, or recover a malformed database."""
try:
if self._process_one_task(event):
return
self._process_one_event(event)
return
except exc.DatabaseError as err:
@ -788,34 +805,63 @@ class Recorder(threading.Thread):
# Schedule a new statistics task if this one didn't finish
self.queue.put(ExternalStatisticsTask(metadata, stats))
def _process_one_event(self, event):
def _lock_database(self, task: DatabaseLockTask):
@callback
def _async_set_database_locked(task: DatabaseLockTask):
task.database_locked.set()
with write_lock_db(self):
# Notify that lock is being held, wait until database can be used again.
self.hass.add_job(_async_set_database_locked, task)
while not task.database_unlock.wait(timeout=DB_LOCK_QUEUE_CHECK_TIMEOUT):
if self.queue.qsize() > MAX_QUEUE_BACKLOG * 0.9:
_LOGGER.warning(
"Database queue backlog reached more than 90% of maximum queue "
"length while waiting for backup to finish; recorder will now "
"resume writing to database. The backup can not be trusted and "
"must be restarted"
)
task.queue_overflow = True
break
_LOGGER.info(
"Database queue backlog reached %d entries during backup",
self.queue.qsize(),
)
def _process_one_task(self, event) -> bool:
"""Process one event."""
if isinstance(event, PurgeTask):
self._run_purge(event.purge_before, event.repack, event.apply_filter)
return
return True
if isinstance(event, PurgeEntitiesTask):
self._run_purge_entities(event.entity_filter)
return
return True
if isinstance(event, PerodicCleanupTask):
perodic_db_cleanups(self)
return
return True
if isinstance(event, StatisticsTask):
self._run_statistics(event.start)
return
return True
if isinstance(event, ClearStatisticsTask):
statistics.clear_statistics(self, event.statistic_ids)
return
return True
if isinstance(event, UpdateStatisticsMetadataTask):
statistics.update_statistics_metadata(
self, event.statistic_id, event.unit_of_measurement
)
return
return True
if isinstance(event, ExternalStatisticsTask):
self._run_external_statistics(event.metadata, event.statistics)
return
return True
if isinstance(event, WaitTask):
self._queue_watch.set()
return
return True
if isinstance(event, DatabaseLockTask):
self._lock_database(event)
return True
return False
def _process_one_event(self, event):
if event.event_type == EVENT_TIME_CHANGED:
self._keepalive_count += 1
if self._keepalive_count >= KEEPALIVE_TIME:
@ -982,6 +1028,42 @@ class Recorder(threading.Thread):
self.queue.put(WaitTask())
self._queue_watch.wait()
async def lock_database(self) -> bool:
"""Lock database so it can be backed up safely."""
if self._database_lock_task:
_LOGGER.warning("Database already locked")
return False
database_locked = asyncio.Event()
task = DatabaseLockTask(database_locked, threading.Event(), False)
self.queue.put(task)
try:
await asyncio.wait_for(database_locked.wait(), timeout=DB_LOCK_TIMEOUT)
except asyncio.TimeoutError as err:
task.database_unlock.set()
raise TimeoutError(
f"Could not lock database within {DB_LOCK_TIMEOUT} seconds."
) from err
self._database_lock_task = task
return True
@callback
def unlock_database(self) -> bool:
"""Unlock database.
Returns true if database lock has been held throughout the process.
"""
if not self._database_lock_task:
_LOGGER.warning("Database currently not locked")
return False
self._database_lock_task.database_unlock.set()
success = not self._database_lock_task.queue_overflow
self._database_lock_task = None
return success
def _setup_connection(self):
"""Ensure database is ready to fly."""
kwargs = {}

View File

@ -457,6 +457,25 @@ def perodic_db_cleanups(instance: Recorder):
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE);"))
@contextmanager
def write_lock_db(instance: Recorder):
"""Lock database for writes."""
if instance.engine.dialect.name == "sqlite":
with instance.engine.connect() as connection:
# Execute sqlite to create a wal checkpoint
# This is optional but makes sure the backup is going to be minimal
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
# Create write lock
_LOGGER.debug("Lock database")
connection.execute(text("BEGIN IMMEDIATE;"))
try:
yield
finally:
_LOGGER.debug("Unlock database")
connection.execute(text("END;"))
def async_migration_in_progress(hass: HomeAssistant) -> bool:
"""Determine is a migration is in progress.

View File

@ -1,6 +1,7 @@
"""The Energy websocket API."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import voluptuous as vol
@ -15,6 +16,8 @@ from .util import async_migration_in_progress
if TYPE_CHECKING:
from . import Recorder
_LOGGER: logging.Logger = logging.getLogger(__package__)
@callback
def async_setup(hass: HomeAssistant) -> None:
@ -23,6 +26,8 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, ws_clear_statistics)
websocket_api.async_register_command(hass, ws_update_statistics_metadata)
websocket_api.async_register_command(hass, ws_info)
websocket_api.async_register_command(hass, ws_backup_start)
websocket_api.async_register_command(hass, ws_backup_end)
@websocket_api.websocket_command(
@ -106,3 +111,38 @@ def ws_info(
"thread_running": thread_alive,
}
connection.send_result(msg["id"], recorder_info)
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "backup/start"})
@websocket_api.async_response
async def ws_backup_start(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Backup start notification."""
_LOGGER.info("Backup start notification, locking database for writes")
instance: Recorder = hass.data[DATA_INSTANCE]
try:
await instance.lock_database()
except TimeoutError as err:
connection.send_error(msg["id"], "timeout_error", str(err))
return
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "backup/end"})
@websocket_api.async_response
async def ws_backup_end(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Backup end notification."""
instance: Recorder = hass.data[DATA_INSTANCE]
_LOGGER.info("Backup end notification, releasing write lock")
if not instance.unlock_database():
connection.send_error(
msg["id"], "database_unlock_failed", "Failed to unlock database."
)
connection.send_result(msg["id"])

View File

@ -902,7 +902,8 @@ def init_recorder_component(hass, add_config=None):
async def async_init_recorder_component(hass, add_config=None):
"""Initialize the recorder asynchronously."""
config = dict(add_config) if add_config else {}
config = add_config or {}
if recorder.CONF_DB_URL not in config:
config[recorder.CONF_DB_URL] = "sqlite://"
with patch("homeassistant.components.recorder.migration.migrate_schema"):

View File

@ -1,5 +1,6 @@
"""The tests for the Recorder component."""
# pylint: disable=protected-access
import asyncio
from datetime import datetime, timedelta
import sqlite3
from unittest.mock import patch
@ -1134,3 +1135,81 @@ def test_entity_id_filter(hass_recorder):
db_events = list(session.query(Events).filter_by(event_type="hello"))
# Keep referring idx + 1, as no new events are being added
assert len(db_events) == idx + 1, data
async def test_database_lock_and_unlock(hass: HomeAssistant, tmp_path):
"""Test writing events during lock getting written after unlocking."""
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
assert await instance.lock_database()
assert not await instance.lock_database()
event_type = "EVENT_TEST"
event_data = {"test_attr": 5, "test_attr_10": "nice"}
hass.bus.fire(event_type, event_data)
task = asyncio.create_task(async_wait_recording_done(hass, instance))
# Recording can't be finished while lock is held
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(task), timeout=1)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 0
assert instance.unlock_database()
await task
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 1
async def test_database_lock_and_overflow(hass: HomeAssistant, tmp_path):
"""Test writing events during lock leading to overflow the queue causes the database to unlock."""
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
with patch.object(recorder, "MAX_QUEUE_BACKLOG", 1), patch.object(
recorder, "DB_LOCK_QUEUE_CHECK_TIMEOUT", 0.1
):
await instance.lock_database()
event_type = "EVENT_TEST"
event_data = {"test_attr": 5, "test_attr_10": "nice"}
hass.bus.fire(event_type, event_data)
# Check that this causes the queue to overflow and write succeeds
# even before unlocking.
await async_wait_recording_done(hass, instance)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 1
assert not instance.unlock_database()
async def test_database_lock_timeout(hass):
"""Test locking database timeout when recorder stopped."""
await async_init_recorder_component(hass)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0.1):
try:
with pytest.raises(TimeoutError):
await instance.lock_database()
finally:
instance.unlock_database()

View File

@ -8,6 +8,7 @@ import pytest
from sqlalchemy import text
from sqlalchemy.sql.elements import TextClause
from homeassistant.components import recorder
from homeassistant.components.recorder import run_information_with_session, util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns
@ -556,3 +557,21 @@ def test_perodic_db_cleanups(hass_recorder):
][0]
assert isinstance(text_obj, TextClause)
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
async def test_write_lock_db(hass, tmp_path):
"""Test database write lock."""
from sqlalchemy.exc import OperationalError
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance = hass.data[DATA_INSTANCE]
with util.write_lock_db(instance):
# Database should be locked now, try writing SQL command
with instance.engine.connect() as connection:
with pytest.raises(OperationalError):
connection.execute(text("DROP TABLE events;"))

View File

@ -358,3 +358,62 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
assert response["result"]["migration_in_progress"] is False
assert response["result"]["recording"] is True
assert response["result"]["thread_running"] is True
async def test_backup_start_no_recorder(hass, hass_ws_client):
"""Test getting backup start when recorder is not present."""
client = await hass_ws_client()
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_command"
async def test_backup_start_timeout(hass, hass_ws_client):
"""Test getting backup start when recorder is not present."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0):
try:
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "timeout_error"
finally:
await client.send_json({"id": 2, "type": "backup/end"})
async def test_backup_end(hass, hass_ws_client):
"""Test backup start."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert response["success"]
await client.send_json({"id": 2, "type": "backup/end"})
response = await client.receive_json()
assert response["success"]
async def test_backup_end_without_start(hass, hass_ws_client):
"""Test backup start."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
await client.send_json({"id": 1, "type": "backup/end"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "database_unlock_failed"