diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index fe9a6c82825..4267127e209 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -8,7 +8,8 @@ import time from typing import Optional, cast from aiohttp import web -from sqlalchemy import and_, func +from sqlalchemy import and_, bindparam, func +from sqlalchemy.ext import baked import voluptuous as vol from homeassistant.components import recorder @@ -88,6 +89,8 @@ QUERY_STATES = [ States.last_updated, ] +HISTORY_BAKERY = "history_bakery" + def get_significant_states(hass, *args, **kwargs): """Wrap _get_significant_states with a sql session.""" @@ -115,26 +118,34 @@ def _get_significant_states( """ timer_start = time.perf_counter() + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + if significant_changes_only: - query = session.query(*QUERY_STATES).filter( + baked_query += lambda q: q.filter( ( States.domain.in_(SIGNIFICANT_DOMAINS) | (States.last_changed == States.last_updated) ) - & (States.last_updated > start_time) + & (States.last_updated > bindparam("start_time")) ) else: - query = session.query(*QUERY_STATES).filter(States.last_updated > start_time) + baked_query += lambda q: q.filter(States.last_updated > bindparam("start_time")) if filters: - query = filters.apply(query, entity_ids) + filters.bake(baked_query, entity_ids) if end_time is not None: - query = query.filter(States.last_updated < end_time) + baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time")) - query = query.order_by(States.entity_id, States.last_updated) + baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) - states = execute(query) + states = execute( + baked_query(session).params( + start_time=start_time, end_time=end_time, entity_ids=entity_ids + ) + ) if _LOGGER.isEnabledFor(logging.DEBUG): elapsed = time.perf_counter() - timer_start @@ -155,21 +166,34 @@ def _get_significant_states( def state_changes_during_period(hass, start_time, end_time=None, entity_id=None): """Return states changes during UTC period start_time - end_time.""" with session_scope(hass=hass) as session: - query = session.query(*QUERY_STATES).filter( + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + + baked_query += lambda q: q.filter( (States.last_changed == States.last_updated) - & (States.last_updated > start_time) + & (States.last_updated > bindparam("start_time")) ) if end_time is not None: - query = query.filter(States.last_updated < end_time) + baked_query += lambda q: q.filter( + States.last_updated < bindparam("end_time") + ) if entity_id is not None: - query = query.filter_by(entity_id=entity_id.lower()) + baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) + entity_id = entity_id.lower() + + baked_query += lambda q: q.order_by(States.entity_id, States.last_updated) + + states = execute( + baked_query(session).params( + start_time=start_time, end_time=end_time, entity_id=entity_id + ) + ) entity_ids = [entity_id] if entity_id is not None else None - states = execute(query.order_by(States.entity_id, States.last_updated)) - return _sorted_states_to_json(hass, session, states, start_time, entity_ids) @@ -178,21 +202,29 @@ def get_last_state_changes(hass, number_of_states, entity_id): start_time = dt_util.utcnow() with session_scope(hass=hass) as session: - query = session.query(*QUERY_STATES).filter( - States.last_changed == States.last_updated + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) ) + baked_query += lambda q: q.filter(States.last_changed == States.last_updated) if entity_id is not None: - query = query.filter_by(entity_id=entity_id.lower()) + baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) + entity_id = entity_id.lower() - entity_ids = [entity_id] if entity_id is not None else None + baked_query += lambda q: q.order_by( + States.entity_id, States.last_updated.desc() + ) + + baked_query += lambda q: q.limit(bindparam("number_of_states")) states = execute( - query.order_by(States.entity_id, States.last_updated.desc()).limit( - number_of_states + baked_query(session).params( + number_of_states=number_of_states, entity_id=entity_id ) ) + entity_ids = [entity_id] if entity_id is not None else None + return _sorted_states_to_json( hass, session, @@ -214,28 +246,18 @@ def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None) with session_scope(hass=hass) as session: return _get_states_with_session( - session, utc_point_in_time, entity_ids, run, filters + hass, session, utc_point_in_time, entity_ids, run, filters ) def _get_states_with_session( - session, utc_point_in_time, entity_ids=None, run=None, filters=None + hass, session, utc_point_in_time, entity_ids=None, run=None, filters=None ): """Return the states at a specific point in time.""" - query = session.query(*QUERY_STATES) - if entity_ids and len(entity_ids) == 1: - # Use an entirely different (and extremely fast) query if we only - # have a single entity id - query = ( - query.filter( - States.last_updated < utc_point_in_time, - States.entity_id.in_(entity_ids), - ) - .order_by(States.last_updated.desc()) - .limit(1) + return _get_single_entity_states_with_session( + hass, session, utc_point_in_time, entity_ids[0] ) - return [LazyState(row) for row in execute(query)] if run is None: run = recorder.run_information_with_session(session, utc_point_in_time) @@ -247,6 +269,7 @@ def _get_states_with_session( # We have more than one entity to look at (most commonly we want # all entities,) so we need to do a search on all states since the # last recorder run started. + query = session.query(*QUERY_STATES) most_recent_states_by_date = session.query( States.entity_id.label("max_entity_id"), @@ -286,6 +309,26 @@ def _get_states_with_session( return [LazyState(row) for row in execute(query)] +def _get_single_entity_states_with_session(hass, session, utc_point_in_time, entity_id): + # Use an entirely different (and extremely fast) query if we only + # have a single entity id + baked_query = hass.data[HISTORY_BAKERY]( + lambda session: session.query(*QUERY_STATES) + ) + baked_query += lambda q: q.filter( + States.last_updated < bindparam("utc_point_in_time"), + States.entity_id == bindparam("entity_id"), + ) + baked_query += lambda q: q.order_by(States.last_updated.desc()) + baked_query += lambda q: q.limit(1) + + query = baked_query(session).params( + utc_point_in_time=utc_point_in_time, entity_id=entity_id + ) + + return [LazyState(row) for row in execute(query)] + + def _sorted_states_to_json( hass, session, @@ -318,7 +361,7 @@ def _sorted_states_to_json( if include_start_time_state: run = recorder.run_information_from_instance(hass, start_time) for state in _get_states_with_session( - session, start_time, entity_ids, run=run, filters=filters + hass, session, start_time, entity_ids, run=run, filters=filters ): state.last_changed = start_time state.last_updated = start_time @@ -337,16 +380,16 @@ def _sorted_states_to_json( domain = split_entity_id(ent_id)[0] ent_results = result[ent_id] if not minimal_response or domain in NEED_ATTRIBUTE_DOMAINS: - ent_results.extend( - [ - native_state - for native_state in (LazyState(db_state) for db_state in group) - if ( - domain != SCRIPT_DOMAIN - or native_state.attributes.get(ATTR_CAN_CANCEL) - ) - ] - ) + if domain == SCRIPT_DOMAIN: + ent_results.extend( + [ + native_state + for native_state in (LazyState(db_state) for db_state in group) + if native_state.attributes.get(ATTR_CAN_CANCEL) + ] + ) + else: + ent_results.extend(LazyState(db_state) for db_state in group) continue # With minimal response we only provide a native @@ -387,7 +430,7 @@ def _sorted_states_to_json( def get_state(hass, utc_point_in_time, entity_id, run=None): """Return a state at a specific point in time.""" - states = list(get_states(hass, utc_point_in_time, (entity_id,), run)) + states = get_states(hass, utc_point_in_time, (entity_id,), run) return states[0] if states else None @@ -396,6 +439,9 @@ async def async_setup(hass, config): conf = config.get(DOMAIN, {}) filters = sqlalchemy_filter_from_include_exclude_conf(conf) + + hass.data[HISTORY_BAKERY] = baked.bakery() + use_include_order = conf.get(CONF_ORDER) hass.http.register_view(HistoryPeriodView(filters, use_include_order)) @@ -560,6 +606,7 @@ class Filters: # specific entities requested - do not in/exclude anything if entity_ids is not None: return query.filter(States.entity_id.in_(entity_ids)) + query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) entity_filter = self.entity_filter() @@ -568,6 +615,27 @@ class Filters: return query + def bake(self, baked_query, entity_ids=None): + """Update a baked query. + + Works the same as apply on a baked_query. + """ + if entity_ids is not None: + baked_query += lambda q: q.filter( + States.entity_id.in_(bindparam("entity_ids", expanding=True)) + ) + return + + baked_query += lambda q: q.filter(~States.domain.in_(IGNORE_DOMAINS)) + + if ( + self.excluded_entities + or self.excluded_domains + or self.included_entities + or self.included_domains + ): + baked_query += lambda q: q.filter(self.entity_filter()) + def entity_filter(self): """Generate the entity filter query.""" entity_filter = None diff --git a/tests/components/history/test_init.py b/tests/components/history/test_init.py index 34b22481400..56318f3e9fb 100644 --- a/tests/components/history/test_init.py +++ b/tests/components/history/test_init.py @@ -61,7 +61,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_states(self): """Test getting states at a specific point in time.""" - self.init_recorder() + self.test_setup() states = [] now = dt_util.utcnow() @@ -115,7 +115,7 @@ class TestComponentHistory(unittest.TestCase): def test_state_changes_during_period(self): """Test state change during period.""" - self.init_recorder() + self.test_setup() entity_id = "media_player.test" def set_state(state): @@ -156,7 +156,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_last_state_changes(self): """Test number of state changes.""" - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state): @@ -195,7 +195,7 @@ class TestComponentHistory(unittest.TestCase): The filter integration uses copy() on states from history. """ - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state): @@ -608,7 +608,7 @@ class TestComponentHistory(unittest.TestCase): def test_get_significant_states_only(self): """Test significant states when significant_states_only is set.""" - self.init_recorder() + self.test_setup() entity_id = "sensor.test" def set_state(state, **kwargs): @@ -683,7 +683,7 @@ class TestComponentHistory(unittest.TestCase): We inject a bunch of state updates from media player, zone and thermostat. """ - self.init_recorder() + self.test_setup() mp = "media_player.test" mp2 = "media_player.test2" mp3 = "media_player.test3"