Add async_track_state_added_domain for tracking when states are added to a domain (#38776)

* Fire event_state_added when a state is added after start

* async_track_state_added_domain

* test

* naming

* coverage
pull/38812/head
J. Nick Koston 2020-08-12 13:30:40 -05:00 committed by GitHub
parent 716fa63e73
commit 45526f4e8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 12 deletions

View File

@ -17,7 +17,14 @@ from homeassistant.const import (
SUN_EVENT_SUNRISE,
SUN_EVENT_SUNSET,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HomeAssistant,
State,
callback,
split_entity_id,
)
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.helpers.template import Template
@ -28,6 +35,9 @@ from homeassistant.util.async_ import run_callback_threadsafe
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
TRACK_STATE_ADDED_DOMAIN_CALLBACKS = "track_state_added_domain_callbacks"
TRACK_STATE_ADDED_DOMAIN_LISTENER = "track_state_added_domain_listener"
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
@ -191,7 +201,7 @@ def async_track_state_change_event(
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER,
@ -203,23 +213,23 @@ def async_track_state_change_event(
@callback
def _async_remove_entity_listeners(
def _async_remove_indexed_listeners(
hass: HomeAssistant,
storage_key: str,
data_key: str,
listener_key: str,
entity_ids: Iterable[str],
storage_keys: Iterable[str],
action: Callable[[Event], Any],
) -> None:
"""Remove a listener."""
entity_callbacks = hass.data[storage_key]
callbacks = hass.data[data_key]
for entity_id in entity_ids:
entity_callbacks[entity_id].remove(action)
if len(entity_callbacks[entity_id]) == 0:
del entity_callbacks[entity_id]
for storage_key in storage_keys:
callbacks[storage_key].remove(action)
if len(callbacks[storage_key]) == 0:
del callbacks[storage_key]
if not entity_callbacks:
if not callbacks:
hass.data[listener_key]()
del hass.data[listener_key]
@ -271,7 +281,7 @@ def async_track_entity_registry_updated_event(
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_entity_listeners(
_async_remove_indexed_listeners(
hass,
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
@ -282,6 +292,63 @@ def async_track_entity_registry_updated_event(
return remove_listener
@bind_hass
def async_track_state_added_domain(
hass: HomeAssistant,
domains: Union[str, Iterable[str]],
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is added to domains."""
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
@callback
def _async_state_change_dispatcher(event: Event) -> None:
"""Dispatch state changes by entity_id."""
if event.data.get("old_state") is not None:
return
domain = split_entity_id(event.data["entity_id"])[0]
if domain not in domain_callbacks:
return
for action in domain_callbacks[domain][:]:
try:
hass.async_run_job(action, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing state added for %s", domain
)
hass.data[TRACK_STATE_ADDED_DOMAIN_LISTENER] = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)
if isinstance(domains, str):
domains = [domains]
domains = [domains.lower() for domains in domains]
for domain in domains:
domain_callbacks.setdefault(domain, []).append(action)
@callback
def remove_listener() -> None:
"""Remove state change listener."""
_async_remove_indexed_listeners(
hass,
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
TRACK_STATE_ADDED_DOMAIN_LISTENER,
domains,
action,
)
return remove_listener
@callback
@bind_hass
def async_track_template(

View File

@ -16,6 +16,7 @@ from homeassistant.helpers.event import (
async_track_point_in_time,
async_track_point_in_utc_time,
async_track_same_state,
async_track_state_added_domain,
async_track_state_change,
async_track_state_change_event,
async_track_sunrise,
@ -341,6 +342,88 @@ async def test_async_track_state_change_event(hass):
unsub_throws()
async def test_async_track_state_added_domain(hass):
"""Test async_track_state_added_domain."""
single_entity_id_tracker = []
multiple_entity_id_tracker = []
@ha.callback
def single_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
single_entity_id_tracker.append((old_state, new_state))
@ha.callback
def multiple_run_callback(event):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
multiple_entity_id_tracker.append((old_state, new_state))
@ha.callback
def callback_that_throws(event):
raise ValueError
unsub_single = async_track_state_added_domain(hass, "light", single_run_callback)
unsub_multi = async_track_state_added_domain(
hass, ["light", "switch"], multiple_run_callback
)
unsub_throws = async_track_state_added_domain(
hass, ["light", "switch"], callback_that_throws
)
# Adding state to state machine
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert single_entity_id_tracker[-1][0] is None
assert single_entity_id_tracker[-1][1] is not None
assert len(multiple_entity_id_tracker) == 1
assert multiple_entity_id_tracker[-1][0] is None
assert multiple_entity_id_tracker[-1][1] is not None
# Set same state should not trigger a state change/listener
hass.states.async_set("light.Bowl", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1
# State change off -> on - nothing added so no trigger
hass.states.async_set("light.Bowl", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1
# State change off -> off - nothing added so no trigger
hass.states.async_set("light.Bowl", "off", {"some_attr": 1})
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1
# Removing state does not trigger
hass.states.async_remove("light.bowl")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 1
# Set state for different entity id
hass.states.async_set("switch.kitchen", "on")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 2
unsub_single()
# Ensure unsubing the listener works
hass.states.async_set("light.new", "off")
await hass.async_block_till_done()
assert len(single_entity_id_tracker) == 1
assert len(multiple_entity_id_tracker) == 3
unsub_multi()
unsub_throws()
async def test_track_template(hass):
"""Test tracking template."""
specific_runs = []