Fix state_automation_listener when new state is None (#32985)
* Fix state_automation_listener when new state is None (fix #32984) * Listen to EVENT_STATE_CHANGED instead of using async_track_state_change and use the event context on automation trigger. * Share `process_state_match` with helpers/event * Add test for state change automation on entity removalpull/33194/head
parent
c2a9aba467
commit
cd57b764ce
|
@ -6,10 +6,14 @@ from typing import Dict
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.const import CONF_FOR, CONF_PLATFORM, MATCH_ALL
|
||||
from homeassistant.const import CONF_FOR, CONF_PLATFORM, EVENT_STATE_CHANGED, MATCH_ALL
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
from homeassistant.helpers.event import async_track_same_state, async_track_state_change
|
||||
from homeassistant.helpers.event import (
|
||||
Event,
|
||||
async_track_same_state,
|
||||
process_state_match,
|
||||
)
|
||||
|
||||
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs
|
||||
|
@ -56,10 +60,30 @@ async def async_attach_trigger(
|
|||
match_all = from_state == MATCH_ALL and to_state == MATCH_ALL
|
||||
unsub_track_same = {}
|
||||
period: Dict[str, timedelta] = {}
|
||||
match_from_state = process_state_match(from_state)
|
||||
match_to_state = process_state_match(to_state)
|
||||
|
||||
@callback
|
||||
def state_automation_listener(entity, from_s, to_s):
|
||||
def state_automation_listener(event: Event):
|
||||
"""Listen for state changes and calls action."""
|
||||
entity: str = event.data["entity_id"]
|
||||
if entity not in entity_id:
|
||||
return
|
||||
|
||||
from_s = event.data.get("old_state")
|
||||
to_s = event.data.get("new_state")
|
||||
|
||||
if (
|
||||
(from_s is not None and not match_from_state(from_s.state))
|
||||
or (to_s is not None and not match_to_state(to_s.state))
|
||||
or (
|
||||
not match_all
|
||||
and from_s is not None
|
||||
and to_s is not None
|
||||
and from_s.state == to_s.state
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
@callback
|
||||
def call_action():
|
||||
|
@ -75,7 +99,7 @@ async def async_attach_trigger(
|
|||
"for": time_delta if not time_delta else period[entity],
|
||||
}
|
||||
},
|
||||
context=to_s.context,
|
||||
context=event.context,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -120,17 +144,16 @@ async def async_attach_trigger(
|
|||
)
|
||||
return
|
||||
|
||||
def _check_same_state(_, _2, new_st):
|
||||
if new_st is None:
|
||||
return False
|
||||
return new_st.state == to_s.state
|
||||
|
||||
unsub_track_same[entity] = async_track_same_state(
|
||||
hass,
|
||||
period[entity],
|
||||
call_action,
|
||||
lambda _, _2, to_state: to_state.state == to_s.state,
|
||||
entity_ids=entity,
|
||||
hass, period[entity], call_action, _check_same_state, entity_ids=entity,
|
||||
)
|
||||
|
||||
unsub = async_track_state_change(
|
||||
hass, entity_id, state_automation_listener, from_state, to_state
|
||||
)
|
||||
unsub = hass.bus.async_listen(EVENT_STATE_CHANGED, state_automation_listener)
|
||||
|
||||
@callback
|
||||
def async_remove():
|
||||
|
|
|
@ -67,8 +67,8 @@ def async_track_state_change(
|
|||
|
||||
Must be run within the event loop.
|
||||
"""
|
||||
match_from_state = _process_state_match(from_state)
|
||||
match_to_state = _process_state_match(to_state)
|
||||
match_from_state = process_state_match(from_state)
|
||||
match_to_state = process_state_match(to_state)
|
||||
|
||||
# Ensure it is a lowercase list with entity ids we want to match on
|
||||
if entity_ids == MATCH_ALL:
|
||||
|
@ -473,7 +473,7 @@ def async_track_time_change(
|
|||
track_time_change = threaded_listener_factory(async_track_time_change)
|
||||
|
||||
|
||||
def _process_state_match(
|
||||
def process_state_match(
|
||||
parameter: Union[None, str, Iterable[str]]
|
||||
) -> Callable[[str], bool]:
|
||||
"""Convert parameter to function that matches input against parameter."""
|
||||
|
|
|
@ -519,6 +519,28 @@ async def test_if_fires_on_entity_change_with_for(hass, calls):
|
|||
assert 1 == len(calls)
|
||||
|
||||
|
||||
async def test_if_fires_on_entity_removal(hass, calls):
|
||||
"""Test for firing on entity removal, when new_state is None."""
|
||||
hass.states.async_set("test.entity", "hello")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"trigger": {"platform": "state", "entity_id": "test.entity"},
|
||||
"action": {"service": "test.automation"},
|
||||
}
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.async_remove("test.entity")
|
||||
await hass.async_block_till_done()
|
||||
assert 1 == len(calls)
|
||||
|
||||
|
||||
async def test_if_fires_on_for_condition(hass, calls):
|
||||
"""Test for firing if condition is on."""
|
||||
point1 = dt_util.utcnow()
|
||||
|
|
Loading…
Reference in New Issue