Refactor script helper actions into their own methods (#18962)
* Refactor script helper actions into their own methods * Lint * Lintpull/18741/head
parent
d0751ffd91
commit
d028236bf2
|
@ -9,7 +9,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.core import HomeAssistant, Context, callback
|
||||
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.helpers import (
|
||||
service, condition, template as template,
|
||||
config_validation as cv)
|
||||
|
@ -34,6 +34,30 @@ CONF_WAIT_TEMPLATE = 'wait_template'
|
|||
CONF_CONTINUE = 'continue_on_timeout'
|
||||
|
||||
|
||||
ACTION_DELAY = 'delay'
|
||||
ACTION_WAIT_TEMPLATE = 'wait_template'
|
||||
ACTION_CHECK_CONDITION = 'condition'
|
||||
ACTION_FIRE_EVENT = 'event'
|
||||
ACTION_CALL_SERVICE = 'call_service'
|
||||
|
||||
|
||||
def _determine_action(action):
|
||||
"""Determine action type."""
|
||||
if CONF_DELAY in action:
|
||||
return ACTION_DELAY
|
||||
|
||||
if CONF_WAIT_TEMPLATE in action:
|
||||
return ACTION_WAIT_TEMPLATE
|
||||
|
||||
if CONF_CONDITION in action:
|
||||
return ACTION_CHECK_CONDITION
|
||||
|
||||
if CONF_EVENT in action:
|
||||
return ACTION_FIRE_EVENT
|
||||
|
||||
return ACTION_CALL_SERVICE
|
||||
|
||||
|
||||
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||
variables: Optional[Sequence] = None,
|
||||
context: Optional[Context] = None) -> None:
|
||||
|
@ -41,6 +65,14 @@ def call_from_config(hass: HomeAssistant, config: ConfigType,
|
|||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
||||
|
||||
|
||||
class _StopScript(Exception):
|
||||
"""Throw if script needs to stop."""
|
||||
|
||||
|
||||
class _SuspendScript(Exception):
|
||||
"""Throw if script needs to suspend."""
|
||||
|
||||
|
||||
class Script():
|
||||
"""Representation of a script."""
|
||||
|
||||
|
@ -60,6 +92,13 @@ class Script():
|
|||
self._async_listener = []
|
||||
self._template_cache = {}
|
||||
self._config_cache = {}
|
||||
self._actions = {
|
||||
ACTION_DELAY: self._async_delay,
|
||||
ACTION_WAIT_TEMPLATE: self._async_wait_template,
|
||||
ACTION_CHECK_CONDITION: self._async_check_condition,
|
||||
ACTION_FIRE_EVENT: self._async_fire_event,
|
||||
ACTION_CALL_SERVICE: self._async_call_service,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
|
@ -87,98 +126,27 @@ class Script():
|
|||
self._async_remove_listener()
|
||||
|
||||
for cur, action in islice(enumerate(self.sequence), self._cur, None):
|
||||
|
||||
if CONF_DELAY in action:
|
||||
# Call ourselves in the future to continue work
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Handle delay."""
|
||||
# pylint: disable=cell-var-from-loop
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
try:
|
||||
if isinstance(delay, template.Template):
|
||||
delay = vol.All(
|
||||
cv.time_period,
|
||||
cv.positive_timedelta)(
|
||||
delay.async_render(variables))
|
||||
elif isinstance(delay, dict):
|
||||
delay_data = {}
|
||||
delay_data.update(
|
||||
template.render_complex(delay, variables))
|
||||
delay = cv.time_period(delay_data)
|
||||
except (TemplateError, vol.Invalid) as ex:
|
||||
_LOGGER.error("Error rendering '%s' delay template: %s",
|
||||
self.name, ex)
|
||||
break
|
||||
|
||||
self.last_action = action.get(
|
||||
CONF_ALIAS, 'delay {}'.format(delay))
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_delay,
|
||||
date_util.utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
|
||||
try:
|
||||
await self._handle_action(action, variables, context)
|
||||
except _SuspendScript:
|
||||
# Store next step to take and notify change listeners
|
||||
self._cur = cur + 1
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
return
|
||||
except _StopScript:
|
||||
break
|
||||
except Exception as err:
|
||||
# Store the step that had an exception
|
||||
# pylint: disable=protected-access
|
||||
err._script_step = cur
|
||||
# Set script to not running
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
# Pass exception on.
|
||||
raise
|
||||
|
||||
if CONF_WAIT_TEMPLATE in action:
|
||||
# Call ourselves in the future to continue work
|
||||
wait_template = action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self.hass
|
||||
|
||||
self.last_action = action.get(CONF_ALIAS, 'wait template')
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
# check if condition already okay
|
||||
if condition.async_template(
|
||||
self.hass, wait_template, variables):
|
||||
continue
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
self._async_listener.append(async_track_template(
|
||||
self.hass, wait_template, async_script_wait, variables))
|
||||
|
||||
self._cur = cur + 1
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
if CONF_TIMEOUT in action:
|
||||
self._async_set_timeout(
|
||||
action, variables, context,
|
||||
action.get(CONF_CONTINUE, True))
|
||||
|
||||
return
|
||||
|
||||
if CONF_CONDITION in action:
|
||||
if not self._async_check_condition(action, variables):
|
||||
break
|
||||
|
||||
elif CONF_EVENT in action:
|
||||
self._async_fire_event(action, variables, context)
|
||||
|
||||
else:
|
||||
await self._async_call_service(action, variables, context)
|
||||
|
||||
# Set script to not-running.
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
if self._change_listener:
|
||||
|
@ -198,6 +166,86 @@ class Script():
|
|||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
async def _handle_action(self, action, variables, context):
|
||||
"""Handle an action."""
|
||||
await self._actions[_determine_action(action)](
|
||||
action, variables, context)
|
||||
|
||||
async def _async_delay(self, action, variables, context):
|
||||
"""Handle delay."""
|
||||
# Call ourselves in the future to continue work
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Handle delay."""
|
||||
# pylint: disable=cell-var-from-loop
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
try:
|
||||
if isinstance(delay, template.Template):
|
||||
delay = vol.All(
|
||||
cv.time_period,
|
||||
cv.positive_timedelta)(
|
||||
delay.async_render(variables))
|
||||
elif isinstance(delay, dict):
|
||||
delay_data = {}
|
||||
delay_data.update(
|
||||
template.render_complex(delay, variables))
|
||||
delay = cv.time_period(delay_data)
|
||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||
_LOGGER.error("Error rendering '%s' delay template: %s",
|
||||
self.name, ex)
|
||||
raise _StopScript
|
||||
|
||||
self.last_action = action.get(
|
||||
CONF_ALIAS, 'delay {}'.format(delay))
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_delay,
|
||||
date_util.utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_wait_template(self, action, variables, context):
|
||||
"""Handle a wait template."""
|
||||
# Call ourselves in the future to continue work
|
||||
wait_template = action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self.hass
|
||||
|
||||
self.last_action = action.get(CONF_ALIAS, 'wait template')
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
# check if condition already okay
|
||||
if condition.async_template(
|
||||
self.hass, wait_template, variables):
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
self._async_listener.append(async_track_template(
|
||||
self.hass, wait_template, async_script_wait, variables))
|
||||
|
||||
if CONF_TIMEOUT in action:
|
||||
self._async_set_timeout(
|
||||
action, variables, context,
|
||||
action.get(CONF_CONTINUE, True))
|
||||
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_call_service(self, action, variables, context):
|
||||
"""Call the service specified in the action.
|
||||
|
||||
|
@ -213,7 +261,7 @@ class Script():
|
|||
context=context
|
||||
)
|
||||
|
||||
def _async_fire_event(self, action, variables, context):
|
||||
async def _async_fire_event(self, action, variables, context):
|
||||
"""Fire an event."""
|
||||
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
@ -222,13 +270,13 @@ class Script():
|
|||
try:
|
||||
event_data.update(template.render_complex(
|
||||
action[CONF_EVENT_DATA_TEMPLATE], variables))
|
||||
except TemplateError as ex:
|
||||
except exceptions.TemplateError as ex:
|
||||
_LOGGER.error('Error rendering event data template: %s', ex)
|
||||
|
||||
self.hass.bus.async_fire(action[CONF_EVENT],
|
||||
event_data, context=context)
|
||||
|
||||
def _async_check_condition(self, action, variables):
|
||||
async def _async_check_condition(self, action, variables, context):
|
||||
"""Test if condition is matching."""
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
|
||||
config = self._config_cache.get(config_cache_key)
|
||||
|
@ -239,7 +287,9 @@ class Script():
|
|||
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
|
||||
check = config(self.hass, variables)
|
||||
self._log("Test condition {}: {}".format(self.last_action, check))
|
||||
return check
|
||||
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
def _async_set_timeout(self, action, variables, context,
|
||||
continue_on_timeout):
|
||||
|
|
|
@ -4,6 +4,10 @@ from datetime import timedelta
|
|||
from unittest import mock
|
||||
import unittest
|
||||
|
||||
import voluptuous as vol
|
||||
import pytest
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.core import Context, callback
|
||||
# Otherwise can't test just this file (import order issue)
|
||||
import homeassistant.components # noqa
|
||||
|
@ -774,3 +778,84 @@ class TestScriptHelper(unittest.TestCase):
|
|||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.last_triggered == time
|
||||
|
||||
|
||||
async def test_propagate_error_service_not_found(hass):
|
||||
"""Test that a script aborts when a service is not found."""
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen('test_event', record_event)
|
||||
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||
{'service': 'test.script'},
|
||||
{'event': 'test_event'}]))
|
||||
|
||||
with pytest.raises(exceptions.ServiceNotFound):
|
||||
await script_obj.async_run()
|
||||
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_propagate_error_invalid_service_data(hass):
|
||||
"""Test that a script aborts when we send invalid service data."""
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen('test_event', record_event)
|
||||
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def record_call(service):
|
||||
"""Add recorded event to set."""
|
||||
calls.append(service)
|
||||
|
||||
hass.services.async_register('test', 'script', record_call,
|
||||
schema=vol.Schema({'text': str}))
|
||||
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||
{'service': 'test.script', 'data': {'text': 1}},
|
||||
{'event': 'test_event'}]))
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
await script_obj.async_run()
|
||||
|
||||
assert len(events) == 0
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
async def test_propagate_error_service_exception(hass):
|
||||
"""Test that a script aborts when a service throws an exception."""
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen('test_event', record_event)
|
||||
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def record_call(service):
|
||||
"""Add recorded event to set."""
|
||||
raise ValueError("BROKEN")
|
||||
|
||||
hass.services.async_register('test', 'script', record_call)
|
||||
|
||||
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
|
||||
{'service': 'test.script'},
|
||||
{'event': 'test_event'}]))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await script_obj.async_run()
|
||||
|
||||
assert len(events) == 0
|
||||
assert len(calls) == 0
|
||||
|
|
Loading…
Reference in New Issue