diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index b4caf9bcd49..c550aa4f2c7 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -88,7 +88,6 @@ _LOGGER = logging.getLogger(__name__) AutomationActionType = Callable[[HomeAssistant, TemplateVarsType], Awaitable[None]] - _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA]) PLATFORM_SCHEMA = vol.All( diff --git a/homeassistant/components/config/script.py b/homeassistant/components/config/script.py index de9c25b223f..a5d1bb2037b 100644 --- a/homeassistant/components/config/script.py +++ b/homeassistant/components/config/script.py @@ -1,5 +1,6 @@ """Provide configuration end points for scripts.""" from homeassistant.components.script import DOMAIN, SCRIPT_ENTRY_SCHEMA +from homeassistant.components.script.config import async_validate_config_item from homeassistant.config import SCRIPT_CONFIG_PATH from homeassistant.const import SERVICE_RELOAD import homeassistant.helpers.config_validation as cv @@ -16,12 +17,13 @@ async def async_setup(hass): hass.http.register_view( EditKeyBasedConfigView( - "script", + DOMAIN, "config", SCRIPT_CONFIG_PATH, cv.slug, SCRIPT_ENTRY_SCHEMA, post_write_hook=hook, + data_validator=async_validate_config_item, ) ) return True diff --git a/homeassistant/components/script/config.py b/homeassistant/components/script/config.py new file mode 100644 index 00000000000..3860a4d0119 --- /dev/null +++ b/homeassistant/components/script/config.py @@ -0,0 +1,50 @@ +"""Config validation helper for the script integration.""" +import asyncio + +import voluptuous as vol + +from homeassistant.config import async_log_exception +from homeassistant.const import CONF_SEQUENCE +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.script import async_validate_action_config + +from . import DOMAIN, SCRIPT_ENTRY_SCHEMA + + +async def async_validate_config_item(hass, config, full_config=None): + """Validate config item.""" + config = SCRIPT_ENTRY_SCHEMA(config) + config[CONF_SEQUENCE] = await asyncio.gather( + *[ + async_validate_action_config(hass, action) + for action in config[CONF_SEQUENCE] + ] + ) + + return config + + +async def _try_async_validate_config_item(hass, object_id, config, full_config=None): + """Validate config item.""" + try: + cv.slug(object_id) + config = await async_validate_config_item(hass, config, full_config) + except (vol.Invalid, HomeAssistantError) as ex: + async_log_exception(ex, DOMAIN, full_config or config, hass) + return None + + return config + + +async def async_validate_config(hass, config): + """Validate config.""" + if DOMAIN in config: + validated_config = {} + for object_id, cfg in config[DOMAIN].items(): + cfg = await _try_async_validate_config_item(hass, object_id, cfg, config) + if cfg is not None: + validated_config[object_id] = cfg + config[DOMAIN] = validated_config + + return config diff --git a/homeassistant/components/template/trigger.py b/homeassistant/components/template/trigger.py index 7cbc1a8ffd4..6a21206c80b 100644 --- a/homeassistant/components/template/trigger.py +++ b/homeassistant/components/template/trigger.py @@ -86,7 +86,9 @@ async def async_attach_trigger( value_template.extract_entities(), ) - unsub = async_track_template(hass, value_template, template_listener) + unsub = async_track_template( + hass, value_template, template_listener, automation_info["variables"] + ) @callback def async_remove(): diff --git a/homeassistant/const.py b/homeassistant/const.py index c72974cf478..09d749c2113 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -180,6 +180,7 @@ CONF_URL = "url" CONF_USERNAME = "username" CONF_VALUE_TEMPLATE = "value_template" CONF_VERIFY_SSL = "verify_ssl" +CONF_WAIT_FOR_TRIGGER = "wait_for_trigger" CONF_WAIT_TEMPLATE = "wait_template" CONF_WEBHOOK_ID = "webhook_id" CONF_WEEKDAY = "weekday" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index f3327e23222..f70812d5a4f 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -67,6 +67,7 @@ from homeassistant.const import ( CONF_UNIT_SYSTEM_METRIC, CONF_UNTIL, CONF_VALUE_TEMPLATE, + CONF_WAIT_FOR_TRIGGER, CONF_WAIT_TEMPLATE, CONF_WHILE, ENTITY_MATCH_ALL, @@ -1074,6 +1075,15 @@ _SCRIPT_CHOOSE_SCHEMA = vol.Schema( } ) +_SCRIPT_WAIT_FOR_TRIGGER_SCHEMA = vol.Schema( + { + vol.Optional(CONF_ALIAS): string, + vol.Required(CONF_WAIT_FOR_TRIGGER): TRIGGER_SCHEMA, + vol.Optional(CONF_TIMEOUT): positive_time_period_template, + vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean, + } +) + SCRIPT_ACTION_DELAY = "delay" SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" SCRIPT_ACTION_CHECK_CONDITION = "condition" @@ -1083,6 +1093,7 @@ SCRIPT_ACTION_DEVICE_AUTOMATION = "device" SCRIPT_ACTION_ACTIVATE_SCENE = "scene" SCRIPT_ACTION_REPEAT = "repeat" SCRIPT_ACTION_CHOOSE = "choose" +SCRIPT_ACTION_WAIT_FOR_TRIGGER = "wait_for_trigger" def determine_script_action(action: dict) -> str: @@ -1111,6 +1122,9 @@ def determine_script_action(action: dict) -> str: if CONF_CHOOSE in action: return SCRIPT_ACTION_CHOOSE + if CONF_WAIT_FOR_TRIGGER in action: + return SCRIPT_ACTION_WAIT_FOR_TRIGGER + return SCRIPT_ACTION_CALL_SERVICE @@ -1124,4 +1138,5 @@ ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = { SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, SCRIPT_ACTION_REPEAT: _SCRIPT_REPEAT_SCHEMA, SCRIPT_ACTION_CHOOSE: _SCRIPT_CHOOSE_SCHEMA, + SCRIPT_ACTION_WAIT_FOR_TRIGGER: _SCRIPT_WAIT_FOR_TRIGGER_SCHEMA, } diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 9f415b10300..e45d00c91d4 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1,5 +1,6 @@ """Helpers to execute scripts.""" import asyncio +from copy import deepcopy from datetime import datetime, timedelta from functools import partial import itertools @@ -45,6 +46,7 @@ from homeassistant.const import ( CONF_SEQUENCE, CONF_TIMEOUT, CONF_UNTIL, + CONF_WAIT_FOR_TRIGGER, CONF_WAIT_TEMPLATE, CONF_WHILE, EVENT_HOMEASSISTANT_STOP, @@ -61,6 +63,10 @@ from homeassistant.helpers.service import ( CONF_SERVICE_DATA, async_prepare_call_from_config, ) +from homeassistant.helpers.trigger import ( + async_initialize_triggers, + async_validate_trigger_config, +) from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify from homeassistant.util.dt import utcnow @@ -123,7 +129,7 @@ async def async_validate_action_config( hass, config[CONF_DOMAIN], "action" ) config = platform.ACTION_SCHEMA(config) # type: ignore - if ( + elif ( action_type == cv.SCRIPT_ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device" ): @@ -131,6 +137,10 @@ async def async_validate_action_config( hass, config[CONF_DOMAIN], "condition" ) config = platform.CONDITION_SCHEMA(config) # type: ignore + elif action_type == cv.SCRIPT_ACTION_WAIT_FOR_TRIGGER: + config[CONF_WAIT_FOR_TRIGGER] = await async_validate_trigger_config( + hass, config[CONF_WAIT_FOR_TRIGGER] + ) return config @@ -539,6 +549,64 @@ class _ScriptRun: if choose_data["default"]: await self._async_run_script(choose_data["default"]) + async def _async_wait_for_trigger_step(self): + """Wait for a trigger event.""" + if CONF_TIMEOUT in self._action: + delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds() + else: + delay = None + + self._script.last_action = self._action.get(CONF_ALIAS, "wait for trigger") + self._log( + "Executing step %s%s", + self._script.last_action, + "" if delay is None else f" (timeout: {timedelta(seconds=delay)})", + ) + + variables = deepcopy(self._variables) + self._variables["wait"] = {"remaining": delay, "trigger": None} + + async def async_done(variables, context=None): + self._variables["wait"] = { + "remaining": to_context.remaining if to_context else delay, + "trigger": variables["trigger"], + } + done.set() + + def log_cb(level, msg): + self._log(msg, level=level) + + to_context = None + remove_triggers = await async_initialize_triggers( + self._hass, + self._action[CONF_WAIT_FOR_TRIGGER], + async_done, + self._script.domain, + self._script.name, + log_cb, + variables=variables, + ) + if not remove_triggers: + return + + self._changed() + done = asyncio.Event() + tasks = [ + self._hass.async_create_task(flag.wait()) for flag in (self._stop, done) + ] + try: + async with timeout(delay) as to_context: + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + except asyncio.TimeoutError: + if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): + self._log(_TIMEOUT_MSG) + raise _StopScript + self._variables["wait"]["remaining"] = 0.0 + finally: + for task in tasks: + task.cancel() + remove_triggers() + async def _async_run_script(self, script): """Execute a script.""" await self._async_run_long_action( diff --git a/homeassistant/helpers/trigger.py b/homeassistant/helpers/trigger.py index 9918ad37fe9..f9dd91dc2f1 100644 --- a/homeassistant/helpers/trigger.py +++ b/homeassistant/helpers/trigger.py @@ -1,7 +1,8 @@ """Triggers.""" import asyncio import logging -from typing import Any, Callable, List, Optional +from types import MappingProxyType +from typing import Any, Callable, Dict, List, Optional, Union import voluptuous as vol @@ -59,12 +60,14 @@ async def async_initialize_triggers( name: str, log_cb: Callable, home_assistant_start: bool = False, + variables: Optional[Union[Dict[str, Any], MappingProxyType]] = None, ) -> Optional[CALLBACK_TYPE]: """Initialize triggers.""" info = { "domain": domain, "name": name, "home_assistant_start": home_assistant_start, + "variables": variables, } triggers = [] diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index ec1b7ecb6e0..584c44916c7 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -185,7 +185,7 @@ invalid_configs = [ @pytest.mark.parametrize("value", invalid_configs) async def test_setup_with_invalid_configs(hass, value): """Test setup with invalid configs.""" - assert not await async_setup_component( + assert await async_setup_component( hass, "script", {"script": value} ), f"Script loaded with wrong config {value}" @@ -418,7 +418,12 @@ async def test_extraction_functions(hass): "service": "test.script", "data": {"entity_id": "light.in_first"}, }, - {"domain": "light", "device_id": "device-in-both"}, + { + "entity_id": "light.device_in_both", + "domain": "light", + "type": "turn_on", + "device_id": "device-in-both", + }, ] }, "test2": { @@ -433,8 +438,18 @@ async def test_extraction_functions(hass): "state": "100", }, {"scene": "scene.hello"}, - {"domain": "light", "device_id": "device-in-both"}, - {"domain": "light", "device_id": "device-in-last"}, + { + "entity_id": "light.device_in_both", + "domain": "light", + "type": "turn_on", + "device_id": "device-in-both", + }, + { + "entity_id": "light.device_in_last", + "domain": "light", + "type": "turn_on", + "device_id": "device-in-last", + }, ], }, } diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 305f11b0258..6223954ab4b 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -475,15 +475,20 @@ async def test_cancel_delay(hass): assert len(events) == 0 -async def test_wait_template_basic(hass): - """Test the wait template.""" +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_wait_basic(hass, action_type): + """Test wait actions.""" wait_alias = "wait step" - sequence = cv.SCRIPT_SCHEMA( - { - "wait_template": "{{ states.switch.test.state == 'off' }}", - "alias": wait_alias, + action = {"alias": wait_alias} + if action_type == "template": + action["wait_template"] = "{{ states.switch.test.state == 'off' }}" + else: + action["wait_for_trigger"] = { + "platform": "state", + "entity_id": "switch.test", + "to": "off", } - ) + sequence = cv.SCRIPT_SCHEMA(action) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, wait_alias) @@ -505,14 +510,25 @@ async def test_wait_template_basic(hass): assert script_obj.last_action is None -async def test_multiple_runs_wait_template(hass): - """Test multiple runs with wait_template in script.""" +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_multiple_runs_wait(hass, action_type): + """Test multiple runs with wait in script.""" event = "test_event" events = async_capture_events(hass, event) + if action_type == "template": + action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + else: + action = { + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + } + } sequence = cv.SCRIPT_SCHEMA( [ {"event": event, "event_data": {"value": 1}}, - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, + action, {"event": event, "event_data": {"value": 2}}, ] ) @@ -529,12 +545,15 @@ async def test_multiple_runs_wait_template(hass): assert script_obj.is_running assert len(events) == 1 assert events[-1].data["value"] == 1 + + # Start second run of script while first run is in wait_template. + wait_started_flag.clear() + hass.async_create_task(script_obj.async_run()) + await asyncio.wait_for(wait_started_flag.wait(), 1) except (AssertionError, asyncio.TimeoutError): await script_obj.async_stop() raise else: - # Start second run of script while first run is in wait_template. - hass.async_create_task(script_obj.async_run()) hass.states.async_set("switch.test", "off") await hass.async_block_till_done() @@ -545,16 +564,22 @@ async def test_multiple_runs_wait_template(hass): assert events[-1].data["value"] == 2 -async def test_cancel_wait_template(hass): - """Test the cancelling while wait_template is present.""" +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_cancel_wait(hass, action_type): + """Test the cancelling while wait is present.""" event = "test_event" events = async_capture_events(hass, event) - sequence = cv.SCRIPT_SCHEMA( - [ - {"wait_template": "{{ states.switch.test.state == 'off' }}"}, - {"event": event}, - ] - ) + if action_type == "template": + action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + else: + action = { + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + } + } + sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -606,20 +631,24 @@ async def test_wait_template_not_schedule(hass): @pytest.mark.parametrize( "timeout_param", [5, "{{ 5 }}", {"seconds": 5}, {"seconds": "{{ 5 }}"}] ) -async def test_wait_template_timeout(hass, caplog, timeout_param): +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_wait_timeout(hass, caplog, timeout_param, action_type): """Test the wait timeout option.""" event = "test_event" events = async_capture_events(hass, event) - sequence = cv.SCRIPT_SCHEMA( - [ - { - "wait_template": "{{ states.switch.test.state == 'off' }}", - "timeout": timeout_param, - "continue_on_timeout": True, - }, - {"event": event}, - ] - ) + if action_type == "template": + action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + else: + action = { + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + } + } + action["timeout"] = timeout_param + action["continue_on_timeout"] = True + sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -651,17 +680,27 @@ async def test_wait_template_timeout(hass, caplog, timeout_param): @pytest.mark.parametrize( "continue_on_timeout,n_events", [(False, 0), (True, 1), (None, 1)] ) -async def test_wait_template_continue_on_timeout(hass, continue_on_timeout, n_events): - """Test the wait template continue_on_timeout option.""" +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_wait_continue_on_timeout( + hass, continue_on_timeout, n_events, action_type +): + """Test the wait continue_on_timeout option.""" event = "test_event" events = async_capture_events(hass, event) - sequence = [ - {"wait_template": "{{ states.switch.test.state == 'off' }}", "timeout": 5}, - {"event": event}, - ] + if action_type == "template": + action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + else: + action = { + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + } + } + action["timeout"] = 5 if continue_on_timeout is not None: - sequence[0]["continue_on_timeout"] = continue_on_timeout - sequence = cv.SCRIPT_SCHEMA(sequence) + action["continue_on_timeout"] = continue_on_timeout + sequence = cv.SCRIPT_SCHEMA([action, {"event": event}]) script_obj = script.Script(hass, sequence, "Test Name", "test_domain") wait_started_flag = async_watch_for_action(script_obj, "wait") @@ -708,11 +747,23 @@ async def test_wait_template_variables_in(hass): @pytest.mark.parametrize("mode", ["no_timeout", "timeout_finish", "timeout_not_finish"]) -async def test_wait_template_variables_out(hass, mode): - """Test the wait template output variable.""" +@pytest.mark.parametrize("action_type", ["template", "trigger"]) +async def test_wait_variables_out(hass, mode, action_type): + """Test the wait output variable.""" event = "test_event" events = async_capture_events(hass, event) - action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + if action_type == "template": + action = {"wait_template": "{{ states.switch.test.state == 'off' }}"} + event_key = "completed" + else: + action = { + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + } + } + event_key = "trigger" if mode != "no_timeout": action["timeout"] = 5 action["continue_on_timeout"] = True @@ -721,7 +772,7 @@ async def test_wait_template_variables_out(hass, mode): { "event": event, "event_data_template": { - "completed": "{{ wait.completed }}", + event_key: f"{{{{ wait.{event_key} }}}}", "remaining": "{{ wait.remaining }}", }, }, @@ -749,7 +800,12 @@ async def test_wait_template_variables_out(hass, mode): assert not script_obj.is_running assert len(events) == 1 - assert events[0].data["completed"] == str(mode != "timeout_not_finish") + if action_type == "template": + assert events[0].data["completed"] == str(mode != "timeout_not_finish") + elif mode != "timeout_not_finish": + assert "'to_state':