Fix state_automation_listener when new state is None (#32985)

* Fix state_automation_listener when new state is None (fix #32984)

* Listen to EVENT_STATE_CHANGED instead of using async_track_state_change

and use the event context on automation trigger.

* Share `process_state_match` with helpers/event

* Add test for state change automation on entity removal
pull/33194/head
Eugenio Panadero 2020-03-24 00:05:21 +01:00 committed by GitHub
parent c2a9aba467
commit cd57b764ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 15 deletions

View File

@ -6,10 +6,14 @@ from typing import Dict
import voluptuous as vol
from homeassistant import exceptions
from homeassistant.const import CONF_FOR, CONF_PLATFORM, MATCH_ALL
from homeassistant.const import CONF_FOR, CONF_PLATFORM, EVENT_STATE_CHANGED, MATCH_ALL
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.event import async_track_same_state, async_track_state_change
from homeassistant.helpers.event import (
Event,
async_track_same_state,
process_state_match,
)
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs
@ -56,10 +60,30 @@ async def async_attach_trigger(
match_all = from_state == MATCH_ALL and to_state == MATCH_ALL
unsub_track_same = {}
period: Dict[str, timedelta] = {}
match_from_state = process_state_match(from_state)
match_to_state = process_state_match(to_state)
@callback
def state_automation_listener(entity, from_s, to_s):
def state_automation_listener(event: Event):
"""Listen for state changes and calls action."""
entity: str = event.data["entity_id"]
if entity not in entity_id:
return
from_s = event.data.get("old_state")
to_s = event.data.get("new_state")
if (
(from_s is not None and not match_from_state(from_s.state))
or (to_s is not None and not match_to_state(to_s.state))
or (
not match_all
and from_s is not None
and to_s is not None
and from_s.state == to_s.state
)
):
return
@callback
def call_action():
@ -75,7 +99,7 @@ async def async_attach_trigger(
"for": time_delta if not time_delta else period[entity],
}
},
context=to_s.context,
context=event.context,
)
)
@ -120,17 +144,16 @@ async def async_attach_trigger(
)
return
def _check_same_state(_, _2, new_st):
if new_st is None:
return False
return new_st.state == to_s.state
unsub_track_same[entity] = async_track_same_state(
hass,
period[entity],
call_action,
lambda _, _2, to_state: to_state.state == to_s.state,
entity_ids=entity,
hass, period[entity], call_action, _check_same_state, entity_ids=entity,
)
unsub = async_track_state_change(
hass, entity_id, state_automation_listener, from_state, to_state
)
unsub = hass.bus.async_listen(EVENT_STATE_CHANGED, state_automation_listener)
@callback
def async_remove():

View File

@ -67,8 +67,8 @@ def async_track_state_change(
Must be run within the event loop.
"""
match_from_state = _process_state_match(from_state)
match_to_state = _process_state_match(to_state)
match_from_state = process_state_match(from_state)
match_to_state = process_state_match(to_state)
# Ensure it is a lowercase list with entity ids we want to match on
if entity_ids == MATCH_ALL:
@ -473,7 +473,7 @@ def async_track_time_change(
track_time_change = threaded_listener_factory(async_track_time_change)
def _process_state_match(
def process_state_match(
parameter: Union[None, str, Iterable[str]]
) -> Callable[[str], bool]:
"""Convert parameter to function that matches input against parameter."""

View File

@ -519,6 +519,28 @@ async def test_if_fires_on_entity_change_with_for(hass, calls):
assert 1 == len(calls)
async def test_if_fires_on_entity_removal(hass, calls):
"""Test for firing on entity removal, when new_state is None."""
hass.states.async_set("test.entity", "hello")
await hass.async_block_till_done()
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {"platform": "state", "entity_id": "test.entity"},
"action": {"service": "test.automation"},
}
},
)
await hass.async_block_till_done()
assert hass.states.async_remove("test.entity")
await hass.async_block_till_done()
assert 1 == len(calls)
async def test_if_fires_on_for_condition(hass, calls):
"""Test for firing if condition is on."""
point1 = dt_util.utcnow()