Teach state trigger about entity registry ids ()

* Teach state trigger about entity registry ids

* Tweak

* Add tests

* Tweak tests

* Fix tests

* Resolve entity ids during config validation

* Update device_triggers

* Fix mistake

* Tweak trigger validator to ensure we don't modify the original config

* Add index from entry id to entry

* Update scaffold

* Pre-compile UUID regex

* Address review comment

* Tweak mock_registry

* Tweak

* Apply suggestion from code review
pull/60835/head
Erik Montnemery 2021-12-02 14:26:45 +01:00 committed by GitHub
parent c0fb1bffce
commit c85bb27d0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 324 additions and 74 deletions
homeassistant
script/scaffold/templates/device_trigger/integration

View File

@ -157,7 +157,7 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -220,7 +220,7 @@ async def async_attach_trigger(hass, config, action, automation_info):
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -11,8 +11,8 @@ from homeassistant.components.automation import (
)
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.homeassistant.triggers.state import (
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
async_attach_trigger as async_attach_state_trigger,
async_validate_trigger_config as async_validate_state_trigger_config,
)
from homeassistant.const import (
CONF_DEVICE_ID,
@ -67,7 +67,7 @@ async def async_attach_trigger(
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
}
state_config = STATE_TRIGGER_SCHEMA(state_config)
state_config = await async_validate_state_trigger_config(hass, state_config)
return await async_attach_state_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -131,7 +131,9 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(
hass, state_config
)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -170,7 +170,9 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(
hass, state_config
)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -164,7 +164,7 @@ async def async_attach_trigger(
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -3,20 +3,24 @@ from __future__ import annotations
from datetime import timedelta
import logging
from typing import Any
import voluptuous as vol
from homeassistant import exceptions
from homeassistant.const import CONF_ATTRIBUTE, CONF_FOR, CONF_PLATFORM, MATCH_ALL
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
template,
)
from homeassistant.helpers.event import (
Event,
async_track_same_state,
async_track_state_change_event,
process_state_match,
)
from homeassistant.helpers.typing import ConfigType
# mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs
@ -30,7 +34,7 @@ CONF_TO = "to"
BASE_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): "state",
vol.Required(CONF_ENTITY_ID): cv.entity_ids,
vol.Required(CONF_ENTITY_ID): cv.entity_ids_or_uuids,
vol.Optional(CONF_FOR): cv.positive_time_period_template,
vol.Optional(CONF_ATTRIBUTE): cv.match_all,
}
@ -52,17 +56,26 @@ TRIGGER_ATTRIBUTE_SCHEMA = BASE_SCHEMA.extend(
)
def TRIGGER_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
"""Validate trigger."""
if not isinstance(value, dict):
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate trigger config."""
if not isinstance(config, dict):
raise vol.Invalid("Expected a dictionary")
# We use this approach instead of vol.Any because
# this gives better error messages.
if CONF_ATTRIBUTE in value:
return TRIGGER_ATTRIBUTE_SCHEMA(value)
if CONF_ATTRIBUTE in config:
config = TRIGGER_ATTRIBUTE_SCHEMA(config)
else:
config = TRIGGER_STATE_SCHEMA(config)
return TRIGGER_STATE_SCHEMA(value)
registry = er.async_get(hass)
config[CONF_ENTITY_ID] = er.async_resolve_entity_ids(
registry, cv.entity_ids_or_uuids(config[CONF_ENTITY_ID])
)
return config
async def async_attach_trigger(
@ -74,7 +87,7 @@ async def async_attach_trigger(
platform_type: str = "state",
) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID)
entity_ids = config[CONF_ENTITY_ID]
if (from_state := config.get(CONF_FROM)) is None:
from_state = MATCH_ALL
if (to_state := config.get(CONF_TO)) is None:
@ -196,7 +209,7 @@ async def async_attach_trigger(
entity_ids=entity,
)
unsub = async_track_state_change_event(hass, entity_id, state_automation_listener)
unsub = async_track_state_change_event(hass, entity_ids, state_automation_listener)
@callback
def async_remove():

View File

@ -104,7 +104,7 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -104,7 +104,7 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -14,8 +14,8 @@ from homeassistant.components.homeassistant.triggers.state import (
CONF_FOR,
CONF_FROM,
CONF_TO,
TRIGGER_SCHEMA as STATE_TRIGGER_SCHEMA,
async_attach_trigger as async_attach_state_trigger,
async_validate_trigger_config as async_validate_state_trigger_config,
)
from homeassistant.components.select.const import ATTR_OPTIONS
from homeassistant.const import (
@ -84,7 +84,7 @@ async def async_attach_trigger(
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = STATE_TRIGGER_SCHEMA(state_config)
state_config = await async_validate_state_trigger_config(hass, state_config)
return await async_attach_state_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -92,7 +92,7 @@ async def async_attach_trigger(
}
if CONF_FOR in config:
state_config[CONF_FOR] = config[CONF_FOR]
state_config = state_trigger.TRIGGER_SCHEMA(state_config)
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -415,7 +415,7 @@ async def async_attach_trigger(
else:
raise HomeAssistantError(f"Unhandled trigger type {trigger_type}")
state_config = state.TRIGGER_SCHEMA(state_config)
state_config = await state.async_validate_trigger_config(hass, state_config)
return await state.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable, Hashable
import contextlib
from datetime import (
date as date_sys,
datetime as datetime_sys,
@ -262,14 +263,34 @@ def entity_id(value: Any) -> str:
raise vol.Invalid(f"Entity ID {value} is an invalid entity ID")
def entity_ids(value: str | list) -> list[str]:
"""Validate Entity IDs."""
def entity_id_or_uuid(value: Any) -> str:
"""Validate Entity specified by entity_id or uuid."""
with contextlib.suppress(vol.Invalid):
return entity_id(value)
with contextlib.suppress(vol.Invalid):
return fake_uuid4_hex(value)
raise vol.Invalid(f"Entity {value} is neither a valid entity ID nor a valid UUID")
def _entity_ids(value: str | list, allow_uuid: bool) -> list[str]:
"""Help validate entity IDs or UUIDs."""
if value is None:
raise vol.Invalid("Entity IDs can not be None")
if isinstance(value, str):
value = [ent_id.strip() for ent_id in value.split(",")]
return [entity_id(ent_id) for ent_id in value]
validator = entity_id_or_uuid if allow_uuid else entity_id
return [validator(ent_id) for ent_id in value]
def entity_ids(value: str | list) -> list[str]:
"""Validate Entity IDs."""
return _entity_ids(value, False)
def entity_ids_or_uuids(value: str | list) -> list[str]:
"""Validate entities specified by entity IDs or UUIDs."""
return _entity_ids(value, True)
comp_entity_ids = vol.Any(
@ -682,6 +703,16 @@ def uuid4_hex(value: Any) -> str:
return result.hex
_FAKE_UUID_4_HEX = re.compile(r"^[0-9a-f]{32}$")
def fake_uuid4_hex(value: Any) -> str:
"""Validate a fake v4 UUID generated by random_uuid_hex."""
if not _FAKE_UUID_4_HEX.match(value):
raise vol.Invalid("Invalid UUID")
return cast(str, value) # Pattern.match throws if input is not a string
def ensure_list_csv(value: Any) -> list:
"""Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str):

