Correct initialization of new databases (#80234)

pull/80262/head
Erik Montnemery 2022-10-13 13:01:27 +02:00 committed by GitHub
parent acb1477673
commit 04cc2ae264
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 21 deletions

View File

@ -703,7 +703,7 @@ class Recorder(threading.Thread):
while tries <= self.db_max_retries:
try:
self._setup_connection()
return True
return migration.initialize_database(self.get_session)
except UnsupportedDialect:
break
except Exception as err: # pylint: disable=broad-except

View File

@ -6,7 +6,7 @@ import contextlib
from dataclasses import dataclass
from datetime import timedelta
import logging
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -62,24 +62,17 @@ def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str])
raise ex
def _get_schema_version(session: Session) -> int | None:
"""Get the schema version."""
res = session.query(SchemaChanges).order_by(SchemaChanges.change_id.desc()).first()
return getattr(res, "schema_version", None)
def get_schema_version(session_maker: Callable[[], Session]) -> int | None:
"""Get the schema version."""
try:
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
.first()
)
current_version = getattr(res, "schema_version", None)
if current_version is None:
current_version = _inspect_schema_version(session)
_LOGGER.debug(
"No schema version found. Inspected version: %s", current_version
)
return cast(int, current_version)
return _get_schema_version(session)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when determining DB schema version: %s", err)
return None
@ -797,8 +790,10 @@ def _apply_update( # noqa: C901
raise ValueError(f"No schema migration defined for version {new_version}")
def _inspect_schema_version(session: Session) -> int:
"""Determine the schema version by inspecting the db structure.
def _initialize_database(session: Session) -> bool:
"""Initialize a new database, or a database created before introducing schema changes.
The function determines the schema version by inspecting the db structure.
When the schema version is not present in the db, either db was just
created with the correct schema, or this is a db created before schema
@ -814,9 +809,22 @@ def _inspect_schema_version(session: Session) -> int:
# Schema addition from version 1 detected. New DB.
session.add(StatisticsRuns(start=get_start_time()))
session.add(SchemaChanges(schema_version=SCHEMA_VERSION))
return SCHEMA_VERSION
return True
# Version 1 schema changes not found, this db needs to be migrated.
current_version = SchemaChanges(schema_version=0)
session.add(current_version)
return cast(int, current_version.schema_version)
return True
def initialize_database(session_maker: Callable[[], Session]) -> bool:
"""Initialize a new database, or a database created before introducing schema changes."""
try:
with session_scope(session=session_maker()) as session:
if _get_schema_version(session) is not None:
return True
return _initialize_database(session)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error when initialise database: %s", err)
return False

View File

@ -669,7 +669,7 @@ def test_recorder_validate_schema_failure(hass):
"""Test some exceptions."""
recorder_helper.async_initialize_recorder(hass)
with patch(
"homeassistant.components.recorder.migration._inspect_schema_version"
"homeassistant.components.recorder.migration._get_schema_version"
) as inspect_schema_version, patch(
"homeassistant.components.recorder.core.time.sleep"
):