diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index fdb593ff27b..98f93f4e69a 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -43,11 +43,9 @@ from homeassistant.util.enum import try_parse_enum from . import migration, statistics from .const import ( - CONTEXT_ID_AS_BINARY_SCHEMA_VERSION, DB_WORKER_PREFIX, DOMAIN, ESTIMATED_QUEUE_ITEM_SIZE, - EVENT_TYPE_IDS_SCHEMA_VERSION, KEEPALIVE_TIME, LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION, MARIADB_PYMYSQL_URL_PREFIX, @@ -58,7 +56,6 @@ from .const import ( QUEUE_PERCENTAGE_ALLOWED_AVAILABLE_MEMORY, SQLITE_MAX_BIND_VARS, SQLITE_URL_PREFIX, - STATES_META_SCHEMA_VERSION, STATISTICS_ROWS_SCHEMA_VERSION, SupportedDialect, ) @@ -78,14 +75,15 @@ from .db_schema import ( StatisticsShortTerm, ) from .executor import DBInterruptibleThreadPoolExecutor +from .migration import ( + EntityIDMigration, + EventsContextIDMigration, + EventTypeIDMigration, + StatesContextIDMigration, +) from .models import DatabaseEngine, StatisticData, StatisticMetaData, UnsupportedDialect from .pool import POOL_SIZE, MutexPool, RecorderPool -from .queries import ( - has_entity_ids_to_migrate, - has_event_type_to_migrate, - has_events_context_ids_to_migrate, - has_states_context_ids_to_migrate, -) +from .queries import get_migration_changes from .table_managers.event_data import EventDataManager from .table_managers.event_types import EventTypeManager from .table_managers.recorder_runs import RecorderRunsManager @@ -101,17 +99,13 @@ from .tasks import ( CommitTask, CompileMissingStatisticsTask, DatabaseLockTask, - EntityIDMigrationTask, EntityIDPostMigrationTask, EventIdMigrationTask, - EventsContextIDMigrationTask, - EventTypeIDMigrationTask, ImportStatisticsTask, KeepAliveTask, PerodicCleanupTask, PurgeTask, RecorderTask, - StatesContextIDMigrationTask, StatisticsTask, StopTask, SynchronizeTask, @@ -783,44 +777,35 @@ class Recorder(threading.Thread): def _activate_and_set_db_ready(self) -> None: """Activate the table managers or schedule migrations and mark the db as ready.""" - with session_scope(session=self.get_session(), read_only=True) as session: + with session_scope(session=self.get_session()) as session: # Prime the statistics meta manager as soon as possible # since we want the frontend queries to avoid a thundering # herd of queries to find the statistics meta data if # there are a lot of statistics graphs on the frontend. - if self.schema_version >= STATISTICS_ROWS_SCHEMA_VERSION: + schema_version = self.schema_version + if schema_version >= STATISTICS_ROWS_SCHEMA_VERSION: self.statistics_meta_manager.load(session) - if ( - self.schema_version < CONTEXT_ID_AS_BINARY_SCHEMA_VERSION - or execute_stmt_lambda_element( - session, has_states_context_ids_to_migrate() - ) - ): - self.queue_task(StatesContextIDMigrationTask()) + migration_changes: dict[str, int] = { + row[0]: row[1] + for row in execute_stmt_lambda_element(session, get_migration_changes()) + } - if ( - self.schema_version < CONTEXT_ID_AS_BINARY_SCHEMA_VERSION - or execute_stmt_lambda_element( - session, has_events_context_ids_to_migrate() - ) - ): - self.queue_task(EventsContextIDMigrationTask()) + for migrator_cls in (StatesContextIDMigration, EventsContextIDMigration): + migrator = migrator_cls(session, schema_version, migration_changes) + if migrator.needs_migrate(): + self.queue_task(migrator.task()) - if ( - self.schema_version < EVENT_TYPE_IDS_SCHEMA_VERSION - or execute_stmt_lambda_element(session, has_event_type_to_migrate()) - ): - self.queue_task(EventTypeIDMigrationTask()) + migrator = EventTypeIDMigration(session, schema_version, migration_changes) + if migrator.needs_migrate(): + self.queue_task(migrator.task()) else: _LOGGER.debug("Activating event_types manager as all data is migrated") self.event_type_manager.active = True - if ( - self.schema_version < STATES_META_SCHEMA_VERSION - or execute_stmt_lambda_element(session, has_entity_ids_to_migrate()) - ): - self.queue_task(EntityIDMigrationTask()) + migrator = EntityIDMigration(session, schema_version, migration_changes) + if migrator.needs_migrate(): + self.queue_task(migrator.task()) else: _LOGGER.debug("Activating states_meta manager as all data is migrated") self.states_meta_manager.active = True diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index eb2e0b6ade3..6755e9c5c9b 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -84,6 +84,7 @@ TABLE_STATISTICS = "statistics" TABLE_STATISTICS_META = "statistics_meta" TABLE_STATISTICS_RUNS = "statistics_runs" TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" +TABLE_MIGRATION_CHANGES = "migration_changes" STATISTICS_TABLES = ("statistics", "statistics_short_term") @@ -100,6 +101,7 @@ ALL_TABLES = [ TABLE_EVENT_TYPES, TABLE_RECORDER_RUNS, TABLE_SCHEMA_CHANGES, + TABLE_MIGRATION_CHANGES, TABLE_STATES_META, TABLE_STATISTICS, TABLE_STATISTICS_META, @@ -771,6 +773,15 @@ class RecorderRuns(Base): return self +class MigrationChanges(Base): + """Representation of migration changes.""" + + __tablename__ = TABLE_MIGRATION_CHANGES + + migration_id: Mapped[str] = mapped_column(String(255), primary_key=True) + version: Mapped[int] = mapped_column(SmallInteger) + + class SchemaChanges(Base): """Representation of schema version changes.""" diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 2d8a2976219..8395b88837c 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -2,6 +2,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable import contextlib from dataclasses import dataclass, replace as dataclass_replace @@ -25,6 +26,7 @@ from sqlalchemy.exc import ( from sqlalchemy.orm.session import Session from sqlalchemy.schema import AddConstraint, DropConstraint from sqlalchemy.sql.expression import true +from sqlalchemy.sql.lambdas import StatementLambdaElement from homeassistant.core import HomeAssistant from homeassistant.util.enum import try_parse_enum @@ -46,7 +48,12 @@ from .auto_repairs.statistics.schema import ( correct_db_schema as statistics_correct_db_schema, validate_db_schema as statistics_validate_db_schema, ) -from .const import SupportedDialect +from .const import ( + CONTEXT_ID_AS_BINARY_SCHEMA_VERSION, + EVENT_TYPE_IDS_SCHEMA_VERSION, + STATES_META_SCHEMA_VERSION, + SupportedDialect, +) from .db_schema import ( CONTEXT_ID_BIN_MAX_LENGTH, DOUBLE_PRECISION_TYPE_SQL, @@ -60,6 +67,7 @@ from .db_schema import ( Base, Events, EventTypes, + MigrationChanges, SchemaChanges, States, StatesMeta, @@ -80,6 +88,10 @@ from .queries import ( find_states_context_ids_to_migrate, find_unmigrated_short_term_statistics_rows, find_unmigrated_statistics_rows, + has_entity_ids_to_migrate, + has_event_type_to_migrate, + has_events_context_ids_to_migrate, + has_states_context_ids_to_migrate, has_used_states_event_ids, migrate_single_short_term_statistics_row_to_timestamp, migrate_single_statistics_row_to_timestamp, @@ -87,11 +99,17 @@ from .queries import ( from .statistics import get_start_time from .tasks import ( CommitTask, + EntityIDMigrationTask, + EventsContextIDMigrationTask, + EventTypeIDMigrationTask, PostSchemaMigrationTask, + RecorderTask, + StatesContextIDMigrationTask, StatisticsTimestampMigrationCleanupTask, ) from .util import ( database_job_retry_wrapper, + execute_stmt_lambda_element, get_index_by_name, retryable_database_job, session_scope, @@ -1478,7 +1496,8 @@ def migrate_states_context_ids(instance: Recorder) -> bool: ) # If there is more work to do return False # so that we can be called again - is_done = not states + if is_done := not states: + _mark_migration_done(session, StatesContextIDMigration) if is_done: _drop_index(session_maker, "states", "ix_states_context_id") @@ -1515,7 +1534,8 @@ def migrate_events_context_ids(instance: Recorder) -> bool: ) # If there is more work to do return False # so that we can be called again - is_done = not events + if is_done := not events: + _mark_migration_done(session, EventsContextIDMigration) if is_done: _drop_index(session_maker, "events", "ix_events_context_id") @@ -1580,7 +1600,8 @@ def migrate_event_type_ids(instance: Recorder) -> bool: # If there is more work to do return False # so that we can be called again - is_done = not events + if is_done := not events: + _mark_migration_done(session, EventTypeIDMigration) if is_done: instance.event_type_manager.active = True @@ -1654,7 +1675,8 @@ def migrate_entity_ids(instance: Recorder) -> bool: # If there is more work to do return False # so that we can be called again - is_done = not states + if is_done := not states: + _mark_migration_done(session, EntityIDMigration) _LOGGER.debug("Migrating entity_ids done=%s", is_done) return is_done @@ -1757,3 +1779,106 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool: except Exception as err: # pylint: disable=broad-except _LOGGER.exception("Error when initialise database: %s", err) return False + + +class BaseRunTimeMigration(ABC): + """Base class for run time migrations.""" + + required_schema_version = 0 + migration_version = 1 + migration_id: str + task: Callable[[], RecorderTask] + + def __init__( + self, session: Session, schema_version: int, migration_changes: dict[str, int] + ) -> None: + """Initialize a new BaseRunTimeMigration.""" + self.schema_version = schema_version + self.session = session + self.migration_changes = migration_changes + + @abstractmethod + def needs_migrate_query(self) -> StatementLambdaElement: + """Return the query to check if the migration needs to run.""" + + def needs_migrate(self) -> bool: + """Return if the migration needs to run. + + If the migration needs to run, it will return True. + + If the migration does not need to run, it will return False and + mark the migration as done in the database if its not already + marked as done. + """ + if self.schema_version < self.required_schema_version: + # Schema is too old, we must have to migrate + return True + if self.migration_changes.get(self.migration_id, -1) >= self.migration_version: + # The migration changes table indicates that the migration has been done + return False + # We do not know if the migration is done from the + # migration changes table so we must check the data + # This is the slow path + if not execute_stmt_lambda_element(self.session, self.needs_migrate_query()): + _mark_migration_done(self.session, self.__class__) + return False + return True + + +class StatesContextIDMigration(BaseRunTimeMigration): + """Migration to migrate states context_ids to binary format.""" + + required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION + migration_id = "state_context_id_as_binary" + task = StatesContextIDMigrationTask + + def needs_migrate_query(self) -> StatementLambdaElement: + """Return the query to check if the migration needs to run.""" + return has_states_context_ids_to_migrate() + + +class EventsContextIDMigration(BaseRunTimeMigration): + """Migration to migrate events context_ids to binary format.""" + + required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION + migration_id = "event_context_id_as_binary" + task = EventsContextIDMigrationTask + + def needs_migrate_query(self) -> StatementLambdaElement: + """Return the query to check if the migration needs to run.""" + return has_events_context_ids_to_migrate() + + +class EventTypeIDMigration(BaseRunTimeMigration): + """Migration to migrate event_type to event_type_ids.""" + + required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION + migration_id = "event_type_id_migration" + task = EventTypeIDMigrationTask + + def needs_migrate_query(self) -> StatementLambdaElement: + """Check if the data is migrated.""" + return has_event_type_to_migrate() + + +class EntityIDMigration(BaseRunTimeMigration): + """Migration to migrate entity_ids to states_meta.""" + + required_schema_version = STATES_META_SCHEMA_VERSION + migration_id = "entity_id_migration" + task = EntityIDMigrationTask + + def needs_migrate_query(self) -> StatementLambdaElement: + """Check if the data is migrated.""" + return has_entity_ids_to_migrate() + + +def _mark_migration_done( + session: Session, migration: type[BaseRunTimeMigration] +) -> None: + """Mark a migration as done in the database.""" + session.merge( + MigrationChanges( + migration_id=migration.migration_id, version=migration.migration_version + ) + ) diff --git a/homeassistant/components/recorder/queries.py b/homeassistant/components/recorder/queries.py index fdb15d2d49c..d982576620d 100644 --- a/homeassistant/components/recorder/queries.py +++ b/homeassistant/components/recorder/queries.py @@ -13,6 +13,7 @@ from .db_schema import ( EventData, Events, EventTypes, + MigrationChanges, RecorderRuns, StateAttributes, States, @@ -812,6 +813,13 @@ def find_states_context_ids_to_migrate(max_bind_vars: int) -> StatementLambdaEle ) +def get_migration_changes() -> StatementLambdaElement: + """Query the database for previous migration changes.""" + return lambda_stmt( + lambda: select(MigrationChanges.migration_id, MigrationChanges.version) + ) + + def find_event_types_to_purge() -> StatementLambdaElement: """Find event_type_ids to purge.""" return lambda_stmt( diff --git a/tests/components/recorder/common.py b/tests/components/recorder/common.py index 19ee449ae0b..816378c2f2e 100644 --- a/tests/components/recorder/common.py +++ b/tests/components/recorder/common.py @@ -20,7 +20,13 @@ from sqlalchemy.orm.session import Session from homeassistant import core as ha from homeassistant.components import recorder -from homeassistant.components.recorder import Recorder, core, get_instance, statistics +from homeassistant.components.recorder import ( + Recorder, + core, + get_instance, + migration, + statistics, +) from homeassistant.components.recorder.db_schema import ( Events, EventTypes, @@ -417,7 +423,7 @@ def old_db_schema(schema_version_postfix: str) -> Iterator[None]: core, "States", old_db_schema.States ), patch.object(core, "Events", old_db_schema.Events), patch.object( core, "StateAttributes", old_db_schema.StateAttributes - ), patch.object(core, "EntityIDMigrationTask", core.RecorderTask), patch( + ), patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch( CREATE_ENGINE_TARGET, new=partial( create_engine_test_for_schema_version_postfix, diff --git a/tests/components/recorder/test_migration_from_schema_32.py b/tests/components/recorder/test_migration_from_schema_32.py index 2e9a71a2a50..e9f51caaee2 100644 --- a/tests/components/recorder/test_migration_from_schema_32.py +++ b/tests/components/recorder/test_migration_from_schema_32.py @@ -22,7 +22,10 @@ from homeassistant.components.recorder.db_schema import ( StatesMeta, ) from homeassistant.components.recorder.models import process_timestamp -from homeassistant.components.recorder.queries import select_event_type_ids +from homeassistant.components.recorder.queries import ( + get_migration_changes, + select_event_type_ids, +) from homeassistant.components.recorder.tasks import ( EntityIDMigrationTask, EntityIDPostMigrationTask, @@ -30,7 +33,10 @@ from homeassistant.components.recorder.tasks import ( EventTypeIDMigrationTask, StatesContextIDMigrationTask, ) -from homeassistant.components.recorder.util import session_scope +from homeassistant.components.recorder.util import ( + execute_stmt_lambda_element, + session_scope, +) from homeassistant.core import HomeAssistant import homeassistant.util.dt as dt_util from homeassistant.util.ulid import bytes_to_ulid, ulid_at_time, ulid_to_bytes @@ -53,6 +59,11 @@ async def _async_wait_migration_done(hass: HomeAssistant) -> None: await async_recorder_block_till_done(hass) +def _get_migration_id(hass: HomeAssistant) -> dict[str, int]: + with session_scope(hass=hass, read_only=True) as session: + return dict(execute_stmt_lambda_element(session, get_migration_changes())) + + def _create_engine_test(*args, **kwargs): """Test version of create_engine that initializes with old schema. @@ -89,7 +100,7 @@ def db_schema_32(): core, "States", old_db_schema.States ), patch.object(core, "Events", old_db_schema.Events), patch.object( core, "StateAttributes", old_db_schema.StateAttributes - ), patch.object(core, "EntityIDMigrationTask", core.RecorderTask), patch( + ), patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch( CREATE_ENGINE_TARGET, new=_create_engine_test ): yield @@ -308,6 +319,12 @@ async def test_migrate_events_context_ids( event_with_garbage_context_id_no_time_fired_ts["context_parent_id_bin"] is None ) + migration_changes = await instance.async_add_executor_job(_get_migration_id, hass) + assert ( + migration_changes[migration.EventsContextIDMigration.migration_id] + == migration.EventsContextIDMigration.migration_version + ) + @pytest.mark.parametrize("enable_migrate_context_ids", [True]) async def test_migrate_states_context_ids( @@ -495,6 +512,12 @@ async def test_migrate_states_context_ids( == b"\n\xe2\x97\x99\xeeNOE\x81\x16\xf5\x82\xd7\xd3\xeee" ) + migration_changes = await instance.async_add_executor_job(_get_migration_id, hass) + assert ( + migration_changes[migration.StatesContextIDMigration.migration_id] + == migration.StatesContextIDMigration.migration_version + ) + @pytest.mark.parametrize("enable_migrate_event_type_ids", [True]) async def test_migrate_event_type_ids( @@ -578,6 +601,12 @@ async def test_migrate_event_type_ids( assert mapped["event_type_one"] is not None assert mapped["event_type_two"] is not None + migration_changes = await instance.async_add_executor_job(_get_migration_id, hass) + assert ( + migration_changes[migration.EventTypeIDMigration.migration_id] + == migration.EventTypeIDMigration.migration_version + ) + @pytest.mark.parametrize("enable_migrate_entity_ids", [True]) async def test_migrate_entity_ids( @@ -646,6 +675,12 @@ async def test_migrate_entity_ids( assert len(states_by_entity_id["sensor.two"]) == 2 assert len(states_by_entity_id["sensor.one"]) == 1 + migration_changes = await instance.async_add_executor_job(_get_migration_id, hass) + assert ( + migration_changes[migration.EntityIDMigration.migration_id] + == migration.EntityIDMigration.migration_version + ) + @pytest.mark.parametrize("enable_migrate_entity_ids", [True]) async def test_post_migrate_entity_ids( @@ -771,6 +806,16 @@ async def test_migrate_null_entity_ids( assert len(states_by_entity_id[migration._EMPTY_ENTITY_ID]) == 1000 assert len(states_by_entity_id["sensor.one"]) == 2 + def _get_migration_id(): + with session_scope(hass=hass, read_only=True) as session: + return dict(execute_stmt_lambda_element(session, get_migration_changes())) + + migration_changes = await instance.async_add_executor_job(_get_migration_id) + assert ( + migration_changes[migration.EntityIDMigration.migration_id] + == migration.EntityIDMigration.migration_version + ) + @pytest.mark.parametrize("enable_migrate_event_type_ids", [True]) async def test_migrate_null_event_type_ids( @@ -847,6 +892,16 @@ async def test_migrate_null_event_type_ids( assert len(events_by_type["event_type_one"]) == 2 assert len(events_by_type[migration._EMPTY_EVENT_TYPE]) == 1000 + def _get_migration_id(): + with session_scope(hass=hass, read_only=True) as session: + return dict(execute_stmt_lambda_element(session, get_migration_changes())) + + migration_changes = await instance.async_add_executor_job(_get_migration_id) + assert ( + migration_changes[migration.EventTypeIDMigration.migration_id] + == migration.EventTypeIDMigration.migration_version + ) + async def test_stats_timestamp_conversion_is_reentrant( async_setup_recorder_instance: RecorderInstanceGenerator, diff --git a/tests/components/recorder/test_migration_run_time_migrations_remember.py b/tests/components/recorder/test_migration_run_time_migrations_remember.py new file mode 100644 index 00000000000..770d5d684a9 --- /dev/null +++ b/tests/components/recorder/test_migration_run_time_migrations_remember.py @@ -0,0 +1,163 @@ +"""Test run time migrations are remembered in the migration_changes table.""" + +import importlib +from pathlib import Path +import sys +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from homeassistant.components import recorder +from homeassistant.components.recorder import core, migration, statistics +from homeassistant.components.recorder.queries import get_migration_changes +from homeassistant.components.recorder.tasks import StatesContextIDMigrationTask +from homeassistant.components.recorder.util import ( + execute_stmt_lambda_element, + session_scope, +) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import HomeAssistant + +from .common import async_recorder_block_till_done, async_wait_recording_done + +from tests.common import async_test_home_assistant +from tests.typing import RecorderInstanceGenerator + +CREATE_ENGINE_TARGET = "homeassistant.components.recorder.core.create_engine" +SCHEMA_MODULE = "tests.components.recorder.db_schema_32" + + +async def _async_wait_migration_done(hass: HomeAssistant) -> None: + """Wait for the migration to be done.""" + await recorder.get_instance(hass).async_block_till_done() + await async_recorder_block_till_done(hass) + + +def _get_migration_id(hass: HomeAssistant) -> dict[str, int]: + with session_scope(hass=hass, read_only=True) as session: + return dict(execute_stmt_lambda_element(session, get_migration_changes())) + + +def _create_engine_test(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + importlib.import_module(SCHEMA_MODULE) + old_db_schema = sys.modules[SCHEMA_MODULE] + engine = create_engine(*args, **kwargs) + old_db_schema.Base.metadata.create_all(engine) + with Session(engine) as session: + session.add( + recorder.db_schema.StatisticsRuns(start=statistics.get_start_time()) + ) + session.add( + recorder.db_schema.SchemaChanges( + schema_version=old_db_schema.SCHEMA_VERSION + ) + ) + session.commit() + return engine + + +@pytest.mark.parametrize("enable_migrate_context_ids", [True]) +async def test_migration_changes_prevent_trying_to_migrate_again( + async_setup_recorder_instance: RecorderInstanceGenerator, + tmp_path: Path, + recorder_db_url: str, +) -> None: + """Test that we do not try to migrate when migration_changes indicate its already migrated. + + This test will start Home Assistant 3 times: + + 1. With schema 32 to populate the data + 2. With current schema so the migration happens + 3. With current schema to verify we do not have to query to see if the migration is done + """ + if recorder_db_url.startswith(("mysql://", "postgresql://")): + # This test uses a test database between runs so its + # SQLite specific + return + + config = { + recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db"), + recorder.CONF_COMMIT_INTERVAL: 1, + } + importlib.import_module(SCHEMA_MODULE) + old_db_schema = sys.modules[SCHEMA_MODULE] + + # Start with db schema that needs migration (version 32) + with patch.object(recorder, "db_schema", old_db_schema), patch.object( + recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION + ), patch.object(core, "StatesMeta", old_db_schema.StatesMeta), patch.object( + core, "EventTypes", old_db_schema.EventTypes + ), patch.object(core, "EventData", old_db_schema.EventData), patch.object( + core, "States", old_db_schema.States + ), patch.object(core, "Events", old_db_schema.Events), patch.object( + core, "StateAttributes", old_db_schema.StateAttributes + ), patch.object(migration.EntityIDMigration, "task", core.RecorderTask), patch( + CREATE_ENGINE_TARGET, new=_create_engine_test + ): + async with async_test_home_assistant() as hass: + await async_setup_recorder_instance(hass, config) + await hass.async_block_till_done() + await async_wait_recording_done(hass) + await _async_wait_migration_done(hass) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + await hass.async_stop() + + # Now start again with current db schema + async with async_test_home_assistant() as hass: + await async_setup_recorder_instance(hass, config) + await hass.async_block_till_done() + await async_wait_recording_done(hass) + await _async_wait_migration_done(hass) + instance = recorder.get_instance(hass) + migration_changes = await instance.async_add_executor_job( + _get_migration_id, hass + ) + assert ( + migration_changes[migration.StatesContextIDMigration.migration_id] + == migration.StatesContextIDMigration.migration_version + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + await hass.async_stop() + + original_queue_task = core.Recorder.queue_task + tasks = [] + + def _queue_task(self, task): + tasks.append(task) + original_queue_task(self, task) + + # Finally verify we did not call needs_migrate_query on StatesContextIDMigration + async with async_test_home_assistant() as hass: + with patch( + "homeassistant.components.recorder.core.Recorder.queue_task", _queue_task + ), patch.object( + migration.StatesContextIDMigration, + "needs_migrate_query", + side_effect=RuntimeError("Should not be called"), + ): + await async_setup_recorder_instance(hass, config) + await hass.async_block_till_done() + await async_wait_recording_done(hass) + await _async_wait_migration_done(hass) + instance = recorder.get_instance(hass) + migration_changes = await instance.async_add_executor_job( + _get_migration_id, hass + ) + assert ( + migration_changes[migration.StatesContextIDMigration.migration_id] + == migration.StatesContextIDMigration.migration_version + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + await hass.async_stop() + + for task in tasks: + assert not isinstance(task, StatesContextIDMigrationTask)