Restore_state helper to restore entity states from the DB on startup (#4614)
* Restore states * feedback * Remove component move into recorder * space * helper * Address my own comments * Improve test coverage * Add test for light restore statepull/6142/head
parent
2b9fb73032
commit
fdc373f27e
homeassistant
components
tests
components
helpers
|
@ -15,7 +15,6 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import (
|
||||
HTTP_BAD_REQUEST, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.components import recorder, script
|
||||
from homeassistant.components.frontend import register_built_in_panel
|
||||
|
@ -28,34 +27,22 @@ DOMAIN = 'history'
|
|||
DEPENDENCIES = ['recorder', 'http']
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
CONF_EXCLUDE: vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
}),
|
||||
CONF_INCLUDE: vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
})
|
||||
}),
|
||||
DOMAIN: recorder.FILTER_SCHEMA,
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
SIGNIFICANT_DOMAINS = ('thermostat', 'climate')
|
||||
IGNORE_DOMAINS = ('zone', 'scene',)
|
||||
|
||||
|
||||
def last_5_states(entity_id):
|
||||
"""Return the last 5 states for entity_id."""
|
||||
entity_id = entity_id.lower()
|
||||
|
||||
states = recorder.get_model('States')
|
||||
return recorder.execute(
|
||||
recorder.query('States').filter(
|
||||
(states.entity_id == entity_id) &
|
||||
(states.last_changed == states.last_updated)
|
||||
).order_by(states.state_id.desc()).limit(5))
|
||||
def last_recorder_run():
|
||||
"""Retireve the last closed recorder run from the DB."""
|
||||
rec_runs = recorder.get_model('RecorderRuns')
|
||||
with recorder.session_scope() as session:
|
||||
res = recorder.query(rec_runs).order_by(rec_runs.end.desc()).first()
|
||||
if res is None:
|
||||
return None
|
||||
session.expunge(res)
|
||||
return res
|
||||
|
||||
|
||||
def get_significant_states(start_time, end_time=None, entity_id=None,
|
||||
|
@ -91,7 +78,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
|
|||
def state_changes_during_period(start_time, end_time=None, entity_id=None):
|
||||
"""Return states changes during UTC period start_time - end_time."""
|
||||
states = recorder.get_model('States')
|
||||
query = recorder.query('States').filter(
|
||||
query = recorder.query(states).filter(
|
||||
(states.last_changed == states.last_updated) &
|
||||
(states.last_changed > start_time))
|
||||
|
||||
|
@ -132,7 +119,7 @@ def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None):
|
|||
most_recent_state_ids = most_recent_state_ids.group_by(
|
||||
states.entity_id).subquery()
|
||||
|
||||
query = recorder.query('States').join(most_recent_state_ids, and_(
|
||||
query = recorder.query(states).join(most_recent_state_ids, and_(
|
||||
states.state_id == most_recent_state_ids.c.max_state_id))
|
||||
|
||||
for state in recorder.execute(query):
|
||||
|
@ -185,27 +172,13 @@ def setup(hass, config):
|
|||
filters.included_entities = include[CONF_ENTITIES]
|
||||
filters.included_domains = include[CONF_DOMAINS]
|
||||
|
||||
hass.http.register_view(Last5StatesView)
|
||||
recorder.get_instance()
|
||||
hass.http.register_view(HistoryPeriodView(filters))
|
||||
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class Last5StatesView(HomeAssistantView):
|
||||
"""Handle last 5 state view requests."""
|
||||
|
||||
url = '/api/history/entity/{entity_id}/recent_states'
|
||||
name = 'api:history:entity-recent-states'
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request, entity_id):
|
||||
"""Retrieve last 5 states of entity."""
|
||||
result = yield from request.app['hass'].loop.run_in_executor(
|
||||
None, last_5_states, entity_id)
|
||||
return self.json(result)
|
||||
|
||||
|
||||
class HistoryPeriodView(HomeAssistantView):
|
||||
"""Handle history period requests."""
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.const import (
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
|
||||
DOMAIN = 'input_boolean'
|
||||
|
||||
|
@ -139,6 +140,14 @@ class InputBoolean(ToggleEntity):
|
|||
"""Return true if entity is on."""
|
||||
return self._state
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_added_to_hass(self):
|
||||
"""Called when entity about to be added to hass."""
|
||||
state = yield from async_get_last_state(self.hass, self.entity_id)
|
||||
if not state:
|
||||
return
|
||||
self._state = state.state == 'on'
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_on(self, **kwargs):
|
||||
"""Turn the entity on."""
|
||||
|
|
|
@ -22,6 +22,7 @@ from homeassistant.helpers.entity import ToggleEntity
|
|||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.restore_state import async_restore_state
|
||||
import homeassistant.util.color as color_util
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
|
||||
|
@ -126,6 +127,14 @@ PROFILE_SCHEMA = vol.Schema(
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_info(state):
|
||||
"""Extract light parameters from a state object."""
|
||||
params = {key: state.attributes[key] for key in PROP_TO_ATTR
|
||||
if key in state.attributes}
|
||||
params['is_on'] = state.state == STATE_ON
|
||||
return params
|
||||
|
||||
|
||||
def is_on(hass, entity_id=None):
|
||||
"""Return if the lights are on based on the statemachine."""
|
||||
entity_id = entity_id or ENTITY_ID_ALL_LIGHTS
|
||||
|
@ -369,3 +378,9 @@ class Light(ToggleEntity):
|
|||
def supported_features(self):
|
||||
"""Flag supported features."""
|
||||
return 0
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_added_to_hass(self):
|
||||
"""Component added, restore_state using platforms."""
|
||||
if hasattr(self, 'async_restore_state'):
|
||||
yield from async_restore_state(self, extract_info)
|
||||
|
|
|
@ -4,6 +4,7 @@ Demo light platform that implements lights.
|
|||
For more details about this platform, please refer to the documentation
|
||||
https://home-assistant.io/components/demo/
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from homeassistant.components.light import (
|
||||
|
@ -149,3 +150,26 @@ class DemoLight(Light):
|
|||
# As we have disabled polling, we need to inform
|
||||
# Home Assistant about updates in our state ourselves.
|
||||
self.schedule_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_restore_state(self, is_on, **kwargs):
|
||||
"""Restore the demo state."""
|
||||
self._state = is_on
|
||||
|
||||
if 'brightness' in kwargs:
|
||||
self._brightness = kwargs['brightness']
|
||||
|
||||
if 'color_temp' in kwargs:
|
||||
self._ct = kwargs['color_temp']
|
||||
|
||||
if 'rgb_color' in kwargs:
|
||||
self._rgb = kwargs['rgb_color']
|
||||
|
||||
if 'xy_color' in kwargs:
|
||||
self._xy_color = kwargs['xy_color']
|
||||
|
||||
if 'white_value' in kwargs:
|
||||
self._white = kwargs['white_value']
|
||||
|
||||
if 'effect' in kwargs:
|
||||
self._effect = kwargs['effect']
|
||||
|
|
|
@ -22,6 +22,7 @@ from homeassistant.const import (
|
|||
ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS,
|
||||
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.helpers.typing import ConfigType, QueryType
|
||||
|
@ -42,36 +43,35 @@ CONNECT_RETRY_WAIT = 10
|
|||
QUERY_RETRY_WAIT = 0.1
|
||||
ERROR_QUERY = "Error during query: %s"
|
||||
|
||||
FILTER_SCHEMA = vol.Schema({
|
||||
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
}),
|
||||
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
})
|
||||
})
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({
|
||||
DOMAIN: vol.Schema({
|
||||
DOMAIN: FILTER_SCHEMA.extend({
|
||||
vol.Optional(CONF_PURGE_DAYS):
|
||||
vol.All(vol.Coerce(int), vol.Range(min=1)),
|
||||
vol.Optional(CONF_DB_URL): cv.string,
|
||||
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
}),
|
||||
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
|
||||
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||||
vol.Optional(CONF_DOMAINS, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string])
|
||||
})
|
||||
})
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
_INSTANCE = None # type: Any
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# These classes will be populated during setup()
|
||||
# scoped_session, in the same thread session_scope() stays the same
|
||||
_SESSION = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope():
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
session = _SESSION()
|
||||
session = _INSTANCE.get_session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
|
@ -83,15 +83,28 @@ def session_scope():
|
|||
session.close()
|
||||
|
||||
|
||||
def get_instance() -> None:
|
||||
"""Throw error if recorder not initialized."""
|
||||
if _INSTANCE is None:
|
||||
raise RuntimeError("Recorder not initialized.")
|
||||
|
||||
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
raise RuntimeError('Cannot be called from within the event loop')
|
||||
|
||||
_wait(_INSTANCE.db_ready, "Database not ready")
|
||||
|
||||
return _INSTANCE
|
||||
|
||||
|
||||
# pylint: disable=invalid-sequence-index
|
||||
def execute(qry: QueryType) -> List[Any]:
|
||||
"""Query the database and convert the objects to HA native form.
|
||||
|
||||
This method also retries a few times in the case of stale connections.
|
||||
"""
|
||||
_verify_instance()
|
||||
|
||||
import sqlalchemy.exc
|
||||
get_instance()
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
with session_scope() as session:
|
||||
for _ in range(0, RETRIES):
|
||||
try:
|
||||
|
@ -99,7 +112,7 @@ def execute(qry: QueryType) -> List[Any]:
|
|||
row for row in
|
||||
(row.to_native() for row in qry)
|
||||
if row is not None]
|
||||
except sqlalchemy.exc.SQLAlchemyError as err:
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.error(ERROR_QUERY, err)
|
||||
session.rollback()
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
|
@ -111,13 +124,13 @@ def run_information(point_in_time: Optional[datetime]=None):
|
|||
|
||||
There is also the run that covers point_in_time.
|
||||
"""
|
||||
_verify_instance()
|
||||
ins = get_instance()
|
||||
|
||||
recorder_runs = get_model('RecorderRuns')
|
||||
if point_in_time is None or point_in_time > _INSTANCE.recording_start:
|
||||
if point_in_time is None or point_in_time > ins.recording_start:
|
||||
return recorder_runs(
|
||||
end=None,
|
||||
start=_INSTANCE.recording_start,
|
||||
start=ins.recording_start,
|
||||
closed_incorrect=False)
|
||||
|
||||
with session_scope() as session:
|
||||
|
@ -148,17 +161,19 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {})
|
||||
_INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url,
|
||||
include=include, exclude=exclude)
|
||||
_INSTANCE.start()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def query(model_name: Union[str, Any], *args) -> QueryType:
|
||||
def query(model_name: Union[str, Any], session=None, *args) -> QueryType:
|
||||
"""Helper to return a query handle."""
|
||||
_verify_instance()
|
||||
if session is None:
|
||||
session = get_instance().get_session()
|
||||
|
||||
if isinstance(model_name, str):
|
||||
return _SESSION().query(get_model(model_name), *args)
|
||||
return _SESSION().query(model_name, *args)
|
||||
return session.query(get_model(model_name), *args)
|
||||
return session.query(model_name, *args)
|
||||
|
||||
|
||||
def get_model(model_name: str) -> Any:
|
||||
|
@ -185,6 +200,7 @@ class Recorder(threading.Thread):
|
|||
self.recording_start = dt_util.utcnow()
|
||||
self.db_url = uri
|
||||
self.db_ready = threading.Event()
|
||||
self.start_recording = threading.Event()
|
||||
self.engine = None # type: Any
|
||||
self._run = None # type: Any
|
||||
|
||||
|
@ -195,23 +211,26 @@ class Recorder(threading.Thread):
|
|||
|
||||
def start_recording(event):
|
||||
"""Start recording."""
|
||||
self.start()
|
||||
self.start_recording.set()
|
||||
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_recording)
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown)
|
||||
hass.bus.listen(MATCH_ALL, self.event_listener)
|
||||
|
||||
self.get_session = None
|
||||
|
||||
def run(self):
|
||||
"""Start processing events to save."""
|
||||
from homeassistant.components.recorder.models import Events, States
|
||||
import sqlalchemy.exc
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
while True:
|
||||
try:
|
||||
self._setup_connection()
|
||||
self._setup_run()
|
||||
self.db_ready.set()
|
||||
break
|
||||
except sqlalchemy.exc.SQLAlchemyError as err:
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.error("Error during connection setup: %s (retrying "
|
||||
"in %s seconds)", err, CONNECT_RETRY_WAIT)
|
||||
time.sleep(CONNECT_RETRY_WAIT)
|
||||
|
@ -220,6 +239,8 @@ class Recorder(threading.Thread):
|
|||
async_track_time_interval(
|
||||
self.hass, self._purge_old_data, timedelta(days=2))
|
||||
|
||||
_wait(self.start_recording, "Waiting to start recording")
|
||||
|
||||
while True:
|
||||
event = self.queue.get()
|
||||
|
||||
|
@ -275,10 +296,9 @@ class Recorder(threading.Thread):
|
|||
def shutdown(self, event):
|
||||
"""Tell the recorder to shut down."""
|
||||
global _INSTANCE # pylint: disable=global-statement
|
||||
_INSTANCE = None
|
||||
|
||||
self.queue.put(None)
|
||||
self.join()
|
||||
_INSTANCE = None
|
||||
|
||||
def block_till_done(self):
|
||||
"""Block till all events processed."""
|
||||
|
@ -286,15 +306,10 @@ class Recorder(threading.Thread):
|
|||
|
||||
def block_till_db_ready(self):
|
||||
"""Block until the database session is ready."""
|
||||
self.db_ready.wait(10)
|
||||
while not self.db_ready.is_set():
|
||||
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
|
||||
self.db_ready.wait(10)
|
||||
_wait(self.db_ready, "Database not ready")
|
||||
|
||||
def _setup_connection(self):
|
||||
"""Ensure database is ready to fly."""
|
||||
global _SESSION # pylint: disable=invalid-name,global-statement
|
||||
|
||||
import homeassistant.components.recorder.models as models
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
@ -312,9 +327,8 @@ class Recorder(threading.Thread):
|
|||
|
||||
models.Base.metadata.create_all(self.engine)
|
||||
session_factory = sessionmaker(bind=self.engine)
|
||||
_SESSION = scoped_session(session_factory)
|
||||
self.get_session = scoped_session(session_factory)
|
||||
self._migrate_schema()
|
||||
self.db_ready.set()
|
||||
|
||||
def _migrate_schema(self):
|
||||
"""Check if the schema needs to be upgraded."""
|
||||
|
@ -396,16 +410,16 @@ class Recorder(threading.Thread):
|
|||
|
||||
def _close_connection(self):
|
||||
"""Close the connection."""
|
||||
global _SESSION # pylint: disable=invalid-name,global-statement
|
||||
self.engine.dispose()
|
||||
self.engine = None
|
||||
_SESSION = None
|
||||
self.get_session = None
|
||||
|
||||
def _setup_run(self):
|
||||
"""Log the start of the current run."""
|
||||
recorder_runs = get_model('RecorderRuns')
|
||||
with session_scope() as session:
|
||||
for run in query('RecorderRuns').filter_by(end=None):
|
||||
for run in query(
|
||||
recorder_runs, session=session).filter_by(end=None):
|
||||
run.closed_incorrect = True
|
||||
run.end = self.recording_start
|
||||
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
|
||||
|
@ -482,13 +496,13 @@ class Recorder(threading.Thread):
|
|||
return False
|
||||
|
||||
|
||||
def _verify_instance() -> None:
|
||||
"""Throw error if recorder not initialized."""
|
||||
if _INSTANCE is None:
|
||||
raise RuntimeError("Recorder not initialized.")
|
||||
|
||||
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
raise RuntimeError('Cannot be called from within the event loop')
|
||||
|
||||
_INSTANCE.block_till_db_ready()
|
||||
def _wait(event, message):
|
||||
"""Event wait helper."""
|
||||
for retry in (10, 20, 30):
|
||||
event.wait(10)
|
||||
if event.is_set():
|
||||
return
|
||||
msg = message + " ({} seconds)".format(retry)
|
||||
_LOGGER.warning(msg)
|
||||
if not event.is_set():
|
||||
raise HomeAssistantError(msg)
|
||||
|
|
|
@ -199,7 +199,7 @@ class HistoryStatsSensor(Entity):
|
|||
if self._start is not None:
|
||||
try:
|
||||
start_rendered = self._start.render()
|
||||
except TemplateError as ex:
|
||||
except (TemplateError, TypeError) as ex:
|
||||
HistoryStatsHelper.handle_template_exception(ex, 'start')
|
||||
return
|
||||
start = dt_util.parse_datetime(start_rendered)
|
||||
|
@ -216,7 +216,7 @@ class HistoryStatsSensor(Entity):
|
|||
if self._end is not None:
|
||||
try:
|
||||
end_rendered = self._end.render()
|
||||
except TemplateError as ex:
|
||||
except (TemplateError, TypeError) as ex:
|
||||
HistoryStatsHelper.handle_template_exception(ex, 'end')
|
||||
return
|
||||
end = dt_util.parse_datetime(end_rendered)
|
||||
|
|
|
@ -288,7 +288,7 @@ class Entity(object):
|
|||
self.hass.add_job(self.async_update_ha_state(force_refresh))
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Remove entitiy from HASS."""
|
||||
"""Remove entity from HASS."""
|
||||
run_coroutine_threadsafe(
|
||||
self.async_remove(), self.hass.loop
|
||||
).result()
|
||||
|
|
|
@ -202,6 +202,10 @@ class EntityComponent(object):
|
|||
'Invalid entity id: {}'.format(entity.entity_id))
|
||||
|
||||
self.entities[entity.entity_id] = entity
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
yield from entity.async_added_to_hass()
|
||||
|
||||
yield from entity.async_update_ha_state()
|
||||
|
||||
return True
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
"""Support for restoring entity states on startup."""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant.core import HomeAssistant, CoreState, callback
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.components.history import get_states, last_recorder_run
|
||||
from homeassistant.components.recorder import DOMAIN as _RECORDER
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_RESTORE_CACHE = 'restore_state_cache'
|
||||
_LOCK = 'restore_lock'
|
||||
|
||||
|
||||
def _load_restore_cache(hass: HomeAssistant):
|
||||
"""Load the restore cache to be used by other components."""
|
||||
@callback
|
||||
def remove_cache(event):
|
||||
"""Remove the states cache."""
|
||||
hass.data.pop(DATA_RESTORE_CACHE, None)
|
||||
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache)
|
||||
|
||||
last_run = last_recorder_run()
|
||||
|
||||
if last_run is None or last_run.end is None:
|
||||
_LOGGER.debug('Not creating cache - no suitable last run found: %s',
|
||||
last_run)
|
||||
hass.data[DATA_RESTORE_CACHE] = {}
|
||||
return
|
||||
|
||||
last_end_time = last_run.end - timedelta(seconds=1)
|
||||
# Unfortunately the recorder_run model do not return offset-aware time
|
||||
last_end_time = last_end_time.replace(tzinfo=dt_util.UTC)
|
||||
_LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time)
|
||||
|
||||
states = get_states(last_end_time, run=last_run)
|
||||
|
||||
# Cache the states
|
||||
hass.data[DATA_RESTORE_CACHE] = {
|
||||
state.entity_id: state for state in states}
|
||||
_LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE]))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_last_state(hass, entity_id: str):
|
||||
"""Helper to restore state."""
|
||||
if (_RECORDER not in hass.config.components or
|
||||
hass.state != CoreState.starting):
|
||||
return None
|
||||
|
||||
if DATA_RESTORE_CACHE in hass.data:
|
||||
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
|
||||
|
||||
if _LOCK not in hass.data:
|
||||
hass.data[_LOCK] = asyncio.Lock(loop=hass.loop)
|
||||
|
||||
with (yield from hass.data[_LOCK]):
|
||||
if DATA_RESTORE_CACHE not in hass.data:
|
||||
yield from hass.loop.run_in_executor(
|
||||
None, _load_restore_cache, hass)
|
||||
|
||||
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_restore_state(entity, extract_info):
|
||||
"""Helper to call entity.async_restore_state with cached info."""
|
||||
if entity.hass.state != CoreState.starting:
|
||||
_LOGGER.debug("Not restoring state: State is not starting: %s",
|
||||
entity.hass.state)
|
||||
return
|
||||
|
||||
state = yield from async_get_last_state(entity.hass, entity.entity_id)
|
||||
|
||||
if not state:
|
||||
return
|
||||
|
||||
yield from entity.async_restore_state(**extract_info(state))
|
|
@ -197,8 +197,8 @@ def load_order_components(components: Sequence[str]) -> OrderedSet:
|
|||
load_order.update(comp_load_order)
|
||||
|
||||
# Push some to first place in load order
|
||||
for comp in ('mqtt_eventstream', 'mqtt', 'logger',
|
||||
'recorder', 'introduction'):
|
||||
for comp in ('mqtt_eventstream', 'mqtt', 'recorder',
|
||||
'introduction', 'logger'):
|
||||
if comp in load_order:
|
||||
load_order.promote(comp)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from homeassistant.const import (
|
|||
STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED,
|
||||
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
||||
ATTR_DISCOVERED, SERVER_PORT)
|
||||
from homeassistant.components import sun, mqtt
|
||||
from homeassistant.components import sun, mqtt, recorder
|
||||
from homeassistant.components.http.auth import auth_middleware
|
||||
from homeassistant.components.http.const import (
|
||||
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
|
||||
|
@ -452,3 +452,31 @@ def assert_setup_component(count, domain=None):
|
|||
res_len = 0 if res is None else len(res)
|
||||
assert res_len == count, 'setup_component failed, expected {} got {}: {}' \
|
||||
.format(count, res_len, res)
|
||||
|
||||
|
||||
def init_recorder_component(hass, add_config=None, db_ready_callback=None):
|
||||
"""Initialize the recorder."""
|
||||
config = dict(add_config) if add_config else {}
|
||||
config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB
|
||||
|
||||
saved_recorder = recorder.Recorder
|
||||
|
||||
class Recorder2(saved_recorder):
|
||||
"""Recorder with a callback after db_ready."""
|
||||
|
||||
def _setup_connection(self):
|
||||
"""Setup the connection and run the callback."""
|
||||
super(Recorder2, self)._setup_connection()
|
||||
if db_ready_callback:
|
||||
_LOGGER.debug('db_ready_callback start (db_ready not set,'
|
||||
'never use get_instance in the callback)')
|
||||
db_ready_callback()
|
||||
_LOGGER.debug('db_ready_callback completed')
|
||||
|
||||
with patch('homeassistant.components.recorder.Recorder',
|
||||
side_effect=Recorder2):
|
||||
assert setup_component(hass, recorder.DOMAIN,
|
||||
{recorder.DOMAIN: config})
|
||||
assert recorder.DOMAIN in hass.config.components
|
||||
recorder.get_instance().block_till_db_ready()
|
||||
_LOGGER.info("In-memory recorder successfully started")
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
"""The tests for the demo light component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from homeassistant.bootstrap import setup_component
|
||||
from homeassistant.core import State, CoreState
|
||||
from homeassistant.bootstrap import setup_component, async_setup_component
|
||||
import homeassistant.components.light as light
|
||||
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
ENTITY_LIGHT = 'light.bed_light'
|
||||
|
||||
|
||||
class TestDemoClimate(unittest.TestCase):
|
||||
"""Test the demo climate hvac."""
|
||||
class TestDemoLight(unittest.TestCase):
|
||||
"""Test the demo light."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def setUp(self):
|
||||
|
@ -60,3 +63,36 @@ class TestDemoClimate(unittest.TestCase):
|
|||
light.turn_off(self.hass, ENTITY_LIGHT)
|
||||
self.hass.block_till_done()
|
||||
self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_restore_state(hass):
|
||||
"""Test state gets restored."""
|
||||
hass.config.components.add('recorder')
|
||||
hass.state = CoreState.starting
|
||||
hass.data[DATA_RESTORE_CACHE] = {
|
||||
'light.bed_light': State('light.bed_light', 'on', {
|
||||
'brightness': 'value-brightness',
|
||||
'color_temp': 'value-color_temp',
|
||||
'rgb_color': 'value-rgb_color',
|
||||
'xy_color': 'value-xy_color',
|
||||
'white_value': 'value-white_value',
|
||||
'effect': 'value-effect',
|
||||
}),
|
||||
}
|
||||
|
||||
yield from async_setup_component(hass, 'light', {
|
||||
'light': {
|
||||
'platform': 'demo',
|
||||
}})
|
||||
|
||||
state = hass.states.get('light.bed_light')
|
||||
assert state is not None
|
||||
assert state.entity_id == 'light.bed_light'
|
||||
assert state.state == 'on'
|
||||
assert state.attributes.get('brightness') == 'value-brightness'
|
||||
assert state.attributes.get('color_temp') == 'value-color_temp'
|
||||
assert state.attributes.get('rgb_color') == 'value-rgb_color'
|
||||
assert state.attributes.get('xy_color') == 'value-xy_color'
|
||||
assert state.attributes.get('white_value') == 'value-white_value'
|
||||
assert state.attributes.get('effect') == 'value-effect'
|
||||
|
|
|
@ -11,8 +11,7 @@ from sqlalchemy import create_engine
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.bootstrap import setup_component
|
||||
from tests.common import get_test_home_assistant
|
||||
from tests.common import get_test_home_assistant, init_recorder_component
|
||||
from tests.components.recorder import models_original
|
||||
|
||||
|
||||
|
@ -22,18 +21,15 @@ class BaseTestRecorder(unittest.TestCase):
|
|||
def setUp(self): # pylint: disable=invalid-name
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
db_uri = 'sqlite://' # In memory DB
|
||||
setup_component(self.hass, recorder.DOMAIN, {
|
||||
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
|
||||
init_recorder_component(self.hass)
|
||||
self.hass.start()
|
||||
recorder._verify_instance()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
recorder.get_instance().block_till_done()
|
||||
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
"""Stop everything that was started."""
|
||||
recorder._INSTANCE.shutdown(None)
|
||||
self.hass.stop()
|
||||
assert recorder._INSTANCE is None
|
||||
with self.assertRaises(RuntimeError):
|
||||
recorder.get_instance()
|
||||
|
||||
def _add_test_states(self):
|
||||
"""Add multiple states to the db for testing."""
|
||||
|
@ -228,7 +224,7 @@ class TestMigrateRecorder(BaseTestRecorder):
|
|||
|
||||
@patch('sqlalchemy.create_engine', new=create_engine_test)
|
||||
@patch('homeassistant.components.recorder.Recorder._migrate_schema')
|
||||
def setUp(self, migrate): # pylint: disable=invalid-name
|
||||
def setUp(self, migrate): # pylint: disable=invalid-name,arguments-differ
|
||||
"""Setup things to be run when tests are started.
|
||||
|
||||
create_engine is patched to create a db that starts with the old
|
||||
|
@ -261,16 +257,12 @@ def hass_recorder():
|
|||
"""HASS fixture with in-memory recorder."""
|
||||
hass = get_test_home_assistant()
|
||||
|
||||
def setup_recorder(config={}):
|
||||
def setup_recorder(config=None):
|
||||
"""Setup with params."""
|
||||
db_uri = 'sqlite://' # In memory DB
|
||||
conf = {recorder.CONF_DB_URL: db_uri}
|
||||
conf.update(config)
|
||||
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: conf})
|
||||
init_recorder_component(hass, config)
|
||||
hass.start()
|
||||
hass.block_till_done()
|
||||
recorder._verify_instance()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
recorder.get_instance().block_till_done()
|
||||
return hass
|
||||
|
||||
yield setup_recorder
|
||||
|
@ -352,12 +344,12 @@ def test_recorder_errors_exceptions(hass_recorder): \
|
|||
|
||||
# Verify the instance fails before setup
|
||||
with pytest.raises(RuntimeError):
|
||||
recorder._verify_instance()
|
||||
recorder.get_instance()
|
||||
|
||||
# Setup the recorder
|
||||
hass_recorder()
|
||||
|
||||
recorder._verify_instance()
|
||||
recorder.get_instance()
|
||||
|
||||
# Verify session scope raises (and prints) an exception
|
||||
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
"""The test for the History Statistics sensor platform."""
|
||||
# pylint: disable=protected-access
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import homeassistant.components.recorder as recorder
|
||||
import homeassistant.core as ha
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.bootstrap import setup_component
|
||||
import homeassistant.components.recorder as recorder
|
||||
from homeassistant.components.sensor.history_stats import HistoryStatsSensor
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.helpers.template import Template
|
||||
from tests.common import get_test_home_assistant
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import init_recorder_component, get_test_home_assistant
|
||||
|
||||
|
||||
class TestHistoryStatsSensor(unittest.TestCase):
|
||||
|
@ -204,12 +205,8 @@ class TestHistoryStatsSensor(unittest.TestCase):
|
|||
|
||||
def init_recorder(self):
|
||||
"""Initialize the recorder."""
|
||||
db_uri = 'sqlite://'
|
||||
with patch('homeassistant.core.Config.path', return_value=db_uri):
|
||||
setup_component(self.hass, recorder.DOMAIN, {
|
||||
"recorder": {
|
||||
"db_url": db_uri}})
|
||||
init_recorder_component(self.hass)
|
||||
self.hass.start()
|
||||
recorder._INSTANCE.block_till_db_ready()
|
||||
recorder.get_instance().block_till_db_ready()
|
||||
self.hass.block_till_done()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
recorder.get_instance().block_till_done()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""The tests the History component."""
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=protected-access,invalid-name
|
||||
from datetime import timedelta
|
||||
import unittest
|
||||
from unittest.mock import patch, sentinel
|
||||
|
@ -10,68 +10,47 @@ import homeassistant.util.dt as dt_util
|
|||
from homeassistant.components import history, recorder
|
||||
|
||||
from tests.common import (
|
||||
mock_http_component, mock_state_change_event, get_test_home_assistant)
|
||||
init_recorder_component, mock_http_component, mock_state_change_event,
|
||||
get_test_home_assistant)
|
||||
|
||||
|
||||
class TestComponentHistory(unittest.TestCase):
|
||||
"""Test History component."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def setUp(self):
|
||||
def setUp(self): # pylint: disable=invalid-name
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def tearDown(self):
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
"""Stop everything that was started."""
|
||||
self.hass.stop()
|
||||
|
||||
def init_recorder(self):
|
||||
"""Initialize the recorder."""
|
||||
db_uri = 'sqlite://'
|
||||
with patch('homeassistant.core.Config.path', return_value=db_uri):
|
||||
setup_component(self.hass, recorder.DOMAIN, {
|
||||
"recorder": {
|
||||
"db_url": db_uri}})
|
||||
init_recorder_component(self.hass)
|
||||
self.hass.start()
|
||||
recorder._INSTANCE.block_till_db_ready()
|
||||
recorder.get_instance().block_till_db_ready()
|
||||
self.wait_recording_done()
|
||||
|
||||
def wait_recording_done(self):
|
||||
"""Block till recording is done."""
|
||||
self.hass.block_till_done()
|
||||
recorder._INSTANCE.block_till_done()
|
||||
recorder.get_instance().block_till_done()
|
||||
|
||||
def test_setup(self):
|
||||
"""Test setup method of history."""
|
||||
mock_http_component(self.hass)
|
||||
config = history.CONFIG_SCHEMA({
|
||||
ha.DOMAIN: {},
|
||||
history.DOMAIN: {history.CONF_INCLUDE: {
|
||||
# ha.DOMAIN: {},
|
||||
history.DOMAIN: {
|
||||
history.CONF_INCLUDE: {
|
||||
history.CONF_DOMAINS: ['media_player'],
|
||||
history.CONF_ENTITIES: ['thermostat.test']},
|
||||
history.CONF_EXCLUDE: {
|
||||
history.CONF_DOMAINS: ['thermostat'],
|
||||
history.CONF_ENTITIES: ['media_player.test']}}})
|
||||
self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
|
||||
|
||||
def test_last_5_states(self):
|
||||
"""Test retrieving the last 5 states."""
|
||||
self.init_recorder()
|
||||
states = []
|
||||
|
||||
entity_id = 'test.last_5_states'
|
||||
|
||||
for i in range(7):
|
||||
self.hass.states.set(entity_id, "State {}".format(i))
|
||||
|
||||
self.wait_recording_done()
|
||||
|
||||
if i > 1:
|
||||
states.append(self.hass.states.get(entity_id))
|
||||
|
||||
self.assertEqual(
|
||||
list(reversed(states)), history.last_5_states(entity_id))
|
||||
self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
|
||||
|
||||
def test_get_states(self):
|
||||
"""Test getting states at a specific point in time."""
|
||||
|
@ -121,6 +100,7 @@ class TestComponentHistory(unittest.TestCase):
|
|||
entity_id = 'media_player.test'
|
||||
|
||||
def set_state(state):
|
||||
"""Set the state."""
|
||||
self.hass.states.set(entity_id, state)
|
||||
self.wait_recording_done()
|
||||
return self.hass.states.get(entity_id)
|
||||
|
@ -311,7 +291,8 @@ class TestComponentHistory(unittest.TestCase):
|
|||
|
||||
config = history.CONFIG_SCHEMA({
|
||||
ha.DOMAIN: {},
|
||||
history.DOMAIN: {history.CONF_INCLUDE: {
|
||||
history.DOMAIN: {
|
||||
history.CONF_INCLUDE: {
|
||||
history.CONF_DOMAINS: ['media_player']},
|
||||
history.CONF_EXCLUDE: {
|
||||
history.CONF_DOMAINS: ['media_player']}}})
|
||||
|
@ -332,7 +313,8 @@ class TestComponentHistory(unittest.TestCase):
|
|||
|
||||
config = history.CONFIG_SCHEMA({
|
||||
ha.DOMAIN: {},
|
||||
history.DOMAIN: {history.CONF_INCLUDE: {
|
||||
history.DOMAIN: {
|
||||
history.CONF_INCLUDE: {
|
||||
history.CONF_ENTITIES: ['media_player.test']},
|
||||
history.CONF_EXCLUDE: {
|
||||
history.CONF_ENTITIES: ['media_player.test']}}})
|
||||
|
@ -351,7 +333,8 @@ class TestComponentHistory(unittest.TestCase):
|
|||
|
||||
config = history.CONFIG_SCHEMA({
|
||||
ha.DOMAIN: {},
|
||||
history.DOMAIN: {history.CONF_INCLUDE: {
|
||||
history.DOMAIN: {
|
||||
history.CONF_INCLUDE: {
|
||||
history.CONF_DOMAINS: ['media_player'],
|
||||
history.CONF_ENTITIES: ['thermostat.test']},
|
||||
history.CONF_EXCLUDE: {
|
||||
|
@ -359,7 +342,8 @@ class TestComponentHistory(unittest.TestCase):
|
|||
history.CONF_ENTITIES: ['media_player.test']}}})
|
||||
self.check_significant_states(zero, four, states, config)
|
||||
|
||||
def check_significant_states(self, zero, four, states, config):
|
||||
def check_significant_states(self, zero, four, states, config): \
|
||||
# pylint: disable=no-self-use
|
||||
"""Check if significant states are retrieved."""
|
||||
filters = history.Filters()
|
||||
exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE)
|
||||
|
@ -390,6 +374,7 @@ class TestComponentHistory(unittest.TestCase):
|
|||
script_c = 'script.can_cancel_this_one'
|
||||
|
||||
def set_state(entity_id, state, **kwargs):
|
||||
"""Set the state."""
|
||||
self.hass.states.set(entity_id, state, **kwargs)
|
||||
self.wait_recording_done()
|
||||
return self.hass.states.get(entity_id)
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
"""The tests for the input_boolean component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
import unittest
|
||||
import logging
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
from homeassistant.bootstrap import setup_component
|
||||
from homeassistant.core import CoreState, State
|
||||
from homeassistant.bootstrap import setup_component, async_setup_component
|
||||
from homeassistant.components.input_boolean import (
|
||||
DOMAIN, is_on, toggle, turn_off, turn_on)
|
||||
from homeassistant.const import (
|
||||
STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME)
|
||||
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -103,3 +106,30 @@ class TestInputBoolean(unittest.TestCase):
|
|||
self.assertEqual('Hello World',
|
||||
state_2.attributes.get(ATTR_FRIENDLY_NAME))
|
||||
self.assertEqual('mdi:work', state_2.attributes.get(ATTR_ICON))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_restore_state(hass):
|
||||
"""Ensure states are restored on startup."""
|
||||
hass.data[DATA_RESTORE_CACHE] = {
|
||||
'input_boolean.b1': State('input_boolean.b1', 'on'),
|
||||
'input_boolean.b2': State('input_boolean.b2', 'off'),
|
||||
'input_boolean.b3': State('input_boolean.b3', 'on'),
|
||||
}
|
||||
|
||||
hass.state = CoreState.starting
|
||||
hass.config.components.add('recorder')
|
||||
|
||||
yield from async_setup_component(hass, DOMAIN, {
|
||||
DOMAIN: {
|
||||
'b1': None,
|
||||
'b2': None,
|
||||
}})
|
||||
|
||||
state = hass.states.get('input_boolean.b1')
|
||||
assert state
|
||||
assert state.state == 'on'
|
||||
|
||||
state = hass.states.get('input_boolean.b2')
|
||||
assert state
|
||||
assert state.state == 'off'
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""The tests for the logbook component."""
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=protected-access,invalid-name
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
@ -13,7 +14,11 @@ import homeassistant.util.dt as dt_util
|
|||
from homeassistant.components import logbook
|
||||
from homeassistant.bootstrap import setup_component
|
||||
|
||||
from tests.common import mock_http_component, get_test_home_assistant
|
||||
from tests.common import (
|
||||
mock_http_component, init_recorder_component, get_test_home_assistant)
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestComponentLogbook(unittest.TestCase):
|
||||
|
@ -24,12 +29,14 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
def setUp(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
init_recorder_component(self.hass) # Force an in memory DB
|
||||
mock_http_component(self.hass)
|
||||
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
|
||||
with patch('homeassistant.components.logbook.'
|
||||
'register_built_in_panel'):
|
||||
assert setup_component(self.hass, logbook.DOMAIN,
|
||||
self.EMPTY_CONFIG)
|
||||
self.hass.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Stop everything that was started."""
|
||||
|
@ -41,6 +48,7 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
|
||||
@ha.callback
|
||||
def event_listener(event):
|
||||
"""Append on event."""
|
||||
calls.append(event)
|
||||
|
||||
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
|
||||
|
@ -72,6 +80,7 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
|
||||
@ha.callback
|
||||
def event_listener(event):
|
||||
"""Append on event."""
|
||||
calls.append(event)
|
||||
|
||||
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
|
||||
|
@ -242,17 +251,17 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
entity_id2 = 'sensor.blu'
|
||||
|
||||
eventA = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
|
||||
logbook.ATTR_NAME: name,
|
||||
logbook.ATTR_MESSAGE: message,
|
||||
logbook.ATTR_DOMAIN: domain,
|
||||
logbook.ATTR_ENTITY_ID: entity_id,
|
||||
})
|
||||
logbook.ATTR_NAME: name,
|
||||
logbook.ATTR_MESSAGE: message,
|
||||
logbook.ATTR_DOMAIN: domain,
|
||||
logbook.ATTR_ENTITY_ID: entity_id,
|
||||
})
|
||||
eventB = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
|
||||
logbook.ATTR_NAME: name,
|
||||
logbook.ATTR_MESSAGE: message,
|
||||
logbook.ATTR_DOMAIN: domain,
|
||||
logbook.ATTR_ENTITY_ID: entity_id2,
|
||||
})
|
||||
logbook.ATTR_NAME: name,
|
||||
logbook.ATTR_MESSAGE: message,
|
||||
logbook.ATTR_DOMAIN: domain,
|
||||
logbook.ATTR_ENTITY_ID: entity_id2,
|
||||
})
|
||||
|
||||
config = logbook.CONFIG_SCHEMA({
|
||||
ha.DOMAIN: {},
|
||||
|
@ -532,7 +541,8 @@ class TestComponentLogbook(unittest.TestCase):
|
|||
|
||||
def create_state_changed_event(self, event_time_fired, entity_id, state,
|
||||
attributes=None, last_changed=None,
|
||||
last_updated=None):
|
||||
last_updated=None): \
|
||||
# pylint: disable=no-self-use
|
||||
"""Create state changed event."""
|
||||
# Logbook only cares about state change events that
|
||||
# contain an old state but will not actually act on it.
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
"""The tests for the Restore component."""
|
||||
import asyncio
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.core import CoreState, State
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from homeassistant.helpers.restore_state import (
|
||||
async_get_last_state, DATA_RESTORE_CACHE)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_caching_data(hass):
|
||||
"""Test that we cache data."""
|
||||
hass.config.components.add('recorder')
|
||||
hass.state = CoreState.starting
|
||||
|
||||
states = [
|
||||
State('input_boolean.b0', 'on'),
|
||||
State('input_boolean.b1', 'on'),
|
||||
State('input_boolean.b2', 'on'),
|
||||
]
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=MagicMock(end=dt_util.utcnow())), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
|
||||
assert DATA_RESTORE_CACHE in hass.data
|
||||
assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states}
|
||||
|
||||
assert state is not None
|
||||
assert state.entity_id == 'input_boolean.b1'
|
||||
assert state.state == 'on'
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert DATA_RESTORE_CACHE not in hass.data
|
Loading…
Reference in New Issue