Improve script validation (#32461)

pull/32508/head
Paulus Schoutsen 2020-03-05 11:44:42 -08:00 committed by GitHub
parent da7c5518f3
commit 6a21afa2a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 89 deletions

View File

@ -35,9 +35,9 @@ CONF_ALIAS = "alias"
CONF_API_KEY = "api_key"
CONF_API_VERSION = "api_version"
CONF_AT = "at"
CONF_AUTHENTICATION = "authentication"
CONF_AUTH_MFA_MODULES = "auth_mfa_modules"
CONF_AUTH_PROVIDERS = "auth_providers"
CONF_AUTHENTICATION = "authentication"
CONF_BASE = "base"
CONF_BEFORE = "before"
CONF_BELOW = "below"
@ -57,11 +57,13 @@ CONF_COMMAND_OPEN = "command_open"
CONF_COMMAND_STATE = "command_state"
CONF_COMMAND_STOP = "command_stop"
CONF_CONDITION = "condition"
CONF_CONTINUE_ON_TIMEOUT = "continue_on_timeout"
CONF_COVERS = "covers"
CONF_CURRENCY = "currency"
CONF_CUSTOMIZE = "customize"
CONF_CUSTOMIZE_DOMAIN = "customize_domain"
CONF_CUSTOMIZE_GLOB = "customize_glob"
CONF_DELAY = "delay"
CONF_DELAY_TIME = "delay_time"
CONF_DEVICE = "device"
CONF_DEVICE_CLASS = "device_class"
@ -82,6 +84,8 @@ CONF_ENTITY_ID = "entity_id"
CONF_ENTITY_NAMESPACE = "entity_namespace"
CONF_ENTITY_PICTURE_TEMPLATE = "entity_picture_template"
CONF_EVENT = "event"
CONF_EVENT_DATA = "event_data"
CONF_EVENT_DATA_TEMPLATE = "event_data_template"
CONF_EXCLUDE = "exclude"
CONF_FILE_PATH = "file_path"
CONF_FILENAME = "filename"
@ -95,15 +99,15 @@ CONF_HOSTS = "hosts"
CONF_HS = "hs"
CONF_ICON = "icon"
CONF_ICON_TEMPLATE = "icon_template"
CONF_INCLUDE = "include"
CONF_ID = "id"
CONF_INCLUDE = "include"
CONF_IP_ADDRESS = "ip_address"
CONF_LATITUDE = "latitude"
CONF_LONGITUDE = "longitude"
CONF_LIGHTS = "lights"
CONF_LONGITUDE = "longitude"
CONF_MAC = "mac"
CONF_METHOD = "method"
CONF_MAXIMUM = "maximum"
CONF_METHOD = "method"
CONF_MINIMUM = "minimum"
CONF_MODE = "mode"
CONF_MONITORED_CONDITIONS = "monitored_conditions"
@ -130,14 +134,18 @@ CONF_RADIUS = "radius"
CONF_RECIPIENT = "recipient"
CONF_REGION = "region"
CONF_RESOURCE = "resource"
CONF_RESOURCES = "resources"
CONF_RESOURCE_TEMPLATE = "resource_template"
CONF_RESOURCES = "resources"
CONF_RGB = "rgb"
CONF_ROOM = "room"
CONF_SCAN_INTERVAL = "scan_interval"
CONF_SCENE = "scene"
CONF_SENDER = "sender"
CONF_SENSOR_TYPE = "sensor_type"
CONF_SENSORS = "sensors"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SERVICE_TEMPLATE = "service_template"
CONF_SHOW_ON_MAP = "show_on_map"
CONF_SLAVE = "slave"
CONF_SOURCE = "source"
@ -159,11 +167,12 @@ CONF_URL = "url"
CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template"
CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_WEBHOOK_ID = "webhook_id"
CONF_WEEKDAY = "weekday"
CONF_WHITE_VALUE = "white_value"
CONF_WHITELIST = "whitelist"
CONF_WHITELIST_EXTERNAL_DIRS = "whitelist_external_dirs"
CONF_WHITE_VALUE = "white_value"
CONF_XY = "xy"
CONF_ZONE = "zone"

View File

@ -39,18 +39,27 @@ from homeassistant.const import (
CONF_ALIAS,
CONF_BELOW,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_ENTITY_NAMESPACE,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_FOR,
CONF_PLATFORM,
CONF_SCAN_INTERVAL,
CONF_SCENE,
CONF_SERVICE,
CONF_SERVICE_TEMPLATE,
CONF_STATE,
CONF_TIMEOUT,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
CONF_VALUE_TEMPLATE,
CONF_WAIT_TEMPLATE,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
SUN_EVENT_SUNRISE,
@ -722,7 +731,7 @@ def key_value_schemas(
if key_value not in value_schemas:
raise vol.Invalid(
f"Unexpected key {key_value}. Expected {', '.join(value_schemas)}"
f"Unexpected value for {key}: '{key_value}'. Expected {', '.join(value_schemas)}"
)
return cast(Dict[str, Any], value_schemas[key_value](value))
@ -800,9 +809,9 @@ def make_entity_service_schema(
EVENT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("event"): string,
vol.Optional("event_data"): dict,
vol.Optional("event_data_template"): {match_all: template_complex},
vol.Required(CONF_EVENT): string,
vol.Optional(CONF_EVENT_DATA): dict,
vol.Optional(CONF_EVENT_DATA_TEMPLATE): {match_all: template_complex},
}
)
@ -810,14 +819,14 @@ SERVICE_SCHEMA = vol.All(
vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Exclusive("service", "service name"): service,
vol.Exclusive("service_template", "service name"): template,
vol.Exclusive(CONF_SERVICE, "service name"): service,
vol.Exclusive(CONF_SERVICE_TEMPLATE, "service name"): template,
vol.Optional("data"): dict,
vol.Optional("data_template"): {match_all: template_complex},
vol.Optional(CONF_ENTITY_ID): comp_entity_ids,
}
),
has_at_least_one_key("service", "service_template"),
has_at_least_one_key(CONF_SERVICE, CONF_SERVICE_TEMPLATE),
)
NUMERIC_STATE_CONDITION_SCHEMA = vol.All(
@ -943,7 +952,7 @@ CONDITION_SCHEMA: vol.Schema = key_value_schemas(
_SCRIPT_DELAY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("delay"): vol.Any(
vol.Required(CONF_DELAY): vol.Any(
vol.All(time_period, positive_timedelta), template, template_complex
),
}
@ -952,9 +961,9 @@ _SCRIPT_DELAY_SCHEMA = vol.Schema(
_SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
vol.Required("wait_template"): template,
vol.Required(CONF_WAIT_TEMPLATE): template,
vol.Optional(CONF_TIMEOUT): vol.All(time_period, positive_timedelta),
vol.Optional("continue_on_timeout"): boolean,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
}
)
@ -964,19 +973,57 @@ DEVICE_ACTION_BASE_SCHEMA = vol.Schema(
DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required("scene"): entity_domain("scene")})
_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
SCRIPT_SCHEMA = vol.All(
ensure_list,
[
vol.Any(
SERVICE_SCHEMA,
_SCRIPT_DELAY_SCHEMA,
_SCRIPT_WAIT_TEMPLATE_SCHEMA,
EVENT_SCHEMA,
CONDITION_SCHEMA,
DEVICE_ACTION_SCHEMA,
_SCRIPT_SCENE_SCHEMA,
)
],
)
SCRIPT_ACTION_DELAY = "delay"
SCRIPT_ACTION_WAIT_TEMPLATE = "wait_template"
SCRIPT_ACTION_CHECK_CONDITION = "condition"
SCRIPT_ACTION_FIRE_EVENT = "event"
SCRIPT_ACTION_CALL_SERVICE = "call_service"
SCRIPT_ACTION_DEVICE_AUTOMATION = "device"
SCRIPT_ACTION_ACTIVATE_SCENE = "scene"
def determine_script_action(action: dict) -> str:
"""Determine action type."""
if CONF_DELAY in action:
return SCRIPT_ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return SCRIPT_ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return SCRIPT_ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return SCRIPT_ACTION_FIRE_EVENT
if CONF_DEVICE_ID in action:
return SCRIPT_ACTION_DEVICE_AUTOMATION
if CONF_SCENE in action:
return SCRIPT_ACTION_ACTIVATE_SCENE
return SCRIPT_ACTION_CALL_SERVICE
ACTION_TYPE_SCHEMAS: Dict[str, Callable[[Any], dict]] = {
SCRIPT_ACTION_CALL_SERVICE: SERVICE_SCHEMA,
SCRIPT_ACTION_DELAY: _SCRIPT_DELAY_SCHEMA,
SCRIPT_ACTION_WAIT_TEMPLATE: _SCRIPT_WAIT_TEMPLATE_SCHEMA,
SCRIPT_ACTION_FIRE_EVENT: EVENT_SCHEMA,
SCRIPT_ACTION_CHECK_CONDITION: CONDITION_SCHEMA,
SCRIPT_ACTION_DEVICE_AUTOMATION: DEVICE_ACTION_SCHEMA,
SCRIPT_ACTION_ACTIVATE_SCENE: _SCRIPT_SCENE_SCHEMA,
}
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")
return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])

