Teach state trigger about entity registry ids (#60271)
* Teach state trigger about entity registry ids * Tweak * Add tests * Tweak tests * Fix tests * Resolve entity ids during config validation * Update device_triggers * Fix mistake * Tweak trigger validator to ensure we don't modify the original config * Add index from entry id to entry * Update scaffold * Pre-compile UUID regex * Address review comment * Tweak mock_registry * Tweak * Apply suggestion from code reviewpull/60835/head
parent
c0fb1bffce
commit
c85bb27d0d
homeassistant
components
alarm_control_panel
binary_sensor
button
climate
cover
device_automation
homeassistant/triggers
media_player
select
vacuum
zwave_js
script/scaffold/templates/device_trigger/integration
tests
components/homeassistant/triggers
|
@ -157,7 +157,7 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -220,7 +220,7 @@ async def async_attach_trigger(hass, config, action, automation_info):
|
|||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -11,8 +11,8 @@ from homeassistant.components.automation import (
|
|||
)
|
||||
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
||||
from homeassistant.components.homeassistant.triggers.state import (
|
||||
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
|
||||
async_attach_trigger as async_attach_state_trigger,
|
||||
async_validate_trigger_config as async_validate_state_trigger_config,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
CONF_DEVICE_ID,
|
||||
|
@ -67,7 +67,7 @@ async def async_attach_trigger(
|
|||
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
||||
}
|
||||
|
||||
state_config = STATE_TRIGGER_SCHEMA(state_config)
|
||||
state_config = await async_validate_state_trigger_config(hass, state_config)
|
||||
return await async_attach_state_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -131,7 +131,9 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(
|
||||
hass, state_config
|
||||
)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -170,7 +170,9 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(
|
||||
hass, state_config
|
||||
)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -164,7 +164,7 @@ async def async_attach_trigger(
|
|||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -3,20 +3,24 @@ from __future__ import annotations
|
|||
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
entity_registry as er,
|
||||
template,
|
||||
)
|
||||
from homeassistant.helpers.event import (
|
||||
Event,
|
||||
async_track_same_state,
|
||||
async_track_state_change_event,
|
||||
process_state_match,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs
|
||||
|
@ -30,7 +34,7 @@ CONF_TO = "to"
|
|||
BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): "state",
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_ids,
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_ids_or_uuids,
|
||||
vol.Optional(CONF_FOR): cv.positive_time_period_template,
|
||||
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
|
||||
}
|
||||
|
@ -52,17 +56,26 @@ TRIGGER_ATTRIBUTE_SCHEMA = BASE_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
|
||||
"""Validate trigger."""
|
||||
if not isinstance(value, dict):
|
||||
async def async_validate_trigger_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
"""Validate trigger config."""
|
||||
if not isinstance(config, dict):
|
||||
raise vol.Invalid("Expected a dictionary")
|
||||
|
||||
# We use this approach instead of vol.Any because
|
||||
# this gives better error messages.
|
||||
if CONF_ATTRIBUTE in value:
|
||||
return TRIGGER_ATTRIBUTE_SCHEMA(value)
|
||||
if CONF_ATTRIBUTE in config:
|
||||
config = TRIGGER_ATTRIBUTE_SCHEMA(config)
|
||||
else:
|
||||
config = TRIGGER_STATE_SCHEMA(config)
|
||||
|
||||
return TRIGGER_STATE_SCHEMA(value)
|
||||
registry = er.async_get(hass)
|
||||
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
|
||||
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
async def async_attach_trigger(
|
||||
|
@ -74,7 +87,7 @@ async def async_attach_trigger(
|
|||
platform_type: str = "state",
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Listen for state changes based on configuration."""
|
||||
entity_id = config.get(CONF_ENTITY_ID)
|
||||
entity_ids = config[CONF_ENTITY_ID]
|
||||
if (from_state := config.get(CONF_FROM)) is None:
|
||||
from_state = MATCH_ALL
|
||||
if (to_state := config.get(CONF_TO)) is None:
|
||||
|
@ -196,7 +209,7 @@ async def async_attach_trigger(
|
|||
entity_ids=entity,
|
||||
)
|
||||
|
||||
unsub = async_track_state_change_event(hass, entity_id, state_automation_listener)
|
||||
unsub = async_track_state_change_event(hass, entity_ids, state_automation_listener)
|
||||
|
||||
@callback
|
||||
def async_remove():
|
||||
|
|
|
@ -104,7 +104,7 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -104,7 +104,7 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -14,8 +14,8 @@ from homeassistant.components.homeassistant.triggers.state import (
|
|||
CONF_FOR,
|
||||
CONF_FROM,
|
||||
CONF_TO,
|
||||
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
|
||||
async_attach_trigger as async_attach_state_trigger,
|
||||
async_validate_trigger_config as async_validate_state_trigger_config,
|
||||
)
|
||||
from homeassistant.components.select.const import ATTR_OPTIONS
|
||||
from homeassistant.const import (
|
||||
|
@ -84,7 +84,7 @@ async def async_attach_trigger(
|
|||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
|
||||
state_config = STATE_TRIGGER_SCHEMA(state_config)
|
||||
state_config = await async_validate_state_trigger_config(hass, state_config)
|
||||
return await async_attach_state_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -92,7 +92,7 @@ async def async_attach_trigger(
|
|||
}
|
||||
if CONF_FOR in config:
|
||||
state_config[CONF_FOR] = config[CONF_FOR]
|
||||
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -415,7 +415,7 @@ async def async_attach_trigger(
|
|||
else:
|
||||
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")
|
||||
|
||||
state_config = state.TRIGGER_SCHEMA(state_config)
|
||||
state_config = await state.async_validate_trigger_config(hass, state_config)
|
||||
return await state.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Hashable
|
||||
import contextlib
|
||||
from datetime import (
|
||||
date as date_sys,
|
||||
datetime as datetime_sys,
|
||||
|
@ -262,14 +263,34 @@ def entity_id(value: Any) -> str:
|
|||
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
|
||||
|
||||
|
||||
def entity_ids(value: str | list) -> list[str]:
|
||||
"""Validate Entity IDs."""
|
||||
def entity_id_or_uuid(value: Any) -> str:
|
||||
"""Validate Entity specified by entity_id or uuid."""
|
||||
with contextlib.suppress(vol.Invalid):
|
||||
return entity_id(value)
|
||||
with contextlib.suppress(vol.Invalid):
|
||||
return fake_uuid4_hex(value)
|
||||
raise vol.Invalid(f"Entity {value} is neither a valid entity ID nor a valid UUID")
|
||||
|
||||
|
||||
def _entity_ids(value: str | list, allow_uuid: bool) -> list[str]:
|
||||
"""Help validate entity IDs or UUIDs."""
|
||||
if value is None:
|
||||
raise vol.Invalid("Entity IDs can not be None")
|
||||
if isinstance(value, str):
|
||||
value = [ent_id.strip() for ent_id in value.split(",")]
|
||||
|
||||
return [entity_id(ent_id) for ent_id in value]
|
||||
validator = entity_id_or_uuid if allow_uuid else entity_id
|
||||
return [validator(ent_id) for ent_id in value]
|
||||
|
||||
|
||||
def entity_ids(value: str | list) -> list[str]:
|
||||
"""Validate Entity IDs."""
|
||||
return _entity_ids(value, False)
|
||||
|
||||
|
||||
def entity_ids_or_uuids(value: str | list) -> list[str]:
|
||||
"""Validate entities specified by entity IDs or UUIDs."""
|
||||
return _entity_ids(value, True)
|
||||
|
||||
|
||||
comp_entity_ids = vol.Any(
|
||||
|
@ -682,6 +703,16 @@ def uuid4_hex(value: Any) -> str:
|
|||
return result.hex
|
||||
|
||||
|
||||
_FAKE_UUID_4_HEX = re.compile(r"^[0-9a-f]{32}$")
|
||||
|
||||
|
||||
def fake_uuid4_hex(value: Any) -> str:
|
||||
"""Validate a fake v4 UUID generated by random_uuid_hex."""
|
||||
if not _FAKE_UUID_4_HEX.match(value):
|
||||
raise vol.Invalid("Invalid UUID")
|
||||
return cast(str, value) # Pattern.match throws if input is not a string
|
||||
|
||||
|
||||
def ensure_list_csv(value: Any) -> list:
|
||||
"""Ensure that input is a list or make one from comma-separated string."""
|
||||
if isinstance(value, str):
|
||||
|
|
|
@ -9,12 +9,13 @@ timer.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
|
@ -161,14 +162,57 @@ class EntityRegistryStore(storage.Store):
|
|||
return await _async_migrate(old_major_version, old_minor_version, old_data)
|
||||
|
||||
|
||||
class EntityRegistryItems(UserDict):
|
||||
"""Container for entity registry items, maps entity_id -> entry.
|
||||
|
||||
Maintains two additional indexes:
|
||||
- id -> entry
|
||||
- (domain, platform, unique_id) -> entry
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the container."""
|
||||
super().__init__()
|
||||
self._entry_ids: dict[str, RegistryEntry] = {}
|
||||
self._index: dict[tuple[str, str, str], str] = {}
|
||||
|
||||
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
|
||||
"""Add an item."""
|
||||
if key in self:
|
||||
old_entry = self[key]
|
||||
del self._entry_ids[old_entry.id]
|
||||
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
|
||||
super().__setitem__(key, entry)
|
||||
self._entry_ids.__setitem__(entry.id, entry)
|
||||
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Remove an item."""
|
||||
entry = self[key]
|
||||
self._entry_ids.__delitem__(entry.id)
|
||||
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
|
||||
super().__delitem__(key)
|
||||
|
||||
def __getitem__(self, key: str) -> RegistryEntry:
|
||||
"""Get an item."""
|
||||
return cast(RegistryEntry, super().__getitem__(key))
|
||||
|
||||
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
|
||||
"""Get entity_id from (domain, platform, unique_id)."""
|
||||
return self._index.get(key)
|
||||
|
||||
def get_entry(self, key: str) -> RegistryEntry | None:
|
||||
"""Get entry from id."""
|
||||
return self._entry_ids.get(key)
|
||||
|
||||
|
||||
class EntityRegistry:
|
||||
"""Class to hold a registry of entities."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the registry."""
|
||||
self.hass = hass
|
||||
self.entities: dict[str, RegistryEntry]
|
||||
self._index: dict[tuple[str, str, str], str] = {}
|
||||
self.entities: EntityRegistryItems
|
||||
self._store = EntityRegistryStore(
|
||||
hass,
|
||||
STORAGE_VERSION_MAJOR,
|
||||
|
@ -218,7 +262,7 @@ class EntityRegistry:
|
|||
self, domain: str, platform: str, unique_id: str
|
||||
) -> str | None:
|
||||
"""Check if an entity_id is currently registered."""
|
||||
return self._index.get((domain, platform, unique_id))
|
||||
return self.entities.get_entity_id((domain, platform, unique_id))
|
||||
|
||||
@callback
|
||||
def async_generate_entity_id(
|
||||
|
@ -320,7 +364,7 @@ class EntityRegistry:
|
|||
):
|
||||
disabled_by = DISABLED_INTEGRATION
|
||||
|
||||
entity = RegistryEntry(
|
||||
entry = RegistryEntry(
|
||||
area_id=area_id,
|
||||
capabilities=capabilities,
|
||||
config_entry_id=config_entry_id,
|
||||
|
@ -336,7 +380,7 @@ class EntityRegistry:
|
|||
unique_id=unique_id,
|
||||
unit_of_measurement=unit_of_measurement,
|
||||
)
|
||||
self._register_entry(entity)
|
||||
self.entities[entity_id] = entry
|
||||
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
|
||||
self.async_schedule_save()
|
||||
|
||||
|
@ -344,12 +388,12 @@ class EntityRegistry:
|
|||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||
)
|
||||
|
||||
return entity
|
||||
return entry
|
||||
|
||||
@callback
|
||||
def async_remove(self, entity_id: str) -> None:
|
||||
"""Remove an entity from registry."""
|
||||
self._unregister_entry(self.entities[entity_id])
|
||||
self.entities.pop(entity_id)
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
|
||||
)
|
||||
|
@ -513,9 +557,7 @@ class EntityRegistry:
|
|||
if not new_values:
|
||||
return old
|
||||
|
||||
self._remove_index(old)
|
||||
new = attr.evolve(old, **new_values)
|
||||
self._register_entry(new)
|
||||
new = self.entities[entity_id] = attr.evolve(old, **new_values)
|
||||
|
||||
self.async_schedule_save()
|
||||
|
||||
|
@ -539,7 +581,7 @@ class EntityRegistry:
|
|||
old_conf_load_func=load_yaml,
|
||||
old_conf_migrate_func=_async_migrate_yaml_to_json,
|
||||
)
|
||||
entities: dict[str, RegistryEntry] = OrderedDict()
|
||||
entities = EntityRegistryItems()
|
||||
|
||||
if data is not None:
|
||||
for entity in data["entities"]:
|
||||
|
@ -571,7 +613,6 @@ class EntityRegistry:
|
|||
)
|
||||
|
||||
self.entities = entities
|
||||
self._rebuild_index()
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
|
@ -626,25 +667,6 @@ class EntityRegistry:
|
|||
if area_id == entry.area_id:
|
||||
self._async_update_entity(entity_id, area_id=None)
|
||||
|
||||
def _register_entry(self, entry: RegistryEntry) -> None:
|
||||
self.entities[entry.entity_id] = entry
|
||||
self._add_index(entry)
|
||||
|
||||
def _add_index(self, entry: RegistryEntry) -> None:
|
||||
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
|
||||
|
||||
def _unregister_entry(self, entry: RegistryEntry) -> None:
|
||||
self._remove_index(entry)
|
||||
del self.entities[entry.entity_id]
|
||||
|
||||
def _remove_index(self, entry: RegistryEntry) -> None:
|
||||
del self._index[(entry.domain, entry.platform, entry.unique_id)]
|
||||
|
||||
def _rebuild_index(self) -> None:
|
||||
self._index = {}
|
||||
for entry in self.entities.values():
|
||||
self._add_index(entry)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get(hass: HomeAssistant) -> EntityRegistry:
|
||||
|
@ -841,3 +863,25 @@ async def async_migrate_entries(
|
|||
|
||||
if updates is not None:
|
||||
ent_reg.async_update_entity(entry.entity_id, **updates)
|
||||
|
||||
|
||||
@callback
|
||||
def async_resolve_entity_ids(
|
||||
registry: EntityRegistry, entity_ids_or_uuids: list[str]
|
||||
) -> list[str]:
|
||||
"""Resolve a list of entity ids or UUIDs to a list of entity ids."""
|
||||
|
||||
def resolve_entity(entity_id_or_uuid: str) -> str | None:
|
||||
"""Resolve an entity id or UUID to an entity id or None."""
|
||||
if valid_entity_id(entity_id_or_uuid):
|
||||
return entity_id_or_uuid
|
||||
if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None:
|
||||
raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}")
|
||||
return entry.entity_id
|
||||
|
||||
tmp = [
|
||||
resolved_item
|
||||
for item in entity_ids_or_uuids
|
||||
if (resolved_item := resolve_entity(item)) is not None
|
||||
]
|
||||
return tmp
|
||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.components.automation import (
|
|||
AutomationTriggerInfo,
|
||||
)
|
||||
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
|
||||
from homeassistant.components.homeassistant.triggers import state
|
||||
from homeassistant.components.homeassistant.triggers import state as state_trigger
|
||||
from homeassistant.const import (
|
||||
CONF_DEVICE_ID,
|
||||
CONF_DOMAIN,
|
||||
|
@ -86,11 +86,11 @@ async def async_attach_trigger(
|
|||
to_state = STATE_OFF
|
||||
|
||||
state_config = {
|
||||
state.CONF_PLATFORM: "state",
|
||||
state_trigger.CONF_PLATFORM: "state",
|
||||
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
|
||||
state.CONF_TO: to_state,
|
||||
state_trigger.CONF_TO: to_state,
|
||||
}
|
||||
state_config = state.TRIGGER_SCHEMA(state_config)
|
||||
return await state.async_attach_trigger(
|
||||
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
|
||||
return await state_trigger.async_attach_trigger(
|
||||
hass, state_config, action, automation_info, platform_type="device"
|
||||
)
|
||||
|
|
|
@ -440,8 +440,11 @@ def mock_component(hass, component):
|
|||
def mock_registry(hass, mock_entries=None):
|
||||
"""Mock the Entity Registry."""
|
||||
registry = entity_registry.EntityRegistry(hass)
|
||||
registry.entities = mock_entries or OrderedDict()
|
||||
registry._rebuild_index()
|
||||
if mock_entries is None:
|
||||
mock_entries = {}
|
||||
registry.entities = entity_registry.EntityRegistryItems()
|
||||
for key, entry in mock_entries.items():
|
||||
registry.entities[key] = entry
|
||||
|
||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||
return registry
|
||||
|
|
|
@ -8,6 +8,7 @@ import homeassistant.components.automation as automation
|
|||
from homeassistant.components.homeassistant.triggers import state as state_trigger
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
@ -82,6 +83,64 @@ async def test_if_fires_on_entity_change(hass, calls):
|
|||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_if_fires_on_entity_change_uuid(hass, calls):
|
||||
"""Test for firing on entity change."""
|
||||
context = Context()
|
||||
|
||||
registry = er.async_get(hass)
|
||||
entry = registry.async_get_or_create(
|
||||
"test", "hue", "1234", suggested_object_id="beer"
|
||||
)
|
||||
|
||||
assert entry.entity_id == "test.beer"
|
||||
|
||||
hass.states.async_set("test.beer", "hello")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: {
|
||||
"trigger": {"platform": "state", "entity_id": entry.id},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data_template": {
|
||||
"some": "{{ trigger.%s }}"
|
||||
% "}} - {{ trigger.".join(
|
||||
(
|
||||
"platform",
|
||||
"entity_id",
|
||||
"from_state.state",
|
||||
"to_state.state",
|
||||
"for",
|
||||
"id",
|
||||
)
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.states.async_set("test.beer", "world", context=context)
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context.parent_id == context.id
|
||||
assert calls[0].data["some"] == "state - test.beer - hello - world - None - 0"
|
||||
|
||||
await hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_TURN_OFF,
|
||||
{ATTR_ENTITY_ID: ENTITY_MATCH_ALL},
|
||||
blocking=True,
|
||||
)
|
||||
hass.states.async_set("test.beer", "planet")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_if_fires_on_entity_change_with_from_filter(hass, calls):
|
||||
"""Test for firing on entity change with filter."""
|
||||
assert await async_setup_component(
|
||||
|
|
|
@ -172,9 +172,10 @@ def test_entity_id():
|
|||
assert schema("sensor.LIGHT") == "sensor.light"
|
||||
|
||||
|
||||
def test_entity_ids():
|
||||
@pytest.mark.parametrize("validator", [cv.entity_ids, cv.entity_ids_or_uuids])
|
||||
def test_entity_ids(validator):
|
||||
"""Test entity ID validation."""
|
||||
schema = vol.Schema(cv.entity_ids)
|
||||
schema = vol.Schema(validator)
|
||||
|
||||
options = (
|
||||
"invalid_entity",
|
||||
|
@ -194,6 +195,32 @@ def test_entity_ids():
|
|||
assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"]
|
||||
|
||||
|
||||
def test_entity_ids_or_uuids():
|
||||
"""Test entity ID validation."""
|
||||
schema = vol.Schema(cv.entity_ids_or_uuids)
|
||||
|
||||
valid_uuid = "a266a680b608c32770e6c45bfe6b8411"
|
||||
valid_uuid2 = "a266a680b608c32770e6c45bfe6b8412"
|
||||
invalid_uuid_capital_letters = "A266A680B608C32770E6C45bfE6B8412"
|
||||
options = (
|
||||
"invalid_uuid",
|
||||
invalid_uuid_capital_letters,
|
||||
f"{valid_uuid},invalid_uuid",
|
||||
["invalid_uuid"],
|
||||
[valid_uuid, "invalid_uuid"],
|
||||
[f"{valid_uuid},invalid_uuid"],
|
||||
)
|
||||
for value in options:
|
||||
with pytest.raises(vol.MultipleInvalid):
|
||||
schema(value)
|
||||
|
||||
options = ([], [valid_uuid], valid_uuid)
|
||||
for value in options:
|
||||
schema(value)
|
||||
|
||||
assert schema(f"{valid_uuid}, {valid_uuid2} ") == [valid_uuid, valid_uuid2]
|
||||
|
||||
|
||||
def test_entity_domain():
|
||||
"""Test entity domain validation."""
|
||||
schema = vol.Schema(cv.entity_domain("sensor"))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
|
||||
|
@ -1023,3 +1024,60 @@ async def test_entity_max_length_exceeded(hass, registry):
|
|||
assert exc_info.value.property_name == "generated_entity_id"
|
||||
assert exc_info.value.max_length == 255
|
||||
assert exc_info.value.value == f"sensor.{long_entity_id_name}_2"
|
||||
|
||||
|
||||
async def test_resolve_entity_ids(hass, registry):
|
||||
"""Test resolving entity IDs."""
|
||||
|
||||
entry1 = registry.async_get_or_create(
|
||||
"light", "hue", "1234", suggested_object_id="beer"
|
||||
)
|
||||
assert entry1.entity_id == "light.beer"
|
||||
|
||||
entry2 = registry.async_get_or_create(
|
||||
"light", "hue", "2345", suggested_object_id="milk"
|
||||
)
|
||||
assert entry2.entity_id == "light.milk"
|
||||
|
||||
expected = ["light.beer", "light.milk"]
|
||||
assert er.async_resolve_entity_ids(registry, [entry1.id, entry2.id]) == expected
|
||||
|
||||
expected = ["light.beer", "light.milk"]
|
||||
assert er.async_resolve_entity_ids(registry, ["light.beer", entry2.id]) == expected
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
er.async_resolve_entity_ids(registry, ["light.beer", "bad_uuid"])
|
||||
|
||||
expected = ["light.unknown"]
|
||||
assert er.async_resolve_entity_ids(registry, ["light.unknown"]) == expected
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
er.async_resolve_entity_ids(registry, ["unknown_uuid"])
|
||||
|
||||
|
||||
def test_entity_registry_items():
|
||||
"""Test the EntityRegistryItems container."""
|
||||
entities = er.EntityRegistryItems()
|
||||
assert entities.get_entity_id(("a", "b", "c")) is None
|
||||
assert entities.get_entry("abc") is None
|
||||
|
||||
entry1 = er.RegistryEntry("test.entity1", "1234", "hue")
|
||||
entry2 = er.RegistryEntry("test.entity2", "2345", "hue")
|
||||
entities["test.entity1"] = entry1
|
||||
entities["test.entity2"] = entry2
|
||||
|
||||
assert entities["test.entity1"] is entry1
|
||||
assert entities["test.entity2"] is entry2
|
||||
|
||||
assert entities.get_entity_id(("test", "hue", "1234")) is entry1.entity_id
|
||||
assert entities.get_entry(entry1.id) is entry1
|
||||
assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id
|
||||
assert entities.get_entry(entry2.id) is entry2
|
||||
|
||||
entities.pop("test.entity1")
|
||||
del entities["test.entity2"]
|
||||
|
||||
assert entities.get_entity_id(("test", "hue", "1234")) is None
|
||||
assert entities.get_entry(entry1.id) is None
|
||||
assert entities.get_entity_id(("test", "hue", "2345")) is None
|
||||
assert entities.get_entry(entry2.id) is None
|
||||
|
|
|
@ -748,6 +748,7 @@ async def test_wait_basic(hass, action_type):
|
|||
"to": "off",
|
||||
}
|
||||
sequence = cv.SCRIPT_SCHEMA(action)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||
|
||||
|
@ -848,6 +849,7 @@ async def test_wait_basic_times_out(hass, action_type):
|
|||
"to": "off",
|
||||
}
|
||||
sequence = cv.SCRIPT_SCHEMA(action)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||
timed_out = False
|
||||
|
@ -904,6 +906,7 @@ async def test_multiple_runs_wait(hass, action_type):
|
|||
{"event": event, "event_data": {"value": 2}},
|
||||
]
|
||||
)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(
|
||||
hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2
|
||||
)
|
||||
|
@ -952,6 +955,7 @@ async def test_cancel_wait(hass, action_type):
|
|||
}
|
||||
}
|
||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
|
@ -1049,6 +1053,7 @@ async def test_wait_timeout(hass, caplog, timeout_param, action_type):
|
|||
action["timeout"] = timeout_param
|
||||
action["continue_on_timeout"] = True
|
||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
|
@ -1116,6 +1121,7 @@ async def test_wait_continue_on_timeout(
|
|||
if continue_on_timeout is not None:
|
||||
action["continue_on_timeout"] = continue_on_timeout
|
||||
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
|
@ -1287,6 +1293,7 @@ async def test_wait_variables_out(hass, mode, action_type):
|
|||
},
|
||||
]
|
||||
sequence = cv.SCRIPT_SCHEMA(sequence)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||
wait_started_flag = async_watch_for_action(script_obj, "wait")
|
||||
|
||||
|
@ -1326,11 +1333,13 @@ async def test_wait_variables_out(hass, mode, action_type):
|
|||
|
||||
async def test_wait_for_trigger_bad(hass, caplog):
|
||||
"""Test bad wait_for_trigger."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||
)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(
|
||||
hass,
|
||||
cv.SCRIPT_SCHEMA(
|
||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||
),
|
||||
sequence,
|
||||
"Test Name",
|
||||
"test_domain",
|
||||
)
|
||||
|
@ -1356,11 +1365,13 @@ async def test_wait_for_trigger_bad(hass, caplog):
|
|||
|
||||
async def test_wait_for_trigger_generated_exception(hass, caplog):
|
||||
"""Test bad wait_for_trigger."""
|
||||
sequence = cv.SCRIPT_SCHEMA(
|
||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||
)
|
||||
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||
script_obj = script.Script(
|
||||
hass,
|
||||
cv.SCRIPT_SCHEMA(
|
||||
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
|
||||
),
|
||||
sequence,
|
||||
"Test Name",
|
||||
"test_domain",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue