WIP: [component/recorder] Refactoring & better handling of SQLAlchemy Sessions (#5607)

* Refactor recorder and Sessions

* Cover #4352

* NO_reset_on_return

* contextmanager

* coverage
pull/5801/head
Johann Kellerman 2017-02-08 07:47:41 +02:00 committed by Paulus Schoutsen
parent bdebe5d53c
commit 490ef6afad
3 changed files with 221 additions and 153 deletions

View File

@ -64,7 +64,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
"""
entity_ids = (entity_id.lower(), ) if entity_id is not None else None
states = recorder.get_model('States')
query = recorder.query('States').filter(
query = recorder.query(states).filter(
(states.domain.in_(SIGNIFICANT_DOMAINS) |
(states.last_changed == states.last_updated)) &
(states.last_updated > start_time))

View File

@ -13,6 +13,7 @@ import threading
import time
from datetime import timedelta, datetime
from typing import Any, Union, Optional, List, Dict
from contextlib import contextmanager
import voluptuous as vol
@ -22,7 +23,7 @@ from homeassistant.const import (
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import track_point_in_utc_time
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType, QueryType
import homeassistant.util.dt as dt_util
@ -39,6 +40,7 @@ CONF_PURGE_DAYS = 'purge_days'
RETRIES = 3
CONNECT_RETRY_WAIT = 10
QUERY_RETRY_WAIT = 0.1
ERROR_QUERY = "Error during query: %s"
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
@ -62,28 +64,43 @@ _INSTANCE = None # type: Any
_LOGGER = logging.getLogger(__name__)
# These classes will be populated during setup()
# pylint: disable=invalid-name,no-member
Session = None # pylint: disable=no-member
# scoped_session, in the same thread session_scope() stays the same
_SESSION = None
@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
session = _SESSION()
try:
yield session
session.commit()
except Exception as err: # pylint: disable=broad-except
_LOGGER.error(ERROR_QUERY, err)
session.rollback()
raise
finally:
session.close()
# pylint: disable=invalid-sequence-index
def execute(q: QueryType) -> List[Any]:
def execute(qry: QueryType) -> List[Any]:
"""Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections.
"""
import sqlalchemy.exc
try:
with session_scope() as session:
for _ in range(0, RETRIES):
try:
return [
row for row in
(row.to_native() for row in q)
(row.to_native() for row in qry)
if row is not None]
except sqlalchemy.exc.SQLAlchemyError as e:
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True)
finally:
Session.close()
except sqlalchemy.exc.SQLAlchemyError as err:
_LOGGER.error(ERROR_QUERY, err)
session.rollback()
time.sleep(QUERY_RETRY_WAIT)
return []
@ -101,9 +118,10 @@ def run_information(point_in_time: Optional[datetime]=None):
start=_INSTANCE.recording_start,
closed_incorrect=False)
return query('RecorderRuns').filter(
(recorder_runs.start < point_in_time) &
(recorder_runs.end > point_in_time)).first()
with session_scope():
return query('RecorderRuns').filter(
(recorder_runs.start < point_in_time) &
(recorder_runs.end > point_in_time)).first()
def setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -132,10 +150,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
def query(model_name: Union[str, Any], *args) -> QueryType:
"""Helper to return a query handle."""
_verify_instance()
if isinstance(model_name, str):
return Session.query(get_model(model_name), *args)
return Session.query(model_name, *args)
return _SESSION().query(get_model(model_name), *args)
return _SESSION().query(model_name, *args)
def get_model(model_name: str) -> Any:
@ -148,22 +165,6 @@ def get_model(model_name: str) -> Any:
return None
def log_error(e: Exception, retry_wait: Optional[float]=0,
rollback: Optional[bool]=True,
message: Optional[str]="Error during query: %s") -> None:
"""Log about SQLAlchemy errors in a sane manner."""
import sqlalchemy.exc
if not isinstance(e, sqlalchemy.exc.OperationalError):
_LOGGER.exception(str(e))
else:
_LOGGER.error(message, str(e))
if rollback:
Session.rollback()
if retry_wait:
_LOGGER.info("Retrying in %s seconds", retry_wait)
time.sleep(retry_wait)
class Recorder(threading.Thread):
"""A threaded recorder class."""
@ -204,18 +205,14 @@ class Recorder(threading.Thread):
self._setup_connection()
self._setup_run()
break
except sqlalchemy.exc.SQLAlchemyError as e:
log_error(e, retry_wait=CONNECT_RETRY_WAIT, rollback=False,
message="Error during connection setup: %s")
except sqlalchemy.exc.SQLAlchemyError as err:
_LOGGER.error("Error during connection setup: %s (retrying "
"in %s seconds)", err, CONNECT_RETRY_WAIT)
time.sleep(CONNECT_RETRY_WAIT)
if self.purge_days is not None:
def purge_ticker(event):
"""Rerun purge every second day."""
self._purge_old_data()
track_point_in_utc_time(self.hass, purge_ticker,
dt_util.utcnow() + timedelta(days=2))
track_point_in_utc_time(self.hass, purge_ticker,
dt_util.utcnow() + timedelta(minutes=5))
async_track_time_interval(
self.hass, self._purge_old_data, timedelta(days=2))
while True:
event = self.queue.get()
@ -250,16 +247,17 @@ class Recorder(threading.Thread):
self.queue.task_done()
continue
dbevent = Events.from_event(event)
self._commit(dbevent)
with session_scope() as session:
dbevent = Events.from_event(event)
self._commit(session, dbevent)
if event.event_type != EVENT_STATE_CHANGED:
self.queue.task_done()
continue
if event.event_type != EVENT_STATE_CHANGED:
self.queue.task_done()
continue
dbstate = States.from_event(event)
dbstate.event_id = dbevent.event_id
self._commit(dbstate)
dbstate = States.from_event(event)
dbstate.event_id = dbevent.event_id
self._commit(session, dbstate)
self.queue.task_done()
@ -282,11 +280,14 @@ class Recorder(threading.Thread):
def block_till_db_ready(self):
"""Block until the database session is ready."""
self.db_ready.wait()
self.db_ready.wait(10)
while not self.db_ready.is_set():
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
self.db_ready.wait(10)
def _setup_connection(self):
"""Ensure database is ready to fly."""
global Session # pylint: disable=global-statement
global _SESSION # pylint: disable=invalid-name,global-statement
import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine
@ -298,40 +299,44 @@ class Recorder(threading.Thread):
self.engine = create_engine(
'sqlite://',
connect_args={'check_same_thread': False},
poolclass=StaticPool)
poolclass=StaticPool,
pool_reset_on_return=None)
else:
self.engine = create_engine(self.db_url, echo=False)
models.Base.metadata.create_all(self.engine)
session_factory = sessionmaker(bind=self.engine)
Session = scoped_session(session_factory)
_SESSION = scoped_session(session_factory)
self._migrate_schema()
self.db_ready.set()
def _migrate_schema(self):
"""Check if the schema needs to be upgraded."""
import homeassistant.components.recorder.models as models
schema_changes = models.SchemaChanges
current_version = getattr(Session.query(schema_changes).order_by(
schema_changes.change_id.desc()).first(), 'schema_version', None)
from homeassistant.components.recorder.models import SCHEMA_VERSION
schema_changes = get_model('SchemaChanges')
with session_scope() as session:
res = session.query(schema_changes).order_by(
schema_changes.change_id.desc()).first()
current_version = getattr(res, 'schema_version', None)
if current_version == models.SCHEMA_VERSION:
return
_LOGGER.debug("Schema version incorrect: %d", current_version)
if current_version == SCHEMA_VERSION:
return
_LOGGER.debug("Schema version incorrect: %s", current_version)
if current_version is None:
current_version = self._inspect_schema_version()
_LOGGER.debug("No schema version found. Inspected version: %d",
current_version)
if current_version is None:
current_version = self._inspect_schema_version()
_LOGGER.debug("No schema version found. Inspected version: %s",
current_version)
for version in range(current_version, models.SCHEMA_VERSION):
new_version = version + 1
_LOGGER.info(
"Upgrading recorder db schema to version %d", new_version)
self._apply_update(new_version)
self._commit(schema_changes(schema_version=new_version))
_LOGGER.info(
"Upgraded recorder db schema to version %d", new_version)
for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s",
new_version)
self._apply_update(new_version)
self._commit(session,
schema_changes(schema_version=new_version))
_LOGGER.info("Upgraded recorder db schema to version %s",
new_version)
def _apply_update(self, new_version):
"""Perform operations to bring schema up to date."""
@ -368,51 +373,54 @@ class Recorder(threading.Thread):
import homeassistant.components.recorder.models as models
inspector = reflection.Inspector.from_engine(self.engine)
indexes = inspector.get_indexes("events")
for index in indexes:
if index['column_names'] == ["time_fired"]:
# Schema addition from version 1 detected. This is a new db.
current_version = models.SchemaChanges(
schema_version=models.SCHEMA_VERSION)
self._commit(current_version)
return models.SCHEMA_VERSION
with session_scope() as session:
for index in indexes:
if index['column_names'] == ["time_fired"]:
# Schema addition from version 1 detected. New DB.
current_version = models.SchemaChanges(
schema_version=models.SCHEMA_VERSION)
self._commit(session, current_version)
return models.SCHEMA_VERSION
# Version 1 schema changes not found, this db needs to be migrated.
current_version = models.SchemaChanges(schema_version=0)
self._commit(current_version)
return current_version.schema_version
# Version 1 schema changes not found, this db needs to be migrated.
current_version = models.SchemaChanges(schema_version=0)
self._commit(session, current_version)
return current_version.schema_version
def _close_connection(self):
"""Close the connection."""
global Session # pylint: disable=global-statement
global _SESSION # pylint: disable=invalid-name,global-statement
self.engine.dispose()
self.engine = None
Session = None
_SESSION = None
def _setup_run(self):
"""Log the start of the current run."""
recorder_runs = get_model('RecorderRuns')
for run in query('RecorderRuns').filter_by(end=None):
run.closed_incorrect = True
run.end = self.recording_start
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
run.run_id, run.start)
Session.add(run)
with session_scope() as session:
for run in query('RecorderRuns').filter_by(end=None):
run.closed_incorrect = True
run.end = self.recording_start
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
run.run_id, run.start)
session.add(run)
_LOGGER.warning("Found unfinished sessions")
_LOGGER.warning("Found unfinished sessions")
self._run = recorder_runs(
start=self.recording_start,
created=dt_util.utcnow()
)
self._commit(self._run)
self._run = recorder_runs(
start=self.recording_start,
created=dt_util.utcnow()
)
self._commit(session, self._run)
def _close_run(self):
"""Save end time for current run."""
self._run.end = dt_util.utcnow()
self._commit(self._run)
with session_scope() as session:
self._commit(session, self._run)
self._run = None
def _purge_old_data(self):
def _purge_old_data(self, _=None):
"""Purge events and states older than purge_days ago."""
from homeassistant.components.recorder.models import Events, States
@ -429,8 +437,9 @@ class Recorder(threading.Thread):
.delete(synchronize_session=False)
_LOGGER.debug("Deleted %s states", deleted_rows)
if self._commit(_purge_states):
_LOGGER.info("Purged states created before %s", purge_before)
with session_scope() as session:
if self._commit(session, _purge_states):
_LOGGER.info("Purged states created before %s", purge_before)
def _purge_events(session):
deleted_rows = session.query(Events) \
@ -438,10 +447,9 @@ class Recorder(threading.Thread):
.delete(synchronize_session=False)
_LOGGER.debug("Deleted %s events", deleted_rows)
if self._commit(_purge_events):
_LOGGER.info("Purged events created before %s", purge_before)
Session.expire_all()
with session_scope() as session:
if self._commit(session, _purge_events):
_LOGGER.info("Purged events created before %s", purge_before)
# Execute sqlite vacuum command to free up space on disk
if self.engine.driver == 'sqlite':
@ -449,10 +457,9 @@ class Recorder(threading.Thread):
self.engine.execute("VACUUM")
@staticmethod
def _commit(work):
def _commit(session, work):
"""Commit & retry work: Either a model or in a function."""
import sqlalchemy.exc
session = Session()
for _ in range(0, RETRIES):
try:
if callable(work):
@ -461,8 +468,10 @@ class Recorder(threading.Thread):
session.add(work)
session.commit()
return True
except sqlalchemy.exc.OperationalError as e:
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True)
except sqlalchemy.exc.OperationalError as err:
_LOGGER.error(ERROR_QUERY, err)
session.rollback()
time.sleep(QUERY_RETRY_WAIT)
return False

View File

@ -3,7 +3,7 @@
import json
from datetime import datetime, timedelta
import unittest
from unittest.mock import patch, call
from unittest.mock import patch, call, MagicMock
import pytest
from homeassistant.core import callback
@ -24,7 +24,6 @@ class TestRecorder(unittest.TestCase):
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
self.hass.start()
recorder._verify_instance()
self.session = recorder.Session()
recorder._INSTANCE.block_till_done()
def tearDown(self): # pylint: disable=invalid-name
@ -42,26 +41,25 @@ class TestRecorder(unittest.TestCase):
self.hass.block_till_done()
recorder._INSTANCE.block_till_done()
for event_id in range(5):
if event_id < 3:
timestamp = five_days_ago
state = 'purgeme'
else:
timestamp = now
state = 'dontpurgeme'
with recorder.session_scope() as session:
for event_id in range(5):
if event_id < 3:
timestamp = five_days_ago
state = 'purgeme'
else:
timestamp = now
state = 'dontpurgeme'
self.session.add(recorder.get_model('States')(
entity_id='test.recorder2',
domain='sensor',
state=state,
attributes=json.dumps(attributes),
last_changed=timestamp,
last_updated=timestamp,
created=timestamp,
event_id=event_id + 1000
))
self.session.commit()
session.add(recorder.get_model('States')(
entity_id='test.recorder2',
domain='sensor',
state=state,
attributes=json.dumps(attributes),
last_changed=timestamp,
last_updated=timestamp,
created=timestamp,
event_id=event_id + 1000
))
def _add_test_events(self):
"""Add a few events for testing."""
@ -71,21 +69,23 @@ class TestRecorder(unittest.TestCase):
self.hass.block_till_done()
recorder._INSTANCE.block_till_done()
for event_id in range(5):
if event_id < 2:
timestamp = five_days_ago
event_type = 'EVENT_TEST_PURGE'
else:
timestamp = now
event_type = 'EVENT_TEST'
self.session.add(recorder.get_model('Events')(
event_type=event_type,
event_data=json.dumps(event_data),
origin='LOCAL',
created=timestamp,
time_fired=timestamp,
))
with recorder.session_scope() as session:
for event_id in range(5):
if event_id < 2:
timestamp = five_days_ago
event_type = 'EVENT_TEST_PURGE'
else:
timestamp = now
event_type = 'EVENT_TEST'
session.add(recorder.get_model('Events')(
event_type=event_type,
event_data=json.dumps(event_data),
origin='LOCAL',
created=timestamp,
time_fired=timestamp,
))
def test_saving_state(self):
"""Test saving and restoring a state."""
@ -205,14 +205,15 @@ class TestRecorder(unittest.TestCase):
with self.assertRaises(ValueError):
recorder._INSTANCE._apply_update(-1)
def test_schema_update_calls(self):
def test_schema_update_calls(self): # pylint: disable=no-self-use
"""Test that schema migrations occurr in correct order."""
test_version = recorder.models.SchemaChanges(schema_version=0)
self.session.add(test_version)
with patch.object(recorder._INSTANCE, '_apply_update') as update:
recorder._INSTANCE._migrate_schema()
update.assert_has_calls([call(version+1) for version in range(
0, recorder.models.SCHEMA_VERSION)])
with recorder.session_scope() as session:
session.add(test_version)
with patch.object(recorder._INSTANCE, '_apply_update') as update:
recorder._INSTANCE._migrate_schema()
update.assert_has_calls([call(version+1) for version in range(
0, recorder.models.SCHEMA_VERSION)])
@pytest.fixture
@ -220,7 +221,7 @@ def hass_recorder():
"""HASS fixture with in-memory recorder."""
hass = get_test_home_assistant()
def setup_recorder(config):
def setup_recorder(config={}):
"""Setup with params."""
db_uri = 'sqlite://' # In memory DB
conf = {recorder.CONF_DB_URL: db_uri}
@ -301,3 +302,61 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder):
assert len(states) == 1
assert hass.states.get('test.ok') == states[0]
assert hass.states.get('test.ok').state == 'state2'
def test_recorder_errors_exceptions(hass_recorder): \
# pylint: disable=redefined-outer-name
"""Test session_scope and get_model errors."""
# Model cannot be resolved
assert recorder.get_model('dont-exist') is None
# Verify the instance fails before setup
with pytest.raises(RuntimeError):
recorder._verify_instance()
# Setup the recorder
hass_recorder()
recorder._verify_instance()
# Verify session scope raises (and prints) an exception
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \
pytest.raises(Exception) as err:
with recorder.session_scope() as session:
session.execute('select * from notthere')
assert e_mock.call_count == 1
assert recorder.ERROR_QUERY[:-4] in e_mock.call_args[0][0]
assert 'no such table' in str(err.value)
def test_recorder_bad_commit(hass_recorder):
"""Bad _commit should retry 3 times."""
hass_recorder()
def work(session):
"""Bad work."""
session.execute('select * from notthere')
with patch('homeassistant.components.recorder.time.sleep') as e_mock, \
recorder.session_scope() as session:
res = recorder._INSTANCE._commit(session, work)
assert res is False
assert e_mock.call_count == 3
def test_recorder_bad_execute(hass_recorder):
"""Bad execute, retry 3 times."""
hass_recorder()
def to_native():
"""Rasie exception."""
from sqlalchemy.exc import SQLAlchemyError
raise SQLAlchemyError()
mck1 = MagicMock()
mck1.to_native = to_native
with patch('homeassistant.components.recorder.time.sleep') as e_mock:
res = recorder.execute((mck1,))
assert res == []
assert e_mock.call_count == 3