View File

@ -15,9 +15,16 @@ import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_EVENT,
CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE,
CONF_SCENE,
CONF_TIMEOUT,
CONF_WAIT_TEMPLATE,
SERVICE_TURN_ON,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
@ -37,24 +44,6 @@ from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_ALIAS = "alias"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
CONF_SEQUENCE = "sequence"
CONF_EVENT = "event"
CONF_EVENT_DATA = "event_data"
CONF_EVENT_DATA_TEMPLATE = "event_data_template"
CONF_DELAY = "delay"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_CONTINUE = "continue_on_timeout"
CONF_SCENE = "scene"
ACTION_DELAY = "delay"
ACTION_WAIT_TEMPLATE = "wait_template"
ACTION_CHECK_CONDITION = "condition"
ACTION_FIRE_EVENT = "event"
ACTION_CALL_SERVICE = "call_service"
ACTION_DEVICE_AUTOMATION = "device"
ACTION_ACTIVATE_SCENE = "scene"
IF_RUNNING_ERROR = "error"
IF_RUNNING_IGNORE = "ignore"
@ -82,41 +71,21 @@ _LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
def _determine_action(action):
"""Determine action type."""
if CONF_DELAY in action:
return ACTION_DELAY
if CONF_WAIT_TEMPLATE in action:
return ACTION_WAIT_TEMPLATE
if CONF_CONDITION in action:
return ACTION_CHECK_CONDITION
if CONF_EVENT in action:
return ACTION_FIRE_EVENT
if CONF_DEVICE_ID in action:
return ACTION_DEVICE_AUTOMATION
if CONF_SCENE in action:
return ACTION_ACTIVATE_SCENE
return ACTION_CALL_SERVICE
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
action_type = _determine_action(config)
action_type = cv.determine_script_action(config)
if action_type == ACTION_DEVICE_AUTOMATION:
if action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "action"
)
config = platform.ACTION_SCHEMA(config) # type: ignore
if action_type == ACTION_CHECK_CONDITION and config[CONF_CONDITION] == "device":
if (
action_type == cv.SCRIPT_ACTION_CHECK_CONDITION
and config[CONF_CONDITION] == "device"
):
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], "condition"
)
@ -165,7 +134,9 @@ class _ScriptRunBase(ABC):
async def _async_step(self, log_exceptions):
try:
await getattr(self, f"_async_{_determine_action(self._action)}_step")()
await getattr(
self, f"_async_{cv.determine_script_action(self._action)}_step"
)()
except Exception as err:
if not isinstance(err, (_SuspendScript, _StopScript)) and (
self._log_exceptions or log_exceptions
@ -178,7 +149,7 @@ class _ScriptRunBase(ABC):
"""Stop script run."""
def _log_exception(self, exception):
action_type = _determine_action(self._action)
action_type = cv.determine_script_action(self._action)
error = str(exception)
level = logging.ERROR
@ -406,7 +377,7 @@ class _ScriptRun(_ScriptRunBase):
timeout,
)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE, True):
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
@ -547,7 +518,7 @@ class _LegacyScriptRun(_ScriptRunBase):
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE, True):
if self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
@ -632,12 +603,12 @@ class Script:
referenced = set()
for step in self.sequence:
action = _determine_action(step)
action = cv.determine_script_action(step)
if action == ACTION_CHECK_CONDITION:
if action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_devices(step)
elif action == ACTION_DEVICE_AUTOMATION:
elif action == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
referenced.add(step[CONF_DEVICE_ID])
self._referenced_devices = referenced
@ -652,9 +623,9 @@ class Script:
referenced = set()
for step in self.sequence:
action = _determine_action(step)
action = cv.determine_script_action(step)
if action == ACTION_CALL_SERVICE:
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA)
if not data:
continue
@ -670,10 +641,10 @@ class Script:
for entity_id in entity_ids:
referenced.add(entity_id)
elif action == ACTION_CHECK_CONDITION:
elif action == cv.SCRIPT_ACTION_CHECK_CONDITION:
referenced |= condition.async_extract_entities(step)
elif action == ACTION_ACTIVATE_SCENE:
elif action == cv.SCRIPT_ACTION_ACTIVATE_SCENE:
referenced.add(step[CONF_SCENE])
self._referenced_entities = referenced

