diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index da3955cb9b8..8a907a8d9fa 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Iterable import concurrent.futures +from dataclasses import dataclass from datetime import datetime, timedelta import logging import queue @@ -76,6 +77,7 @@ from .util import ( session_scope, setup_connection_for_dialect, validate_or_move_away_sqlite_database, + write_lock_db, ) _LOGGER = logging.getLogger(__name__) @@ -123,6 +125,9 @@ KEEPALIVE_TIME = 30 # States and Events objects EXPIRE_AFTER_COMMITS = 120 +DB_LOCK_TIMEOUT = 30 +DB_LOCK_QUEUE_CHECK_TIMEOUT = 1 + CONF_AUTO_PURGE = "auto_purge" CONF_DB_URL = "db_url" CONF_DB_MAX_RETRIES = "db_max_retries" @@ -370,6 +375,15 @@ class WaitTask: """An object to insert into the recorder queue to tell it set the _queue_watch event.""" +@dataclass +class DatabaseLockTask: + """An object to insert into the recorder queue to prevent writes to the database.""" + + database_locked: asyncio.Event + database_unlock: threading.Event + queue_overflow: bool + + class Recorder(threading.Thread): """A threaded recorder class.""" @@ -419,6 +433,7 @@ class Recorder(threading.Thread): self.migration_in_progress = False self._queue_watcher = None self._db_supports_row_number = True + self._database_lock_task: DatabaseLockTask | None = None self.enabled = True @@ -687,6 +702,8 @@ class Recorder(threading.Thread): def _process_one_event_or_recover(self, event): """Process an event, reconnect, or recover a malformed database.""" try: + if self._process_one_task(event): + return self._process_one_event(event) return except exc.DatabaseError as err: @@ -788,34 +805,63 @@ class Recorder(threading.Thread): # Schedule a new statistics task if this one didn't finish self.queue.put(ExternalStatisticsTask(metadata, stats)) - def _process_one_event(self, event): + def _lock_database(self, task: DatabaseLockTask): + @callback + def _async_set_database_locked(task: DatabaseLockTask): + task.database_locked.set() + + with write_lock_db(self): + # Notify that lock is being held, wait until database can be used again. + self.hass.add_job(_async_set_database_locked, task) + while not task.database_unlock.wait(timeout=DB_LOCK_QUEUE_CHECK_TIMEOUT): + if self.queue.qsize() > MAX_QUEUE_BACKLOG * 0.9: + _LOGGER.warning( + "Database queue backlog reached more than 90% of maximum queue " + "length while waiting for backup to finish; recorder will now " + "resume writing to database. The backup can not be trusted and " + "must be restarted" + ) + task.queue_overflow = True + break + _LOGGER.info( + "Database queue backlog reached %d entries during backup", + self.queue.qsize(), + ) + + def _process_one_task(self, event) -> bool: """Process one event.""" if isinstance(event, PurgeTask): self._run_purge(event.purge_before, event.repack, event.apply_filter) - return + return True if isinstance(event, PurgeEntitiesTask): self._run_purge_entities(event.entity_filter) - return + return True if isinstance(event, PerodicCleanupTask): perodic_db_cleanups(self) - return + return True if isinstance(event, StatisticsTask): self._run_statistics(event.start) - return + return True if isinstance(event, ClearStatisticsTask): statistics.clear_statistics(self, event.statistic_ids) - return + return True if isinstance(event, UpdateStatisticsMetadataTask): statistics.update_statistics_metadata( self, event.statistic_id, event.unit_of_measurement ) - return + return True if isinstance(event, ExternalStatisticsTask): self._run_external_statistics(event.metadata, event.statistics) - return + return True if isinstance(event, WaitTask): self._queue_watch.set() - return + return True + if isinstance(event, DatabaseLockTask): + self._lock_database(event) + return True + return False + + def _process_one_event(self, event): if event.event_type == EVENT_TIME_CHANGED: self._keepalive_count += 1 if self._keepalive_count >= KEEPALIVE_TIME: @@ -982,6 +1028,42 @@ class Recorder(threading.Thread): self.queue.put(WaitTask()) self._queue_watch.wait() + async def lock_database(self) -> bool: + """Lock database so it can be backed up safely.""" + if self._database_lock_task: + _LOGGER.warning("Database already locked") + return False + + database_locked = asyncio.Event() + task = DatabaseLockTask(database_locked, threading.Event(), False) + self.queue.put(task) + try: + await asyncio.wait_for(database_locked.wait(), timeout=DB_LOCK_TIMEOUT) + except asyncio.TimeoutError as err: + task.database_unlock.set() + raise TimeoutError( + f"Could not lock database within {DB_LOCK_TIMEOUT} seconds." + ) from err + self._database_lock_task = task + return True + + @callback + def unlock_database(self) -> bool: + """Unlock database. + + Returns true if database lock has been held throughout the process. + """ + if not self._database_lock_task: + _LOGGER.warning("Database currently not locked") + return False + + self._database_lock_task.database_unlock.set() + success = not self._database_lock_task.queue_overflow + + self._database_lock_task = None + + return success + def _setup_connection(self): """Ensure database is ready to fly.""" kwargs = {} diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index c63f6abee3a..3900641db63 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -457,6 +457,25 @@ def perodic_db_cleanups(instance: Recorder): connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE);")) +@contextmanager +def write_lock_db(instance: Recorder): + """Lock database for writes.""" + + if instance.engine.dialect.name == "sqlite": + with instance.engine.connect() as connection: + # Execute sqlite to create a wal checkpoint + # This is optional but makes sure the backup is going to be minimal + connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)")) + # Create write lock + _LOGGER.debug("Lock database") + connection.execute(text("BEGIN IMMEDIATE;")) + try: + yield + finally: + _LOGGER.debug("Unlock database") + connection.execute(text("END;")) + + def async_migration_in_progress(hass: HomeAssistant) -> bool: """Determine is a migration is in progress. diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 5a4f0425919..f6d4d57a7e5 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -1,6 +1,7 @@ """The Energy websocket API.""" from __future__ import annotations +import logging from typing import TYPE_CHECKING import voluptuous as vol @@ -15,6 +16,8 @@ from .util import async_migration_in_progress if TYPE_CHECKING: from . import Recorder +_LOGGER: logging.Logger = logging.getLogger(__package__) + @callback def async_setup(hass: HomeAssistant) -> None: @@ -23,6 +26,8 @@ def async_setup(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, ws_clear_statistics) websocket_api.async_register_command(hass, ws_update_statistics_metadata) websocket_api.async_register_command(hass, ws_info) + websocket_api.async_register_command(hass, ws_backup_start) + websocket_api.async_register_command(hass, ws_backup_end) @websocket_api.websocket_command( @@ -106,3 +111,38 @@ def ws_info( "thread_running": thread_alive, } connection.send_result(msg["id"], recorder_info) + + +@websocket_api.require_admin +@websocket_api.websocket_command({vol.Required("type"): "backup/start"}) +@websocket_api.async_response +async def ws_backup_start( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Backup start notification.""" + + _LOGGER.info("Backup start notification, locking database for writes") + instance: Recorder = hass.data[DATA_INSTANCE] + try: + await instance.lock_database() + except TimeoutError as err: + connection.send_error(msg["id"], "timeout_error", str(err)) + return + connection.send_result(msg["id"]) + + +@websocket_api.require_admin +@websocket_api.websocket_command({vol.Required("type"): "backup/end"}) +@websocket_api.async_response +async def ws_backup_end( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict +) -> None: + """Backup end notification.""" + + instance: Recorder = hass.data[DATA_INSTANCE] + _LOGGER.info("Backup end notification, releasing write lock") + if not instance.unlock_database(): + connection.send_error( + msg["id"], "database_unlock_failed", "Failed to unlock database." + ) + connection.send_result(msg["id"]) diff --git a/tests/common.py b/tests/common.py index 55c76e953cd..9d4a9cfe366 100644 --- a/tests/common.py +++ b/tests/common.py @@ -902,8 +902,9 @@ def init_recorder_component(hass, add_config=None): async def async_init_recorder_component(hass, add_config=None): """Initialize the recorder asynchronously.""" - config = dict(add_config) if add_config else {} - config[recorder.CONF_DB_URL] = "sqlite://" + config = add_config or {} + if recorder.CONF_DB_URL not in config: + config[recorder.CONF_DB_URL] = "sqlite://" with patch("homeassistant.components.recorder.migration.migrate_schema"): assert await async_setup_component( diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index e41a0da34ba..7d7c3f27fb6 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -1,5 +1,6 @@ """The tests for the Recorder component.""" # pylint: disable=protected-access +import asyncio from datetime import datetime, timedelta import sqlite3 from unittest.mock import patch @@ -1134,3 +1135,81 @@ def test_entity_id_filter(hass_recorder): db_events = list(session.query(Events).filter_by(event_type="hello")) # Keep referring idx + 1, as no new events are being added assert len(db_events) == idx + 1, data + + +async def test_database_lock_and_unlock(hass: HomeAssistant, tmp_path): + """Test writing events during lock getting written after unlocking.""" + # Use file DB, in memory DB cannot do write locks. + config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")} + await async_init_recorder_component(hass, config) + await hass.async_block_till_done() + + instance: Recorder = hass.data[DATA_INSTANCE] + + assert await instance.lock_database() + + assert not await instance.lock_database() + + event_type = "EVENT_TEST" + event_data = {"test_attr": 5, "test_attr_10": "nice"} + hass.bus.fire(event_type, event_data) + task = asyncio.create_task(async_wait_recording_done(hass, instance)) + + # Recording can't be finished while lock is held + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(task), timeout=1) + + with session_scope(hass=hass) as session: + db_events = list(session.query(Events).filter_by(event_type=event_type)) + assert len(db_events) == 0 + + assert instance.unlock_database() + + await task + with session_scope(hass=hass) as session: + db_events = list(session.query(Events).filter_by(event_type=event_type)) + assert len(db_events) == 1 + + +async def test_database_lock_and_overflow(hass: HomeAssistant, tmp_path): + """Test writing events during lock leading to overflow the queue causes the database to unlock.""" + # Use file DB, in memory DB cannot do write locks. + config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")} + await async_init_recorder_component(hass, config) + await hass.async_block_till_done() + + instance: Recorder = hass.data[DATA_INSTANCE] + + with patch.object(recorder, "MAX_QUEUE_BACKLOG", 1), patch.object( + recorder, "DB_LOCK_QUEUE_CHECK_TIMEOUT", 0.1 + ): + await instance.lock_database() + + event_type = "EVENT_TEST" + event_data = {"test_attr": 5, "test_attr_10": "nice"} + hass.bus.fire(event_type, event_data) + + # Check that this causes the queue to overflow and write succeeds + # even before unlocking. + await async_wait_recording_done(hass, instance) + + with session_scope(hass=hass) as session: + db_events = list(session.query(Events).filter_by(event_type=event_type)) + assert len(db_events) == 1 + + assert not instance.unlock_database() + + +async def test_database_lock_timeout(hass): + """Test locking database timeout when recorder stopped.""" + await async_init_recorder_component(hass) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + + instance: Recorder = hass.data[DATA_INSTANCE] + with patch.object(recorder, "DB_LOCK_TIMEOUT", 0.1): + try: + with pytest.raises(TimeoutError): + await instance.lock_database() + finally: + instance.unlock_database() diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 940925c48ca..fa449aefefc 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -8,6 +8,7 @@ import pytest from sqlalchemy import text from sqlalchemy.sql.elements import TextClause +from homeassistant.components import recorder from homeassistant.components.recorder import run_information_with_session, util from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.models import RecorderRuns @@ -556,3 +557,21 @@ def test_perodic_db_cleanups(hass_recorder): ][0] assert isinstance(text_obj, TextClause) assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);" + + +async def test_write_lock_db(hass, tmp_path): + """Test database write lock.""" + from sqlalchemy.exc import OperationalError + + # Use file DB, in memory DB cannot do write locks. + config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")} + await async_init_recorder_component(hass, config) + await hass.async_block_till_done() + + instance = hass.data[DATA_INSTANCE] + + with util.write_lock_db(instance): + # Database should be locked now, try writing SQL command + with instance.engine.connect() as connection: + with pytest.raises(OperationalError): + connection.execute(text("DROP TABLE events;")) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 7a45dea0379..994d1c677af 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -358,3 +358,62 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client): assert response["result"]["migration_in_progress"] is False assert response["result"]["recording"] is True assert response["result"]["thread_running"] is True + + +async def test_backup_start_no_recorder(hass, hass_ws_client): + """Test getting backup start when recorder is not present.""" + client = await hass_ws_client() + + await client.send_json({"id": 1, "type": "backup/start"}) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "unknown_command" + + +async def test_backup_start_timeout(hass, hass_ws_client): + """Test getting backup start when recorder is not present.""" + client = await hass_ws_client() + await async_init_recorder_component(hass) + + # Ensure there are no queued events + await async_wait_recording_done_without_instance(hass) + + with patch.object(recorder, "DB_LOCK_TIMEOUT", 0): + try: + await client.send_json({"id": 1, "type": "backup/start"}) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "timeout_error" + finally: + await client.send_json({"id": 2, "type": "backup/end"}) + + +async def test_backup_end(hass, hass_ws_client): + """Test backup start.""" + client = await hass_ws_client() + await async_init_recorder_component(hass) + + # Ensure there are no queued events + await async_wait_recording_done_without_instance(hass) + + await client.send_json({"id": 1, "type": "backup/start"}) + response = await client.receive_json() + assert response["success"] + + await client.send_json({"id": 2, "type": "backup/end"}) + response = await client.receive_json() + assert response["success"] + + +async def test_backup_end_without_start(hass, hass_ws_client): + """Test backup start.""" + client = await hass_ws_client() + await async_init_recorder_component(hass) + + # Ensure there are no queued events + await async_wait_recording_done_without_instance(hass) + + await client.send_json({"id": 1, "type": "backup/end"}) + response = await client.receive_json() + assert not response["success"] + assert response["error"]["code"] == "database_unlock_failed"