diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 14b9fe11574..8d3213a4805 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -345,15 +345,16 @@ class Recorder(threading.Thread): def _apply_update(self, new_version): """Perform operations to bring schema up to date.""" - from sqlalchemy import Index, Table + from sqlalchemy import Table import homeassistant.components.recorder.models as models if new_version == 1: def create_index(table_name, column_name): """Create an index for the specified table and column.""" table = Table(table_name, models.Base.metadata) - index_name = "_".join(("ix", table_name, column_name)) - index = Index(index_name, getattr(table.c, column_name)) + name = "_".join(("ix", table_name, column_name)) + # Look up the index object that was created from the models + index = next(idx for idx in table.indexes if idx.name == name) _LOGGER.debug("Creating index for table %s column %s", table_name, column_name) index.create(self.engine) diff --git a/tests/components/recorder/models_original.py b/tests/components/recorder/models_original.py new file mode 100644 index 00000000000..31ec5ee7ed7 --- /dev/null +++ b/tests/components/recorder/models_original.py @@ -0,0 +1,163 @@ +"""Models for SQLAlchemy. + +This file contains the original models definitions before schema tracking was +implemented. It is used to test the schema migration logic. +""" + +import json +from datetime import datetime +import logging + +from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, + String, Text, distinct) +from sqlalchemy.ext.declarative import declarative_base + +import homeassistant.util.dt as dt_util +from homeassistant.core import Event, EventOrigin, State, split_entity_id +from homeassistant.remote import JSONEncoder + +# SQLAlchemy Schema +# pylint: disable=invalid-name +Base = declarative_base() + +_LOGGER = logging.getLogger(__name__) + + +class Events(Base): # type: ignore + """Event history data.""" + + __tablename__ = 'events' + event_id = Column(Integer, primary_key=True) + event_type = Column(String(32), index=True) + event_data = Column(Text) + origin = Column(String(32)) + time_fired = Column(DateTime(timezone=True)) + created = Column(DateTime(timezone=True), default=datetime.utcnow) + + @staticmethod + def from_event(event): + """Create an event database object from a native event.""" + return Events(event_type=event.event_type, + event_data=json.dumps(event.data, cls=JSONEncoder), + origin=str(event.origin), + time_fired=event.time_fired) + + def to_native(self): + """Convert to a natve HA Event.""" + try: + return Event( + self.event_type, + json.loads(self.event_data), + EventOrigin(self.origin), + _process_timestamp(self.time_fired) + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class States(Base): # type: ignore + """State change history.""" + + __tablename__ = 'states' + state_id = Column(Integer, primary_key=True) + domain = Column(String(64)) + entity_id = Column(String(255)) + state = Column(String(255)) + attributes = Column(Text) + event_id = Column(Integer, ForeignKey('events.event_id')) + last_changed = Column(DateTime(timezone=True), default=datetime.utcnow) + last_updated = Column(DateTime(timezone=True), default=datetime.utcnow) + created = Column(DateTime(timezone=True), default=datetime.utcnow) + + __table_args__ = (Index('states__state_changes', + 'last_changed', 'last_updated', 'entity_id'), + Index('states__significant_changes', + 'domain', 'last_updated', 'entity_id'), ) + + @staticmethod + def from_event(event): + """Create object from a state_changed event.""" + entity_id = event.data['entity_id'] + state = event.data.get('new_state') + + dbstate = States(entity_id=entity_id) + + # State got deleted + if state is None: + dbstate.state = '' + dbstate.domain = split_entity_id(entity_id)[0] + dbstate.attributes = '{}' + dbstate.last_changed = event.time_fired + dbstate.last_updated = event.time_fired + else: + dbstate.domain = state.domain + dbstate.state = state.state + dbstate.attributes = json.dumps(dict(state.attributes), + cls=JSONEncoder) + dbstate.last_changed = state.last_changed + dbstate.last_updated = state.last_updated + + return dbstate + + def to_native(self): + """Convert to an HA state object.""" + try: + return State( + self.entity_id, self.state, + json.loads(self.attributes), + _process_timestamp(self.last_changed), + _process_timestamp(self.last_updated) + ) + except ValueError: + # When json.loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + + +class RecorderRuns(Base): # type: ignore + """Representation of recorder run.""" + + __tablename__ = 'recorder_runs' + run_id = Column(Integer, primary_key=True) + start = Column(DateTime(timezone=True), default=datetime.utcnow) + end = Column(DateTime(timezone=True)) + closed_incorrect = Column(Boolean, default=False) + created = Column(DateTime(timezone=True), default=datetime.utcnow) + + def entity_ids(self, point_in_time=None): + """Return the entity ids that existed in this run. + + Specify point_in_time if you want to know which existed at that point + in time inside the run. + """ + from sqlalchemy.orm.session import Session + + session = Session.object_session(self) + + assert session is not None, 'RecorderRuns need to be persisted' + + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] + + def to_native(self): + """Return self, native format is this model.""" + return self + + +def _process_timestamp(ts): + """Process a timestamp into datetime object.""" + if ts is None: + return None + elif ts.tzinfo is None: + return dt_util.UTC.localize(ts) + else: + return dt_util.as_utc(ts) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index ce395044d11..0bfa3a20997 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -6,15 +6,18 @@ import unittest from unittest.mock import patch, call, MagicMock import pytest +from sqlalchemy import create_engine + from homeassistant.core import callback from homeassistant.const import MATCH_ALL from homeassistant.components import recorder from homeassistant.bootstrap import setup_component from tests.common import get_test_home_assistant +from tests.components.recorder import models_original -class TestRecorder(unittest.TestCase): - """Test the recorder module.""" +class BaseTestRecorder(unittest.TestCase): + """Base class for common recorder tests.""" def setUp(self): # pylint: disable=invalid-name """Setup things to be run when tests are started.""" @@ -87,6 +90,10 @@ class TestRecorder(unittest.TestCase): time_fired=timestamp, )) + +class TestRecorder(BaseTestRecorder): + """Test the recorder module.""" + def test_saving_state(self): """Test saving and restoring a state.""" entity_id = 'test.recorder' @@ -205,15 +212,48 @@ class TestRecorder(unittest.TestCase): with self.assertRaises(ValueError): recorder._INSTANCE._apply_update(-1) + +def create_engine_test(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + engine = create_engine(*args, **kwargs) + models_original.Base.metadata.create_all(engine) + return engine + + +class TestMigrateRecorder(BaseTestRecorder): + """Test recorder class that starts with an original schema db.""" + + @patch('sqlalchemy.create_engine', new=create_engine_test) + @patch('homeassistant.components.recorder.Recorder._migrate_schema') + def setUp(self, migrate): # pylint: disable=invalid-name + """Setup things to be run when tests are started. + + create_engine is patched to create a db that starts with the old + schema. + + _migrate_schema is mocked to ensure it isn't run, so we can test it + below. + """ + super().setUp() + def test_schema_update_calls(self): # pylint: disable=no-self-use """Test that schema migrations occurr in correct order.""" - test_version = recorder.models.SchemaChanges(schema_version=0) - with recorder.session_scope() as session: - session.add(test_version) - with patch.object(recorder._INSTANCE, '_apply_update') as update: - recorder._INSTANCE._migrate_schema() - update.assert_has_calls([call(version+1) for version in range( - 0, recorder.models.SCHEMA_VERSION)]) + with patch.object(recorder._INSTANCE, '_apply_update') as update: + recorder._INSTANCE._migrate_schema() + update.assert_has_calls([call(version+1) for version in range( + 0, recorder.models.SCHEMA_VERSION)]) + + def test_schema_migrate(self): # pylint: disable=no-self-use + """Test the full schema migration logic. + + We're just testing that the logic can execute successfully here without + throwing exceptions. Maintaining a set of assertions based on schema + inspection could quickly become quite cumbersome. + """ + recorder._INSTANCE._migrate_schema() @pytest.fixture