Refactor script helper actions into their own methods (#18962)

* Refactor script helper actions into their own methods

* Lint

* Lint
pull/18741/head
Paulus Schoutsen 2018-12-03 15:46:25 +01:00 committed by GitHub
parent d0751ffd91
commit d028236bf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 227 additions and 92 deletions

View File

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, Context, callback
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
from homeassistant.exceptions import TemplateError
from homeassistant import exceptions
from homeassistant.helpers import (
service, condition, template as template,
config_validation as cv)
@ -34,6 +34,30 @@ CONF_WAIT_TEMPLATE = 'wait_template'
CONF_CONTINUE = 'continue_on_timeout'
ACTION_DELAY = 'delay'
ACTION_WAIT_TEMPLATE = 'wait_template'
ACTION_CHECK_CONDITION = 'condition'
ACTION_FIRE_EVENT = 'event'
ACTION_CALL_SERVICE = 'call_service'
def _determine_action(action):
"""Determine action type."""
if CONF_DELAY in action:
return ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return ACTION_FIRE_EVENT
return ACTION_CALL_SERVICE
def call_from_config(hass: HomeAssistant, config: ConfigType,
variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None:
@ -41,6 +65,14 @@ def call_from_config(hass: HomeAssistant, config: ConfigType,
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
class _StopScript(Exception):
"""Throw if script needs to stop."""
class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class Script():
"""Representation of a script."""
@ -60,6 +92,13 @@ class Script():
self._async_listener = []
self._template_cache = {}
self._config_cache = {}
self._actions = {
ACTION_DELAY: self._async_delay,
ACTION_WAIT_TEMPLATE: self._async_wait_template,
ACTION_CHECK_CONDITION: self._async_check_condition,
ACTION_FIRE_EVENT: self._async_fire_event,
ACTION_CALL_SERVICE: self._async_call_service,
}
@property
def is_running(self) -> bool:
@ -87,98 +126,27 @@ class Script():
self._async_remove_listener()
for cur, action in islice(enumerate(self.sequence), self._cur, None):
if CONF_DELAY in action:
# Call ourselves in the future to continue work
unsub = None
@callback
def async_script_delay(now):
"""Handle delay."""
# pylint: disable=cell-var-from-loop
with suppress(ValueError):
self._async_listener.remove(unsub)
self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY]
try:
if isinstance(delay, template.Template):
delay = vol.All(
cv.time_period,
cv.positive_timedelta)(
delay.async_render(variables))
elif isinstance(delay, dict):
delay_data = {}
delay_data.update(
template.render_complex(delay, variables))
delay = cv.time_period(delay_data)
except (TemplateError, vol.Invalid) as ex:
_LOGGER.error("Error rendering '%s' delay template: %s",
self.name, ex)
break
self.last_action = action.get(
CONF_ALIAS, 'delay {}'.format(delay))
self._log("Executing step %s" % self.last_action)
unsub = async_track_point_in_utc_time(
self.hass, async_script_delay,
date_util.utcnow() + delay
)
self._async_listener.append(unsub)
try:
await self._handle_action(action, variables, context)
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = cur + 1
if self._change_listener:
self.hass.async_add_job(self._change_listener)
return
except _StopScript:
break
except Exception as err:
# Store the step that had an exception
# pylint: disable=protected-access
err._script_step = cur
# Set script to not running
self._cur = -1
self.last_action = None
# Pass exception on.
raise
if CONF_WAIT_TEMPLATE in action:
# Call ourselves in the future to continue work
wait_template = action[CONF_WAIT_TEMPLATE]
wait_template.hass = self.hass
self.last_action = action.get(CONF_ALIAS, 'wait template')
self._log("Executing step %s" % self.last_action)
# check if condition already okay
if condition.async_template(
self.hass, wait_template, variables):
continue
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables))
self._cur = cur + 1
if self._change_listener:
self.hass.async_add_job(self._change_listener)
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, context,
action.get(CONF_CONTINUE, True))
return
if CONF_CONDITION in action:
if not self._async_check_condition(action, variables):
break
elif CONF_EVENT in action:
self._async_fire_event(action, variables, context)
else:
await self._async_call_service(action, variables, context)
# Set script to not-running.
self._cur = -1
self.last_action = None
if self._change_listener:
@ -198,6 +166,86 @@ class Script():
if self._change_listener:
self.hass.async_add_job(self._change_listener)
async def _handle_action(self, action, variables, context):
"""Handle an action."""
await self._actions[_determine_action(action)](
action, variables, context)
async def _async_delay(self, action, variables, context):
"""Handle delay."""
# Call ourselves in the future to continue work
unsub = None
@callback
def async_script_delay(now):
"""Handle delay."""
# pylint: disable=cell-var-from-loop
with suppress(ValueError):
self._async_listener.remove(unsub)
self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY]
try:
if isinstance(delay, template.Template):
delay = vol.All(
cv.time_period,
cv.positive_timedelta)(
delay.async_render(variables))
elif isinstance(delay, dict):
delay_data = {}
delay_data.update(
template.render_complex(delay, variables))
delay = cv.time_period(delay_data)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error("Error rendering '%s' delay template: %s",
self.name, ex)
raise _StopScript
self.last_action = action.get(
CONF_ALIAS, 'delay {}'.format(delay))
self._log("Executing step %s" % self.last_action)
unsub = async_track_point_in_utc_time(
self.hass, async_script_delay,
date_util.utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template(self, action, variables, context):
"""Handle a wait template."""
# Call ourselves in the future to continue work
wait_template = action[CONF_WAIT_TEMPLATE]
wait_template.hass = self.hass
self.last_action = action.get(CONF_ALIAS, 'wait template')
self._log("Executing step %s" % self.last_action)
# check if condition already okay
if condition.async_template(
self.hass, wait_template, variables):
return
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables))
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, context,
action.get(CONF_CONTINUE, True))
raise _SuspendScript
async def _async_call_service(self, action, variables, context):
"""Call the service specified in the action.
@ -213,7 +261,7 @@ class Script():
context=context
)
def _async_fire_event(self, action, variables, context):
async def _async_fire_event(self, action, variables, context):
"""Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action)
@ -222,13 +270,13 @@ class Script():
try:
event_data.update(template.render_complex(
action[CONF_EVENT_DATA_TEMPLATE], variables))
except TemplateError as ex:
except exceptions.TemplateError as ex:
_LOGGER.error('Error rendering event data template: %s', ex)
self.hass.bus.async_fire(action[CONF_EVENT],
event_data, context=context)
def _async_check_condition(self, action, variables):
async def _async_check_condition(self, action, variables, context):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
config = self._config_cache.get(config_cache_key)
@ -239,7 +287,9 @@ class Script():
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
check = config(self.hass, variables)
self._log("Test condition {}: {}".format(self.last_action, check))
return check
if not check:
raise _StopScript
def _async_set_timeout(self, action, variables, context,
continue_on_timeout):

