Ensure recorder always attempts clean shutdown if recorder thread raises (#91261)

* Ensure recorder run shutdown if the run loop raises

If anything goes wrong with the recorder we should
still try to shutdown cleanly

* tweak

* tests

* tests

* handle migraiton failure

* tweak comment

* naming

* order

* order

* order

* reword

* adjust test

* fixes

* threading

* failure case

* fix test

* have to wait for stop because the task blocks on thread join
pull/89456/head^2
J. Nick Koston 2023-04-14 15:03:24 -10:00 committed by GitHub
parent 56cc6633f5
commit 1379ad60c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 17 deletions

View File

@ -444,10 +444,17 @@ class Recorder(threading.Thread):
async_at_started(self.hass, self._async_hass_started)
@callback
def async_connection_failed(self) -> None:
"""Connect failed tasks."""
self.async_db_connected.set_result(False)
self.async_db_ready.set_result(False)
def _async_startup_failed(self) -> None:
"""Report startup failure."""
# If a live migration failed, we were able to connect (async_db_connected
# marked True), the database was marked ready (async_db_ready marked
# True), the data in the queue cannot be written to the database because
# the schema not in the correct format so we must stop listeners and report
# failure.
if not self.async_db_connected.done():
self.async_db_connected.set_result(False)
if not self.async_db_ready.done():
self.async_db_ready.set_result(False)
persistent_notification.async_create(
self.hass,
"The recorder could not start, check [the logs](/config/logs)",
@ -645,19 +652,26 @@ class Recorder(threading.Thread):
return SHUTDOWN_TASK
def run(self) -> None:
"""Run the recorder thread."""
try:
self._run()
finally:
# Ensure shutdown happens cleanly if
# anything goes wrong in the run loop
self._shutdown()
def _run(self) -> None:
"""Start processing events to save."""
self.thread_id = threading.get_ident()
setup_result = self._setup_recorder()
if not setup_result:
# Give up if we could not connect
self.hass.add_job(self.async_connection_failed)
return
schema_status = migration.validate_db_schema(self.hass, self, self.get_session)
if schema_status is None:
# Give up if we could not validate the schema
self.hass.add_job(self.async_connection_failed)
return
self.schema_version = schema_status.current_version
@ -684,7 +698,6 @@ class Recorder(threading.Thread):
self.migration_in_progress = False
# Make sure we cleanly close the run if
# we restart before startup finishes
self._shutdown()
return
if not schema_status.valid:
@ -702,8 +715,6 @@ class Recorder(threading.Thread):
"Database Migration Failed",
"recorder_database_migration",
)
self.hass.add_job(self.async_set_db_ready)
self._shutdown()
return
if not database_was_ready:
@ -715,7 +726,6 @@ class Recorder(threading.Thread):
self._adjust_lru_size()
self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop()
self._shutdown()
def _activate_and_set_db_ready(self) -> None:
"""Activate the table managers or schedule migrations and mark the db as ready."""
@ -1355,9 +1365,9 @@ class Recorder(threading.Thread):
def _close_connection(self) -> None:
"""Close the connection."""
assert self.engine is not None
self.engine.dispose()
self.engine = None
if self.engine:
self.engine.dispose()
self.engine = None
self._get_session = None
def _setup_run(self) -> None:
@ -1389,9 +1399,19 @@ class Recorder(threading.Thread):
def _shutdown(self) -> None:
"""Save end time for current run."""
_LOGGER.debug("Shutting down recorder")
self.hass.add_job(self._async_stop_listeners)
self._stop_executor()
if not self.schema_version or self.schema_version != SCHEMA_VERSION:
# If the schema version is not set, we never had a working
# connection to the database or the schema never reached a
# good state.
#
# In either case, we want to mark startup as failed.
#
self.hass.add_job(self._async_startup_failed)
else:
self.hass.add_job(self._async_stop_listeners)
try:
self._end_session()
finally:
self._stop_executor()
self._close_connection()

View File

@ -338,7 +338,6 @@ def test_state_changes_during_period_descending(
> hist_states[1].last_changed
> hist_states[2].last_changed
)
hist = history.state_changes_during_period(
hass,
start_time, # Pick a point where we will generate a start time state

View File

@ -8,7 +8,7 @@ from pathlib import Path
import sqlite3
import threading
from typing import cast
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory
import pytest
@ -27,6 +27,7 @@ from homeassistant.components.recorder import (
SQLITE_URL_PREFIX,
Recorder,
get_instance,
migration,
pool,
statistics,
)
@ -2239,3 +2240,90 @@ async def test_lru_increases_with_many_entities(
== mock_entity_count * 2
)
assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2
async def test_clean_shutdown_when_recorder_thread_raises_during_initialize_database(
hass: HomeAssistant,
) -> None:
"""Test we still shutdown cleanly when the recorder thread raises during initialize_database."""
with patch.object(migration, "initialize_database", side_effect=Exception), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None
async def test_clean_shutdown_when_recorder_thread_raises_during_validate_db_schema(
hass: HomeAssistant,
) -> None:
"""Test we still shutdown cleanly when the recorder thread raises during validate_db_schema."""
with patch.object(migration, "validate_db_schema", side_effect=Exception), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None
async def test_clean_shutdown_when_schema_migration_fails(hass: HomeAssistant) -> None:
"""Test we still shutdown cleanly when schema migration fails."""
with patch.object(
migration,
"validate_db_schema",
return_value=MagicMock(valid=False, current_version=1),
), patch(
"homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True
), patch.object(
migration,
"migrate_schema",
side_effect=Exception,
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass,
recorder.DOMAIN,
{
recorder.DOMAIN: {
CONF_DB_URL: "sqlite://",
CONF_DB_RETRY_WAIT: 0,
CONF_DB_MAX_RETRIES: 1,
}
},
)
await hass.async_block_till_done()
instance = recorder.get_instance(hass)
await hass.async_stop()
assert instance.engine is None