From 6a21afa2a8a6fa5ccaa59e2be518b9d7df8312cc Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Thu, 5 Mar 2020 11:44:42 -0800 Subject: [PATCH] Improve script validation (#32461) --- homeassistant/const.py | 21 +++-- homeassistant/helpers/config_validation.py | 97 ++++++++++++++++------ homeassistant/helpers/script.py | 81 ++++++------------ homeassistant/helpers/service.py | 4 +- tests/helpers/test_config_validation.py | 27 +++++- 5 files changed, 141 insertions(+), 89 deletions(-) diff --git a/homeassistant/const.py b/homeassistant/const.py index 4e4be408b40..66db936669b 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -35,9 +35,9 @@ CONF_ALIAS = "alias" CONF_API_KEY = "api_key" CONF_API_VERSION = "api_version" CONF_AT = "at" -CONF_AUTHENTICATION = "authentication" CONF_AUTH_MFA_MODULES = "auth_mfa_modules" CONF_AUTH_PROVIDERS = "auth_providers" +CONF_AUTHENTICATION = "authentication" CONF_BASE = "base" CONF_BEFORE = "before" CONF_BELOW = "below" @@ -57,11 +57,13 @@ CONF_COMMAND_OPEN = "command_open" CONF_COMMAND_STATE = "command_state" CONF_COMMAND_STOP = "command_stop" CONF_CONDITION = "condition" +CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout" CONF_COVERS = "covers" CONF_CURRENCY = "currency" CONF_CUSTOMIZE = "customize" CONF_CUSTOMIZE_DOMAIN = "customize_domain" CONF_CUSTOMIZE_GLOB = "customize_glob" +CONF_DELAY = "delay" CONF_DELAY_TIME = "delay_time" CONF_DEVICE = "device" CONF_DEVICE_CLASS = "device_class" @@ -82,6 +84,8 @@ CONF_ENTITY_ID = "entity_id" CONF_ENTITY_NAMESPACE = "entity_namespace" CONF_ENTITY_PICTURE_TEMPLATE = "entity_picture_template" CONF_EVENT = "event" +CONF_EVENT_DATA = "event_data" +CONF_EVENT_DATA_TEMPLATE = "event_data_template" CONF_EXCLUDE = "exclude" CONF_FILE_PATH = "file_path" CONF_FILENAME = "filename" @@ -95,15 +99,15 @@ CONF_HOSTS = "hosts" CONF_HS = "hs" CONF_ICON = "icon" CONF_ICON_TEMPLATE = "icon_template" -CONF_INCLUDE = "include" CONF_ID = "id" +CONF_INCLUDE = "include" CONF_IP_ADDRESS = "ip_address" CONF_LATITUDE = "latitude" -CONF_LONGITUDE = "longitude" CONF_LIGHTS = "lights" +CONF_LONGITUDE = "longitude" CONF_MAC = "mac" -CONF_METHOD = "method" CONF_MAXIMUM = "maximum" +CONF_METHOD = "method" CONF_MINIMUM = "minimum" CONF_MODE = "mode" CONF_MONITORED_CONDITIONS = "monitored_conditions" @@ -130,14 +134,18 @@ CONF_RADIUS = "radius" CONF_RECIPIENT = "recipient" CONF_REGION = "region" CONF_RESOURCE = "resource" -CONF_RESOURCES = "resources" CONF_RESOURCE_TEMPLATE = "resource_template" +CONF_RESOURCES = "resources" CONF_RGB = "rgb" CONF_ROOM = "room" CONF_SCAN_INTERVAL = "scan_interval" +CONF_SCENE = "scene" CONF_SENDER = "sender" CONF_SENSOR_TYPE = "sensor_type" CONF_SENSORS = "sensors" +CONF_SERVICE = "service" +CONF_SERVICE_DATA = "data" +CONF_SERVICE_TEMPLATE = "service_template" CONF_SHOW_ON_MAP = "show_on_map" CONF_SLAVE = "slave" CONF_SOURCE = "source" @@ -159,11 +167,12 @@ CONF_URL = "url" CONF_USERNAME = "username" CONF_VALUE_TEMPLATE = "value_template" CONF_VERIFY_SSL = "verify_ssl" +CONF_WAIT_TEMPLATE = "wait_template" CONF_WEBHOOK_ID = "webhook_id" CONF_WEEKDAY = "weekday" +CONF_WHITE_VALUE = "white_value" CONF_WHITELIST = "whitelist" CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs" -CONF_WHITE_VALUE = "white_value" CONF_XY = "xy" CONF_ZONE = "zone" diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 565cac4058c..db966d93412 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -39,18 +39,27 @@ from homeassistant.const import ( CONF_ALIAS, CONF_BELOW, CONF_CONDITION, + CONF_CONTINUE_ON_TIMEOUT, + CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, CONF_ENTITY_ID, CONF_ENTITY_NAMESPACE, + CONF_EVENT, + CONF_EVENT_DATA, + CONF_EVENT_DATA_TEMPLATE, CONF_FOR, CONF_PLATFORM, CONF_SCAN_INTERVAL, + CONF_SCENE, + CONF_SERVICE, + CONF_SERVICE_TEMPLATE, CONF_STATE, CONF_TIMEOUT, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC, CONF_VALUE_TEMPLATE, + CONF_WAIT_TEMPLATE, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE, SUN_EVENT_SUNRISE, @@ -722,7 +731,7 @@ def key_value_schemas( if key_value not in value_schemas: raise vol.Invalid( - f"Unexpected key {key_value}. Expected {', '.join(value_schemas)}" + f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(value_schemas)}" ) return cast(Dict[str, Any], value_schemas[key_value](value)) @@ -800,9 +809,9 @@ def make_entity_service_schema( EVENT_SCHEMA = vol.Schema( { vol.Optional(CONF_ALIAS): string, - vol.Required("event"): string, - vol.Optional("event_data"): dict, - vol.Optional("event_data_template"): {match_all: template_complex}, + vol.Required(CONF_EVENT): string, + vol.Optional(CONF_EVENT_DATA): dict, + vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex}, } ) @@ -810,14 +819,14 @@ SERVICE_SCHEMA = vol.All( vol.Schema( { vol.Optional(CONF_ALIAS): string, - vol.Exclusive("service", "service name"): service, - vol.Exclusive("service_template", "service name"): template, + vol.Exclusive(CONF_SERVICE, "service name"): service, + vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template, vol.Optional("data"): dict, vol.Optional("data_template"): {match_all: template_complex}, vol.Optional(CONF_ENTITY_ID): comp_entity_ids, } ), - has_at_least_one_key("service", "service_template"), + has_at_least_one_key(CONF_SERVICE, CONF_SERVICE_TEMPLATE), ) NUMERIC_STATE_CONDITION_SCHEMA = vol.All( @@ -943,7 +952,7 @@ CONDITION_SCHEMA: vol.Schema = key_value_schemas( _SCRIPT_DELAY_SCHEMA = vol.Schema( { vol.Optional(CONF_ALIAS): string, - vol.Required("delay"): vol.Any( + vol.Required(CONF_DELAY): vol.Any( vol.All(time_period, positive_timedelta), template, template_complex ), } @@ -952,9 +961,9 @@ _SCRIPT_DELAY_SCHEMA = vol.Schema( _SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema( { vol.Optional(CONF_ALIAS): string, - vol.Required("wait_template"): template, + vol.Required(CONF_WAIT_TEMPLATE): template, vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta), - vol.Optional("continue_on_timeout"): boolean, + vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean, } ) @@ -964,19 +973,57 @@ DEVICE_ACTION_BASE_SCHEMA = vol.Schema( DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) -_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required("scene"): entity_domain("scene")}) +_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")}) -SCRIPT_SCHEMA = vol.All( - ensure_list, - [ - vol.Any( - SERVICE_SCHEMA, - _SCRIPT_DELAY_SCHEMA, - _SCRIPT_WAIT_TEMPLATE_SCHEMA, - EVENT_SCHEMA, - CONDITION_SCHEMA, - DEVICE_ACTION_SCHEMA, - _SCRIPT_SCENE_SCHEMA, - ) - ], -) +SCRIPT_ACTION_DELAY = "delay" +SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template" +SCRIPT_ACTION_CHECK_CONDITION = "condition" +SCRIPT_ACTION_FIRE_EVENT = "event" +SCRIPT_ACTION_CALL_SERVICE = "call_service" +SCRIPT_ACTION_DEVICE_AUTOMATION = "device" +SCRIPT_ACTION_ACTIVATE_SCENE = "scene" + + +def determine_script_action(action: dict) -> str: + """Determine action type.""" + if CONF_DELAY in action: + return SCRIPT_ACTION_DELAY + + if CONF_WAIT_TEMPLATE in action: + return SCRIPT_ACTION_WAIT_TEMPLATE + + if CONF_CONDITION in action: + return SCRIPT_ACTION_CHECK_CONDITION + + if CONF_EVENT in action: + return SCRIPT_ACTION_FIRE_EVENT + + if CONF_DEVICE_ID in action: + return SCRIPT_ACTION_DEVICE_AUTOMATION + + if CONF_SCENE in action: + return SCRIPT_ACTION_ACTIVATE_SCENE + + return SCRIPT_ACTION_CALL_SERVICE + + +ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = { + SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA, + SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA, + SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA, + SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA, + SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA, + SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA, + SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA, +} + + +def script_action(value: Any) -> dict: + """Validate a script action.""" + if not isinstance(value, dict): + raise vol.Invalid("expected dictionary") + + return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value) + + +SCRIPT_SCHEMA = vol.All(ensure_list, [script_action]) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1ce9d2b87bb..937a675aada 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -15,9 +15,16 @@ import homeassistant.components.scene as scene from homeassistant.const import ( ATTR_ENTITY_ID, CONF_CONDITION, + CONF_CONTINUE_ON_TIMEOUT, + CONF_DELAY, CONF_DEVICE_ID, CONF_DOMAIN, + CONF_EVENT, + CONF_EVENT_DATA, + CONF_EVENT_DATA_TEMPLATE, + CONF_SCENE, CONF_TIMEOUT, + CONF_WAIT_TEMPLATE, SERVICE_TURN_ON, ) from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback @@ -37,24 +44,6 @@ from homeassistant.util.dt import utcnow # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs CONF_ALIAS = "alias" -CONF_SERVICE = "service" -CONF_SERVICE_DATA = "data" -CONF_SEQUENCE = "sequence" -CONF_EVENT = "event" -CONF_EVENT_DATA = "event_data" -CONF_EVENT_DATA_TEMPLATE = "event_data_template" -CONF_DELAY = "delay" -CONF_WAIT_TEMPLATE = "wait_template" -CONF_CONTINUE = "continue_on_timeout" -CONF_SCENE = "scene" - -ACTION_DELAY = "delay" -ACTION_WAIT_TEMPLATE = "wait_template" -ACTION_CHECK_CONDITION = "condition" -ACTION_FIRE_EVENT = "event" -ACTION_CALL_SERVICE = "call_service" -ACTION_DEVICE_AUTOMATION = "device" -ACTION_ACTIVATE_SCENE = "scene" IF_RUNNING_ERROR = "error" IF_RUNNING_IGNORE = "ignore" @@ -82,41 +71,21 @@ _LOG_EXCEPTION = logging.ERROR + 1 _TIMEOUT_MSG = "Timeout reached, abort script." -def _determine_action(action): - """Determine action type.""" - if CONF_DELAY in action: - return ACTION_DELAY - - if CONF_WAIT_TEMPLATE in action: - return ACTION_WAIT_TEMPLATE - - if CONF_CONDITION in action: - return ACTION_CHECK_CONDITION - - if CONF_EVENT in action: - return ACTION_FIRE_EVENT - - if CONF_DEVICE_ID in action: - return ACTION_DEVICE_AUTOMATION - - if CONF_SCENE in action: - return ACTION_ACTIVATE_SCENE - - return ACTION_CALL_SERVICE - - async def async_validate_action_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" - action_type = _determine_action(config) + action_type = cv.determine_script_action(config) - if action_type == ACTION_DEVICE_AUTOMATION: + if action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "action" ) config = platform.ACTION_SCHEMA(config) # type: ignore - if action_type == ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device": + if ( + action_type == cv.SCRIPT_ACTION_CHECK_CONDITION + and config[CONF_CONDITION] == "device" + ): platform = await device_automation.async_get_device_automation_platform( hass, config[CONF_DOMAIN], "condition" ) @@ -165,7 +134,9 @@ class _ScriptRunBase(ABC): async def _async_step(self, log_exceptions): try: - await getattr(self, f"_async_{_determine_action(self._action)}_step")() + await getattr( + self, f"_async_{cv.determine_script_action(self._action)}_step" + )() except Exception as err: if not isinstance(err, (_SuspendScript, _StopScript)) and ( self._log_exceptions or log_exceptions @@ -178,7 +149,7 @@ class _ScriptRunBase(ABC): """Stop script run.""" def _log_exception(self, exception): - action_type = _determine_action(self._action) + action_type = cv.determine_script_action(self._action) error = str(exception) level = logging.ERROR @@ -406,7 +377,7 @@ class _ScriptRun(_ScriptRunBase): timeout, ) except asyncio.TimeoutError: - if not self._action.get(CONF_CONTINUE, True): + if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._log(_TIMEOUT_MSG) raise _StopScript finally: @@ -547,7 +518,7 @@ class _LegacyScriptRun(_ScriptRunBase): # Check if we want to continue to execute # the script after the timeout - if self._action.get(CONF_CONTINUE, True): + if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): self._hass.async_create_task(self._async_run(False)) else: self._log(_TIMEOUT_MSG) @@ -632,12 +603,12 @@ class Script: referenced = set() for step in self.sequence: - action = _determine_action(step) + action = cv.determine_script_action(step) - if action == ACTION_CHECK_CONDITION: + if action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_devices(step) - elif action == ACTION_DEVICE_AUTOMATION: + elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION: referenced.add(step[CONF_DEVICE_ID]) self._referenced_devices = referenced @@ -652,9 +623,9 @@ class Script: referenced = set() for step in self.sequence: - action = _determine_action(step) + action = cv.determine_script_action(step) - if action == ACTION_CALL_SERVICE: + if action == cv.SCRIPT_ACTION_CALL_SERVICE: data = step.get(service.CONF_SERVICE_DATA) if not data: continue @@ -670,10 +641,10 @@ class Script: for entity_id in entity_ids: referenced.add(entity_id) - elif action == ACTION_CHECK_CONDITION: + elif action == cv.SCRIPT_ACTION_CHECK_CONDITION: referenced |= condition.async_extract_entities(step) - elif action == ACTION_ACTIVATE_SCENE: + elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE: referenced.add(step[CONF_SCENE]) self._referenced_entities = referenced diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 9085c929651..578d5368314 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -10,6 +10,8 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL from homeassistant.const import ( ATTR_AREA_ID, ATTR_ENTITY_ID, + CONF_SERVICE, + CONF_SERVICE_TEMPLATE, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE, ) @@ -29,8 +31,6 @@ from homeassistant.util.yaml.loader import JSON_TYPE # mypy: allow-untyped-defs, no-check-untyped-defs -CONF_SERVICE = "service" -CONF_SERVICE_TEMPLATE = "service_template" CONF_SERVICE_ENTITY_ID = "entity_id" CONF_SERVICE_DATA = "data" CONF_SERVICE_DATA_TEMPLATE = "data_template" diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 71d845ac637..ff269d2b8c6 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -1008,7 +1008,10 @@ def test_key_value_schemas(): for mode in None, "invalid": with pytest.raises(vol.Invalid) as excinfo: schema({"mode": mode}) - assert str(excinfo.value) == f"Unexpected key {mode}. Expected number, string" + assert ( + str(excinfo.value) + == f"Unexpected value for mode: '{mode}'. Expected number, string" + ) with pytest.raises(vol.Invalid) as excinfo: schema({"mode": "number", "data": "string-value"}) @@ -1020,3 +1023,25 @@ def test_key_value_schemas(): for mode, data in (("number", 1), ("string", "hello")): schema({"mode": mode, "data": data}) + + +def test_script(caplog): + """Test script validation is user friendly.""" + for data, msg in ( + ({"delay": "{{ invalid"}, "should be format 'HH:MM'"), + ({"wait_template": "{{ invalid"}, "invalid template"), + ({"condition": "invalid"}, "Unexpected value for condition: 'invalid'"), + ({"event": None}, "string value is None for dictionary value @ data['event']"), + ( + {"device_id": None}, + "string value is None for dictionary value @ data['device_id']", + ), + ( + {"scene": "light.kitchen"}, + "Entity ID 'light.kitchen' does not belong to domain 'scene'", + ), + ): + with pytest.raises(vol.Invalid) as excinfo: + cv.script_action(data) + + assert msg in str(excinfo.value)