From 2fabd4edb891f33929b16411e2767fee7692e559 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 6 Oct 2020 00:25:05 -0500 Subject: [PATCH] Prevent collecting states already referenced by domain or all (#41308) The template engine would collect all the states in a domain or all states while iterating even though they were already included in all or the domain This lead to the rate limit not being applied to templates that iterated all states that also accessed a collectable property because the engine incorrectly believed they were specifically referenced. --- homeassistant/helpers/template.py | 11 +++-- tests/helpers/test_event.py | 76 +++++++++++++++++++++++++++++- tests/helpers/test_template.py | 77 ++++++++++++++----------------- 3 files changed, 115 insertions(+), 49 deletions(-) diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index c0ceb17fdd7..9c849bee22e 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -587,17 +587,18 @@ class DomainStates: class TemplateState(State): """Class to represent a state object in a template.""" - __slots__ = ("_hass", "_state") + __slots__ = ("_hass", "_state", "_collect") # Inheritance is done so functions that check against State keep working # pylint: disable=super-init-not-called - def __init__(self, hass, state): + def __init__(self, hass, state, collect=True): """Initialize template state.""" self._hass = hass self._state = state + self._collect = collect def _collect_state(self): - if _RENDER_INFO in self._hass.data: + if self._collect and _RENDER_INFO in self._hass.data: self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id) # Jinja will try __getitem__ first and it avoids the need @@ -606,7 +607,7 @@ class TemplateState(State): """Return a property as an attribute for jinja.""" if item in _COLLECTABLE_STATE_ATTRIBUTES: # _collect_state inlined here for performance - if _RENDER_INFO in self._hass.data: + if self._collect and _RENDER_INFO in self._hass.data: self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id) return getattr(self._state, item) if item == "entity_id": @@ -697,7 +698,7 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None: def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator: """State generator for a domain or all states.""" for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")): - yield TemplateState(hass, state) + yield TemplateState(hass, state, collect=False) def _get_state_if_valid( diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index bf0efaf8cf6..bb0d17d7b0e 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -1284,7 +1284,7 @@ async def test_track_template_result_iterator(hass): assert info.listeners == { "all": False, "domains": {"sensor"}, - "entities": {"sensor.test"}, + "entities": set(), } hass.states.async_set("sensor.test", 6) @@ -1488,6 +1488,80 @@ async def test_track_template_rate_limit_five(hass): assert refresh_runs == ["0", "1"] +async def test_track_template_has_default_rate_limit(hass): + """Test template has a rate limit by default.""" + hass.states.async_set("sensor.zero", "any") + template_refresh = Template("{{ states | list | count }}", hass) + + refresh_runs = [] + + @ha.callback + def refresh_listener(event, updates): + refresh_runs.append(updates.pop().result) + + info = async_track_template_result( + hass, + [TrackTemplate(template_refresh, None)], + refresh_listener, + ) + await hass.async_block_till_done() + info.async_refresh() + await hass.async_block_till_done() + + assert refresh_runs == ["1"] + hass.states.async_set("sensor.one", "any") + await hass.async_block_till_done() + assert refresh_runs == ["1"] + info.async_refresh() + assert refresh_runs == ["1", "2"] + hass.states.async_set("sensor.two", "any") + await hass.async_block_till_done() + assert refresh_runs == ["1", "2"] + hass.states.async_set("sensor.three", "any") + await hass.async_block_till_done() + assert refresh_runs == ["1", "2"] + + +async def test_track_template_unavailable_sates_has_default_rate_limit(hass): + """Test template watching for unavailable states has a rate limit by default.""" + hass.states.async_set("sensor.zero", "unknown") + template_refresh = Template( + "{{ states | selectattr('state', 'in', ['unavailable', 'unknown', 'none']) | list | count }}", + hass, + ) + + refresh_runs = [] + + @ha.callback + def refresh_listener(event, updates): + refresh_runs.append(updates.pop().result) + + info = async_track_template_result( + hass, + [TrackTemplate(template_refresh, None)], + refresh_listener, + ) + await hass.async_block_till_done() + info.async_refresh() + await hass.async_block_till_done() + + assert refresh_runs == ["1"] + hass.states.async_set("sensor.one", "unknown") + await hass.async_block_till_done() + assert refresh_runs == ["1"] + info.async_refresh() + assert refresh_runs == ["1", "2"] + hass.states.async_set("sensor.two", "any") + await hass.async_block_till_done() + assert refresh_runs == ["1", "2"] + hass.states.async_set("sensor.three", "unknown") + await hass.async_block_till_done() + assert refresh_runs == ["1", "2"] + info.async_refresh() + await hass.async_block_till_done() + assert refresh_runs == ["1", "2", "3"] + + async def test_specifically_referenced_entity_is_not_rate_limited(hass): """Test template rate limit of 5 seconds.""" hass.states.async_set("sensor.one", "none") diff --git a/tests/helpers/test_template.py b/tests/helpers/test_template.py index 7d39675bdc0..5535fa53993 100644 --- a/tests/helpers/test_template.py +++ b/tests/helpers/test_template.py @@ -155,9 +155,25 @@ def test_iterating_all_states(hass): hass.states.async_set("sensor.temperature", 10) info = render_to_info(hass, tmpl_str) - assert_result_info( - info, "10happy", entities=["test.object", "sensor.temperature"], all_states=True - ) + assert_result_info(info, "10happy", entities=[], all_states=True) + + +def test_iterating_all_states_unavailable(hass): + """Test iterating all states unavailable.""" + hass.states.async_set("test.object", "on") + + tmpl_str = "{{ states | selectattr('state', 'in', ['unavailable', 'unknown', 'none']) | list | count }}" + + info = render_to_info(hass, tmpl_str) + + assert info.all_states is True + assert info.rate_limit == template.DEFAULT_RATE_LIMIT + + hass.states.async_set("test.object", "unknown") + hass.states.async_set("sensor.temperature", 10) + + info = render_to_info(hass, tmpl_str) + assert_result_info(info, "1", entities=[], all_states=True) def test_iterating_domain_states(hass): @@ -176,7 +192,7 @@ def test_iterating_domain_states(hass): assert_result_info( info, "open10", - entities=["sensor.back_door", "sensor.temperature"], + entities=[], domains=["sensor"], ) @@ -1426,9 +1442,7 @@ async def test_expand(hass): info = render_to_info( hass, "{{ expand(states.group) | map(attribute='entity_id') | join(', ') }}" ) - assert_result_info( - info, "test.object", {"test.object", "group.new_group"}, ["group"] - ) + assert_result_info(info, "test.object", {"test.object"}, ["group"]) assert info.rate_limit == template.DEFAULT_RATE_LIMIT info = render_to_info( @@ -1587,7 +1601,7 @@ async def test_async_render_to_info_with_wildcard_matching_entity_id(hass): """Test tracking template with a wildcard.""" template_complex_str = r""" -{% for state in states %} +{% for state in states.cover %} {% if state.entity_id | regex_match('.*\.office_') %} {{ state.entity_id }}={{ state.state }} {% endif %} @@ -1599,13 +1613,9 @@ async def test_async_render_to_info_with_wildcard_matching_entity_id(hass): hass.states.async_set("cover.office_skylight", "open") info = render_to_info(hass, template_complex_str) - assert not info.domains - assert info.entities == { - "cover.office_drapes", - "cover.office_window", - "cover.office_skylight", - } - assert info.all_states is True + assert info.domains == {"cover"} + assert info.entities == set() + assert info.all_states is False assert info.rate_limit == template.DEFAULT_RATE_LIMIT @@ -1629,13 +1639,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass): info = render_to_info(hass, template_complex_str) assert not info.domains - assert info.entities == { - "cover.x_skylight", - "binary_sensor.door", - "cover.office_drapes", - "cover.office_window", - "cover.office_skylight", - } + assert info.entities == set() assert info.all_states is True assert info.rate_limit == template.DEFAULT_RATE_LIMIT @@ -1643,13 +1647,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass): info = render_to_info(hass, template_complex_str) assert not info.domains - assert info.entities == { - "cover.x_skylight", - "binary_sensor.door", - "cover.office_drapes", - "cover.office_window", - "cover.office_skylight", - } + assert info.entities == set() assert info.all_states is True assert info.rate_limit == template.DEFAULT_RATE_LIMIT @@ -1666,12 +1664,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass): info = render_to_info(hass, template_cover_str) assert info.domains == {"cover"} - assert info.entities == { - "cover.x_skylight", - "cover.office_drapes", - "cover.office_window", - "cover.office_skylight", - } + assert info.entities == set() assert info.all_states is False assert info.rate_limit == template.DEFAULT_RATE_LIMIT @@ -1965,9 +1958,7 @@ def test_generate_filter_iterators(hass): {% endfor %} """, ) - assert_result_info( - info, "sensor.test_sensor=off,", ["sensor.test_sensor"], ["sensor"] - ) + assert_result_info(info, "sensor.test_sensor=off,", [], ["sensor"]) info = render_to_info( hass, @@ -1977,9 +1968,7 @@ def test_generate_filter_iterators(hass): {% endfor %} """, ) - assert_result_info( - info, "sensor.test_sensor=value,", ["sensor.test_sensor"], ["sensor"] - ) + assert_result_info(info, "sensor.test_sensor=value,", [], ["sensor"]) def test_generate_select(hass): @@ -2001,7 +1990,7 @@ def test_generate_select(hass): assert_result_info( info, "sensor.test_sensor", - ["sensor.test_sensor", "sensor.test_sensor_on"], + [], ["sensor"], ) assert info.domains_lifecycle == {"sensor"} @@ -2542,7 +2531,9 @@ async def test_lights(hass): tmp = template.Template(tmpl, hass) info = tmp.async_render_to_info() - assert info.entities == set(states) + assert info.entities == set() + assert info.domains == {"light"} + assert "lights are on" in info.result() for i in range(10): assert f"sensor{i}" in info.result()