View File

@ -4,6 +4,10 @@ from datetime import timedelta
from unittest import mock
import unittest
import voluptuous as vol
import pytest
from homeassistant import exceptions
from homeassistant.core import Context, callback
# Otherwise can't test just this file (import order issue)
import homeassistant.components # noqa
@ -774,3 +778,84 @@ class TestScriptHelper(unittest.TestCase):
self.hass.block_till_done()
assert script_obj.last_triggered == time
async def test_propagate_error_service_not_found(hass):
"""Test that a script aborts when a service is not found."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script'},
{'event': 'test_event'}]))
with pytest.raises(exceptions.ServiceNotFound):
await script_obj.async_run()
assert len(events) == 0
async def test_propagate_error_invalid_service_data(hass):
"""Test that a script aborts when we send invalid service data."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
calls = []
@callback
def record_call(service):
"""Add recorded event to set."""
calls.append(service)
hass.services.async_register('test', 'script', record_call,
schema=vol.Schema({'text': str}))
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script', 'data': {'text': 1}},
{'event': 'test_event'}]))
with pytest.raises(vol.Invalid):
await script_obj.async_run()
assert len(events) == 0
assert len(calls) == 0
async def test_propagate_error_service_exception(hass):
"""Test that a script aborts when a service throws an exception."""
events = []
@callback
def record_event(event):
events.append(event)
hass.bus.async_listen('test_event', record_event)
calls = []
@callback
def record_call(service):
"""Add recorded event to set."""
raise ValueError("BROKEN")
hass.services.async_register('test', 'script', record_call)
script_obj = script.Script(hass, cv.SCRIPT_SCHEMA([
{'service': 'test.script'},
{'event': 'test_event'}]))
with pytest.raises(ValueError):
await script_obj.async_run()
assert len(events) == 0
assert len(calls) == 0