[recorder] Add tests for full schema migration (#5831)

* [recorder] Add tests for full schema migration

* Remove leftover code

* Fix duplicate creation of sqlalchemy Index object

* It's that kind of day...

* Improve models_original docstring
pull/5843/head
Adam Mills 2017-02-09 21:17:17 -05:00 committed by Paulus Schoutsen
parent 4c5e6399e9
commit be08bf0ef7
3 changed files with 216 additions and 12 deletions

View File

@ -345,15 +345,16 @@ class Recorder(threading.Thread):
def _apply_update(self, new_version):
"""Perform operations to bring schema up to date."""
from sqlalchemy import Index, Table
from sqlalchemy import Table
import homeassistant.components.recorder.models as models
if new_version == 1:
def create_index(table_name, column_name):
"""Create an index for the specified table and column."""
table = Table(table_name, models.Base.metadata)
index_name = "_".join(("ix", table_name, column_name))
index = Index(index_name, getattr(table.c, column_name))
name = "_".join(("ix", table_name, column_name))
# Look up the index object that was created from the models
index = next(idx for idx in table.indexes if idx.name == name)
_LOGGER.debug("Creating index for table %s column %s",
table_name, column_name)
index.create(self.engine)

View File

@ -0,0 +1,163 @@
"""Models for SQLAlchemy.
This file contains the original models definitions before schema tracking was
implemented. It is used to test the schema migration logic.
"""
import json
from datetime import datetime
import logging
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
String, Text, distinct)
from sqlalchemy.ext.declarative import declarative_base
import homeassistant.util.dt as dt_util
from homeassistant.core import Event, EventOrigin, State, split_entity_id
from homeassistant.remote import JSONEncoder
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
_LOGGER = logging.getLogger(__name__)
class Events(Base): # type: ignore
"""Event history data."""
__tablename__ = 'events'
event_id = Column(Integer, primary_key=True)
event_type = Column(String(32), index=True)
event_data = Column(Text)
origin = Column(String(32))
time_fired = Column(DateTime(timezone=True))
created = Column(DateTime(timezone=True), default=datetime.utcnow)
@staticmethod
def from_event(event):
"""Create an event database object from a native event."""
return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired)
def to_native(self):
"""Convert to a natve HA Event."""
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
_process_timestamp(self.time_fired)
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class States(Base): # type: ignore
"""State change history."""
__tablename__ = 'states'
state_id = Column(Integer, primary_key=True)
domain = Column(String(64))
entity_id = Column(String(255))
state = Column(String(255))
attributes = Column(Text)
event_id = Column(Integer, ForeignKey('events.event_id'))
last_changed = Column(DateTime(timezone=True), default=datetime.utcnow)
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
__table_args__ = (Index('states__state_changes',
'last_changed', 'last_updated', 'entity_id'),
Index('states__significant_changes',
'domain', 'last_updated', 'entity_id'), )
@staticmethod
def from_event(event):
"""Create object from a state_changed event."""
entity_id = event.data['entity_id']
state = event.data.get('new_state')
dbstate = States(entity_id=entity_id)
# State got deleted
if state is None:
dbstate.state = ''
dbstate.domain = split_entity_id(entity_id)[0]
dbstate.attributes = '{}'
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.domain = state.domain
dbstate.state = state.state
dbstate.attributes = json.dumps(dict(state.attributes),
cls=JSONEncoder)
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self):
"""Convert to an HA state object."""
try:
return State(
self.entity_id, self.state,
json.loads(self.attributes),
_process_timestamp(self.last_changed),
_process_timestamp(self.last_updated)
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class RecorderRuns(Base): # type: ignore
"""Representation of recorder run."""
__tablename__ = 'recorder_runs'
run_id = Column(Integer, primary_key=True)
start = Column(DateTime(timezone=True), default=datetime.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
def entity_ids(self, point_in_time=None):
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
from sqlalchemy.orm.session import Session
session = Session.object_session(self)
assert session is not None, 'RecorderRuns need to be persisted'
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self):
"""Return self, native format is this model."""
return self
def _process_timestamp(ts):
"""Process a timestamp into datetime object."""
if ts is None:
return None
elif ts.tzinfo is None:
return dt_util.UTC.localize(ts)
else:
return dt_util.as_utc(ts)

View File

@ -6,15 +6,18 @@ import unittest
from unittest.mock import patch, call, MagicMock
import pytest
from sqlalchemy import create_engine
from homeassistant.core import callback
from homeassistant.const import MATCH_ALL
from homeassistant.components import recorder
from homeassistant.bootstrap import setup_component
from tests.common import get_test_home_assistant
from tests.components.recorder import models_original
class TestRecorder(unittest.TestCase):
"""Test the recorder module."""
class BaseTestRecorder(unittest.TestCase):
"""Base class for common recorder tests."""
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
@ -87,6 +90,10 @@ class TestRecorder(unittest.TestCase):
time_fired=timestamp,
))
class TestRecorder(BaseTestRecorder):
"""Test the recorder module."""
def test_saving_state(self):
"""Test saving and restoring a state."""
entity_id = 'test.recorder'
@ -205,15 +212,48 @@ class TestRecorder(unittest.TestCase):
with self.assertRaises(ValueError):
recorder._INSTANCE._apply_update(-1)
def create_engine_test(*args, **kwargs):
"""Test version of create_engine that initializes with old schema.
This simulates an existing db with the old schema.
"""
engine = create_engine(*args, **kwargs)
models_original.Base.metadata.create_all(engine)
return engine
class TestMigrateRecorder(BaseTestRecorder):
"""Test recorder class that starts with an original schema db."""
@patch('sqlalchemy.create_engine', new=create_engine_test)
@patch('homeassistant.components.recorder.Recorder._migrate_schema')
def setUp(self, migrate): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.
create_engine is patched to create a db that starts with the old
schema.
_migrate_schema is mocked to ensure it isn't run, so we can test it
below.
"""
super().setUp()
def test_schema_update_calls(self): # pylint: disable=no-self-use
"""Test that schema migrations occurr in correct order."""
test_version = recorder.models.SchemaChanges(schema_version=0)
with recorder.session_scope() as session:
session.add(test_version)
with patch.object(recorder._INSTANCE, '_apply_update') as update:
recorder._INSTANCE._migrate_schema()
update.assert_has_calls([call(version+1) for version in range(
0, recorder.models.SCHEMA_VERSION)])
with patch.object(recorder._INSTANCE, '_apply_update') as update:
recorder._INSTANCE._migrate_schema()
update.assert_has_calls([call(version+1) for version in range(
0, recorder.models.SCHEMA_VERSION)])
def test_schema_migrate(self): # pylint: disable=no-self-use
"""Test the full schema migration logic.
We're just testing that the logic can execute successfully here without
throwing exceptions. Maintaining a set of assertions based on schema
inspection could quickly become quite cumbersome.
"""
recorder._INSTANCE._migrate_schema()
@pytest.fixture