Test that variables are passed to wait_for_trigger script action (#46221)

pull/46262/head
Erik Montnemery 2021-02-09 00:34:18 +01:00 committed by GitHub
parent c602c619a2
commit 58b4a91a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 1 deletions

View File

@ -12,6 +12,9 @@ import os
import pathlib
import threading
import time
from time import monotonic
import types
from typing import Any, Awaitable, Collection, Optional
from unittest.mock import AsyncMock, Mock, patch
import uuid
@ -43,7 +46,7 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import State
from homeassistant.core import BLOCK_LOG_TIMEOUT, State
from homeassistant.helpers import (
area_registry,
device_registry,
@ -190,9 +193,76 @@ async def async_test_home_assistant(loop):
return orig_async_create_task(coroutine)
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: Optional[float] = None
while len(self._pending_tasks) > max_remaining_tasks:
pending = [
task for task in self._pending_tasks if not task.done()
] # type: Collection[Awaitable[Any]]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)
if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)
async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0
return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED
while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
return []
hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

View File

@ -545,6 +545,49 @@ async def test_wait_basic(hass, action_type):
assert script_obj.last_action is None
async def test_wait_for_trigger_variables(hass):
"""Test variables are passed to wait_for_trigger action."""
context = Context()
wait_alias = "wait step"
actions = [
{
"alias": "variables",
"variables": {"seconds": 5},
},
{
"alias": wait_alias,
"wait_for_trigger": {
"platform": "state",
"entity_id": "switch.test",
"to": "off",
"for": {"seconds": "{{ seconds }}"},
},
},
]
sequence = cv.SCRIPT_SCHEMA(actions)
sequence = await script.async_validate_actions_config(hass, sequence)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
try:
hass.states.async_set("switch.test", "on")
hass.async_create_task(script_obj.async_run(context=context))
await asyncio.wait_for(wait_started_flag.wait(), 1)
assert script_obj.is_running
assert script_obj.last_action == wait_alias
hass.states.async_set("switch.test", "off")
# the script task + 2 tasks created by wait_for_trigger script step
await hass.async_wait_for_task_count(3)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
await hass.async_block_till_done()
except (AssertionError, asyncio.TimeoutError):
await script_obj.async_stop()
raise
else:
assert not script_obj.is_running
assert script_obj.last_action is None
@pytest.mark.parametrize("action_type", ["template", "trigger"])
async def test_wait_basic_times_out(hass, action_type):
"""Test wait actions times out when the action does not happen."""