Add default variables to script helper (#39895)

pull/39914/head
Paulus Schoutsen 2020-09-10 20:41:42 +02:00 committed by GitHub
parent b5005430be
commit aa9dff572e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 190 additions and 15 deletions

View File

@ -13,6 +13,7 @@ from homeassistant.const import (
CONF_ID,
CONF_MODE,
CONF_PLATFORM,
CONF_VARIABLES,
CONF_ZONE,
EVENT_HOMEASSISTANT_STARTED,
SERVICE_RELOAD,
@ -29,7 +30,7 @@ from homeassistant.core import (
split_entity_id,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, extract_domain_configs
from homeassistant.helpers import condition, extract_domain_configs, template
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
@ -104,6 +105,7 @@ PLATFORM_SCHEMA = vol.All(
vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
},
SCRIPT_MODE_SINGLE,
@ -239,6 +241,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
cond_func,
action_script,
initial_state,
variables,
):
"""Initialize an automation entity."""
self._id = automation_id
@ -253,6 +256,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
self._logger = _LOGGER
self._variables = variables
self._variables_dynamic = template.is_complex(variables)
@property
def name(self):
@ -329,6 +334,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
"""Startup with initial state or previous state."""
await super().async_added_to_hass()
if self._variables_dynamic:
template.attach(cast(HomeAssistant, self.hass), self._variables)
self._logger = logging.getLogger(
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
)
@ -378,11 +386,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
else:
await self.async_disable()
async def async_trigger(self, variables, context=None, skip_condition=False):
async def async_trigger(self, run_variables, context=None, skip_condition=False):
"""Trigger automation.
This method is a coroutine.
"""
if self._variables:
if self._variables_dynamic:
variables = template.render_complex(self._variables, run_variables)
else:
variables = dict(self._variables)
else:
variables = {}
if run_variables:
variables.update(run_variables)
if (
not skip_condition
and self._cond_func is not None
@ -518,6 +537,9 @@ async def _async_process_config(hass, config, component):
max_runs=config_block[CONF_MAX],
max_exceeded=config_block[CONF_MAX_EXCEEDED],
logger=_LOGGER,
# We don't pass variables here
# Automation will already render them to use them in the condition
# and so will pass them on to the script.
)
if CONF_CONDITION in config_block:
@ -535,6 +557,7 @@ async def _async_process_config(hass, config, component):
cond_func,
action_script,
initial_state,
config_block.get(CONF_VARIABLES),
)
entities.append(entity)

View File

@ -12,6 +12,7 @@ from homeassistant.const import (
CONF_ICON,
CONF_MODE,
CONF_SEQUENCE,
CONF_VARIABLES,
SERVICE_RELOAD,
SERVICE_TOGGLE,
SERVICE_TURN_OFF,
@ -59,6 +60,7 @@ SCRIPT_ENTRY_SCHEMA = make_script_schema(
vol.Optional(CONF_ICON): cv.icon,
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Optional(CONF_FIELDS, default={}): {
cv.string: {
vol.Optional(CONF_DESCRIPTION): cv.string,
@ -75,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
{vol.Optional(ATTR_VARIABLES): dict}
{vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
)
RELOAD_SERVICE_SCHEMA = vol.Schema({})
@ -263,6 +265,7 @@ class ScriptEntity(ToggleEntity):
max_runs=cfg[CONF_MAX],
max_exceeded=cfg[CONF_MAX_EXCEEDED],
logger=logging.getLogger(f"{__name__}.{object_id}"),
variables=cfg.get(CONF_VARIABLES),
)
self._changed = asyncio.Event()

View File

@ -179,6 +179,7 @@ CONF_UNTIL = "until"
CONF_URL = "url"
CONF_USERNAME = "username"
CONF_VALUE_TEMPLATE = "value_template"
CONF_VARIABLES = "variables"
CONF_VERIFY_SSL = "verify_ssl"
CONF_WAIT_FOR_TRIGGER = "wait_for_trigger"
CONF_WAIT_TEMPLATE = "wait_template"

View File

@ -863,6 +863,9 @@ def make_entity_service_schema(
)
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
def script_action(value: Any) -> dict:
"""Validate a script action."""
if not isinstance(value, dict):

View File

@ -53,11 +53,7 @@ from homeassistant.const import (
SERVICE_TURN_ON,
)
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
from homeassistant.helpers import (
condition,
config_validation as cv,
template as template,
)
from homeassistant.helpers import condition, config_validation as cv, template
from homeassistant.helpers.event import async_call_later, async_track_template
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
@ -721,6 +717,7 @@ class Script:
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
top_level: bool = True,
variables: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS)
@ -759,6 +756,10 @@ class Script:
self._choose_data: Dict[int, Dict[str, Any]] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
self.variables = variables
self._variables_dynamic = template.is_complex(variables)
if self._variables_dynamic:
template.attach(hass, variables)
def _set_logger(self, logger: Optional[logging.Logger] = None) -> None:
if logger:
@ -867,7 +868,7 @@ class Script:
async def async_run(
self,
variables: Optional[_VarsType] = None,
run_variables: Optional[_VarsType] = None,
context: Optional[Context] = None,
started_action: Optional[Callable[..., Any]] = None,
) -> None:
@ -898,8 +899,19 @@ class Script:
# are read-only, but more importantly, so as not to leak any variables created
# during the run back to the caller.
if self._top_level:
variables = dict(variables) if variables is not None else {}
if self.variables:
if self._variables_dynamic:
variables = template.render_complex(self.variables, run_variables)
else:
variables = dict(self.variables)
else:
variables = {}
if run_variables:
variables.update(run_variables)
variables["context"] = context
else:
variables = cast(dict, run_variables)
if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun

View File

@ -65,7 +65,7 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
if isinstance(obj, list):
for child in obj:
attach(hass, child)
elif isinstance(obj, dict):
elif isinstance(obj, collections.abc.Mapping):
for child_key, child_value in obj.items():
attach(hass, child_key)
attach(hass, child_value)
@ -77,7 +77,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
"""Recursive template creator helper function."""
if isinstance(value, list):
return [render_complex(item, variables) for item in value]
if isinstance(value, dict):
if isinstance(value, collections.abc.Mapping):
return {
render_complex(key, variables): render_complex(item, variables)
for key, item in value.items()
@ -88,6 +88,19 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
return value
def is_complex(value: Any) -> bool:
"""Test if data structure is a complex template."""
if isinstance(value, Template):
return True
if isinstance(value, list):
return any(is_complex(val) for val in value)
if isinstance(value, collections.abc.Mapping):
return any(is_complex(val) for val in value.keys()) or any(
is_complex(val) for val in value.values()
)
return False
def is_template_string(maybe_template: str) -> bool:
"""Check if the input is a Jinja2 template."""
return _RE_JINJA_DELIMITERS.search(maybe_template) is not None

View File

@ -1,5 +1,5 @@
"""Typing Helpers for Home Assistant."""
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Mapping, Optional, Tuple, Union
import homeassistant.core
@ -12,7 +12,7 @@ HomeAssistantType = homeassistant.core.HomeAssistant
ServiceCallType = homeassistant.core.ServiceCall
ServiceDataType = Dict[str, Any]
StateType = Union[None, str, int, float]
TemplateVarsType = Optional[Dict[str, Any]]
TemplateVarsType = Optional[Mapping[str, Any]]
# Custom type for recorder Queries
QueryType = Any

View File

@ -1134,3 +1134,57 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["domain"] == "automation"
assert event2["message"] == "has been triggered by source of trigger"
assert event2["entity_id"] == "automation.bye"
async def test_automation_variables(hass):
"""Test automation variables."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"variables": {
"test_var": "defined_in_config",
"event_type": "{{ trigger.event.event_type }}",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
{
"variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event_2"},
"condition": {
"condition": "template",
"value_template": "{{ trigger.event.data.pass_condition }}",
},
"action": {
"service": "test.automation",
},
},
]
},
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["value"] == "defined_in_config"
assert calls[0].data["event_type"] == "test_event"
hass.bus.async_fire("test_event_2")
await hass.async_block_till_done()
assert len(calls) == 1
hass.bus.async_fire("test_event_2", {"pass_condition": True})
await hass.async_block_till_done()
assert len(calls) == 2

View File

@ -23,7 +23,7 @@ from homeassistant.loader import bind_hass
from homeassistant.setup import async_setup_component, setup_component
from tests.async_mock import Mock, patch
from tests.common import get_test_home_assistant
from tests.common import async_mock_service, get_test_home_assistant
from tests.components.logbook.test_init import MockLazyEventPartialState
ENTITY_ID = "script.test"
@ -615,3 +615,69 @@ async def test_concurrent_script(hass, concurrently):
assert not script.is_on(hass, "script.script1")
assert not script.is_on(hass, "script.script2")
async def test_script_variables(hass):
"""Test defining scripts."""
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"variables": {
"test_var": "from_config",
"templated_config_var": "{{ var_from_service | default('config-default') }}",
},
"sequence": [
{
"service": "test.script",
"data": {
"value": "{{ test_var }}",
"templated_config_var": "{{ templated_config_var }}",
},
},
],
},
"script2": {
"variables": {
"test_var": "from_config",
},
"sequence": [
{
"service": "test.script",
"data": {
"value": "{{ test_var }}",
},
},
],
},
}
},
)
mock_calls = async_mock_service(hass, "test", "script")
await hass.services.async_call(
"script", "script1", {"var_from_service": "hello"}, blocking=True
)
assert len(mock_calls) == 1
assert mock_calls[0].data["value"] == "from_config"
assert mock_calls[0].data["templated_config_var"] == "hello"
await hass.services.async_call(
"script", "script1", {"test_var": "from_service"}, blocking=True
)
assert len(mock_calls) == 2
assert mock_calls[1].data["value"] == "from_service"
assert mock_calls[1].data["templated_config_var"] == "config-default"
# Call script with vars but no templates in it
await hass.services.async_call(
"script", "script2", {"test_var": "from_service"}, blocking=True
)
assert len(mock_calls) == 3
assert mock_calls[2].data["value"] == "from_service"