From 58b4a91a5b4fcf03bb9236ad85a8383c6f4cb47a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 9 Feb 2021 00:34:18 +0100 Subject: [PATCH] Test that variables are passed to wait_for_trigger script action (#46221) --- tests/common.py | 72 +++++++++++++++++++++++++++++++++++- tests/helpers/test_script.py | 43 +++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/tests/common.py b/tests/common.py index 2621f2f4b15..ab5da25e38d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,6 +12,9 @@ import os import pathlib import threading import time +from time import monotonic +import types +from typing import Any, Awaitable, Collection, Optional from unittest.mock import AsyncMock, Mock, patch import uuid @@ -43,7 +46,7 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import State +from homeassistant.core import BLOCK_LOG_TIMEOUT, State from homeassistant.helpers import ( area_registry, device_registry, @@ -190,9 +193,76 @@ async def async_test_home_assistant(loop): return orig_async_create_task(coroutine) + async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None: + """Block until at most max_remaining_tasks remain. + + Based on HomeAssistant.async_block_till_done + """ + # To flush out any call_soon_threadsafe + await asyncio.sleep(0) + start_time: Optional[float] = None + + while len(self._pending_tasks) > max_remaining_tasks: + pending = [ + task for task in self._pending_tasks if not task.done() + ] # type: Collection[Awaitable[Any]] + self._pending_tasks.clear() + if len(pending) > max_remaining_tasks: + remaining_pending = await self._await_count_and_log_pending( + pending, max_remaining_tasks=max_remaining_tasks + ) + self._pending_tasks.extend(remaining_pending) + + if start_time is None: + # Avoid calling monotonic() until we know + # we may need to start logging blocked tasks. + start_time = 0 + elif start_time == 0: + # If we have waited twice then we set the start + # time + start_time = monotonic() + elif monotonic() - start_time > BLOCK_LOG_TIMEOUT: + # We have waited at least three loops and new tasks + # continue to block. At this point we start + # logging all waiting tasks. + for task in pending: + _LOGGER.debug("Waiting for task: %s", task) + else: + self._pending_tasks.extend(pending) + await asyncio.sleep(0) + + async def _await_count_and_log_pending( + self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0 + ) -> Collection[Awaitable[Any]]: + """Block at most max_remaining_tasks remain and log tasks that take a long time. + + Based on HomeAssistant._await_and_log_pending + """ + wait_time = 0 + + return_when = asyncio.ALL_COMPLETED + if max_remaining_tasks: + return_when = asyncio.FIRST_COMPLETED + + while len(pending) > max_remaining_tasks: + _, pending = await asyncio.wait( + pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when + ) + if not pending or max_remaining_tasks: + return pending + wait_time += BLOCK_LOG_TIMEOUT + for task in pending: + _LOGGER.debug("Waited %s seconds for task: %s", wait_time, task) + + return [] + hass.async_add_job = async_add_job hass.async_add_executor_job = async_add_executor_job hass.async_create_task = async_create_task + hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass) + hass._await_count_and_log_pending = types.MethodType( + _await_count_and_log_pending, hass + ) hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5cd9a9d2449..a22cf27acdc 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -545,6 +545,49 @@ async def test_wait_basic(hass, action_type): assert script_obj.last_action is None +async def test_wait_for_trigger_variables(hass): + """Test variables are passed to wait_for_trigger action.""" + context = Context() + wait_alias = "wait step" + actions = [ + { + "alias": "variables", + "variables": {"seconds": 5}, + }, + { + "alias": wait_alias, + "wait_for_trigger": { + "platform": "state", + "entity_id": "switch.test", + "to": "off", + "for": {"seconds": "{{ seconds }}"}, + }, + }, + ] + sequence = cv.SCRIPT_SCHEMA(actions) + sequence = await script.async_validate_actions_config(hass, sequence) + script_obj = script.Script(hass, sequence, "Test Name", "test_domain") + wait_started_flag = async_watch_for_action(script_obj, wait_alias) + + try: + hass.states.async_set("switch.test", "on") + hass.async_create_task(script_obj.async_run(context=context)) + await asyncio.wait_for(wait_started_flag.wait(), 1) + assert script_obj.is_running + assert script_obj.last_action == wait_alias + hass.states.async_set("switch.test", "off") + # the script task + 2 tasks created by wait_for_trigger script step + await hass.async_wait_for_task_count(3) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10)) + await hass.async_block_till_done() + except (AssertionError, asyncio.TimeoutError): + await script_obj.async_stop() + raise + else: + assert not script_obj.is_running + assert script_obj.last_action is None + + @pytest.mark.parametrize("action_type", ["template", "trigger"]) async def test_wait_basic_times_out(hass, action_type): """Test wait actions times out when the action does not happen."""