diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 802debbe63e..3183dab0803 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -92,7 +92,7 @@ def trigger(hass, config, action): def state_for_listener(now): """Fire on state changes after a delay and calls action.""" hass.bus.remove_listener( - EVENT_STATE_CHANGED, attached_state_for_cancel_listener) + EVENT_STATE_CHANGED, attached_state_for_cancel) call_action() def state_for_cancel_listener(entity, inner_from_s, inner_to_s): @@ -102,12 +102,12 @@ def trigger(hass, config, action): hass.bus.remove_listener(EVENT_TIME_CHANGED, attached_state_for_listener) hass.bus.remove_listener(EVENT_STATE_CHANGED, - attached_state_for_cancel_listener) + attached_state_for_cancel) attached_state_for_listener = track_point_in_time( hass, state_for_listener, dt_util.utcnow() + time_delta) - attached_state_for_cancel_listener = track_state_change( + attached_state_for_cancel = track_state_change( hass, entity_id, state_for_cancel_listener) track_state_change( diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index c9db88a83c2..7de43d7f5e3 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -35,7 +35,7 @@ _SUN_EVENT = vol.All(vol.Lower, vol.Any(EVENT_SUNRISE, EVENT_SUNSET)) TRIGGER_SCHEMA = vol.Schema({ vol.Required(CONF_PLATFORM): 'sun', vol.Required(CONF_EVENT): _SUN_EVENT, - vol.Required(CONF_OFFSET, default=timedelta(0)): cv.time_offset, + vol.Required(CONF_OFFSET, default=timedelta(0)): cv.time_period, }) IF_ACTION_SCHEMA = vol.All( @@ -43,8 +43,8 @@ IF_ACTION_SCHEMA = vol.All( vol.Required(CONF_PLATFORM): 'sun', CONF_BEFORE: _SUN_EVENT, CONF_AFTER: _SUN_EVENT, - vol.Required(CONF_BEFORE_OFFSET, default=timedelta(0)): cv.time_offset, - vol.Required(CONF_AFTER_OFFSET, default=timedelta(0)): cv.time_offset, + vol.Required(CONF_BEFORE_OFFSET, default=timedelta(0)): cv.time_period, + vol.Required(CONF_AFTER_OFFSET, default=timedelta(0)): cv.time_period, }), cv.has_at_least_one_key(CONF_BEFORE, CONF_AFTER), ) diff --git a/homeassistant/components/script.py b/homeassistant/components/script.py index c19e614f19d..3557179c6eb 100644 --- a/homeassistant/components/script.py +++ b/homeassistant/components/script.py @@ -8,101 +8,33 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/script/ """ import logging -import threading -from datetime import timedelta -from itertools import islice import voluptuous as vol -import homeassistant.util.dt as date_util from homeassistant.const import ( - ATTR_ENTITY_ID, EVENT_TIME_CHANGED, SERVICE_TURN_OFF, SERVICE_TURN_ON, - SERVICE_TOGGLE, STATE_ON) + ATTR_ENTITY_ID, SERVICE_TURN_OFF, SERVICE_TURN_ON, + SERVICE_TOGGLE, STATE_ON, CONF_ALIAS) from homeassistant.helpers.entity import ToggleEntity, split_entity_id from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.event import track_point_in_utc_time -from homeassistant.helpers.service import (call_from_config, - validate_service_call) import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.script import Script + DOMAIN = "script" ENTITY_ID_FORMAT = DOMAIN + '.{}' DEPENDENCIES = ["group"] -STATE_NOT_RUNNING = 'Not Running' - -CONF_ALIAS = "alias" -CONF_SERVICE = "service" -CONF_SERVICE_DATA = "data" CONF_SEQUENCE = "sequence" -CONF_EVENT = "event" -CONF_EVENT_DATA = "event_data" -CONF_DELAY = "delay" ATTR_LAST_ACTION = 'last_action' ATTR_CAN_CANCEL = 'can_cancel' _LOGGER = logging.getLogger(__name__) -_ALIAS_VALIDATOR = vol.Schema(cv.string) - - -def _alias_stripper(validator): - """Strip alias from object for validation.""" - def validate(value): - """Validate without alias value.""" - value = value.copy() - alias = value.pop(CONF_ALIAS, None) - - if alias is not None: - alias = _ALIAS_VALIDATOR(alias) - - value = validator(value) - - if alias is not None: - value[CONF_ALIAS] = alias - - return value - - return validate - - -_TIMESPEC = vol.Schema({ - 'days': cv.positive_int, - 'hours': cv.positive_int, - 'minutes': cv.positive_int, - 'seconds': cv.positive_int, - 'milliseconds': cv.positive_int, -}) -_TIMESPEC_REQ = cv.has_at_least_one_key( - 'days', 'hours', 'minutes', 'seconds', 'milliseconds', -) - -_DELAY_SCHEMA = vol.Any( - vol.Schema({ - vol.Required(CONF_DELAY): vol.All(_TIMESPEC.extend({ - vol.Optional(CONF_ALIAS): cv.string - }), _TIMESPEC_REQ) - }), - # Alternative format in case people forgot to indent after 'delay:' - vol.All(_TIMESPEC.extend({ - vol.Required(CONF_DELAY): None, - vol.Optional(CONF_ALIAS): cv.string, - }), _TIMESPEC_REQ) -) - -_EVENT_SCHEMA = cv.EVENT_SCHEMA.extend({ - CONF_ALIAS: cv.string, -}) _SCRIPT_ENTRY_SCHEMA = vol.Schema({ CONF_ALIAS: cv.string, - vol.Required(CONF_SEQUENCE): vol.All(vol.Length(min=1), [vol.Any( - _EVENT_SCHEMA, - _DELAY_SCHEMA, - # Can't extend SERVICE_SCHEMA because it is an vol.All - _alias_stripper(cv.SERVICE_SCHEMA), - )]), + vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA, }) CONFIG_SCHEMA = vol.Schema({ @@ -152,7 +84,7 @@ def setup(hass, config): for object_id, cfg in config[DOMAIN].items(): alias = cfg.get(CONF_ALIAS, object_id) - script = Script(object_id, alias, cfg[CONF_SEQUENCE]) + script = ScriptEntity(hass, object_id, alias, cfg[CONF_SEQUENCE]) component.add_entities((script,)) hass.services.register(DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA) @@ -183,21 +115,14 @@ def setup(hass, config): return True -class Script(ToggleEntity): - """Representation of a script.""" +class ScriptEntity(ToggleEntity): + """Representation of a script entity.""" # pylint: disable=too-many-instance-attributes - def __init__(self, object_id, name, sequence): + def __init__(self, hass, object_id, name, sequence): """Initialize the script.""" self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self._name = name - self.sequence = sequence - self._lock = threading.Lock() - self._cur = -1 - self._last_action = None - self._listener = None - self._can_cancel = any(CONF_DELAY in action for action - in self.sequence) + self.script = Script(hass, sequence, name, self.update_ha_state) @property def should_poll(self): @@ -207,91 +132,27 @@ class Script(ToggleEntity): @property def name(self): """Return the name of the entity.""" - return self._name + return self.script.name @property def state_attributes(self): """Return the state attributes.""" attrs = {} - if self._can_cancel: - attrs[ATTR_CAN_CANCEL] = self._can_cancel - if self._last_action: - attrs[ATTR_LAST_ACTION] = self._last_action + if self.script.can_cancel: + attrs[ATTR_CAN_CANCEL] = self.script.can_cancel + if self.script.last_action: + attrs[ATTR_LAST_ACTION] = self.script.last_action return attrs @property def is_on(self): """Return true if script is on.""" - return self._cur != -1 + return self.script.is_running def turn_on(self, **kwargs): """Turn the entity on.""" - _LOGGER.info("Executing script %s", self._name) - with self._lock: - if self._cur == -1: - self._cur = 0 - - # Unregister callback if we were in a delay but turn on is called - # again. In that case we just continue execution. - self._remove_listener() - - for cur, action in islice(enumerate(self.sequence), self._cur, - None): - - if validate_service_call(action) is None: - self._call_service(action) - - elif CONF_EVENT in action: - self._fire_event(action) - - elif CONF_DELAY in action: - # Call ourselves in the future to continue work - def script_delay(now): - """Called after delay is done.""" - self._listener = None - self.turn_on() - - timespec = action[CONF_DELAY] or action.copy() - timespec.pop(CONF_DELAY, None) - delay = timedelta(**timespec) - self._listener = track_point_in_utc_time( - self.hass, script_delay, date_util.utcnow() + delay) - self._cur = cur + 1 - self.update_ha_state() - return - - self._cur = -1 - self._last_action = None - self.update_ha_state() + self.script.run() def turn_off(self, **kwargs): """Turn script off.""" - _LOGGER.info("Cancelled script %s", self._name) - with self._lock: - if self._cur == -1: - return - - self._cur = -1 - self.update_ha_state() - self._remove_listener() - - def _call_service(self, action): - """Call the service specified in the action.""" - self._last_action = action.get(CONF_ALIAS, 'call service') - _LOGGER.info("Executing script %s step %s", self._name, - self._last_action) - call_from_config(self.hass, action, True) - - def _fire_event(self, action): - """Fire an event.""" - self._last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) - _LOGGER.info("Executing script %s step %s", self._name, - self._last_action) - self.hass.bus.fire(action[CONF_EVENT], action.get(CONF_EVENT_DATA)) - - def _remove_listener(self): - """Remove point in time listener, if any.""" - if self._listener: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, - self._listener) - self._listener = None + self.script.stop() diff --git a/homeassistant/const.py b/homeassistant/const.py index 77e540cd76f..b2971ab59f6 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -13,6 +13,7 @@ MATCH_ALL = '*' DEVICE_DEFAULT_NAME = "Unnamed Device" # #### CONFIG #### +CONF_ALIAS = "alias" CONF_ICON = "icon" CONF_LATITUDE = "latitude" CONF_LONGITUDE = "longitude" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 51684e5f1cd..71e103f7dd3 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -6,7 +6,8 @@ import voluptuous as vol from homeassistant.loader import get_platform from homeassistant.const import ( - CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELSIUS, TEMP_FAHRENHEIT) + CONF_PLATFORM, CONF_SCAN_INTERVAL, TEMP_CELSIUS, TEMP_FAHRENHEIT, + CONF_ALIAS) from homeassistant.helpers.entity import valid_entity_id import homeassistant.util.dt as dt_util from homeassistant.util import slugify @@ -23,6 +24,23 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180), msg='invalid longitude') +# Adapted from: +# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 +def has_at_least_one_key(*keys): + """Validator that at least one key exists.""" + def validate(obj): + """Test keys exist in dict.""" + if not isinstance(obj, dict): + raise vol.Invalid('expected dictionary') + + for k in obj.keys(): + if k in keys: + return obj + raise vol.Invalid('must contain one of {}.'.format(', '.join(keys))) + + return validate + + def boolean(value): """Validate and coerce a boolean value.""" if isinstance(value, str): @@ -72,10 +90,24 @@ def icon(value): raise vol.Invalid('Icons should start with prefix "mdi:"') -def time_offset(value): +time_period_dict = vol.All( + dict, vol.Schema({ + 'days': vol.Coerce(int), + 'hours': vol.Coerce(int), + 'minutes': vol.Coerce(int), + 'seconds': vol.Coerce(int), + 'milliseconds': vol.Coerce(int), + }), + has_at_least_one_key('days', 'hours', 'minutes', + 'seconds', 'milliseconds'), + lambda value: timedelta(**value)) + + +def time_period_str(value): """Validate and transform time offset.""" if not isinstance(value, str): - raise vol.Invalid('offset should be a string') + raise vol.Invalid( + 'offset {} should be format HH:MM or HH:MM:SS'.format(value)) negative_offset = False if value.startswith('-'): @@ -107,6 +139,9 @@ def time_offset(value): return offset +time_period = vol.Any(time_period_str, timedelta, time_period_dict) + + def match_all(value): """Validator that matches all values.""" return value @@ -125,6 +160,13 @@ def platform_validator(domain): return validator +def positive_timedelta(value): + """Validate timedelta is positive.""" + if value < timedelta(0): + raise vol.Invalid('Time period should be positive') + return value + + def service(value): """Validate service.""" # Services use same format as entities so we can use same helper. @@ -200,23 +242,6 @@ def key_dependency(key, dependency): return validator -# Adapted from: -# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 -def has_at_least_one_key(*keys): - """Validator that at least one key exists.""" - def validate(obj): - """Test keys exist in dict.""" - if not isinstance(obj, dict): - raise vol.Invalid('expected dictionary') - - for k in obj.keys(): - if k in keys: - return obj - raise vol.Invalid('must contain one of {}.'.format(', '.join(keys))) - - return validate - - # Schemas PLATFORM_SCHEMA = vol.Schema({ @@ -225,14 +250,28 @@ PLATFORM_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) EVENT_SCHEMA = vol.Schema({ + vol.Optional(CONF_ALIAS): string, vol.Required('event'): string, - 'event_data': dict + vol.Optional('event_data'): dict, }) SERVICE_SCHEMA = vol.All(vol.Schema({ + vol.Optional(CONF_ALIAS): string, vol.Exclusive('service', 'service name'): service, vol.Exclusive('service_template', 'service name'): template, vol.Optional('data'): dict, vol.Optional('data_template'): {match_all: template}, vol.Optional('entity_id'): entity_ids, }), has_at_least_one_key('service', 'service_template')) + +# ----- SCRIPT + +_DELAY_SCHEMA = vol.Schema({ + vol.Optional(CONF_ALIAS): string, + vol.Required("delay"): vol.All(time_period, positive_timedelta) +}) + +SCRIPT_SCHEMA = vol.All( + ensure_list, + [vol.Any(SERVICE_SCHEMA, _DELAY_SCHEMA, EVENT_SCHEMA)], +) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py new file mode 100644 index 00000000000..e4cf2f6756d --- /dev/null +++ b/homeassistant/helpers/script.py @@ -0,0 +1,125 @@ +"""Helpers to execute scripts.""" +import logging +import threading +from itertools import islice + +import homeassistant.util.dt as date_util +from homeassistant.const import EVENT_TIME_CHANGED +from homeassistant.helpers.event import track_point_in_utc_time +from homeassistant.helpers import service +import homeassistant.helpers.config_validation as cv + +_LOGGER = logging.getLogger(__name__) + +CONF_ALIAS = "alias" +CONF_SERVICE = "service" +CONF_SERVICE_DATA = "data" +CONF_SEQUENCE = "sequence" +CONF_EVENT = "event" +CONF_EVENT_DATA = "event_data" +CONF_DELAY = "delay" + + +def call_from_config(hass, config): + """Call a script based on a config entry.""" + Script(hass, config).run() + + +class Script(): + """Representation of a script.""" + + # pylint: disable=too-many-instance-attributes + def __init__(self, hass, sequence, name=None, change_listener=None): + """Initialize the script.""" + self.hass = hass + self.sequence = cv.SCRIPT_SCHEMA(sequence) + self.name = name + self._change_listener = change_listener + self._cur = -1 + self.last_action = None + self.can_cancel = any(CONF_DELAY in action for action + in self.sequence) + self._lock = threading.Lock() + self._delay_listener = None + + @property + def is_running(self): + """Return true if script is on.""" + return self._cur != -1 + + def run(self): + """Run script.""" + with self._lock: + if self._cur == -1: + self._log('Running script') + self._cur = 0 + + # Unregister callback if we were in a delay but turn on is called + # again. In that case we just continue execution. + self._remove_listener() + + for cur, action in islice(enumerate(self.sequence), self._cur, + None): + + if CONF_DELAY in action: + # Call ourselves in the future to continue work + def script_delay(now): + """Called after delay is done.""" + self._delay_listener = None + self.run() + + self._delay_listener = track_point_in_utc_time( + self.hass, script_delay, + date_util.utcnow() + action[CONF_DELAY]) + self._cur = cur + 1 + if self._change_listener: + self._change_listener() + return + + elif service.validate_service_call(action) is None: + self._call_service(action) + + elif CONF_EVENT in action: + self._fire_event(action) + + self._cur = -1 + self.last_action = None + if self._change_listener: + self._change_listener() + + def stop(self): + """Stop running script.""" + with self._lock: + if self._cur == -1: + return + + self._cur = -1 + self._remove_listener() + if self._change_listener: + self._change_listener() + + def _call_service(self, action): + """Call the service specified in the action.""" + self.last_action = action.get(CONF_ALIAS, 'call service') + self._log("Executing step %s", self.last_action) + service.call_from_config(self.hass, action, True) + + def _fire_event(self, action): + """Fire an event.""" + self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) + self._log("Executing step %s", self.last_action) + self.hass.bus.fire(action[CONF_EVENT], action.get(CONF_EVENT_DATA)) + + def _remove_listener(self): + """Remove point in time listener, if any.""" + if self._delay_listener: + self.hass.bus.remove_listener(EVENT_TIME_CHANGED, + self._delay_listener) + self._delay_listener = None + + def _log(self, msg, *substitutes): + """Logger helper.""" + if self.name is not None: + msg = "Script {}: {}".format(self.name, msg, *substitutes) + + _LOGGER.info(msg) diff --git a/tests/components/test_script.py b/tests/components/test_script.py index 4f912dc77a0..f8b99533c18 100644 --- a/tests/components/test_script.py +++ b/tests/components/test_script.py @@ -34,13 +34,6 @@ class TestScript(unittest.TestCase): 'sequence': [{'event': 'bla'}] } }, - { - 'test': { - 'sequence': { - 'event': 'test_event' - } - } - }, { 'test': { 'sequence': { @@ -49,7 +42,6 @@ class TestScript(unittest.TestCase): } } }, - ): assert not _setup_component(self.hass, 'script', { 'script': value @@ -206,45 +198,6 @@ class TestScript(unittest.TestCase): self.assertEqual(2, len(calls)) - def test_alt_delay(self): - """Test alternative delay config format.""" - event = 'test_event' - calls = [] - - def record_event(event): - """Add recorded event to set.""" - calls.append(event) - - self.hass.bus.listen(event, record_event) - - assert _setup_component(self.hass, 'script', { - 'script': { - 'test': { - 'sequence': [{ - 'event': event, - }, { - 'delay': None, - 'seconds': 5 - }, { - 'event': event, - }] - } - } - }) - - script.turn_on(self.hass, ENTITY_ID) - self.hass.pool.block_till_done() - - self.assertTrue(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(1, len(calls)) - - future = dt_util.utcnow() + timedelta(seconds=5) - fire_time_changed(self.hass, future) - self.hass.pool.block_till_done() - - self.assertFalse(script.is_on(self.hass, ENTITY_ID)) - self.assertEqual(2, len(calls)) - def test_cancel_while_delay(self): """Test the cancelling while the delay is present.""" event = 'test_event' diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 3f4789eca4f..b73dc6d6f94 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -145,18 +145,19 @@ def test_icon(): schema('mdi:work') -def test_time_offset(): - """Test time_offset validation.""" - schema = vol.Schema(cv.time_offset) +def test_time_period(): + """Test time_period validation.""" + schema = vol.Schema(cv.time_period) for value in ( - None, '', 1234, 'hello:world', '12:', '12:34:56:78' + None, '', 1234, 'hello:world', '12:', '12:34:56:78', + {}, {'wrong_key': -10} ): with pytest.raises(vol.MultipleInvalid): schema(value) for value in ( - '8:20', '23:59', '-8:20', '-23:59:59', '-48:00' + '8:20', '23:59', '-8:20', '-23:59:59', '-48:00', {'minutes': 5} ): schema(value) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index c863a46ad3b..11ace1ab5d8 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -37,7 +37,7 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual(1, len(runs)) def test_template_service_call(self): - """ Test service call with tempating. """ + """Test service call with tempating.""" config = { 'service_template': '{{ \'test_domain.test_service\' }}', 'entity_id': 'hello.world', @@ -56,6 +56,7 @@ class TestServiceHelpers(unittest.TestCase): self.assertEqual('goodbye', runs[0].data['hello']) def test_passing_variables_to_templates(self): + """Test passing variables to templates.""" config = { 'service_template': '{{ var_service }}', 'entity_id': 'hello.world', @@ -141,7 +142,7 @@ class TestServiceHelpers(unittest.TestCase): service.extract_entity_ids(self.hass, call)) def test_validate_service_call(self): - """Test is_valid_service_call method""" + """Test is_valid_service_call method.""" self.assertNotEqual( service.validate_service_call( {}),