WIP: [component/recorder] Refactoring & better handling of SQLAlchemy Sessions (#5607)
* Refactor recorder and Sessions * Cover #4352 * NO_reset_on_return * contextmanager * coveragepull/5801/head
parent
bdebe5d53c
commit
490ef6afad
|
@ -64,7 +64,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
|
|||
"""
|
||||
entity_ids = (entity_id.lower(), ) if entity_id is not None else None
|
||||
states = recorder.get_model('States')
|
||||
query = recorder.query('States').filter(
|
||||
query = recorder.query(states).filter(
|
||||
(states.domain.in_(SIGNIFICANT_DOMAINS) |
|
||||
(states.last_changed == states.last_updated)) &
|
||||
(states.last_updated > start_time))
|
||||
|
|
|
@ -13,6 +13,7 @@ import threading
|
|||
import time
|
||||
from datetime import timedelta, datetime
|
||||
from typing import Any, Union, Optional, List, Dict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -22,7 +23,7 @@ from homeassistant.const import (
|
|||
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import track_point_in_utc_time
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.helpers.typing import ConfigType, QueryType
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
@ -39,6 +40,7 @@ CONF_PURGE_DAYS = 'purge_days'
|
|||
RETRIES = 3
|
||||
CONNECT_RETRY_WAIT = 10
|
||||
QUERY_RETRY_WAIT = 0.1
|
||||
ERROR_QUERY = "Error during query: %s"
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
|
@ -62,28 +64,43 @@ _INSTANCE = None # type: Any
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# These classes will be populated during setup()
|
||||
# pylint: disable=invalid-name,no-member
|
||||
Session = None # pylint: disable=no-member
|
||||
# scoped_session, in the same thread session_scope() stays the same
|
||||
_SESSION = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope():
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
session = _SESSION()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.error(ERROR_QUERY, err)
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# pylint: disable=invalid-sequence-index
|
||||
def execute(q: QueryType) -> List[Any]:
|
||||
def execute(qry: QueryType) -> List[Any]:
|
||||
"""Query the database and convert the objects to HA native form.
|
||||
|
||||
This method also retries a few times in the case of stale connections.
|
||||
"""
|
||||
import sqlalchemy.exc
|
||||
try:
|
||||
with session_scope() as session:
|
||||
for _ in range(0, RETRIES):
|
||||
try:
|
||||
return [
|
||||
row for row in
|
||||
(row.to_native() for row in q)
|
||||
(row.to_native() for row in qry)
|
||||
if row is not None]
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True)
|
||||
finally:
|
||||
Session.close()
|
||||
except sqlalchemy.exc.SQLAlchemyError as err:
|
||||
_LOGGER.error(ERROR_QUERY, err)
|
||||
session.rollback()
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
return []
|
||||
|
||||
|
||||
|
@ -101,9 +118,10 @@ def run_information(point_in_time: Optional[datetime]=None):
|
|||
start=_INSTANCE.recording_start,
|
||||
closed_incorrect=False)
|
||||
|
||||
return query('RecorderRuns').filter(
|
||||
(recorder_runs.start < point_in_time) &
|
||||
(recorder_runs.end > point_in_time)).first()
|
||||
with session_scope():
|
||||
return query('RecorderRuns').filter(
|
||||
(recorder_runs.start < point_in_time) &
|
||||
(recorder_runs.end > point_in_time)).first()
|
||||
|
||||
|
||||
def setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
@ -132,10 +150,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
def query(model_name: Union[str, Any], *args) -> QueryType:
|
||||
"""Helper to return a query handle."""
|
||||
_verify_instance()
|
||||
|
||||
if isinstance(model_name, str):
|
||||
return Session.query(get_model(model_name), *args)
|
||||
return Session.query(model_name, *args)
|
||||
return _SESSION().query(get_model(model_name), *args)
|
||||
return _SESSION().query(model_name, *args)
|
||||
|
||||
|
||||
def get_model(model_name: str) -> Any:
|
||||
|
@ -148,22 +165,6 @@ def get_model(model_name: str) -> Any:
|
|||
return None
|
||||
|
||||
|
||||
def log_error(e: Exception, retry_wait: Optional[float]=0,
|
||||
rollback: Optional[bool]=True,
|
||||
message: Optional[str]="Error during query: %s") -> None:
|
||||
"""Log about SQLAlchemy errors in a sane manner."""
|
||||
import sqlalchemy.exc
|
||||
if not isinstance(e, sqlalchemy.exc.OperationalError):
|
||||
_LOGGER.exception(str(e))
|
||||
else:
|
||||
_LOGGER.error(message, str(e))
|
||||
if rollback:
|
||||
Session.rollback()
|
||||
if retry_wait:
|
||||
_LOGGER.info("Retrying in %s seconds", retry_wait)
|
||||
time.sleep(retry_wait)
|
||||
|
||||
|
||||
class Recorder(threading.Thread):
|
||||
"""A threaded recorder class."""
|
||||
|
||||
|
@ -204,18 +205,14 @@ class Recorder(threading.Thread):
|
|||
self._setup_connection()
|
||||
self._setup_run()
|
||||
break
|
||||
except sqlalchemy.exc.SQLAlchemyError as e:
|
||||
log_error(e, retry_wait=CONNECT_RETRY_WAIT, rollback=False,
|
||||
message="Error during connection setup: %s")
|
||||
except sqlalchemy.exc.SQLAlchemyError as err:
|
||||
_LOGGER.error("Error during connection setup: %s (retrying "
|
||||
"in %s seconds)", err, CONNECT_RETRY_WAIT)
|
||||
time.sleep(CONNECT_RETRY_WAIT)
|
||||
|
||||
if self.purge_days is not None:
|
||||
def purge_ticker(event):
|
||||
"""Rerun purge every second day."""
|
||||
self._purge_old_data()
|
||||
track_point_in_utc_time(self.hass, purge_ticker,
|
||||
dt_util.utcnow() + timedelta(days=2))
|
||||
track_point_in_utc_time(self.hass, purge_ticker,
|
||||
dt_util.utcnow() + timedelta(minutes=5))
|
||||
async_track_time_interval(
|
||||
self.hass, self._purge_old_data, timedelta(days=2))
|
||||
|
||||
while True:
|
||||
event = self.queue.get()
|
||||
|
@ -250,16 +247,17 @@ class Recorder(threading.Thread):
|
|||
self.queue.task_done()
|
||||
continue
|
||||
|
||||
dbevent = Events.from_event(event)
|
||||
self._commit(dbevent)
|
||||
with session_scope() as session:
|
||||
dbevent = Events.from_event(event)
|
||||
self._commit(session, dbevent)
|
||||
|
||||
if event.event_type != EVENT_STATE_CHANGED:
|
||||
self.queue.task_done()
|
||||
continue
|
||||
if event.event_type != EVENT_STATE_CHANGED:
|
||||
self.queue.task_done()
|
||||
continue
|
||||
|
||||
dbstate = States.from_event(event)
|
||||
dbstate.event_id = dbevent.event_id
|
||||
self._commit(dbstate)
|
||||
dbstate = States.from_event(event)
|
||||
dbstate.event_id = dbevent.event_id
|
||||
self._commit(session, dbstate)
|
||||
|
||||
self.queue.task_done()
|
||||
|
||||
|
@ -282,11 +280,14 @@ class Recorder(threading.Thread):
|
|||
|
||||
def block_till_db_ready(self):
|
||||
"""Block until the database session is ready."""
|
||||
self.db_ready.wait()
|
||||
self.db_ready.wait(10)
|
||||
while not self.db_ready.is_set():
|
||||
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
|
||||
self.db_ready.wait(10)
|
||||
|
||||
def _setup_connection(self):
|
||||
"""Ensure database is ready to fly."""
|
||||
global Session # pylint: disable=global-statement
|
||||
global _SESSION # pylint: disable=invalid-name,global-statement
|
||||
|
||||
import homeassistant.components.recorder.models as models
|
||||
from sqlalchemy import create_engine
|
||||
|
@ -298,40 +299,44 @@ class Recorder(threading.Thread):
|
|||
self.engine = create_engine(
|
||||
'sqlite://',
|
||||
connect_args={'check_same_thread': False},
|
||||
poolclass=StaticPool)
|
||||
poolclass=StaticPool,
|
||||
pool_reset_on_return=None)
|
||||
else:
|
||||
self.engine = create_engine(self.db_url, echo=False)
|
||||
|
||||
models.Base.metadata.create_all(self.engine)
|
||||
session_factory = sessionmaker(bind=self.engine)
|
||||
Session = scoped_session(session_factory)
|
||||
_SESSION = scoped_session(session_factory)
|
||||
self._migrate_schema()
|
||||
self.db_ready.set()
|
||||
|
||||
def _migrate_schema(self):
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
import homeassistant.components.recorder.models as models
|
||||
schema_changes = models.SchemaChanges
|
||||
current_version = getattr(Session.query(schema_changes).order_by(
|
||||
schema_changes.change_id.desc()).first(), 'schema_version', None)
|
||||
from homeassistant.components.recorder.models import SCHEMA_VERSION
|
||||
schema_changes = get_model('SchemaChanges')
|
||||
with session_scope() as session:
|
||||
res = session.query(schema_changes).order_by(
|
||||
schema_changes.change_id.desc()).first()
|
||||
current_version = getattr(res, 'schema_version', None)
|
||||
|
||||
if current_version == models.SCHEMA_VERSION:
|
||||
return
|
||||
_LOGGER.debug("Schema version incorrect: %d", current_version)
|
||||
if current_version == SCHEMA_VERSION:
|
||||
return
|
||||
_LOGGER.debug("Schema version incorrect: %s", current_version)
|
||||
|
||||
if current_version is None:
|
||||
current_version = self._inspect_schema_version()
|
||||
_LOGGER.debug("No schema version found. Inspected version: %d",
|
||||
current_version)
|
||||
if current_version is None:
|
||||
current_version = self._inspect_schema_version()
|
||||
_LOGGER.debug("No schema version found. Inspected version: %s",
|
||||
current_version)
|
||||
|
||||
for version in range(current_version, models.SCHEMA_VERSION):
|
||||
new_version = version + 1
|
||||
_LOGGER.info(
|
||||
"Upgrading recorder db schema to version %d", new_version)
|
||||
self._apply_update(new_version)
|
||||
self._commit(schema_changes(schema_version=new_version))
|
||||
_LOGGER.info(
|
||||
"Upgraded recorder db schema to version %d", new_version)
|
||||
for version in range(current_version, SCHEMA_VERSION):
|
||||
new_version = version + 1
|
||||
_LOGGER.info("Upgrading recorder db schema to version %s",
|
||||
new_version)
|
||||
self._apply_update(new_version)
|
||||
self._commit(session,
|
||||
schema_changes(schema_version=new_version))
|
||||
_LOGGER.info("Upgraded recorder db schema to version %s",
|
||||
new_version)
|
||||
|
||||
def _apply_update(self, new_version):
|
||||
"""Perform operations to bring schema up to date."""
|
||||
|
@ -368,51 +373,54 @@ class Recorder(threading.Thread):
|
|||
import homeassistant.components.recorder.models as models
|
||||
inspector = reflection.Inspector.from_engine(self.engine)
|
||||
indexes = inspector.get_indexes("events")
|
||||
for index in indexes:
|
||||
if index['column_names'] == ["time_fired"]:
|
||||
# Schema addition from version 1 detected. This is a new db.
|
||||
current_version = models.SchemaChanges(
|
||||
schema_version=models.SCHEMA_VERSION)
|
||||
self._commit(current_version)
|
||||
return models.SCHEMA_VERSION
|
||||
with session_scope() as session:
|
||||
for index in indexes:
|
||||
if index['column_names'] == ["time_fired"]:
|
||||
# Schema addition from version 1 detected. New DB.
|
||||
current_version = models.SchemaChanges(
|
||||
schema_version=models.SCHEMA_VERSION)
|
||||
self._commit(session, current_version)
|
||||
return models.SCHEMA_VERSION
|
||||
|
||||
# Version 1 schema changes not found, this db needs to be migrated.
|
||||
current_version = models.SchemaChanges(schema_version=0)
|
||||
self._commit(current_version)
|
||||
return current_version.schema_version
|
||||
# Version 1 schema changes not found, this db needs to be migrated.
|
||||
current_version = models.SchemaChanges(schema_version=0)
|
||||
self._commit(session, current_version)
|
||||
return current_version.schema_version
|
||||
|
||||
def _close_connection(self):
|
||||
"""Close the connection."""
|
||||
global Session # pylint: disable=global-statement
|
||||
global _SESSION # pylint: disable=invalid-name,global-statement
|
||||
self.engine.dispose()
|
||||
self.engine = None
|
||||
Session = None
|
||||
_SESSION = None
|
||||
|
||||
def _setup_run(self):
|
||||
"""Log the start of the current run."""
|
||||
recorder_runs = get_model('RecorderRuns')
|
||||
for run in query('RecorderRuns').filter_by(end=None):
|
||||
run.closed_incorrect = True
|
||||
run.end = self.recording_start
|
||||
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
|
||||
run.run_id, run.start)
|
||||
Session.add(run)
|
||||
with session_scope() as session:
|
||||
for run in query('RecorderRuns').filter_by(end=None):
|
||||
run.closed_incorrect = True
|
||||
run.end = self.recording_start
|
||||
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
|
||||
run.run_id, run.start)
|
||||
session.add(run)
|
||||
|
||||
_LOGGER.warning("Found unfinished sessions")
|
||||
_LOGGER.warning("Found unfinished sessions")
|
||||
|
||||
self._run = recorder_runs(
|
||||
start=self.recording_start,
|
||||
created=dt_util.utcnow()
|
||||
)
|
||||
self._commit(self._run)
|
||||
self._run = recorder_runs(
|
||||
start=self.recording_start,
|
||||
created=dt_util.utcnow()
|
||||
)
|
||||
self._commit(session, self._run)
|
||||
|
||||
def _close_run(self):
|
||||
"""Save end time for current run."""
|
||||
self._run.end = dt_util.utcnow()
|
||||
self._commit(self._run)
|
||||
with session_scope() as session:
|
||||
self._commit(session, self._run)
|
||||
self._run = None
|
||||
|
||||
def _purge_old_data(self):
|
||||
def _purge_old_data(self, _=None):
|
||||
"""Purge events and states older than purge_days ago."""
|
||||
from homeassistant.components.recorder.models import Events, States
|
||||
|
||||
|
@ -429,8 +437,9 @@ class Recorder(threading.Thread):
|
|||
.delete(synchronize_session=False)
|
||||
_LOGGER.debug("Deleted %s states", deleted_rows)
|
||||
|
||||
if self._commit(_purge_states):
|
||||
_LOGGER.info("Purged states created before %s", purge_before)
|
||||
with session_scope() as session:
|
||||
if self._commit(session, _purge_states):
|
||||
_LOGGER.info("Purged states created before %s", purge_before)
|
||||
|
||||
def _purge_events(session):
|
||||
deleted_rows = session.query(Events) \
|
||||
|
@ -438,10 +447,9 @@ class Recorder(threading.Thread):
|
|||
.delete(synchronize_session=False)
|
||||
_LOGGER.debug("Deleted %s events", deleted_rows)
|
||||
|
||||
if self._commit(_purge_events):
|
||||
_LOGGER.info("Purged events created before %s", purge_before)
|
||||
|
||||
Session.expire_all()
|
||||
with session_scope() as session:
|
||||
if self._commit(session, _purge_events):
|
||||
_LOGGER.info("Purged events created before %s", purge_before)
|
||||
|
||||
# Execute sqlite vacuum command to free up space on disk
|
||||
if self.engine.driver == 'sqlite':
|
||||
|
@ -449,10 +457,9 @@ class Recorder(threading.Thread):
|
|||
self.engine.execute("VACUUM")
|
||||
|
||||
@staticmethod
|
||||
def _commit(work):
|
||||
def _commit(session, work):
|
||||
"""Commit & retry work: Either a model or in a function."""
|
||||
import sqlalchemy.exc
|
||||
session = Session()
|
||||
for _ in range(0, RETRIES):
|
||||
try:
|
||||
if callable(work):
|
||||
|
@ -461,8 +468,10 @@ class Recorder(threading.Thread):
|
|||
session.add(work)
|
||||
session.commit()
|
||||
return True
|
||||
except sqlalchemy.exc.OperationalError as e:
|
||||
log_error(e, retry_wait=QUERY_RETRY_WAIT, rollback=True)
|
||||
except sqlalchemy.exc.OperationalError as err:
|
||||
_LOGGER.error(ERROR_QUERY, err)
|
||||
session.rollback()
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import unittest
|
||||
from unittest.mock import patch, call
|
||||
from unittest.mock import patch, call, MagicMock
|
||||
|
||||
import pytest
|
||||
from homeassistant.core import callback
|
||||
|
@ -24,7 +24,6 @@ class TestRecorder(unittest.TestCase):
|
|||
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
|
||||
self.hass.start()
|
||||
recorder._verify_instance()
|
||||
self.session = recorder.Session()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
|
@ -42,26 +41,25 @@ class TestRecorder(unittest.TestCase):
|
|||
self.hass.block_till_done()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
|
||||
for event_id in range(5):
|
||||
if event_id < 3:
|
||||
timestamp = five_days_ago
|
||||
state = 'purgeme'
|
||||
else:
|
||||
timestamp = now
|
||||
state = 'dontpurgeme'
|
||||
with recorder.session_scope() as session:
|
||||
for event_id in range(5):
|
||||
if event_id < 3:
|
||||
timestamp = five_days_ago
|
||||
state = 'purgeme'
|
||||
else:
|
||||
timestamp = now
|
||||
state = 'dontpurgeme'
|
||||
|
||||
self.session.add(recorder.get_model('States')(
|
||||
entity_id='test.recorder2',
|
||||
domain='sensor',
|
||||
state=state,
|
||||
attributes=json.dumps(attributes),
|
||||
last_changed=timestamp,
|
||||
last_updated=timestamp,
|
||||
created=timestamp,
|
||||
event_id=event_id + 1000
|
||||
))
|
||||
|
||||
self.session.commit()
|
||||
session.add(recorder.get_model('States')(
|
||||
entity_id='test.recorder2',
|
||||
domain='sensor',
|
||||
state=state,
|
||||
attributes=json.dumps(attributes),
|
||||
last_changed=timestamp,
|
||||
last_updated=timestamp,
|
||||
created=timestamp,
|
||||
event_id=event_id + 1000
|
||||
))
|
||||
|
||||
def _add_test_events(self):
|
||||
"""Add a few events for testing."""
|
||||
|
@ -71,21 +69,23 @@ class TestRecorder(unittest.TestCase):
|
|||
|
||||
self.hass.block_till_done()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
for event_id in range(5):
|
||||
if event_id < 2:
|
||||
timestamp = five_days_ago
|
||||
event_type = 'EVENT_TEST_PURGE'
|
||||
else:
|
||||
timestamp = now
|
||||
event_type = 'EVENT_TEST'
|
||||
|
||||
self.session.add(recorder.get_model('Events')(
|
||||
event_type=event_type,
|
||||
event_data=json.dumps(event_data),
|
||||
origin='LOCAL',
|
||||
created=timestamp,
|
||||
time_fired=timestamp,
|
||||
))
|
||||
with recorder.session_scope() as session:
|
||||
for event_id in range(5):
|
||||
if event_id < 2:
|
||||
timestamp = five_days_ago
|
||||
event_type = 'EVENT_TEST_PURGE'
|
||||
else:
|
||||
timestamp = now
|
||||
event_type = 'EVENT_TEST'
|
||||
|
||||
session.add(recorder.get_model('Events')(
|
||||
event_type=event_type,
|
||||
event_data=json.dumps(event_data),
|
||||
origin='LOCAL',
|
||||
created=timestamp,
|
||||
time_fired=timestamp,
|
||||
))
|
||||
|
||||
def test_saving_state(self):
|
||||
"""Test saving and restoring a state."""
|
||||
|
@ -205,14 +205,15 @@ class TestRecorder(unittest.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
recorder._INSTANCE._apply_update(-1)
|
||||
|
||||
def test_schema_update_calls(self):
|
||||
def test_schema_update_calls(self): # pylint: disable=no-self-use
|
||||
"""Test that schema migrations occurr in correct order."""
|
||||
test_version = recorder.models.SchemaChanges(schema_version=0)
|
||||
self.session.add(test_version)
|
||||
with patch.object(recorder._INSTANCE, '_apply_update') as update:
|
||||
recorder._INSTANCE._migrate_schema()
|
||||
update.assert_has_calls([call(version+1) for version in range(
|
||||
0, recorder.models.SCHEMA_VERSION)])
|
||||
with recorder.session_scope() as session:
|
||||
session.add(test_version)
|
||||
with patch.object(recorder._INSTANCE, '_apply_update') as update:
|
||||
recorder._INSTANCE._migrate_schema()
|
||||
update.assert_has_calls([call(version+1) for version in range(
|
||||
0, recorder.models.SCHEMA_VERSION)])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -220,7 +221,7 @@ def hass_recorder():
|
|||
"""HASS fixture with in-memory recorder."""
|
||||
hass = get_test_home_assistant()
|
||||
|
||||
def setup_recorder(config):
|
||||
def setup_recorder(config={}):
|
||||
"""Setup with params."""
|
||||
db_uri = 'sqlite://' # In memory DB
|
||||
conf = {recorder.CONF_DB_URL: db_uri}
|
||||
|
@ -301,3 +302,61 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder):
|
|||
assert len(states) == 1
|
||||
assert hass.states.get('test.ok') == states[0]
|
||||
assert hass.states.get('test.ok').state == 'state2'
|
||||
|
||||
|
||||
def test_recorder_errors_exceptions(hass_recorder): \
|
||||
# pylint: disable=redefined-outer-name
|
||||
"""Test session_scope and get_model errors."""
|
||||
# Model cannot be resolved
|
||||
assert recorder.get_model('dont-exist') is None
|
||||
|
||||
# Verify the instance fails before setup
|
||||
with pytest.raises(RuntimeError):
|
||||
recorder._verify_instance()
|
||||
|
||||
# Setup the recorder
|
||||
hass_recorder()
|
||||
|
||||
recorder._verify_instance()
|
||||
|
||||
# Verify session scope raises (and prints) an exception
|
||||
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \
|
||||
pytest.raises(Exception) as err:
|
||||
with recorder.session_scope() as session:
|
||||
session.execute('select * from notthere')
|
||||
assert e_mock.call_count == 1
|
||||
assert recorder.ERROR_QUERY[:-4] in e_mock.call_args[0][0]
|
||||
assert 'no such table' in str(err.value)
|
||||
|
||||
|
||||
def test_recorder_bad_commit(hass_recorder):
|
||||
"""Bad _commit should retry 3 times."""
|
||||
hass_recorder()
|
||||
|
||||
def work(session):
|
||||
"""Bad work."""
|
||||
session.execute('select * from notthere')
|
||||
|
||||
with patch('homeassistant.components.recorder.time.sleep') as e_mock, \
|
||||
recorder.session_scope() as session:
|
||||
res = recorder._INSTANCE._commit(session, work)
|
||||
assert res is False
|
||||
assert e_mock.call_count == 3
|
||||
|
||||
|
||||
def test_recorder_bad_execute(hass_recorder):
|
||||
"""Bad execute, retry 3 times."""
|
||||
hass_recorder()
|
||||
|
||||
def to_native():
|
||||
"""Rasie exception."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
raise SQLAlchemyError()
|
||||
|
||||
mck1 = MagicMock()
|
||||
mck1.to_native = to_native
|
||||
|
||||
with patch('homeassistant.components.recorder.time.sleep') as e_mock:
|
||||
res = recorder.execute((mck1,))
|
||||
assert res == []
|
||||
assert e_mock.call_count == 3
|
||||
|
|
Loading…
Reference in New Issue