Don't use shared session during recorder migration (#65672)

pull/65713/head
Erik Montnemery 2022-02-04 18:55:11 +01:00 committed by Paulus Schoutsen
parent 4e3cd1471a
commit 9cd6bb7335
2 changed files with 176 additions and 143 deletions

View File

@ -68,20 +68,18 @@ def schema_is_current(current_version):
def migrate_schema(instance, current_version):
"""Check if the schema needs to be upgraded."""
with session_scope(session=instance.get_session()) as session:
_LOGGER.warning(
"Database is about to upgrade. Schema version: %s", current_version
)
for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance, session, new_version, current_version)
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance, new_version, current_version)
with session_scope(session=instance.get_session()) as session:
session.add(SchemaChanges(schema_version=new_version))
_LOGGER.info("Upgrade to version %s done", new_version)
_LOGGER.info("Upgrade to version %s done", new_version)
def _create_index(connection, table_name, index_name):
def _create_index(instance, table_name, index_name):
"""Create an index for the specified table.
The index name should match the name given for the index
@ -103,7 +101,9 @@ def _create_index(connection, table_name, index_name):
index_name,
)
try:
index.create(connection)
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
index.create(connection)
except (InternalError, ProgrammingError, OperationalError) as err:
raise_if_exception_missing_str(err, ["already exists", "duplicate"])
_LOGGER.warning(
@ -113,7 +113,7 @@ def _create_index(connection, table_name, index_name):
_LOGGER.debug("Finished creating %s", index_name)
def _drop_index(connection, table_name, index_name):
def _drop_index(instance, table_name, index_name):
"""Drop an index from a specified table.
There is no universal way to do something like `DROP INDEX IF EXISTS`
@ -129,7 +129,9 @@ def _drop_index(connection, table_name, index_name):
# Engines like DB2/Oracle
try:
connection.execute(text(f"DROP INDEX {index_name}"))
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(text(f"DROP INDEX {index_name}"))
except SQLAlchemyError:
pass
else:
@ -138,13 +140,15 @@ def _drop_index(connection, table_name, index_name):
# Engines like SQLite, SQL Server
if not success:
try:
connection.execute(
text(
"DROP INDEX {table}.{index}".format(
index=index_name, table=table_name
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"DROP INDEX {table}.{index}".format(
index=index_name, table=table_name
)
)
)
)
except SQLAlchemyError:
pass
else:
@ -153,13 +157,15 @@ def _drop_index(connection, table_name, index_name):
if not success:
# Engines like MySQL, MS Access
try:
connection.execute(
text(
"DROP INDEX {index} ON {table}".format(
index=index_name, table=table_name
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"DROP INDEX {index} ON {table}".format(
index=index_name, table=table_name
)
)
)
)
except SQLAlchemyError:
pass
else:
@ -184,7 +190,7 @@ def _drop_index(connection, table_name, index_name):
)
def _add_columns(connection, table_name, columns_def):
def _add_columns(instance, table_name, columns_def):
"""Add columns to a table."""
_LOGGER.warning(
"Adding columns %s to table %s. Note: this can take several "
@ -197,14 +203,16 @@ def _add_columns(connection, table_name, columns_def):
columns_def = [f"ADD {col_def}" for col_def in columns_def]
try:
connection.execute(
text(
"ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def)
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def)
)
)
)
)
return
return
except (InternalError, OperationalError):
# Some engines support adding all columns at once,
# this error is when they don't
@ -212,13 +220,15 @@ def _add_columns(connection, table_name, columns_def):
for column_def in columns_def:
try:
connection.execute(
text(
"ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def
)
)
)
)
except (InternalError, OperationalError) as err:
raise_if_exception_missing_str(err, ["already exists", "duplicate"])
_LOGGER.warning(
@ -228,7 +238,7 @@ def _add_columns(connection, table_name, columns_def):
)
def _modify_columns(connection, engine, table_name, columns_def):
def _modify_columns(instance, engine, table_name, columns_def):
"""Modify columns in a table."""
if engine.dialect.name == "sqlite":
_LOGGER.debug(
@ -261,33 +271,37 @@ def _modify_columns(connection, engine, table_name, columns_def):
columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
try:
connection.execute(
text(
"ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def)
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"ALTER TABLE {table} {columns_def}".format(
table=table_name, columns_def=", ".join(columns_def)
)
)
)
)
return
return
except (InternalError, OperationalError):
_LOGGER.info("Unable to use quick column modify. Modifying 1 by 1")
for column_def in columns_def:
try:
connection.execute(
text(
"ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
text(
"ALTER TABLE {table} {column_def}".format(
table=table_name, column_def=column_def
)
)
)
)
except (InternalError, OperationalError):
_LOGGER.exception(
"Could not modify column %s in table %s", column_def, table_name
)
def _update_states_table_with_foreign_key_options(connection, engine):
def _update_states_table_with_foreign_key_options(instance, engine):
"""Add the options to foreign key constraints."""
inspector = sqlalchemy.inspect(engine)
alters = []
@ -316,17 +330,19 @@ def _update_states_table_with_foreign_key_options(connection, engine):
for alter in alters:
try:
connection.execute(DropConstraint(alter["old_fk"]))
for fkc in states_key_constraints:
if fkc.column_keys == alter["columns"]:
connection.execute(AddConstraint(fkc))
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(DropConstraint(alter["old_fk"]))
for fkc in states_key_constraints:
if fkc.column_keys == alter["columns"]:
connection.execute(AddConstraint(fkc))
except (InternalError, OperationalError):
_LOGGER.exception(
"Could not update foreign options in %s table", TABLE_STATES
)
def _drop_foreign_key_constraints(connection, engine, table, columns):
def _drop_foreign_key_constraints(instance, engine, table, columns):
"""Drop foreign key constraints for a table on specific columns."""
inspector = sqlalchemy.inspect(engine)
drops = []
@ -345,7 +361,9 @@ def _drop_foreign_key_constraints(connection, engine, table, columns):
for drop in drops:
try:
connection.execute(DropConstraint(drop))
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(DropConstraint(drop))
except (InternalError, OperationalError):
_LOGGER.exception(
"Could not drop foreign constraints in %s table on %s",
@ -354,17 +372,16 @@ def _drop_foreign_key_constraints(connection, engine, table, columns):
)
def _apply_update(instance, session, new_version, old_version): # noqa: C901
def _apply_update(instance, new_version, old_version): # noqa: C901
"""Perform operations to bring schema up to date."""
engine = instance.engine
connection = session.connection()
if new_version == 1:
_create_index(connection, "events", "ix_events_time_fired")
_create_index(instance, "events", "ix_events_time_fired")
elif new_version == 2:
# Create compound start/end index for recorder_runs
_create_index(connection, "recorder_runs", "ix_recorder_runs_start_end")
_create_index(instance, "recorder_runs", "ix_recorder_runs_start_end")
# Create indexes for states
_create_index(connection, "states", "ix_states_last_updated")
_create_index(instance, "states", "ix_states_last_updated")
elif new_version == 3:
# There used to be a new index here, but it was removed in version 4.
pass
@ -374,41 +391,41 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
if old_version == 3:
# Remove index that was added in version 3
_drop_index(connection, "states", "ix_states_created_domain")
_drop_index(instance, "states", "ix_states_created_domain")
if old_version == 2:
# Remove index that was added in version 2
_drop_index(connection, "states", "ix_states_entity_id_created")
_drop_index(instance, "states", "ix_states_entity_id_created")
# Remove indexes that were added in version 0
_drop_index(connection, "states", "states__state_changes")
_drop_index(connection, "states", "states__significant_changes")
_drop_index(connection, "states", "ix_states_entity_id_created")
_drop_index(instance, "states", "states__state_changes")
_drop_index(instance, "states", "states__significant_changes")
_drop_index(instance, "states", "ix_states_entity_id_created")
_create_index(connection, "states", "ix_states_entity_id_last_updated")
_create_index(instance, "states", "ix_states_entity_id_last_updated")
elif new_version == 5:
# Create supporting index for States.event_id foreign key
_create_index(connection, "states", "ix_states_event_id")
_create_index(instance, "states", "ix_states_event_id")
elif new_version == 6:
_add_columns(
session,
instance,
"events",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
)
_create_index(connection, "events", "ix_events_context_id")
_create_index(connection, "events", "ix_events_context_user_id")
_create_index(instance, "events", "ix_events_context_id")
_create_index(instance, "events", "ix_events_context_user_id")
_add_columns(
connection,
instance,
"states",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
)
_create_index(connection, "states", "ix_states_context_id")
_create_index(connection, "states", "ix_states_context_user_id")
_create_index(instance, "states", "ix_states_context_id")
_create_index(instance, "states", "ix_states_context_user_id")
elif new_version == 7:
_create_index(connection, "states", "ix_states_entity_id")
_create_index(instance, "states", "ix_states_entity_id")
elif new_version == 8:
_add_columns(connection, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(connection, "states", ["old_state_id INTEGER"])
_create_index(connection, "events", "ix_events_context_parent_id")
_add_columns(instance, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(instance, "states", ["old_state_id INTEGER"])
_create_index(instance, "events", "ix_events_context_parent_id")
elif new_version == 9:
# We now get the context from events with a join
# since its always there on state_changed events
@ -418,36 +435,36 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
# and we would have to move to something like
# sqlalchemy alembic to make that work
#
_drop_index(connection, "states", "ix_states_context_id")
_drop_index(connection, "states", "ix_states_context_user_id")
_drop_index(instance, "states", "ix_states_context_id")
_drop_index(instance, "states", "ix_states_context_user_id")
# This index won't be there if they were not running
# nightly but we don't treat that as a critical issue
_drop_index(connection, "states", "ix_states_context_parent_id")
_drop_index(instance, "states", "ix_states_context_parent_id")
# Redundant keys on composite index:
# We already have ix_states_entity_id_last_updated
_drop_index(connection, "states", "ix_states_entity_id")
_create_index(connection, "events", "ix_events_event_type_time_fired")
_drop_index(connection, "events", "ix_events_event_type")
_drop_index(instance, "states", "ix_states_entity_id")
_create_index(instance, "events", "ix_events_event_type_time_fired")
_drop_index(instance, "events", "ix_events_event_type")
elif new_version == 10:
# Now done in step 11
pass
elif new_version == 11:
_create_index(connection, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(connection, engine)
_create_index(instance, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(instance, engine)
elif new_version == 12:
if engine.dialect.name == "mysql":
_modify_columns(connection, engine, "events", ["event_data LONGTEXT"])
_modify_columns(connection, engine, "states", ["attributes LONGTEXT"])
_modify_columns(instance, engine, "events", ["event_data LONGTEXT"])
_modify_columns(instance, engine, "states", ["attributes LONGTEXT"])
elif new_version == 13:
if engine.dialect.name == "mysql":
_modify_columns(
connection,
instance,
engine,
"events",
["time_fired DATETIME(6)", "created DATETIME(6)"],
)
_modify_columns(
connection,
instance,
engine,
"states",
[
@ -457,14 +474,12 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
],
)
elif new_version == 14:
_modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"])
_modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
elif new_version == 15:
# This dropped the statistics table, done again in version 18.
pass
elif new_version == 16:
_drop_foreign_key_constraints(
connection, engine, TABLE_STATES, ["old_state_id"]
)
_drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"])
elif new_version == 17:
# This dropped the statistics table, done again in version 18.
pass
@ -489,12 +504,13 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
elif new_version == 19:
# This adds the statistic runs table, insert a fake run to prevent duplicating
# statistics.
session.add(StatisticsRuns(start=get_start_time()))
with session_scope(session=instance.get_session()) as session:
session.add(StatisticsRuns(start=get_start_time()))
elif new_version == 20:
# This changed the precision of statistics from float to double
if engine.dialect.name in ["mysql", "postgresql"]:
_modify_columns(
connection,
instance,
engine,
"statistics",
[
@ -516,14 +532,16 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
table,
)
with contextlib.suppress(SQLAlchemyError):
connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
# https://github.com/home-assistant/core/issues/56104
text(
f"ALTER TABLE {table} CONVERT TO "
"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci LOCK=EXCLUSIVE"
with session_scope(session=instance.get_session()) as session:
connection = session.connection()
connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
# https://github.com/home-assistant/core/issues/56104
text(
f"ALTER TABLE {table} CONVERT TO "
"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci LOCK=EXCLUSIVE"
)
)
)
elif new_version == 22:
# Recreate the all statistics tables for Oracle DB with Identity columns
#
@ -549,57 +567,64 @@ def _apply_update(instance, session, new_version, old_version): # noqa: C901
# Block 5-minute statistics for one hour from the last run, or it will overlap
# with existing hourly statistics. Don't block on a database with no existing
# statistics.
if session.query(Statistics.id).count() and (
last_run_string := session.query(func.max(StatisticsRuns.start)).scalar()
):
last_run_start_time = process_timestamp(last_run_string)
if last_run_start_time:
fake_start_time = last_run_start_time + timedelta(minutes=5)
while fake_start_time < last_run_start_time + timedelta(hours=1):
session.add(StatisticsRuns(start=fake_start_time))
fake_start_time += timedelta(minutes=5)
with session_scope(session=instance.get_session()) as session:
if session.query(Statistics.id).count() and (
last_run_string := session.query(
func.max(StatisticsRuns.start)
).scalar()
):
last_run_start_time = process_timestamp(last_run_string)
if last_run_start_time:
fake_start_time = last_run_start_time + timedelta(minutes=5)
while fake_start_time < last_run_start_time + timedelta(hours=1):
session.add(StatisticsRuns(start=fake_start_time))
fake_start_time += timedelta(minutes=5)
# When querying the database, be careful to only explicitly query for columns
# which were present in schema version 21. If querying the table, SQLAlchemy
# will refer to future columns.
for sum_statistic in session.query(StatisticsMeta.id).filter_by(has_sum=true()):
last_statistic = (
session.query(
Statistics.start,
Statistics.last_reset,
Statistics.state,
Statistics.sum,
)
.filter_by(metadata_id=sum_statistic.id)
.order_by(Statistics.start.desc())
.first()
)
if last_statistic:
session.add(
StatisticsShortTerm(
metadata_id=sum_statistic.id,
start=last_statistic.start,
last_reset=last_statistic.last_reset,
state=last_statistic.state,
sum=last_statistic.sum,
with session_scope(session=instance.get_session()) as session:
for sum_statistic in session.query(StatisticsMeta.id).filter_by(
has_sum=true()
):
last_statistic = (
session.query(
Statistics.start,
Statistics.last_reset,
Statistics.state,
Statistics.sum,
)
.filter_by(metadata_id=sum_statistic.id)
.order_by(Statistics.start.desc())
.first()
)
if last_statistic:
session.add(
StatisticsShortTerm(
metadata_id=sum_statistic.id,
start=last_statistic.start,
last_reset=last_statistic.last_reset,
state=last_statistic.state,
sum=last_statistic.sum,
)
)
elif new_version == 23:
# Add name column to StatisticsMeta
_add_columns(session, "statistics_meta", ["name VARCHAR(255)"])
_add_columns(instance, "statistics_meta", ["name VARCHAR(255)"])
elif new_version == 24:
# Delete duplicated statistics
delete_duplicates(instance, session)
with session_scope(session=instance.get_session()) as session:
delete_duplicates(instance, session)
# Recreate statistics indices to block duplicated statistics
_drop_index(connection, "statistics", "ix_statistics_statistic_id_start")
_create_index(connection, "statistics", "ix_statistics_statistic_id_start")
_drop_index(instance, "statistics", "ix_statistics_statistic_id_start")
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
_drop_index(
connection,
instance,
"statistics_short_term",
"ix_statistics_short_term_statistic_id_start",
)
_create_index(
connection,
instance,
"statistics_short_term",
"ix_statistics_short_term_statistic_id_start",
)

View File

@ -5,7 +5,7 @@ import importlib
import sqlite3
import sys
import threading
from unittest.mock import ANY, Mock, PropertyMock, call, patch
from unittest.mock import Mock, PropertyMock, call, patch
import pytest
from sqlalchemy import create_engine, text
@ -57,7 +57,7 @@ async def test_schema_update_calls(hass):
assert recorder.util.async_migration_in_progress(hass) is False
update.assert_has_calls(
[
call(hass.data[DATA_INSTANCE], ANY, version + 1, 0)
call(hass.data[DATA_INSTANCE], version + 1, 0)
for version in range(0, models.SCHEMA_VERSION)
]
)
@ -309,7 +309,7 @@ async def test_schema_migrate(hass, start_version):
def test_invalid_update():
"""Test that an invalid new version raises an exception."""
with pytest.raises(ValueError):
migration._apply_update(Mock(), Mock(), -1, 0)
migration._apply_update(Mock(), -1, 0)
@pytest.mark.parametrize(
@ -324,9 +324,13 @@ def test_invalid_update():
def test_modify_column(engine_type, substr):
"""Test that modify column generates the expected query."""
connection = Mock()
session = Mock()
session.connection = Mock(return_value=connection)
instance = Mock()
instance.get_session = Mock(return_value=session)
engine = Mock()
engine.dialect.name = engine_type
migration._modify_columns(connection, engine, "events", ["event_type VARCHAR(64)"])
migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
if substr:
assert substr in connection.execute.call_args[0][0].text
else:
@ -338,8 +342,10 @@ def test_forgiving_add_column():
engine = create_engine("sqlite://", poolclass=StaticPool)
with Session(engine) as session:
session.execute(text("CREATE TABLE hello (id int)"))
migration._add_columns(session, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(session, "hello", ["context_id CHARACTER(36)"])
instance = Mock()
instance.get_session = Mock(return_value=session)
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
def test_forgiving_add_index():
@ -347,7 +353,9 @@ def test_forgiving_add_index():
engine = create_engine("sqlite://", poolclass=StaticPool)
models.Base.metadata.create_all(engine)
with Session(engine) as session:
migration._create_index(session, "states", "ix_states_context_id")
instance = Mock()
instance.get_session = Mock(return_value=session)
migration._create_index(instance, "states", "ix_states_context_id")
@pytest.mark.parametrize(