362 lines
13 KiB
Python
362 lines
13 KiB
Python
"""Helpers for dealing with entity targets."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
import dataclasses
|
|
import logging
|
|
from logging import Logger
|
|
from typing import Any, TypeGuard
|
|
|
|
from homeassistant.const import (
|
|
ATTR_AREA_ID,
|
|
ATTR_DEVICE_ID,
|
|
ATTR_ENTITY_ID,
|
|
ATTR_FLOOR_ID,
|
|
ATTR_LABEL_ID,
|
|
ENTITY_MATCH_NONE,
|
|
)
|
|
from homeassistant.core import (
|
|
CALLBACK_TYPE,
|
|
Event,
|
|
EventStateChangedData,
|
|
HomeAssistant,
|
|
callback,
|
|
)
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
|
|
from . import (
|
|
area_registry as ar,
|
|
config_validation as cv,
|
|
device_registry as dr,
|
|
entity_registry as er,
|
|
floor_registry as fr,
|
|
group,
|
|
label_registry as lr,
|
|
)
|
|
from .event import async_track_state_change_event
|
|
from .typing import ConfigType
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass(slots=True, frozen=True)
|
|
class TargetStateChangedData:
|
|
"""Data for state change events related to targets."""
|
|
|
|
state_change_event: Event[EventStateChangedData]
|
|
targeted_entity_ids: set[str]
|
|
|
|
|
|
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
|
|
"""Check if ids can match anything."""
|
|
return ids not in (None, ENTITY_MATCH_NONE)
|
|
|
|
|
|
class TargetSelectorData:
|
|
"""Class to hold data of target selector."""
|
|
|
|
__slots__ = ("area_ids", "device_ids", "entity_ids", "floor_ids", "label_ids")
|
|
|
|
def __init__(self, config: ConfigType) -> None:
|
|
"""Extract ids from the config."""
|
|
entity_ids: str | list | None = config.get(ATTR_ENTITY_ID)
|
|
device_ids: str | list | None = config.get(ATTR_DEVICE_ID)
|
|
area_ids: str | list | None = config.get(ATTR_AREA_ID)
|
|
floor_ids: str | list | None = config.get(ATTR_FLOOR_ID)
|
|
label_ids: str | list | None = config.get(ATTR_LABEL_ID)
|
|
|
|
self.entity_ids = (
|
|
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
|
|
)
|
|
self.device_ids = (
|
|
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
|
|
)
|
|
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
|
|
self.floor_ids = (
|
|
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
|
|
)
|
|
self.label_ids = (
|
|
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
|
|
)
|
|
|
|
@property
|
|
def has_any_selector(self) -> bool:
|
|
"""Determine if any selectors are present."""
|
|
return bool(
|
|
self.entity_ids
|
|
or self.device_ids
|
|
or self.area_ids
|
|
or self.floor_ids
|
|
or self.label_ids
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(slots=True)
|
|
class SelectedEntities:
|
|
"""Class to hold the selected entities."""
|
|
|
|
# Entity IDs of entities that were explicitly mentioned.
|
|
referenced: set[str] = dataclasses.field(default_factory=set)
|
|
|
|
# Entity IDs of entities that were referenced via device/area/floor/label ID.
|
|
# Should not trigger a warning when they don't exist.
|
|
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)
|
|
|
|
# Referenced items that could not be found.
|
|
missing_devices: set[str] = dataclasses.field(default_factory=set)
|
|
missing_areas: set[str] = dataclasses.field(default_factory=set)
|
|
missing_floors: set[str] = dataclasses.field(default_factory=set)
|
|
missing_labels: set[str] = dataclasses.field(default_factory=set)
|
|
|
|
referenced_devices: set[str] = dataclasses.field(default_factory=set)
|
|
referenced_areas: set[str] = dataclasses.field(default_factory=set)
|
|
|
|
def log_missing(self, missing_entities: set[str], logger: Logger) -> None:
|
|
"""Log about missing items."""
|
|
parts = []
|
|
for label, items in (
|
|
("floors", self.missing_floors),
|
|
("areas", self.missing_areas),
|
|
("devices", self.missing_devices),
|
|
("entities", missing_entities),
|
|
("labels", self.missing_labels),
|
|
):
|
|
if items:
|
|
parts.append(f"{label} {', '.join(sorted(items))}")
|
|
|
|
if not parts:
|
|
return
|
|
|
|
logger.warning(
|
|
"Referenced %s are missing or not currently available",
|
|
", ".join(parts),
|
|
)
|
|
|
|
|
|
def async_extract_referenced_entity_ids(
|
|
hass: HomeAssistant, selector_data: TargetSelectorData, expand_group: bool = True
|
|
) -> SelectedEntities:
|
|
"""Extract referenced entity IDs from a target selector."""
|
|
selected = SelectedEntities()
|
|
|
|
if not selector_data.has_any_selector:
|
|
return selected
|
|
|
|
entity_ids: set[str] | list[str] = selector_data.entity_ids
|
|
if expand_group:
|
|
entity_ids = group.expand_entity_ids(hass, entity_ids)
|
|
|
|
selected.referenced.update(entity_ids)
|
|
|
|
if (
|
|
not selector_data.device_ids
|
|
and not selector_data.area_ids
|
|
and not selector_data.floor_ids
|
|
and not selector_data.label_ids
|
|
):
|
|
return selected
|
|
|
|
entities = er.async_get(hass).entities
|
|
dev_reg = dr.async_get(hass)
|
|
area_reg = ar.async_get(hass)
|
|
|
|
if selector_data.floor_ids:
|
|
floor_reg = fr.async_get(hass)
|
|
for floor_id in selector_data.floor_ids:
|
|
if floor_id not in floor_reg.floors:
|
|
selected.missing_floors.add(floor_id)
|
|
|
|
for area_id in selector_data.area_ids:
|
|
if area_id not in area_reg.areas:
|
|
selected.missing_areas.add(area_id)
|
|
|
|
for device_id in selector_data.device_ids:
|
|
if device_id not in dev_reg.devices:
|
|
selected.missing_devices.add(device_id)
|
|
|
|
if selector_data.label_ids:
|
|
label_reg = lr.async_get(hass)
|
|
for label_id in selector_data.label_ids:
|
|
if label_id not in label_reg.labels:
|
|
selected.missing_labels.add(label_id)
|
|
|
|
for entity_entry in entities.get_entries_for_label(label_id):
|
|
if entity_entry.hidden_by is None:
|
|
selected.indirectly_referenced.add(entity_entry.entity_id)
|
|
|
|
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
|
|
selected.referenced_devices.add(device_entry.id)
|
|
|
|
for area_entry in area_reg.areas.get_areas_for_label(label_id):
|
|
selected.referenced_areas.add(area_entry.id)
|
|
|
|
# Find areas for targeted floors
|
|
if selector_data.floor_ids:
|
|
selected.referenced_areas.update(
|
|
area_entry.id
|
|
for floor_id in selector_data.floor_ids
|
|
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
|
|
)
|
|
|
|
selected.referenced_areas.update(selector_data.area_ids)
|
|
selected.referenced_devices.update(selector_data.device_ids)
|
|
|
|
if not selected.referenced_areas and not selected.referenced_devices:
|
|
return selected
|
|
|
|
# Add indirectly referenced by device
|
|
selected.indirectly_referenced.update(
|
|
entry.entity_id
|
|
for device_id in selected.referenced_devices
|
|
for entry in entities.get_entries_for_device_id(device_id)
|
|
# Do not add entities which are hidden or which are config
|
|
# or diagnostic entities.
|
|
if (entry.entity_category is None and entry.hidden_by is None)
|
|
)
|
|
|
|
# Find devices for targeted areas
|
|
referenced_devices_by_area: set[str] = set()
|
|
if selected.referenced_areas:
|
|
for area_id in selected.referenced_areas:
|
|
referenced_devices_by_area.update(
|
|
device_entry.id
|
|
for device_entry in dev_reg.devices.get_devices_for_area_id(area_id)
|
|
)
|
|
selected.referenced_devices.update(referenced_devices_by_area)
|
|
|
|
# Add indirectly referenced by area
|
|
selected.indirectly_referenced.update(
|
|
entry.entity_id
|
|
for area_id in selected.referenced_areas
|
|
# The entity's area matches a targeted area
|
|
for entry in entities.get_entries_for_area_id(area_id)
|
|
# Do not add entities which are hidden or which are config
|
|
# or diagnostic entities.
|
|
if entry.entity_category is None and entry.hidden_by is None
|
|
)
|
|
# Add indirectly referenced by area through device
|
|
selected.indirectly_referenced.update(
|
|
entry.entity_id
|
|
for device_id in referenced_devices_by_area
|
|
for entry in entities.get_entries_for_device_id(device_id)
|
|
# Do not add entities which are hidden or which are config
|
|
# or diagnostic entities.
|
|
if (
|
|
entry.entity_category is None
|
|
and entry.hidden_by is None
|
|
and (
|
|
# The entity's device matches a device referenced
|
|
# by an area and the entity
|
|
# has no explicitly set area
|
|
not entry.area_id
|
|
)
|
|
)
|
|
)
|
|
|
|
return selected
|
|
|
|
|
|
class TargetStateChangeTracker:
|
|
"""Helper class to manage state change tracking for targets."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
selector_data: TargetSelectorData,
|
|
action: Callable[[TargetStateChangedData], Any],
|
|
entity_filter: Callable[[set[str]], set[str]],
|
|
) -> None:
|
|
"""Initialize the state change tracker."""
|
|
self._hass = hass
|
|
self._selector_data = selector_data
|
|
self._action = action
|
|
self._entity_filter = entity_filter
|
|
|
|
self._state_change_unsub: CALLBACK_TYPE | None = None
|
|
self._registry_unsubs: list[CALLBACK_TYPE] = []
|
|
|
|
def async_setup(self) -> Callable[[], None]:
|
|
"""Set up the state change tracking."""
|
|
self._setup_registry_listeners()
|
|
self._track_entities_state_change()
|
|
return self._unsubscribe
|
|
|
|
def _track_entities_state_change(self) -> None:
|
|
"""Set up state change tracking for currently selected entities."""
|
|
selected = async_extract_referenced_entity_ids(
|
|
self._hass, self._selector_data, expand_group=False
|
|
)
|
|
|
|
tracked_entities = self._entity_filter(
|
|
selected.referenced.union(selected.indirectly_referenced)
|
|
)
|
|
|
|
@callback
|
|
def state_change_listener(event: Event[EventStateChangedData]) -> None:
|
|
"""Handle state change events."""
|
|
if (
|
|
event.data["entity_id"] in selected.referenced
|
|
or event.data["entity_id"] in selected.indirectly_referenced
|
|
):
|
|
self._action(TargetStateChangedData(event, tracked_entities))
|
|
|
|
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
|
|
self._state_change_unsub = async_track_state_change_event(
|
|
self._hass, tracked_entities, state_change_listener
|
|
)
|
|
|
|
def _setup_registry_listeners(self) -> None:
|
|
"""Set up listeners for registry changes that require resubscription."""
|
|
|
|
@callback
|
|
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
|
|
"""Resubscribe to state change events when registry changes."""
|
|
if self._state_change_unsub:
|
|
self._state_change_unsub()
|
|
self._track_entities_state_change()
|
|
|
|
# Subscribe to registry updates that can change the entities to track:
|
|
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
|
|
# - Device registry: device labels changed; device area changed.
|
|
# - Area registry: area floor changed.
|
|
#
|
|
# We don't track other registries (like floor or label registries) because their
|
|
# changes don't affect which entities are tracked.
|
|
self._registry_unsubs = [
|
|
self._hass.bus.async_listen(
|
|
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
|
|
),
|
|
self._hass.bus.async_listen(
|
|
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
|
|
),
|
|
self._hass.bus.async_listen(
|
|
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
|
|
),
|
|
]
|
|
|
|
def _unsubscribe(self) -> None:
|
|
"""Unsubscribe from all events."""
|
|
for registry_unsub in self._registry_unsubs:
|
|
registry_unsub()
|
|
self._registry_unsubs.clear()
|
|
if self._state_change_unsub:
|
|
self._state_change_unsub()
|
|
self._state_change_unsub = None
|
|
|
|
|
|
def async_track_target_selector_state_change_event(
|
|
hass: HomeAssistant,
|
|
target_selector_config: ConfigType,
|
|
action: Callable[[TargetStateChangedData], Any],
|
|
entity_filter: Callable[[set[str]], set[str]] = lambda x: x,
|
|
) -> CALLBACK_TYPE:
|
|
"""Track state changes for entities referenced directly or indirectly in a target selector."""
|
|
selector_data = TargetSelectorData(target_selector_config)
|
|
if not selector_data.has_any_selector:
|
|
raise HomeAssistantError(
|
|
f"Target selector {target_selector_config} does not have any selectors defined"
|
|
)
|
|
tracker = TargetStateChangeTracker(hass, selector_data, action, entity_filter)
|
|
return tracker.async_setup()
|