Add wait_for_trigger script action (#38075)

* Add wait_for_trigger script action

* Add tests

* Change script integration to use config validator
pull/39102/head
Phil Bruckner 2020-08-21 04:38:25 -05:00 committed by GitHub
parent c1ed584f2d
commit 76ead858cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 288 additions and 53 deletions

View File

@ -88,7 +88,6 @@ _LOGGER = logging.getLogger(__name__)
AutomationActionType = Callable[[HomeAssistant, TemplateVarsType], Awaitable[None]]
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
PLATFORM_SCHEMA = vol.All(

View File

@ -1,5 +1,6 @@
"""Provide configuration end points for scripts."""
from homeassistant.components.script import DOMAIN, SCRIPT_ENTRY_SCHEMA
from homeassistant.components.script.config import async_validate_config_item
from homeassistant.config import SCRIPT_CONFIG_PATH
from homeassistant.const import SERVICE_RELOAD
import homeassistant.helpers.config_validation as cv
@ -16,12 +17,13 @@ async def async_setup(hass):
hass.http.register_view(
EditKeyBasedConfigView(
"script",
DOMAIN,
"config",
SCRIPT_CONFIG_PATH,
cv.slug,
SCRIPT_ENTRY_SCHEMA,
post_write_hook=hook,
data_validator=async_validate_config_item,
)
)
return True

View File

@ -0,0 +1,50 @@
"""Config validation helper for the script integration."""
import asyncio
import voluptuous as vol
from homeassistant.config import async_log_exception
from homeassistant.const import CONF_SEQUENCE
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.script import async_validate_action_config
from . import DOMAIN, SCRIPT_ENTRY_SCHEMA
async def async_validate_config_item(hass, config, full_config=None):
"""Validate config item."""
config = SCRIPT_ENTRY_SCHEMA(config)
config[CONF_SEQUENCE] = await asyncio.gather(
*[
async_validate_action_config(hass, action)
for action in config[CONF_SEQUENCE]
]
)
return config
async def _try_async_validate_config_item(hass, object_id, config, full_config=None):
"""Validate config item."""
try:
cv.slug(object_id)
config = await async_validate_config_item(hass, config, full_config)
except (vol.Invalid, HomeAssistantError) as ex:
async_log_exception(ex, DOMAIN, full_config or config, hass)
return None
return config
async def async_validate_config(hass, config):
"""Validate config."""
if DOMAIN in config:
validated_config = {}
for object_id, cfg in config[DOMAIN].items():
cfg = await _try_async_validate_config_item(hass, object_id, cfg, config)
if cfg is not None:
validated_config[object_id] = cfg
config[DOMAIN] = validated_config
return config

View File

@ -86,7 +86,9 @@ async def async_attach_trigger(
value_template.extract_entities(),
)
unsub = async_track_template(hass, value_template, template_listener)
unsub = async_track_template(
hass, value_template, template_listener, automation_info["variables"]
)
@callback
def async_remove():

View File

@ -180,6 +180,7 @@ CONF_URL = "url"
CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template"
CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_FOR_TRIGGER = "wait_for_trigger"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_WEBHOOK_ID = "webhook_id"
CONF_WEEKDAY = "weekday"

View File

@ -67,6 +67,7 @@ from homeassistant.const import (
CONF_UNIT_SYSTEM_METRIC,
CONF_UNTIL,
CONF_VALUE_TEMPLATE,
CONF_WAIT_FOR_TRIGGER,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
ENTITY_MATCH_ALL,
@ -1074,6 +1075,15 @@ _SCRIPT_CHOOSE_SCHEMA = vol.Schema(
}
)
_SCRIPT_WAIT_FOR_TRIGGER_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_WAIT_FOR_TRIGGER): TRIGGER_SCHEMA,
vol.Optional(CONF_TIMEOUT): positive_time_period_template,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
}
)
SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
@ -1083,6 +1093,7 @@ SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
SCRIPT_ACTION_REPEAT = "repeat"
SCRIPT_ACTION_CHOOSE = "choose"
SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger"
def determine_script_action(action: dict) -> str:
@ -1111,6 +1122,9 @@ def determine_script_action(action: dict) -> str:
if CONF_CHOOSE in action:
return SCRIPT_ACTION_CHOOSE
if CONF_WAIT_FOR_TRIGGER in action:
return SCRIPT_ACTION_WAIT_FOR_TRIGGER
return SCRIPT_ACTION_CALL_SERVICE
@ -1124,4 +1138,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA,
SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA,
SCRIPT_ACTION_WAIT_FOR_TRIGGER: _SCRIPT_WAIT_FOR_TRIGGER_SCHEMA,
}

