diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index 6875eaabcdd..69ed528f661 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -64,7 +64,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None, """ entity_ids = (entity_id.lower(), ) if entity_id is not None else None states = recorder.get_model('States') - query = recorder.query('States').filter( + query = recorder.query(states).filter( (states.domain.in_(SIGNIFICANT_DOMAINS) | (states.last_changed == states.last_updated)) & (states.last_updated > start_time)) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 9040d1f9fde..6577d4af91d 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -13,6 +13,7 @@ import threading import time from datetime import timedelta, datetime from typing import Any, Union, Optional, List, Dict +from contextlib import contextmanager import voluptuous as vol @@ -22,7 +23,7 @@ from homeassistant.const import ( CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.event import track_point_in_utc_time +from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType, QueryType import homeassistant.util.dt as dt_util @@ -39,6 +40,7 @@ CONF_PURGE_DAYS = 'purge_days' RETRIES = 3 CONNECT_RETRY_WAIT = 10 QUERY_RETRY_WAIT = 0.1 +ERROR_QUERY = "Error during query: %s" CONFIG_SCHEMA = vol.Schema({ DOMAIN: vol.Schema({ @@ -62,28 +64,43 @@ _INSTANCE = None # type: Any _LOGGER = logging.getLogger(__name__) # These classes will be populated during setup() -# pylint: disable=invalid-name,no-member -Session = None # pylint: disable=no-member +# scoped_session, in the same thread session_scope() stays the same +_SESSION = None + + +@contextmanager +def session_scope(): + """Provide a transactional scope around a series of operations.""" + session = _SESSION() + try: + yield session + session.commit() + except Exception as err: # pylint: disable=broad-except + _LOGGER.error(ERROR_QUERY, err) + session.rollback() + raise + finally: + session.close() # pylint: disable=invalid-sequence-index -def execute(q: QueryType) -> List[Any]: +def execute(qry: QueryType) -> List[Any]: """Query the database and convert the objects to HA native form. This method also retries a few times in the case of stale connections. """ import sqlalchemy.exc - try: + with session_scope() as session: for _ in range(0, RETRIES): try: return [ row for row in - (row.to_native() for row in q) + (row.to_native() for row in qry) if row is not None] - except sqlalchemy.exc.SQLAlchemyError as e: - log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) - finally: - Session.close() + except sqlalchemy.exc.SQLAlchemyError as err: + _LOGGER.error(ERROR_QUERY, err) + session.rollback() + time.sleep(QUERY_RETRY_WAIT) return [] @@ -101,9 +118,10 @@ def run_information(point_in_time: Optional[datetime]=None): start=_INSTANCE.recording_start, closed_incorrect=False) - return query('RecorderRuns').filter( - (recorder_runs.start < point_in_time) & - (recorder_runs.end > point_in_time)).first() + with session_scope(): + return query('RecorderRuns').filter( + (recorder_runs.start < point_in_time) & + (recorder_runs.end > point_in_time)).first() def setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -132,10 +150,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool: def query(model_name: Union[str, Any], *args) -> QueryType: """Helper to return a query handle.""" _verify_instance() - if isinstance(model_name, str): - return Session.query(get_model(model_name), *args) - return Session.query(model_name, *args) + return _SESSION().query(get_model(model_name), *args) + return _SESSION().query(model_name, *args) def get_model(model_name: str) -> Any: @@ -148,22 +165,6 @@ def get_model(model_name: str) -> Any: return None -def log_error(e: Exception, retry_wait: Optional[float]=0, - rollback: Optional[bool]=True, - message: Optional[str]="Error during query: %s") -> None: - """Log about SQLAlchemy errors in a sane manner.""" - import sqlalchemy.exc - if not isinstance(e, sqlalchemy.exc.OperationalError): - _LOGGER.exception(str(e)) - else: - _LOGGER.error(message, str(e)) - if rollback: - Session.rollback() - if retry_wait: - _LOGGER.info("Retrying in %s seconds", retry_wait) - time.sleep(retry_wait) - - class Recorder(threading.Thread): """A threaded recorder class.""" @@ -204,18 +205,14 @@ class Recorder(threading.Thread): self._setup_connection() self._setup_run() break - except sqlalchemy.exc.SQLAlchemyError as e: - log_error(e, retry_wait=CONNECT_RETRY_WAIT, rollback=False, - message="Error during connection setup: %s") + except sqlalchemy.exc.SQLAlchemyError as err: + _LOGGER.error("Error during connection setup: %s (retrying " + "in %s seconds)", err, CONNECT_RETRY_WAIT) + time.sleep(CONNECT_RETRY_WAIT) if self.purge_days is not None: - def purge_ticker(event): - """Rerun purge every second day.""" - self._purge_old_data() - track_point_in_utc_time(self.hass, purge_ticker, - dt_util.utcnow() + timedelta(days=2)) - track_point_in_utc_time(self.hass, purge_ticker, - dt_util.utcnow() + timedelta(minutes=5)) + async_track_time_interval( + self.hass, self._purge_old_data, timedelta(days=2)) while True: event = self.queue.get() @@ -250,16 +247,17 @@ class Recorder(threading.Thread): self.queue.task_done() continue - dbevent = Events.from_event(event) - self._commit(dbevent) + with session_scope() as session: + dbevent = Events.from_event(event) + self._commit(session, dbevent) - if event.event_type != EVENT_STATE_CHANGED: - self.queue.task_done() - continue + if event.event_type != EVENT_STATE_CHANGED: + self.queue.task_done() + continue - dbstate = States.from_event(event) - dbstate.event_id = dbevent.event_id - self._commit(dbstate) + dbstate = States.from_event(event) + dbstate.event_id = dbevent.event_id + self._commit(session, dbstate) self.queue.task_done() @@ -282,11 +280,14 @@ class Recorder(threading.Thread): def block_till_db_ready(self): """Block until the database session is ready.""" - self.db_ready.wait() + self.db_ready.wait(10) + while not self.db_ready.is_set(): + _LOGGER.warning('Database not ready, waiting another 10 seconds.') + self.db_ready.wait(10) def _setup_connection(self): """Ensure database is ready to fly.""" - global Session # pylint: disable=global-statement + global _SESSION # pylint: disable=invalid-name,global-statement import homeassistant.components.recorder.models as models from sqlalchemy import create_engine @@ -298,40 +299,44 @@ class Recorder(threading.Thread): self.engine = create_engine( 'sqlite://', connect_args={'check_same_thread': False}, - poolclass=StaticPool) + poolclass=StaticPool, + pool_reset_on_return=None) else: self.engine = create_engine(self.db_url, echo=False) models.Base.metadata.create_all(self.engine) session_factory = sessionmaker(bind=self.engine) - Session = scoped_session(session_factory) + _SESSION = scoped_session(session_factory) self._migrate_schema() self.db_ready.set() def _migrate_schema(self): """Check if the schema needs to be upgraded.""" - import homeassistant.components.recorder.models as models - schema_changes = models.SchemaChanges - current_version = getattr(Session.query(schema_changes).order_by( - schema_changes.change_id.desc()).first(), 'schema_version', None) + from homeassistant.components.recorder.models import SCHEMA_VERSION + schema_changes = get_model('SchemaChanges') + with session_scope() as session: + res = session.query(schema_changes).order_by( + schema_changes.change_id.desc()).first() + current_version = getattr(res, 'schema_version', None) - if current_version == models.SCHEMA_VERSION: - return - _LOGGER.debug("Schema version incorrect: %d", current_version) + if current_version == SCHEMA_VERSION: + return + _LOGGER.debug("Schema version incorrect: %s", current_version) - if current_version is None: - current_version = self._inspect_schema_version() - _LOGGER.debug("No schema version found. Inspected version: %d", - current_version) + if current_version is None: + current_version = self._inspect_schema_version() + _LOGGER.debug("No schema version found. Inspected version: %s", + current_version) - for version in range(current_version, models.SCHEMA_VERSION): - new_version = version + 1 - _LOGGER.info( - "Upgrading recorder db schema to version %d", new_version) - self._apply_update(new_version) - self._commit(schema_changes(schema_version=new_version)) - _LOGGER.info( - "Upgraded recorder db schema to version %d", new_version) + for version in range(current_version, SCHEMA_VERSION): + new_version = version + 1 + _LOGGER.info("Upgrading recorder db schema to version %s", + new_version) + self._apply_update(new_version) + self._commit(session, + schema_changes(schema_version=new_version)) + _LOGGER.info("Upgraded recorder db schema to version %s", + new_version) def _apply_update(self, new_version): """Perform operations to bring schema up to date.""" @@ -368,51 +373,54 @@ class Recorder(threading.Thread): import homeassistant.components.recorder.models as models inspector = reflection.Inspector.from_engine(self.engine) indexes = inspector.get_indexes("events") - for index in indexes: - if index['column_names'] == ["time_fired"]: - # Schema addition from version 1 detected. This is a new db. - current_version = models.SchemaChanges( - schema_version=models.SCHEMA_VERSION) - self._commit(current_version) - return models.SCHEMA_VERSION + with session_scope() as session: + for index in indexes: + if index['column_names'] == ["time_fired"]: + # Schema addition from version 1 detected. New DB. + current_version = models.SchemaChanges( + schema_version=models.SCHEMA_VERSION) + self._commit(session, current_version) + return models.SCHEMA_VERSION - # Version 1 schema changes not found, this db needs to be migrated. - current_version = models.SchemaChanges(schema_version=0) - self._commit(current_version) - return current_version.schema_version + # Version 1 schema changes not found, this db needs to be migrated. + current_version = models.SchemaChanges(schema_version=0) + self._commit(session, current_version) + return current_version.schema_version def _close_connection(self): """Close the connection.""" - global Session # pylint: disable=global-statement + global _SESSION # pylint: disable=invalid-name,global-statement self.engine.dispose() self.engine = None - Session = None + _SESSION = None def _setup_run(self): """Log the start of the current run.""" recorder_runs = get_model('RecorderRuns') - for run in query('RecorderRuns').filter_by(end=None): - run.closed_incorrect = True - run.end = self.recording_start - _LOGGER.warning("Ended unfinished session (id=%s from %s)", - run.run_id, run.start) - Session.add(run) + with session_scope() as session: + for run in query('RecorderRuns').filter_by(end=None): + run.closed_incorrect = True + run.end = self.recording_start + _LOGGER.warning("Ended unfinished session (id=%s from %s)", + run.run_id, run.start) + session.add(run) - _LOGGER.warning("Found unfinished sessions") + _LOGGER.warning("Found unfinished sessions") - self._run = recorder_runs( - start=self.recording_start, - created=dt_util.utcnow() - ) - self._commit(self._run) + self._run = recorder_runs( + start=self.recording_start, + created=dt_util.utcnow() + ) + self._commit(session, self._run) def _close_run(self): """Save end time for current run.""" self._run.end = dt_util.utcnow() - self._commit(self._run) + with session_scope() as session: + self._commit(session, self._run) self._run = None - def _purge_old_data(self): + def _purge_old_data(self, _=None): """Purge events and states older than purge_days ago.""" from homeassistant.components.recorder.models import Events, States @@ -429,8 +437,9 @@ class Recorder(threading.Thread): .delete(synchronize_session=False) _LOGGER.debug("Deleted %s states", deleted_rows) - if self._commit(_purge_states): - _LOGGER.info("Purged states created before %s", purge_before) + with session_scope() as session: + if self._commit(session, _purge_states): + _LOGGER.info("Purged states created before %s", purge_before) def _purge_events(session): deleted_rows = session.query(Events) \ @@ -438,10 +447,9 @@ class Recorder(threading.Thread): .delete(synchronize_session=False) _LOGGER.debug("Deleted %s events", deleted_rows) - if self._commit(_purge_events): - _LOGGER.info("Purged events created before %s", purge_before) - - Session.expire_all() + with session_scope() as session: + if self._commit(session, _purge_events): + _LOGGER.info("Purged events created before %s", purge_before) # Execute sqlite vacuum command to free up space on disk if self.engine.driver == 'sqlite': @@ -449,10 +457,9 @@ class Recorder(threading.Thread): self.engine.execute("VACUUM") @staticmethod - def _commit(work): + def _commit(session, work): """Commit & retry work: Either a model or in a function.""" import sqlalchemy.exc - session = Session() for _ in range(0, RETRIES): try: if callable(work): @@ -461,8 +468,10 @@ class Recorder(threading.Thread): session.add(work) session.commit() return True - except sqlalchemy.exc.OperationalError as e: - log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True) + except sqlalchemy.exc.OperationalError as err: + _LOGGER.error(ERROR_QUERY, err) + session.rollback() + time.sleep(QUERY_RETRY_WAIT) return False diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index f729303c685..ce395044d11 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -3,7 +3,7 @@ import json from datetime import datetime, timedelta import unittest -from unittest.mock import patch, call +from unittest.mock import patch, call, MagicMock import pytest from homeassistant.core import callback @@ -24,7 +24,6 @@ class TestRecorder(unittest.TestCase): recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}}) self.hass.start() recorder._verify_instance() - self.session = recorder.Session() recorder._INSTANCE.block_till_done() def tearDown(self): # pylint: disable=invalid-name @@ -42,26 +41,25 @@ class TestRecorder(unittest.TestCase): self.hass.block_till_done() recorder._INSTANCE.block_till_done() - for event_id in range(5): - if event_id < 3: - timestamp = five_days_ago - state = 'purgeme' - else: - timestamp = now - state = 'dontpurgeme' + with recorder.session_scope() as session: + for event_id in range(5): + if event_id < 3: + timestamp = five_days_ago + state = 'purgeme' + else: + timestamp = now + state = 'dontpurgeme' - self.session.add(recorder.get_model('States')( - entity_id='test.recorder2', - domain='sensor', - state=state, - attributes=json.dumps(attributes), - last_changed=timestamp, - last_updated=timestamp, - created=timestamp, - event_id=event_id + 1000 - )) - - self.session.commit() + session.add(recorder.get_model('States')( + entity_id='test.recorder2', + domain='sensor', + state=state, + attributes=json.dumps(attributes), + last_changed=timestamp, + last_updated=timestamp, + created=timestamp, + event_id=event_id + 1000 + )) def _add_test_events(self): """Add a few events for testing.""" @@ -71,21 +69,23 @@ class TestRecorder(unittest.TestCase): self.hass.block_till_done() recorder._INSTANCE.block_till_done() - for event_id in range(5): - if event_id < 2: - timestamp = five_days_ago - event_type = 'EVENT_TEST_PURGE' - else: - timestamp = now - event_type = 'EVENT_TEST' - self.session.add(recorder.get_model('Events')( - event_type=event_type, - event_data=json.dumps(event_data), - origin='LOCAL', - created=timestamp, - time_fired=timestamp, - )) + with recorder.session_scope() as session: + for event_id in range(5): + if event_id < 2: + timestamp = five_days_ago + event_type = 'EVENT_TEST_PURGE' + else: + timestamp = now + event_type = 'EVENT_TEST' + + session.add(recorder.get_model('Events')( + event_type=event_type, + event_data=json.dumps(event_data), + origin='LOCAL', + created=timestamp, + time_fired=timestamp, + )) def test_saving_state(self): """Test saving and restoring a state.""" @@ -205,14 +205,15 @@ class TestRecorder(unittest.TestCase): with self.assertRaises(ValueError): recorder._INSTANCE._apply_update(-1) - def test_schema_update_calls(self): + def test_schema_update_calls(self): # pylint: disable=no-self-use """Test that schema migrations occurr in correct order.""" test_version = recorder.models.SchemaChanges(schema_version=0) - self.session.add(test_version) - with patch.object(recorder._INSTANCE, '_apply_update') as update: - recorder._INSTANCE._migrate_schema() - update.assert_has_calls([call(version+1) for version in range( - 0, recorder.models.SCHEMA_VERSION)]) + with recorder.session_scope() as session: + session.add(test_version) + with patch.object(recorder._INSTANCE, '_apply_update') as update: + recorder._INSTANCE._migrate_schema() + update.assert_has_calls([call(version+1) for version in range( + 0, recorder.models.SCHEMA_VERSION)]) @pytest.fixture @@ -220,7 +221,7 @@ def hass_recorder(): """HASS fixture with in-memory recorder.""" hass = get_test_home_assistant() - def setup_recorder(config): + def setup_recorder(config={}): """Setup with params.""" db_uri = 'sqlite://' # In memory DB conf = {recorder.CONF_DB_URL: db_uri} @@ -301,3 +302,61 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder): assert len(states) == 1 assert hass.states.get('test.ok') == states[0] assert hass.states.get('test.ok').state == 'state2' + + +def test_recorder_errors_exceptions(hass_recorder): \ + # pylint: disable=redefined-outer-name + """Test session_scope and get_model errors.""" + # Model cannot be resolved + assert recorder.get_model('dont-exist') is None + + # Verify the instance fails before setup + with pytest.raises(RuntimeError): + recorder._verify_instance() + + # Setup the recorder + hass_recorder() + + recorder._verify_instance() + + # Verify session scope raises (and prints) an exception + with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \ + pytest.raises(Exception) as err: + with recorder.session_scope() as session: + session.execute('select * from notthere') + assert e_mock.call_count == 1 + assert recorder.ERROR_QUERY[:-4] in e_mock.call_args[0][0] + assert 'no such table' in str(err.value) + + +def test_recorder_bad_commit(hass_recorder): + """Bad _commit should retry 3 times.""" + hass_recorder() + + def work(session): + """Bad work.""" + session.execute('select * from notthere') + + with patch('homeassistant.components.recorder.time.sleep') as e_mock, \ + recorder.session_scope() as session: + res = recorder._INSTANCE._commit(session, work) + assert res is False + assert e_mock.call_count == 3 + + +def test_recorder_bad_execute(hass_recorder): + """Bad execute, retry 3 times.""" + hass_recorder() + + def to_native(): + """Rasie exception.""" + from sqlalchemy.exc import SQLAlchemyError + raise SQLAlchemyError() + + mck1 = MagicMock() + mck1.to_native = to_native + + with patch('homeassistant.components.recorder.time.sleep') as e_mock: + res = recorder.execute((mck1,)) + assert res == [] + assert e_mock.call_count == 3