Add async_track_state_added_domain for tracking when states are added to a domain (#38776)
* Fire event_state_added when a state is added after start * async_track_state_added_domain * test * naming * coveragepull/38812/head
parent
716fa63e73
commit
45526f4e8a
|
@ -17,7 +17,14 @@ from homeassistant.const import (
|
|||
SUN_EVENT_SUNRISE,
|
||||
SUN_EVENT_SUNSET,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
State,
|
||||
callback,
|
||||
split_entity_id,
|
||||
)
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.sun import get_astral_event_next
|
||||
from homeassistant.helpers.template import Template
|
||||
|
@ -28,6 +35,9 @@ from homeassistant.util.async_ import run_callback_threadsafe
|
|||
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
||||
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
||||
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"
|
||||
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||
|
||||
|
@ -191,7 +201,7 @@ def async_track_state_change_event(
|
|||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_entity_listeners(
|
||||
_async_remove_indexed_listeners(
|
||||
hass,
|
||||
TRACK_STATE_CHANGE_CALLBACKS,
|
||||
TRACK_STATE_CHANGE_LISTENER,
|
||||
|
@ -203,23 +213,23 @@ def async_track_state_change_event(
|
|||
|
||||
|
||||
@callback
|
||||
def _async_remove_entity_listeners(
|
||||
def _async_remove_indexed_listeners(
|
||||
hass: HomeAssistant,
|
||||
storage_key: str,
|
||||
data_key: str,
|
||||
listener_key: str,
|
||||
entity_ids: Iterable[str],
|
||||
storage_keys: Iterable[str],
|
||||
action: Callable[[Event], Any],
|
||||
) -> None:
|
||||
"""Remove a listener."""
|
||||
|
||||
entity_callbacks = hass.data[storage_key]
|
||||
callbacks = hass.data[data_key]
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks[entity_id].remove(action)
|
||||
if len(entity_callbacks[entity_id]) == 0:
|
||||
del entity_callbacks[entity_id]
|
||||
for storage_key in storage_keys:
|
||||
callbacks[storage_key].remove(action)
|
||||
if len(callbacks[storage_key]) == 0:
|
||||
del callbacks[storage_key]
|
||||
|
||||
if not entity_callbacks:
|
||||
if not callbacks:
|
||||
hass.data[listener_key]()
|
||||
del hass.data[listener_key]
|
||||
|
||||
|
@ -271,7 +281,7 @@ def async_track_entity_registry_updated_event(
|
|||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_entity_listeners(
|
||||
_async_remove_indexed_listeners(
|
||||
hass,
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
|
||||
|
@ -282,6 +292,63 @@ def async_track_entity_registry_updated_event(
|
|||
return remove_listener
|
||||
|
||||
|
||||
@bind_hass
|
||||
def async_track_state_added_domain(
|
||||
hass: HomeAssistant,
|
||||
domains: Union[str, Iterable[str]],
|
||||
action: Callable[[Event], Any],
|
||||
) -> Callable[[], None]:
|
||||
"""Track state change events when an entity is added to domains."""
|
||||
|
||||
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})
|
||||
|
||||
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
|
||||
|
||||
@callback
|
||||
def _async_state_change_dispatcher(event: Event) -> None:
|
||||
"""Dispatch state changes by entity_id."""
|
||||
if event.data.get("old_state") is not None:
|
||||
return
|
||||
|
||||
domain = split_entity_id(event.data["entity_id"])[0]
|
||||
|
||||
if domain not in domain_callbacks:
|
||||
return
|
||||
|
||||
for action in domain_callbacks[domain][:]:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing state added for %s", domain
|
||||
)
|
||||
|
||||
hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
|
||||
EVENT_STATE_CHANGED, _async_state_change_dispatcher
|
||||
)
|
||||
|
||||
if isinstance(domains, str):
|
||||
domains = [domains]
|
||||
|
||||
domains = [domains.lower() for domains in domains]
|
||||
|
||||
for domain in domains:
|
||||
domain_callbacks.setdefault(domain, []).append(action)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_indexed_listeners(
|
||||
hass,
|
||||
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
|
||||
TRACK_STATE_ADDED_DOMAIN_LISTENER,
|
||||
domains,
|
||||
action,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_template(
|
||||
|
|
|
@ -16,6 +16,7 @@ from homeassistant.helpers.event import (
|
|||
async_track_point_in_time,
|
||||
async_track_point_in_utc_time,
|
||||
async_track_same_state,
|
||||
async_track_state_added_domain,
|
||||
async_track_state_change,
|
||||
async_track_state_change_event,
|
||||
async_track_sunrise,
|
||||
|
@ -341,6 +342,88 @@ async def test_async_track_state_change_event(hass):
|
|||
unsub_throws()
|
||||
|
||||
|
||||
async def test_async_track_state_added_domain(hass):
|
||||
"""Test async_track_state_added_domain."""
|
||||
single_entity_id_tracker = []
|
||||
multiple_entity_id_tracker = []
|
||||
|
||||
@ha.callback
|
||||
def single_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
single_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def multiple_run_callback(event):
|
||||
old_state = event.data.get("old_state")
|
||||
new_state = event.data.get("new_state")
|
||||
|
||||
multiple_entity_id_tracker.append((old_state, new_state))
|
||||
|
||||
@ha.callback
|
||||
def callback_that_throws(event):
|
||||
raise ValueError
|
||||
|
||||
unsub_single = async_track_state_added_domain(hass, "light", single_run_callback)
|
||||
unsub_multi = async_track_state_added_domain(
|
||||
hass, ["light", "switch"], multiple_run_callback
|
||||
)
|
||||
unsub_throws = async_track_state_added_domain(
|
||||
hass, ["light", "switch"], callback_that_throws
|
||||
)
|
||||
|
||||
# Adding state to state machine
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert single_entity_id_tracker[-1][0] is None
|
||||
assert single_entity_id_tracker[-1][1] is not None
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
assert multiple_entity_id_tracker[-1][0] is None
|
||||
assert multiple_entity_id_tracker[-1][1] is not None
|
||||
|
||||
# Set same state should not trigger a state change/listener
|
||||
hass.states.async_set("light.Bowl", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
|
||||
# State change off -> on - nothing added so no trigger
|
||||
hass.states.async_set("light.Bowl", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
|
||||
# State change off -> off - nothing added so no trigger
|
||||
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
|
||||
# Removing state does not trigger
|
||||
hass.states.async_remove("light.bowl")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 1
|
||||
|
||||
# Set state for different entity id
|
||||
hass.states.async_set("switch.kitchen", "on")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 2
|
||||
|
||||
unsub_single()
|
||||
# Ensure unsubing the listener works
|
||||
hass.states.async_set("light.new", "off")
|
||||
await hass.async_block_till_done()
|
||||
assert len(single_entity_id_tracker) == 1
|
||||
assert len(multiple_entity_id_tracker) == 3
|
||||
|
||||
unsub_multi()
|
||||
unsub_throws()
|
||||
|
||||
|
||||
async def test_track_template(hass):
|
||||
"""Test tracking template."""
|
||||
specific_runs = []
|
||||
|
|
Loading…
Reference in New Issue