From 9cd6bb73350114ed7cbe288b2ec8aa4dde6148c1 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 4 Feb 2022 18:55:11 +0100 Subject: [PATCH] Don't use shared session during recorder migration (#65672) --- .../components/recorder/migration.py | 297 ++++++++++-------- tests/components/recorder/test_migrate.py | 22 +- 2 files changed, 176 insertions(+), 143 deletions(-) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 32119b85597..b49aee29ba1 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -68,20 +68,18 @@ def schema_is_current(current_version): def migrate_schema(instance, current_version): """Check if the schema needs to be upgraded.""" - with session_scope(session=instance.get_session()) as session: - _LOGGER.warning( - "Database is about to upgrade. Schema version: %s", current_version - ) - for version in range(current_version, SCHEMA_VERSION): - new_version = version + 1 - _LOGGER.info("Upgrading recorder db schema to version %s", new_version) - _apply_update(instance, session, new_version, current_version) + _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) + for version in range(current_version, SCHEMA_VERSION): + new_version = version + 1 + _LOGGER.info("Upgrading recorder db schema to version %s", new_version) + _apply_update(instance, new_version, current_version) + with session_scope(session=instance.get_session()) as session: session.add(SchemaChanges(schema_version=new_version)) - _LOGGER.info("Upgrade to version %s done", new_version) + _LOGGER.info("Upgrade to version %s done", new_version) -def _create_index(connection, table_name, index_name): +def _create_index(instance, table_name, index_name): """Create an index for the specified table. The index name should match the name given for the index @@ -103,7 +101,9 @@ def _create_index(connection, table_name, index_name): index_name, ) try: - index.create(connection) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + index.create(connection) except (InternalError, ProgrammingError, OperationalError) as err: raise_if_exception_missing_str(err, ["already exists", "duplicate"]) _LOGGER.warning( @@ -113,7 +113,7 @@ def _create_index(connection, table_name, index_name): _LOGGER.debug("Finished creating %s", index_name) -def _drop_index(connection, table_name, index_name): +def _drop_index(instance, table_name, index_name): """Drop an index from a specified table. There is no universal way to do something like `DROP INDEX IF EXISTS` @@ -129,7 +129,9 @@ def _drop_index(connection, table_name, index_name): # Engines like DB2/Oracle try: - connection.execute(text(f"DROP INDEX {index_name}")) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute(text(f"DROP INDEX {index_name}")) except SQLAlchemyError: pass else: @@ -138,13 +140,15 @@ def _drop_index(connection, table_name, index_name): # Engines like SQLite, SQL Server if not success: try: - connection.execute( - text( - "DROP INDEX {table}.{index}".format( - index=index_name, table=table_name + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "DROP INDEX {table}.{index}".format( + index=index_name, table=table_name + ) ) ) - ) except SQLAlchemyError: pass else: @@ -153,13 +157,15 @@ def _drop_index(connection, table_name, index_name): if not success: # Engines like MySQL, MS Access try: - connection.execute( - text( - "DROP INDEX {index} ON {table}".format( - index=index_name, table=table_name + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "DROP INDEX {index} ON {table}".format( + index=index_name, table=table_name + ) ) ) - ) except SQLAlchemyError: pass else: @@ -184,7 +190,7 @@ def _drop_index(connection, table_name, index_name): ) -def _add_columns(connection, table_name, columns_def): +def _add_columns(instance, table_name, columns_def): """Add columns to a table.""" _LOGGER.warning( "Adding columns %s to table %s. Note: this can take several " @@ -197,14 +203,16 @@ def _add_columns(connection, table_name, columns_def): columns_def = [f"ADD {col_def}" for col_def in columns_def] try: - connection.execute( - text( - "ALTER TABLE {table} {columns_def}".format( - table=table_name, columns_def=", ".join(columns_def) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "ALTER TABLE {table} {columns_def}".format( + table=table_name, columns_def=", ".join(columns_def) + ) ) ) - ) - return + return except (InternalError, OperationalError): # Some engines support adding all columns at once, # this error is when they don't @@ -212,13 +220,15 @@ def _add_columns(connection, table_name, columns_def): for column_def in columns_def: try: - connection.execute( - text( - "ALTER TABLE {table} {column_def}".format( - table=table_name, column_def=column_def + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "ALTER TABLE {table} {column_def}".format( + table=table_name, column_def=column_def + ) ) ) - ) except (InternalError, OperationalError) as err: raise_if_exception_missing_str(err, ["already exists", "duplicate"]) _LOGGER.warning( @@ -228,7 +238,7 @@ def _add_columns(connection, table_name, columns_def): ) -def _modify_columns(connection, engine, table_name, columns_def): +def _modify_columns(instance, engine, table_name, columns_def): """Modify columns in a table.""" if engine.dialect.name == "sqlite": _LOGGER.debug( @@ -261,33 +271,37 @@ def _modify_columns(connection, engine, table_name, columns_def): columns_def = [f"MODIFY {col_def}" for col_def in columns_def] try: - connection.execute( - text( - "ALTER TABLE {table} {columns_def}".format( - table=table_name, columns_def=", ".join(columns_def) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "ALTER TABLE {table} {columns_def}".format( + table=table_name, columns_def=", ".join(columns_def) + ) ) ) - ) - return + return except (InternalError, OperationalError): _LOGGER.info("Unable to use quick column modify. Modifying 1 by 1") for column_def in columns_def: try: - connection.execute( - text( - "ALTER TABLE {table} {column_def}".format( - table=table_name, column_def=column_def + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + text( + "ALTER TABLE {table} {column_def}".format( + table=table_name, column_def=column_def + ) ) ) - ) except (InternalError, OperationalError): _LOGGER.exception( "Could not modify column %s in table %s", column_def, table_name ) -def _update_states_table_with_foreign_key_options(connection, engine): +def _update_states_table_with_foreign_key_options(instance, engine): """Add the options to foreign key constraints.""" inspector = sqlalchemy.inspect(engine) alters = [] @@ -316,17 +330,19 @@ def _update_states_table_with_foreign_key_options(connection, engine): for alter in alters: try: - connection.execute(DropConstraint(alter["old_fk"])) - for fkc in states_key_constraints: - if fkc.column_keys == alter["columns"]: - connection.execute(AddConstraint(fkc)) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute(DropConstraint(alter["old_fk"])) + for fkc in states_key_constraints: + if fkc.column_keys == alter["columns"]: + connection.execute(AddConstraint(fkc)) except (InternalError, OperationalError): _LOGGER.exception( "Could not update foreign options in %s table", TABLE_STATES ) -def _drop_foreign_key_constraints(connection, engine, table, columns): +def _drop_foreign_key_constraints(instance, engine, table, columns): """Drop foreign key constraints for a table on specific columns.""" inspector = sqlalchemy.inspect(engine) drops = [] @@ -345,7 +361,9 @@ def _drop_foreign_key_constraints(connection, engine, table, columns): for drop in drops: try: - connection.execute(DropConstraint(drop)) + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute(DropConstraint(drop)) except (InternalError, OperationalError): _LOGGER.exception( "Could not drop foreign constraints in %s table on %s", @@ -354,17 +372,16 @@ def _drop_foreign_key_constraints(connection, engine, table, columns): ) -def _apply_update(instance, session, new_version, old_version): # noqa: C901 +def _apply_update(instance, new_version, old_version): # noqa: C901 """Perform operations to bring schema up to date.""" engine = instance.engine - connection = session.connection() if new_version == 1: - _create_index(connection, "events", "ix_events_time_fired") + _create_index(instance, "events", "ix_events_time_fired") elif new_version == 2: # Create compound start/end index for recorder_runs - _create_index(connection, "recorder_runs", "ix_recorder_runs_start_end") + _create_index(instance, "recorder_runs", "ix_recorder_runs_start_end") # Create indexes for states - _create_index(connection, "states", "ix_states_last_updated") + _create_index(instance, "states", "ix_states_last_updated") elif new_version == 3: # There used to be a new index here, but it was removed in version 4. pass @@ -374,41 +391,41 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 if old_version == 3: # Remove index that was added in version 3 - _drop_index(connection, "states", "ix_states_created_domain") + _drop_index(instance, "states", "ix_states_created_domain") if old_version == 2: # Remove index that was added in version 2 - _drop_index(connection, "states", "ix_states_entity_id_created") + _drop_index(instance, "states", "ix_states_entity_id_created") # Remove indexes that were added in version 0 - _drop_index(connection, "states", "states__state_changes") - _drop_index(connection, "states", "states__significant_changes") - _drop_index(connection, "states", "ix_states_entity_id_created") + _drop_index(instance, "states", "states__state_changes") + _drop_index(instance, "states", "states__significant_changes") + _drop_index(instance, "states", "ix_states_entity_id_created") - _create_index(connection, "states", "ix_states_entity_id_last_updated") + _create_index(instance, "states", "ix_states_entity_id_last_updated") elif new_version == 5: # Create supporting index for States.event_id foreign key - _create_index(connection, "states", "ix_states_event_id") + _create_index(instance, "states", "ix_states_event_id") elif new_version == 6: _add_columns( - session, + instance, "events", ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ) - _create_index(connection, "events", "ix_events_context_id") - _create_index(connection, "events", "ix_events_context_user_id") + _create_index(instance, "events", "ix_events_context_id") + _create_index(instance, "events", "ix_events_context_user_id") _add_columns( - connection, + instance, "states", ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ) - _create_index(connection, "states", "ix_states_context_id") - _create_index(connection, "states", "ix_states_context_user_id") + _create_index(instance, "states", "ix_states_context_id") + _create_index(instance, "states", "ix_states_context_user_id") elif new_version == 7: - _create_index(connection, "states", "ix_states_entity_id") + _create_index(instance, "states", "ix_states_entity_id") elif new_version == 8: - _add_columns(connection, "events", ["context_parent_id CHARACTER(36)"]) - _add_columns(connection, "states", ["old_state_id INTEGER"]) - _create_index(connection, "events", "ix_events_context_parent_id") + _add_columns(instance, "events", ["context_parent_id CHARACTER(36)"]) + _add_columns(instance, "states", ["old_state_id INTEGER"]) + _create_index(instance, "events", "ix_events_context_parent_id") elif new_version == 9: # We now get the context from events with a join # since its always there on state_changed events @@ -418,36 +435,36 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 # and we would have to move to something like # sqlalchemy alembic to make that work # - _drop_index(connection, "states", "ix_states_context_id") - _drop_index(connection, "states", "ix_states_context_user_id") + _drop_index(instance, "states", "ix_states_context_id") + _drop_index(instance, "states", "ix_states_context_user_id") # This index won't be there if they were not running # nightly but we don't treat that as a critical issue - _drop_index(connection, "states", "ix_states_context_parent_id") + _drop_index(instance, "states", "ix_states_context_parent_id") # Redundant keys on composite index: # We already have ix_states_entity_id_last_updated - _drop_index(connection, "states", "ix_states_entity_id") - _create_index(connection, "events", "ix_events_event_type_time_fired") - _drop_index(connection, "events", "ix_events_event_type") + _drop_index(instance, "states", "ix_states_entity_id") + _create_index(instance, "events", "ix_events_event_type_time_fired") + _drop_index(instance, "events", "ix_events_event_type") elif new_version == 10: # Now done in step 11 pass elif new_version == 11: - _create_index(connection, "states", "ix_states_old_state_id") - _update_states_table_with_foreign_key_options(connection, engine) + _create_index(instance, "states", "ix_states_old_state_id") + _update_states_table_with_foreign_key_options(instance, engine) elif new_version == 12: if engine.dialect.name == "mysql": - _modify_columns(connection, engine, "events", ["event_data LONGTEXT"]) - _modify_columns(connection, engine, "states", ["attributes LONGTEXT"]) + _modify_columns(instance, engine, "events", ["event_data LONGTEXT"]) + _modify_columns(instance, engine, "states", ["attributes LONGTEXT"]) elif new_version == 13: if engine.dialect.name == "mysql": _modify_columns( - connection, + instance, engine, "events", ["time_fired DATETIME(6)", "created DATETIME(6)"], ) _modify_columns( - connection, + instance, engine, "states", [ @@ -457,14 +474,12 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 ], ) elif new_version == 14: - _modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"]) + _modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) elif new_version == 15: # This dropped the statistics table, done again in version 18. pass elif new_version == 16: - _drop_foreign_key_constraints( - connection, engine, TABLE_STATES, ["old_state_id"] - ) + _drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"]) elif new_version == 17: # This dropped the statistics table, done again in version 18. pass @@ -489,12 +504,13 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 elif new_version == 19: # This adds the statistic runs table, insert a fake run to prevent duplicating # statistics. - session.add(StatisticsRuns(start=get_start_time())) + with session_scope(session=instance.get_session()) as session: + session.add(StatisticsRuns(start=get_start_time())) elif new_version == 20: # This changed the precision of statistics from float to double if engine.dialect.name in ["mysql", "postgresql"]: _modify_columns( - connection, + instance, engine, "statistics", [ @@ -516,14 +532,16 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 table, ) with contextlib.suppress(SQLAlchemyError): - connection.execute( - # Using LOCK=EXCLUSIVE to prevent the database from corrupting - # https://github.com/home-assistant/core/issues/56104 - text( - f"ALTER TABLE {table} CONVERT TO " - "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci LOCK=EXCLUSIVE" + with session_scope(session=instance.get_session()) as session: + connection = session.connection() + connection.execute( + # Using LOCK=EXCLUSIVE to prevent the database from corrupting + # https://github.com/home-assistant/core/issues/56104 + text( + f"ALTER TABLE {table} CONVERT TO " + "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci LOCK=EXCLUSIVE" + ) ) - ) elif new_version == 22: # Recreate the all statistics tables for Oracle DB with Identity columns # @@ -549,57 +567,64 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901 # Block 5-minute statistics for one hour from the last run, or it will overlap # with existing hourly statistics. Don't block on a database with no existing # statistics. - if session.query(Statistics.id).count() and ( - last_run_string := session.query(func.max(StatisticsRuns.start)).scalar() - ): - last_run_start_time = process_timestamp(last_run_string) - if last_run_start_time: - fake_start_time = last_run_start_time + timedelta(minutes=5) - while fake_start_time < last_run_start_time + timedelta(hours=1): - session.add(StatisticsRuns(start=fake_start_time)) - fake_start_time += timedelta(minutes=5) + with session_scope(session=instance.get_session()) as session: + if session.query(Statistics.id).count() and ( + last_run_string := session.query( + func.max(StatisticsRuns.start) + ).scalar() + ): + last_run_start_time = process_timestamp(last_run_string) + if last_run_start_time: + fake_start_time = last_run_start_time + timedelta(minutes=5) + while fake_start_time < last_run_start_time + timedelta(hours=1): + session.add(StatisticsRuns(start=fake_start_time)) + fake_start_time += timedelta(minutes=5) # When querying the database, be careful to only explicitly query for columns # which were present in schema version 21. If querying the table, SQLAlchemy # will refer to future columns. - for sum_statistic in session.query(StatisticsMeta.id).filter_by(has_sum=true()): - last_statistic = ( - session.query( - Statistics.start, - Statistics.last_reset, - Statistics.state, - Statistics.sum, - ) - .filter_by(metadata_id=sum_statistic.id) - .order_by(Statistics.start.desc()) - .first() - ) - if last_statistic: - session.add( - StatisticsShortTerm( - metadata_id=sum_statistic.id, - start=last_statistic.start, - last_reset=last_statistic.last_reset, - state=last_statistic.state, - sum=last_statistic.sum, + with session_scope(session=instance.get_session()) as session: + for sum_statistic in session.query(StatisticsMeta.id).filter_by( + has_sum=true() + ): + last_statistic = ( + session.query( + Statistics.start, + Statistics.last_reset, + Statistics.state, + Statistics.sum, ) + .filter_by(metadata_id=sum_statistic.id) + .order_by(Statistics.start.desc()) + .first() ) + if last_statistic: + session.add( + StatisticsShortTerm( + metadata_id=sum_statistic.id, + start=last_statistic.start, + last_reset=last_statistic.last_reset, + state=last_statistic.state, + sum=last_statistic.sum, + ) + ) elif new_version == 23: # Add name column to StatisticsMeta - _add_columns(session, "statistics_meta", ["name VARCHAR(255)"]) + _add_columns(instance, "statistics_meta", ["name VARCHAR(255)"]) elif new_version == 24: # Delete duplicated statistics - delete_duplicates(instance, session) + with session_scope(session=instance.get_session()) as session: + delete_duplicates(instance, session) # Recreate statistics indices to block duplicated statistics - _drop_index(connection, "statistics", "ix_statistics_statistic_id_start") - _create_index(connection, "statistics", "ix_statistics_statistic_id_start") + _drop_index(instance, "statistics", "ix_statistics_statistic_id_start") + _create_index(instance, "statistics", "ix_statistics_statistic_id_start") _drop_index( - connection, + instance, "statistics_short_term", "ix_statistics_short_term_statistic_id_start", ) _create_index( - connection, + instance, "statistics_short_term", "ix_statistics_short_term_statistic_id_start", ) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 5c8a1c556c9..5e837eb36ac 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -5,7 +5,7 @@ import importlib import sqlite3 import sys import threading -from unittest.mock import ANY, Mock, PropertyMock, call, patch +from unittest.mock import Mock, PropertyMock, call, patch import pytest from sqlalchemy import create_engine, text @@ -57,7 +57,7 @@ async def test_schema_update_calls(hass): assert recorder.util.async_migration_in_progress(hass) is False update.assert_has_calls( [ - call(hass.data[DATA_INSTANCE], ANY, version + 1, 0) + call(hass.data[DATA_INSTANCE], version + 1, 0) for version in range(0, models.SCHEMA_VERSION) ] ) @@ -309,7 +309,7 @@ async def test_schema_migrate(hass, start_version): def test_invalid_update(): """Test that an invalid new version raises an exception.""" with pytest.raises(ValueError): - migration._apply_update(Mock(), Mock(), -1, 0) + migration._apply_update(Mock(), -1, 0) @pytest.mark.parametrize( @@ -324,9 +324,13 @@ def test_invalid_update(): def test_modify_column(engine_type, substr): """Test that modify column generates the expected query.""" connection = Mock() + session = Mock() + session.connection = Mock(return_value=connection) + instance = Mock() + instance.get_session = Mock(return_value=session) engine = Mock() engine.dialect.name = engine_type - migration._modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"]) + migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) if substr: assert substr in connection.execute.call_args[0][0].text else: @@ -338,8 +342,10 @@ def test_forgiving_add_column(): engine = create_engine("sqlite://", poolclass=StaticPool) with Session(engine) as session: session.execute(text("CREATE TABLE hello (id int)")) - migration._add_columns(session, "hello", ["context_id CHARACTER(36)"]) - migration._add_columns(session, "hello", ["context_id CHARACTER(36)"]) + instance = Mock() + instance.get_session = Mock(return_value=session) + migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) + migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) def test_forgiving_add_index(): @@ -347,7 +353,9 @@ def test_forgiving_add_index(): engine = create_engine("sqlite://", poolclass=StaticPool) models.Base.metadata.create_all(engine) with Session(engine) as session: - migration._create_index(session, "states", "ix_states_context_id") + instance = Mock() + instance.get_session = Mock(return_value=session) + migration._create_index(instance, "states", "ix_states_context_id") @pytest.mark.parametrize(