View File

@ -10,6 +10,8 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_CONTROL
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_ENTITY_ID,
CONF_SERVICE,
CONF_SERVICE_TEMPLATE,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
)
@ -29,8 +31,6 @@ from homeassistant.util.yaml.loader import JSON_TYPE
# mypy: allow-untyped-defs, no-check-untyped-defs
CONF_SERVICE = "service"
CONF_SERVICE_TEMPLATE = "service_template"
CONF_SERVICE_ENTITY_ID = "entity_id"
CONF_SERVICE_DATA = "data"
CONF_SERVICE_DATA_TEMPLATE = "data_template"

View File

@ -1008,7 +1008,10 @@ def test_key_value_schemas():
for mode in None, "invalid":
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": mode})
assert str(excinfo.value) == f"Unexpected key {mode}. Expected number, string"
assert (
str(excinfo.value)
== f"Unexpected value for mode: '{mode}'. Expected number, string"
)
with pytest.raises(vol.Invalid) as excinfo:
schema({"mode": "number", "data": "string-value"})
@ -1020,3 +1023,25 @@ def test_key_value_schemas():
for mode, data in (("number", 1), ("string", "hello")):
schema({"mode": mode, "data": data})
def test_script(caplog):
"""Test script validation is user friendly."""
for data, msg in (
({"delay": "{{ invalid"}, "should be format 'HH:MM'"),
({"wait_template": "{{ invalid"}, "invalid template"),
({"condition": "invalid"}, "Unexpected value for condition: 'invalid'"),
({"event": None}, "string value is None for dictionary value @ data['event']"),
(
{"device_id": None},
"string value is None for dictionary value @ data['device_id']",
),
(
{"scene": "light.kitchen"},
"Entity ID 'light.kitchen' does not belong to domain 'scene'",
),
):
with pytest.raises(vol.Invalid) as excinfo:
cv.script_action(data)
assert msg in str(excinfo.value)