Test that variables are passed to wait_for_trigger script action (#46221)
parent
c602c619a2
commit
58b4a91a5b
|
@ -12,6 +12,9 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from time import monotonic
|
||||||
|
import types
|
||||||
|
from typing import Any, Awaitable, Collection, Optional
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
@ -43,7 +46,7 @@ from homeassistant.const import (
|
||||||
STATE_OFF,
|
STATE_OFF,
|
||||||
STATE_ON,
|
STATE_ON,
|
||||||
)
|
)
|
||||||
from homeassistant.core import State
|
from homeassistant.core import BLOCK_LOG_TIMEOUT, State
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry,
|
area_registry,
|
||||||
device_registry,
|
device_registry,
|
||||||
|
@ -190,9 +193,76 @@ async def async_test_home_assistant(loop):
|
||||||
|
|
||||||
return orig_async_create_task(coroutine)
|
return orig_async_create_task(coroutine)
|
||||||
|
|
||||||
|
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
|
||||||
|
"""Block until at most max_remaining_tasks remain.
|
||||||
|
|
||||||
|
Based on HomeAssistant.async_block_till_done
|
||||||
|
"""
|
||||||
|
# To flush out any call_soon_threadsafe
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
start_time: Optional[float] = None
|
||||||
|
|
||||||
|
while len(self._pending_tasks) > max_remaining_tasks:
|
||||||
|
pending = [
|
||||||
|
task for task in self._pending_tasks if not task.done()
|
||||||
|
] # type: Collection[Awaitable[Any]]
|
||||||
|
self._pending_tasks.clear()
|
||||||
|
if len(pending) > max_remaining_tasks:
|
||||||
|
remaining_pending = await self._await_count_and_log_pending(
|
||||||
|
pending, max_remaining_tasks=max_remaining_tasks
|
||||||
|
)
|
||||||
|
self._pending_tasks.extend(remaining_pending)
|
||||||
|
|
||||||
|
if start_time is None:
|
||||||
|
# Avoid calling monotonic() until we know
|
||||||
|
# we may need to start logging blocked tasks.
|
||||||
|
start_time = 0
|
||||||
|
elif start_time == 0:
|
||||||
|
# If we have waited twice then we set the start
|
||||||
|
# time
|
||||||
|
start_time = monotonic()
|
||||||
|
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
|
||||||
|
# We have waited at least three loops and new tasks
|
||||||
|
# continue to block. At this point we start
|
||||||
|
# logging all waiting tasks.
|
||||||
|
for task in pending:
|
||||||
|
_LOGGER.debug("Waiting for task: %s", task)
|
||||||
|
else:
|
||||||
|
self._pending_tasks.extend(pending)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def _await_count_and_log_pending(
|
||||||
|
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
|
||||||
|
) -> Collection[Awaitable[Any]]:
|
||||||
|
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
|
||||||
|
|
||||||
|
Based on HomeAssistant._await_and_log_pending
|
||||||
|
"""
|
||||||
|
wait_time = 0
|
||||||
|
|
||||||
|
return_when = asyncio.ALL_COMPLETED
|
||||||
|
if max_remaining_tasks:
|
||||||
|
return_when = asyncio.FIRST_COMPLETED
|
||||||
|
|
||||||
|
while len(pending) > max_remaining_tasks:
|
||||||
|
_, pending = await asyncio.wait(
|
||||||
|
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
|
||||||
|
)
|
||||||
|
if not pending or max_remaining_tasks:
|
||||||
|
return pending
|
||||||
|
wait_time += BLOCK_LOG_TIMEOUT
|
||||||
|
for task in pending:
|
||||||
|
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
hass.async_add_job = async_add_job
|
hass.async_add_job = async_add_job
|
||||||
hass.async_add_executor_job = async_add_executor_job
|
hass.async_add_executor_job = async_add_executor_job
|
||||||
hass.async_create_task = async_create_task
|
hass.async_create_task = async_create_task
|
||||||
|
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
|
||||||
|
hass._await_count_and_log_pending = types.MethodType(
|
||||||
|
_await_count_and_log_pending, hass
|
||||||
|
)
|
||||||
|
|
||||||
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
|
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
|
||||||
|
|
||||||
|
|
|
@ -545,6 +545,49 @@ async def test_wait_basic(hass, action_type):
|
||||||
assert script_obj.last_action is None
|
assert script_obj.last_action is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wait_for_trigger_variables(hass):
|
||||||
|
"""Test variables are passed to wait_for_trigger action."""
|
||||||
|
context = Context()
|
||||||
|
wait_alias = "wait step"
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"alias": "variables",
|
||||||
|
"variables": {"seconds": 5},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"alias": wait_alias,
|
||||||
|
"wait_for_trigger": {
|
||||||
|
"platform": "state",
|
||||||
|
"entity_id": "switch.test",
|
||||||
|
"to": "off",
|
||||||
|
"for": {"seconds": "{{ seconds }}"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
sequence = cv.SCRIPT_SCHEMA(actions)
|
||||||
|
sequence = await script.async_validate_actions_config(hass, sequence)
|
||||||
|
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
|
||||||
|
wait_started_flag = async_watch_for_action(script_obj, wait_alias)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hass.states.async_set("switch.test", "on")
|
||||||
|
hass.async_create_task(script_obj.async_run(context=context))
|
||||||
|
await asyncio.wait_for(wait_started_flag.wait(), 1)
|
||||||
|
assert script_obj.is_running
|
||||||
|
assert script_obj.last_action == wait_alias
|
||||||
|
hass.states.async_set("switch.test", "off")
|
||||||
|
# the script task + 2 tasks created by wait_for_trigger script step
|
||||||
|
await hass.async_wait_for_task_count(3)
|
||||||
|
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=10))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
except (AssertionError, asyncio.TimeoutError):
|
||||||
|
await script_obj.async_stop()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
assert not script_obj.is_running
|
||||||
|
assert script_obj.last_action is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("action_type", ["template", "trigger"])
|
@pytest.mark.parametrize("action_type", ["template", "trigger"])
|
||||||
async def test_wait_basic_times_out(hass, action_type):
|
async def test_wait_basic_times_out(hass, action_type):
|
||||||
"""Test wait actions times out when the action does not happen."""
|
"""Test wait actions times out when the action does not happen."""
|
||||||
|
|
Loading…
Reference in New Issue