Prevent collecting states already referenced by domain or all (#41308)

The template engine would collect all the states in
a domain or all states while iterating even though
they were already included in all or the domain

This lead to the rate limit not being applied to
templates that iterated all states that also
accessed a collectable property because the engine
incorrectly believed they were specifically
referenced.
pull/41325/head
J. Nick Koston 2020-10-06 00:25:05 -05:00 committed by GitHub
parent 999eeb39b9
commit 2fabd4edb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 49 deletions

View File

@ -587,17 +587,18 @@ class DomainStates:
class TemplateState(State):
"""Class to represent a state object in a template."""
__slots__ = ("_hass", "_state")
__slots__ = ("_hass", "_state", "_collect")
# Inheritance is done so functions that check against State keep working
# pylint: disable=super-init-not-called
def __init__(self, hass, state):
def __init__(self, hass, state, collect=True):
"""Initialize template state."""
self._hass = hass
self._state = state
self._collect = collect
def _collect_state(self):
if _RENDER_INFO in self._hass.data:
if self._collect and _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
# Jinja will try __getitem__ first and it avoids the need
@ -606,7 +607,7 @@ class TemplateState(State):
"""Return a property as an attribute for jinja."""
if item in _COLLECTABLE_STATE_ATTRIBUTES:
# _collect_state inlined here for performance
if _RENDER_INFO in self._hass.data:
if self._collect and _RENDER_INFO in self._hass.data:
self._hass.data[_RENDER_INFO].entities.add(self._state.entity_id)
return getattr(self._state, item)
if item == "entity_id":
@ -697,7 +698,7 @@ def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
def _state_generator(hass: HomeAssistantType, domain: Optional[str]) -> Generator:
"""State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield TemplateState(hass, state)
yield TemplateState(hass, state, collect=False)
def _get_state_if_valid(

View File

@ -1284,7 +1284,7 @@ async def test_track_template_result_iterator(hass):
assert info.listeners == {
"all": False,
"domains": {"sensor"},
"entities": {"sensor.test"},
"entities": set(),
}
hass.states.async_set("sensor.test", 6)
@ -1488,6 +1488,80 @@ async def test_track_template_rate_limit_five(hass):
assert refresh_runs == ["0", "1"]
async def test_track_template_has_default_rate_limit(hass):
"""Test template has a rate limit by default."""
hass.states.async_set("sensor.zero", "any")
template_refresh = Template("{{ states | list | count }}", hass)
refresh_runs = []
@ha.callback
def refresh_listener(event, updates):
refresh_runs.append(updates.pop().result)
info = async_track_template_result(
hass,
[TrackTemplate(template_refresh, None)],
refresh_listener,
)
await hass.async_block_till_done()
info.async_refresh()
await hass.async_block_till_done()
assert refresh_runs == ["1"]
hass.states.async_set("sensor.one", "any")
await hass.async_block_till_done()
assert refresh_runs == ["1"]
info.async_refresh()
assert refresh_runs == ["1", "2"]
hass.states.async_set("sensor.two", "any")
await hass.async_block_till_done()
assert refresh_runs == ["1", "2"]
hass.states.async_set("sensor.three", "any")
await hass.async_block_till_done()
assert refresh_runs == ["1", "2"]
async def test_track_template_unavailable_sates_has_default_rate_limit(hass):
"""Test template watching for unavailable states has a rate limit by default."""
hass.states.async_set("sensor.zero", "unknown")
template_refresh = Template(
"{{ states | selectattr('state', 'in', ['unavailable', 'unknown', 'none']) | list | count }}",
hass,
)
refresh_runs = []
@ha.callback
def refresh_listener(event, updates):
refresh_runs.append(updates.pop().result)
info = async_track_template_result(
hass,
[TrackTemplate(template_refresh, None)],
refresh_listener,
)
await hass.async_block_till_done()
info.async_refresh()
await hass.async_block_till_done()
assert refresh_runs == ["1"]
hass.states.async_set("sensor.one", "unknown")
await hass.async_block_till_done()
assert refresh_runs == ["1"]
info.async_refresh()
assert refresh_runs == ["1", "2"]
hass.states.async_set("sensor.two", "any")
await hass.async_block_till_done()
assert refresh_runs == ["1", "2"]
hass.states.async_set("sensor.three", "unknown")
await hass.async_block_till_done()
assert refresh_runs == ["1", "2"]
info.async_refresh()
await hass.async_block_till_done()
assert refresh_runs == ["1", "2", "3"]
async def test_specifically_referenced_entity_is_not_rate_limited(hass):
"""Test template rate limit of 5 seconds."""
hass.states.async_set("sensor.one", "none")

View File

@ -155,9 +155,25 @@ def test_iterating_all_states(hass):
hass.states.async_set("sensor.temperature", 10)
info = render_to_info(hass, tmpl_str)
assert_result_info(
info, "10happy", entities=["test.object", "sensor.temperature"], all_states=True
)
assert_result_info(info, "10happy", entities=[], all_states=True)
def test_iterating_all_states_unavailable(hass):
"""Test iterating all states unavailable."""
hass.states.async_set("test.object", "on")
tmpl_str = "{{ states | selectattr('state', 'in', ['unavailable', 'unknown', 'none']) | list | count }}"
info = render_to_info(hass, tmpl_str)
assert info.all_states is True
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
hass.states.async_set("test.object", "unknown")
hass.states.async_set("sensor.temperature", 10)
info = render_to_info(hass, tmpl_str)
assert_result_info(info, "1", entities=[], all_states=True)
def test_iterating_domain_states(hass):
@ -176,7 +192,7 @@ def test_iterating_domain_states(hass):
assert_result_info(
info,
"open10",
entities=["sensor.back_door", "sensor.temperature"],
entities=[],
domains=["sensor"],
)
@ -1426,9 +1442,7 @@ async def test_expand(hass):
info = render_to_info(
hass, "{{ expand(states.group) | map(attribute='entity_id') | join(', ') }}"
)
assert_result_info(
info, "test.object", {"test.object", "group.new_group"}, ["group"]
)
assert_result_info(info, "test.object", {"test.object"}, ["group"])
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
info = render_to_info(
@ -1587,7 +1601,7 @@ async def test_async_render_to_info_with_wildcard_matching_entity_id(hass):
"""Test tracking template with a wildcard."""
template_complex_str = r"""
{% for state in states %}
{% for state in states.cover %}
{% if state.entity_id | regex_match('.*\.office_') %}
{{ state.entity_id }}={{ state.state }}
{% endif %}
@ -1599,13 +1613,9 @@ async def test_async_render_to_info_with_wildcard_matching_entity_id(hass):
hass.states.async_set("cover.office_skylight", "open")
info = render_to_info(hass, template_complex_str)
assert not info.domains
assert info.entities == {
"cover.office_drapes",
"cover.office_window",
"cover.office_skylight",
}
assert info.all_states is True
assert info.domains == {"cover"}
assert info.entities == set()
assert info.all_states is False
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
@ -1629,13 +1639,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass):
info = render_to_info(hass, template_complex_str)
assert not info.domains
assert info.entities == {
"cover.x_skylight",
"binary_sensor.door",
"cover.office_drapes",
"cover.office_window",
"cover.office_skylight",
}
assert info.entities == set()
assert info.all_states is True
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
@ -1643,13 +1647,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass):
info = render_to_info(hass, template_complex_str)
assert not info.domains
assert info.entities == {
"cover.x_skylight",
"binary_sensor.door",
"cover.office_drapes",
"cover.office_window",
"cover.office_skylight",
}
assert info.entities == set()
assert info.all_states is True
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
@ -1666,12 +1664,7 @@ async def test_async_render_to_info_with_wildcard_matching_state(hass):
info = render_to_info(hass, template_cover_str)
assert info.domains == {"cover"}
assert info.entities == {
"cover.x_skylight",
"cover.office_drapes",
"cover.office_window",
"cover.office_skylight",
}
assert info.entities == set()
assert info.all_states is False
assert info.rate_limit == template.DEFAULT_RATE_LIMIT
@ -1965,9 +1958,7 @@ def test_generate_filter_iterators(hass):
{% endfor %}
""",
)
assert_result_info(
info, "sensor.test_sensor=off,", ["sensor.test_sensor"], ["sensor"]
)
assert_result_info(info, "sensor.test_sensor=off,", [], ["sensor"])
info = render_to_info(
hass,
@ -1977,9 +1968,7 @@ def test_generate_filter_iterators(hass):
{% endfor %}
""",
)
assert_result_info(
info, "sensor.test_sensor=value,", ["sensor.test_sensor"], ["sensor"]
)
assert_result_info(info, "sensor.test_sensor=value,", [], ["sensor"])
def test_generate_select(hass):
@ -2001,7 +1990,7 @@ def test_generate_select(hass):
assert_result_info(
info,
"sensor.test_sensor",
["sensor.test_sensor", "sensor.test_sensor_on"],
[],
["sensor"],
)
assert info.domains_lifecycle == {"sensor"}
@ -2542,7 +2531,9 @@ async def test_lights(hass):
tmp = template.Template(tmpl, hass)
info = tmp.async_render_to_info()
assert info.entities == set(states)
assert info.entities == set()
assert info.domains == {"light"}
assert "lights are on" in info.result()
for i in range(10):
assert f"sensor{i}" in info.result()