View File

@ -9,12 +9,13 @@ timer.
"""
from __future__ import annotations
from collections import OrderedDict
from collections import UserDict
from collections.abc import Callable, Iterable, Mapping
import logging
from typing import TYPE_CHECKING, Any, cast
import attr
import voluptuous as vol
from homeassistant.const import (
ATTR_DEVICE_CLASS,
@ -161,14 +162,57 @@ class EntityRegistryStore(storage.Store):
return await _async_migrate(old_major_version, old_minor_version, old_data)
class EntityRegistryItems(UserDict):
"""Container for entity registry items, maps entity_id -> entry.
Maintains two additional indexes:
- id -> entry
- (domain, platform, unique_id) -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._entry_ids: dict[str, RegistryEntry] = {}
self._index: dict[tuple[str, str, str], str] = {}
def __setitem__(self, key: str, entry: RegistryEntry) -> None:
"""Add an item."""
if key in self:
old_entry = self[key]
del self._entry_ids[old_entry.id]
del self._index[(old_entry.domain, old_entry.platform, old_entry.unique_id)]
super().__setitem__(key, entry)
self._entry_ids.__setitem__(entry.id, entry)
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
self._entry_ids.__delitem__(entry.id)
self._index.__delitem__((entry.domain, entry.platform, entry.unique_id))
super().__delitem__(key)
def __getitem__(self, key: str) -> RegistryEntry:
"""Get an item."""
return cast(RegistryEntry, super().__getitem__(key))
def get_entity_id(self, key: tuple[str, str, str]) -> str | None:
"""Get entity_id from (domain, platform, unique_id)."""
return self._index.get(key)
def get_entry(self, key: str) -> RegistryEntry | None:
"""Get entry from id."""
return self._entry_ids.get(key)
class EntityRegistry:
"""Class to hold a registry of entities."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the registry."""
self.hass = hass
self.entities: dict[str, RegistryEntry]
self._index: dict[tuple[str, str, str], str] = {}
self.entities: EntityRegistryItems
self._store = EntityRegistryStore(
hass,
STORAGE_VERSION_MAJOR,
@ -218,7 +262,7 @@ class EntityRegistry:
self, domain: str, platform: str, unique_id: str
) -> str | None:
"""Check if an entity_id is currently registered."""
return self._index.get((domain, platform, unique_id))
return self.entities.get_entity_id((domain, platform, unique_id))
@callback
def async_generate_entity_id(
@ -320,7 +364,7 @@ class EntityRegistry:
):
disabled_by = DISABLED_INTEGRATION
entity = RegistryEntry(
entry = RegistryEntry(
area_id=area_id,
capabilities=capabilities,
config_entry_id=config_entry_id,
@ -336,7 +380,7 @@ class EntityRegistry:
unique_id=unique_id,
unit_of_measurement=unit_of_measurement,
)
self._register_entry(entity)
self.entities[entity_id] = entry
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
self.async_schedule_save()
@ -344,12 +388,12 @@ class EntityRegistry:
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
)
return entity
return entry
@callback
def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry."""
self._unregister_entry(self.entities[entity_id])
self.entities.pop(entity_id)
self.hass.bus.async_fire(
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": entity_id}
)
@ -513,9 +557,7 @@ class EntityRegistry:
if not new_values:
return old
self._remove_index(old)
new = attr.evolve(old, **new_values)
self._register_entry(new)
new = self.entities[entity_id] = attr.evolve(old, **new_values)
self.async_schedule_save()
@ -539,7 +581,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate_yaml_to_json,
)
entities: dict[str, RegistryEntry] = OrderedDict()
entities = EntityRegistryItems()
if data is not None:
for entity in data["entities"]:
@ -571,7 +613,6 @@ class EntityRegistry:
)
self.entities = entities
self._rebuild_index()
@callback
def async_schedule_save(self) -> None:
@ -626,25 +667,6 @@ class EntityRegistry:
if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry
self._add_index(entry)
def _add_index(self, entry: RegistryEntry) -> None:
self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id
def _unregister_entry(self, entry: RegistryEntry) -> None:
self._remove_index(entry)
del self.entities[entry.entity_id]
def _remove_index(self, entry: RegistryEntry) -> None:
del self._index[(entry.domain, entry.platform, entry.unique_id)]
def _rebuild_index(self) -> None:
self._index = {}
for entry in self.entities.values():
self._add_index(entry)
@callback
def async_get(hass: HomeAssistant) -> EntityRegistry:
@ -841,3 +863,25 @@ async def async_migrate_entries(
if updates is not None:
ent_reg.async_update_entity(entry.entity_id, **updates)
@callback
def async_resolve_entity_ids(
registry: EntityRegistry, entity_ids_or_uuids: list[str]
) -> list[str]:
"""Resolve a list of entity ids or UUIDs to a list of entity ids."""
def resolve_entity(entity_id_or_uuid: str) -> str | None:
"""Resolve an entity id or UUID to an entity id or None."""
if valid_entity_id(entity_id_or_uuid):
return entity_id_or_uuid
if (entry := registry.entities.get_entry(entity_id_or_uuid)) is None:
raise vol.Invalid(f"Unknown entity registry entry {entity_id_or_uuid}")
return entry.entity_id
tmp = [
resolved_item
for item in entity_ids_or_uuids
if (resolved_item := resolve_entity(item)) is not None
]
return tmp

View File

@ -10,7 +10,7 @@ from homeassistant.components.automation import (
AutomationTriggerInfo,
)
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
from homeassistant.components.homeassistant.triggers import state
from homeassistant.components.homeassistant.triggers import state as state_trigger
from homeassistant.const import (
CONF_DEVICE_ID,
CONF_DOMAIN,
@ -86,11 +86,11 @@ async def async_attach_trigger(
to_state = STATE_OFF
state_config = {
state.CONF_PLATFORM: "state",
state_trigger.CONF_PLATFORM: "state",
CONF_ENTITY_ID: config[CONF_ENTITY_ID],
state.CONF_TO: to_state,
state_trigger.CONF_TO: to_state,
}
state_config = state.TRIGGER_SCHEMA(state_config)
return await state.async_attach_trigger(
state_config = await state_trigger.async_validate_trigger_config(hass, state_config)
return await state_trigger.async_attach_trigger(
hass, state_config, action, automation_info, platform_type="device"
)

View File

@ -440,8 +440,11 @@ def mock_component(hass, component):
def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or OrderedDict()
registry._rebuild_index()
if mock_entries is None:
mock_entries = {}
registry.entities = entity_registry.EntityRegistryItems()
for key, entry in mock_entries.items():
registry.entities[key] = entry
hass.data[entity_registry.DATA_REGISTRY] = registry
return registry

View File

@ -8,6 +8,7 @@ import homeassistant.components.automation as automation
from homeassistant.components.homeassistant.triggers import state as state_trigger
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, SERVICE_TURN_OFF
from homeassistant.core import Context
from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -82,6 +83,64 @@ async def test_if_fires_on_entity_change(hass, calls):
assert len(calls) == 1
async def test_if_fires_on_entity_change_uuid(hass, calls):
"""Test for firing on entity change."""
context = Context()
registry = er.async_get(hass)
entry = registry.async_get_or_create(
"test", "hue", "1234", suggested_object_id="beer"
)
assert entry.entity_id == "test.beer"
hass.states.async_set("test.beer", "hello")
await hass.async_block_till_done()
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {"platform": "state", "entity_id": entry.id},
"action": {
"service": "test.automation",
"data_template": {
"some": "{{ trigger.%s }}"
% "}} - {{ trigger.".join(
(
"platform",
"entity_id",
"from_state.state",
"to_state.state",
"for",
"id",
)
)
},
},
}
},
)
await hass.async_block_till_done()
hass.states.async_set("test.beer", "world", context=context)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].context.parent_id == context.id
assert calls[0].data["some"] == "state - test.beer - hello - world - None - 0"
await hass.services.async_call(
automation.DOMAIN,
SERVICE_TURN_OFF,
{ATTR_ENTITY_ID: ENTITY_MATCH_ALL},
blocking=True,
)
hass.states.async_set("test.beer", "planet")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_if_fires_on_entity_change_with_from_filter(hass, calls):
"""Test for firing on entity change with filter."""
assert await async_setup_component(

View File

@ -172,9 +172,10 @@ def test_entity_id():
assert schema("sensor.LIGHT") == "sensor.light"
def test_entity_ids():
@pytest.mark.parametrize("validator", [cv.entity_ids, cv.entity_ids_or_uuids])
def test_entity_ids(validator):
"""Test entity ID validation."""
schema = vol.Schema(cv.entity_ids)
schema = vol.Schema(validator)
options = (
"invalid_entity",
@ -194,6 +195,32 @@ def test_entity_ids():
assert schema("sensor.LIGHT, light.kitchen ") == ["sensor.light", "light.kitchen"]
def test_entity_ids_or_uuids():
"""Test entity ID validation."""
schema = vol.Schema(cv.entity_ids_or_uuids)
valid_uuid = "a266a680b608c32770e6c45bfe6b8411"
valid_uuid2 = "a266a680b608c32770e6c45bfe6b8412"
invalid_uuid_capital_letters = "A266A680B608C32770E6C45bfE6B8412"
options = (
"invalid_uuid",
invalid_uuid_capital_letters,
f"{valid_uuid},invalid_uuid",
["invalid_uuid"],
[valid_uuid, "invalid_uuid"],
[f"{valid_uuid},invalid_uuid"],
)
for value in options:
with pytest.raises(vol.MultipleInvalid):
schema(value)
options = ([], [valid_uuid], valid_uuid)
for value in options:
schema(value)
assert schema(f"{valid_uuid}, {valid_uuid2} ") == [valid_uuid, valid_uuid2]
def test_entity_domain():
"""Test entity domain validation."""
schema = vol.Schema(cv.entity_domain("sensor"))

View File

@ -2,6 +2,7 @@
from unittest.mock import patch
import pytest
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
@ -1023,3 +1024,60 @@ async def test_entity_max_length_exceeded(hass, registry):
assert exc_info.value.property_name == "generated_entity_id"
assert exc_info.value.max_length == 255
assert exc_info.value.value == f"sensor.{long_entity_id_name}_2"
async def test_resolve_entity_ids(hass, registry):
"""Test resolving entity IDs."""
entry1 = registry.async_get_or_create(
"light", "hue", "1234", suggested_object_id="beer"
)
assert entry1.entity_id == "light.beer"
entry2 = registry.async_get_or_create(
"light", "hue", "2345", suggested_object_id="milk"
)
assert entry2.entity_id == "light.milk"
expected = ["light.beer", "light.milk"]
assert er.async_resolve_entity_ids(registry, [entry1.id, entry2.id]) == expected
expected = ["light.beer", "light.milk"]
assert er.async_resolve_entity_ids(registry, ["light.beer", entry2.id]) == expected
with pytest.raises(vol.Invalid):
er.async_resolve_entity_ids(registry, ["light.beer", "bad_uuid"])
expected = ["light.unknown"]
assert er.async_resolve_entity_ids(registry, ["light.unknown"]) == expected
with pytest.raises(vol.Invalid):
er.async_resolve_entity_ids(registry, ["unknown_uuid"])
def test_entity_registry_items():
"""Test the EntityRegistryItems container."""
entities = er.EntityRegistryItems()
assert entities.get_entity_id(("a", "b", "c")) is None
assert entities.get_entry("abc") is None
entry1 = er.RegistryEntry("test.entity1", "1234", "hue")
entry2 = er.RegistryEntry("test.entity2", "2345", "hue")
entities["test.entity1"] = entry1
entities["test.entity2"] = entry2
assert entities["test.entity1"] is entry1
assert entities["test.entity2"] is entry2
assert entities.get_entity_id(("test", "hue", "1234")) is entry1.entity_id
assert entities.get_entry(entry1.id) is entry1
assert entities.get_entity_id(("test", "hue", "2345")) is entry2.entity_id
assert entities.get_entry(entry2.id) is entry2
entities.pop("test.entity1")
del entities["test.entity2"]
assert entities.get_entity_id(("test", "hue", "1234")) is None
assert entities.get_entry(entry1.id) is None
assert entities.get_entity_id(("test", "hue", "2345")) is None
assert entities.get_entry(entry2.id) is None

View File

@ -748,6 +748,7 @@ async def test_wait_basic(hass, action_type):
"to": "off",
}
sequence = cv.SCRIPT_SCHEMA(action)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
@ -848,6 +849,7 @@ async def test_wait_basic_times_out(hass, action_type):
"to": "off",
}
sequence = cv.SCRIPT_SCHEMA(action)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
timed_out = False
@ -904,6 +906,7 @@ async def test_multiple_runs_wait(hass, action_type):
{"event": event, "event_data": {"value": 2}},
]
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(
hass, sequence, "Test Name", "test_domain", script_mode="parallel", max_runs=2
)
@ -952,6 +955,7 @@ async def test_cancel_wait(hass, action_type):
}
}
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1049,6 +1053,7 @@ async def test_wait_timeout(hass, caplog, timeout_param, action_type):
action["timeout"] = timeout_param
action["continue_on_timeout"] = True
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1116,6 +1121,7 @@ async def test_wait_continue_on_timeout(
if continue_on_timeout is not None:
action["continue_on_timeout"] = continue_on_timeout
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1287,6 +1293,7 @@ async def test_wait_variables_out(hass, mode, action_type):
},
]
sequence = cv.SCRIPT_SCHEMA(sequence)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -1326,11 +1333,13 @@ async def test_wait_variables_out(hass, mode, action_type):
async def test_wait_for_trigger_bad(hass, caplog):
"""Test bad wait_for_trigger."""
sequence = cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(
hass,
cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
),
sequence,
"Test Name",
"test_domain",
)
@ -1356,11 +1365,13 @@ async def test_wait_for_trigger_bad(hass, caplog):
async def test_wait_for_trigger_generated_exception(hass, caplog):
"""Test bad wait_for_trigger."""
sequence = cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(
hass,
cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
),
sequence,
"Test Name",
"test_domain",
)