View File

@ -1,5 +1,6 @@
"""Helpers to execute scripts."""
import asyncio
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
import itertools
@ -45,6 +46,7 @@ from homeassistant.const import (
CONF_SEQUENCE,
CONF_TIMEOUT,
CONF_UNTIL,
CONF_WAIT_FOR_TRIGGER,
CONF_WAIT_TEMPLATE,
CONF_WHILE,
EVENT_HOMEASSISTANT_STOP,
@ -61,6 +63,10 @@ from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.trigger import (
async_initialize_triggers,
async_validate_trigger_config,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify
from homeassistant.util.dt import utcnow
@ -123,7 +129,7 @@ async def async_validate_action_config(
hass, config[CONF_DOMAIN], "action"
)
config = platform.ACTION_SCHEMA(config) # type: ignore
if (
elif (
action_type == cv.SCRIPT_ACTION_CHECK_CONDITION
and config[CONF_CONDITION] == "device"
):
@ -131,6 +137,10 @@ async def async_validate_action_config(
hass, config[CONF_DOMAIN], "condition"
)
config = platform.CONDITION_SCHEMA(config) # type: ignore
elif action_type == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER:
config[CONF_WAIT_FOR_TRIGGER] = await async_validate_trigger_config(
hass, config[CONF_WAIT_FOR_TRIGGER]
)
return config
@ -539,6 +549,64 @@ class _ScriptRun:
if choose_data["default"]:
await self._async_run_script(choose_data["default"])
async def _async_wait_for_trigger_step(self):
"""Wait for a trigger event."""
if CONF_TIMEOUT in self._action:
delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
delay = None
self._script.last_action = self._action.get(CONF_ALIAS, "wait for trigger")
self._log(
"Executing step %s%s",
self._script.last_action,
"" if delay is None else f" (timeout: {timedelta(seconds=delay)})",
)
variables = deepcopy(self._variables)
self._variables["wait"] = {"remaining": delay, "trigger": None}
async def async_done(variables, context=None):
self._variables["wait"] = {
"remaining": to_context.remaining if to_context else delay,
"trigger": variables["trigger"],
}
done.set()
def log_cb(level, msg):
self._log(msg, level=level)
to_context = None
remove_triggers = await async_initialize_triggers(
self._hass,
self._action[CONF_WAIT_FOR_TRIGGER],
async_done,
self._script.domain,
self._script.name,
log_cb,
variables=variables,
)
if not remove_triggers:
return
self._changed()
done = asyncio.Event()
tasks = [
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
self._variables["wait"]["remaining"] = 0.0
finally:
for task in tasks:
task.cancel()
remove_triggers()
async def _async_run_script(self, script):
"""Execute a script."""
await self._async_run_long_action(

View File

@ -1,7 +1,8 @@
"""Triggers."""
import asyncio
import logging
from typing import Any, Callable, List, Optional
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Union
import voluptuous as vol
@ -59,12 +60,14 @@ async def async_initialize_triggers(
name: str,
log_cb: Callable,
home_assistant_start: bool = False,
variables: Optional[Union[Dict[str, Any], MappingProxyType]] = None,
) -> Optional[CALLBACK_TYPE]:
"""Initialize triggers."""
info = {
"domain": domain,
"name": name,
"home_assistant_start": home_assistant_start,
"variables": variables,
}
triggers = []

View File

@ -185,7 +185,7 @@ invalid_configs = [
@pytest.mark.parametrize("value", invalid_configs)
async def test_setup_with_invalid_configs(hass, value):
"""Test setup with invalid configs."""
assert not await async_setup_component(
assert await async_setup_component(
hass, "script", {"script": value}
), f"Script loaded with wrong config {value}"
@ -418,7 +418,12 @@ async def test_extraction_functions(hass):
"service": "test.script",
"data": {"entity_id": "light.in_first"},
},
{"domain": "light", "device_id": "device-in-both"},
{
"entity_id": "light.device_in_both",
"domain": "light",
"type": "turn_on",
"device_id": "device-in-both",
},
]
},
"test2": {
@ -433,8 +438,18 @@ async def test_extraction_functions(hass):
"state": "100",
},
{"scene": "scene.hello"},
{"domain": "light", "device_id": "device-in-both"},
{"domain": "light", "device_id": "device-in-last"},
{
"entity_id": "light.device_in_both",
"domain": "light",
"type": "turn_on",
"device_id": "device-in-both",
},
{
"entity_id": "light.device_in_last",
"domain": "light",
"type": "turn_on",
"device_id": "device-in-last",
},
],
},
}

View File

@ -475,15 +475,20 @@ async def test_cancel_delay(hass):
assert len(events) == 0
async def test_wait_template_basic(hass):
"""Test the wait template."""
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_basic(hass, action_type):
"""Test wait actions."""
wait_alias = "wait step"
sequence = cv.SCRIPT_SCHEMA(
{
"wait_template": "{{ states.switch.test.state == 'off' }}",
"alias": wait_alias,
action = {"alias": wait_alias}
if action_type == "template":
action["wait_template"] = "{{ states.switch.test.state == 'off' }}"
else:
action["wait_for_trigger"] = {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
)
sequence = cv.SCRIPT_SCHEMA(action)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
@ -505,14 +510,25 @@ async def test_wait_template_basic(hass):
assert script_obj.last_action is None
async def test_multiple_runs_wait_template(hass):
"""Test multiple runs with wait_template in script."""
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_multiple_runs_wait(hass, action_type):
"""Test multiple runs with wait in script."""
event = "test_event"
events = async_capture_events(hass, event)
if action_type == "template":
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
else:
action = {
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
}
sequence = cv.SCRIPT_SCHEMA(
[
{"event": event, "event_data": {"value": 1}},
{"wait_template": "{{ states.switch.test.state == 'off' }}"},
action,
{"event": event, "event_data": {"value": 2}},
]
)
@ -529,12 +545,15 @@ async def test_multiple_runs_wait_template(hass):
assert script_obj.is_running
assert len(events) == 1
assert events[-1].data["value"] == 1
# Start second run of script while first run is in wait_template.
wait_started_flag.clear()
hass.async_create_task(script_obj.async_run())
await asyncio.wait_for(wait_started_flag.wait(), 1)
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
# Start second run of script while first run is in wait_template.
hass.async_create_task(script_obj.async_run())
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
@ -545,16 +564,22 @@ async def test_multiple_runs_wait_template(hass):
assert events[-1].data["value"] == 2
async def test_cancel_wait_template(hass):
"""Test the cancelling while wait_template is present."""
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_cancel_wait(hass, action_type):
"""Test the cancelling while wait is present."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{"wait_template": "{{ states.switch.test.state == 'off' }}"},
{"event": event},
]
)
if action_type == "template":
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
else:
action = {
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
}
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -606,20 +631,24 @@ async def test_wait_template_not_schedule(hass):
@pytest.mark.parametrize(
"timeout_param", [5, "{{ 5 }}", {"seconds": 5}, {"seconds": "{{ 5 }}"}]
)
async def test_wait_template_timeout(hass, caplog, timeout_param):
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_timeout(hass, caplog, timeout_param, action_type):
"""Test the wait timeout option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = cv.SCRIPT_SCHEMA(
[
{
"wait_template": "{{ states.switch.test.state == 'off' }}",
"timeout": timeout_param,
"continue_on_timeout": True,
},
{"event": event},
]
)
if action_type == "template":
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
else:
action = {
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
}
action["timeout"] = timeout_param
action["continue_on_timeout"] = True
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -651,17 +680,27 @@ async def test_wait_template_timeout(hass, caplog, timeout_param):
@pytest.mark.parametrize(
"continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)]
)
async def test_wait_template_continue_on_timeout(hass, continue_on_timeout, n_events):
"""Test the wait template continue_on_timeout option."""
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_continue_on_timeout(
hass, continue_on_timeout, n_events, action_type
):
"""Test the wait continue_on_timeout option."""
event = "test_event"
events = async_capture_events(hass, event)
sequence = [
{"wait_template": "{{ states.switch.test.state == 'off' }}", "timeout": 5},
{"event": event},
]
if action_type == "template":
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
else:
action = {
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
}
action["timeout"] = 5
if continue_on_timeout is not None:
sequence[0]["continue_on_timeout"] = continue_on_timeout
sequence = cv.SCRIPT_SCHEMA(sequence)
action["continue_on_timeout"] = continue_on_timeout
sequence = cv.SCRIPT_SCHEMA([action, {"event": event}])
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, "wait")
@ -708,11 +747,23 @@ async def test_wait_template_variables_in(hass):
@pytest.mark.parametrize("mode", ["no_timeout", "timeout_finish", "timeout_not_finish"])
async def test_wait_template_variables_out(hass, mode):
"""Test the wait template output variable."""
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_variables_out(hass, mode, action_type):
"""Test the wait output variable."""
event = "test_event"
events = async_capture_events(hass, event)
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
if action_type == "template":
action = {"wait_template": "{{ states.switch.test.state == 'off' }}"}
event_key = "completed"
else:
action = {
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
}
}
event_key = "trigger"
if mode != "no_timeout":
action["timeout"] = 5
action["continue_on_timeout"] = True
@ -721,7 +772,7 @@ async def test_wait_template_variables_out(hass, mode):
{
"event": event,
"event_data_template": {
"completed": "{{ wait.completed }}",
event_key: f"{{{{ wait.{event_key} }}}}",
"remaining": "{{ wait.remaining }}",
},
},
@ -749,7 +800,12 @@ async def test_wait_template_variables_out(hass, mode):
assert not script_obj.is_running
assert len(events) == 1
assert events[0].data["completed"] == str(mode != "timeout_not_finish")
if action_type == "template":
assert events[0].data["completed"] == str(mode != "timeout_not_finish")
elif mode != "timeout_not_finish":
assert "'to_state': <state switch.test=off" in events[0].data["trigger"]
else:
assert events[0].data["trigger"] == "None"
remaining = events[0].data["remaining"]
if mode == "no_timeout":
assert remaining == "None"
@ -759,6 +815,30 @@ async def test_wait_template_variables_out(hass, mode):
assert float(remaining) == 0.0
async def test_wait_for_trigger_bad(hass, caplog):
"""Test bad wait_for_trigger."""
script_obj = script.Script(
hass,
cv.SCRIPT_SCHEMA(
{"wait_for_trigger": {"platform": "state", "entity_id": "sensor.abc"}}
),
"Test Name",
"test_domain",
)
async def async_attach_trigger_mock(*args, **kwargs):
return None
with mock.patch(
"homeassistant.components.homeassistant.triggers.state.async_attach_trigger",
wraps=async_attach_trigger_mock,
):
hass.async_create_task(script_obj.async_run())
await hass.async_block_till_done()
assert "Error setting up trigger" in caplog.text
async def test_condition_basic(hass):
"""Test if we can use conditions in a script."""
event = "test_event"