Test that variables are passed to wait_for_trigger script action (#46221)

pull/46262/head
Erik Montnemery 2021-02-09 00:34:18 +01:00 committed by GitHub
parent c602c619a2
commit 58b4a91a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 1 deletions

View File

@ -12,6 +12,9 @@ import os
import pathlib import pathlib
import threading import threading
import time import time
from time import monotonic
import types
from typing import Any, Awaitable, Collection, Optional
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import uuid import uuid
@ -43,7 +46,7 @@ from homeassistant.const import (
STATE_OFF, STATE_OFF,
STATE_ON, STATE_ON,
) )
from homeassistant.core import State from homeassistant.core import BLOCK_LOG_TIMEOUT, State
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry, area_registry,
device_registry, device_registry,
@ -190,9 +193,76 @@ async def async_test_home_assistant(loop):
return orig_async_create_task(coroutine) return orig_async_create_task(coroutine)
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: Optional[float] = None
while len(self._pending_tasks) > max_remaining_tasks:
pending = [
task for task in self._pending_tasks if not task.done()
] # type: Collection[Awaitable[Any]]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)
if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)
async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0
return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED
while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
return []
hass.async_add_job = async_add_job hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {} hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

View File

@ -545,6 +545,49 @@ async def test_wait_basic(hass, action_type):
assert script_obj.last_action is None assert script_obj.last_action is None
async def test_wait_for_trigger_variables(hass):
"""Test variables are passed to wait_for_trigger action."""
context = Context()
wait_alias = "wait step"
actions = [
{
"alias": "variables",
"variables": {"seconds": 5},
},
{
"alias": wait_alias,
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
"for": {"seconds": "{{ seconds }}"},
},
},
]
sequence = cv.SCRIPT_SCHEMA(actions)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run(context=context))
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert script_obj.last_action == wait_alias
hass.states.async_set("switch.test", "off")
# the script task + 2 tasks created by wait_for_trigger script step
await hass.async_wait_for_task_count(3)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
await hass.async_block_till_done()
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
assert not script_obj.is_running
assert script_obj.last_action is None
@pytest.mark.parametrize("action_type", ["template", "trigger"]) @pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_basic_times_out(hass, action_type): async def test_wait_basic_times_out(hass, action_type):
"""Test wait actions times out when the action does not happen.""" """Test wait actions times out when the action does not happen."""