Add default variables to script helper (#39895)
parent
b5005430be
commit
aa9dff572e
|
@ -13,6 +13,7 @@ from homeassistant.const import (
|
|||
CONF_ID,
|
||||
CONF_MODE,
|
||||
CONF_PLATFORM,
|
||||
CONF_VARIABLES,
|
||||
CONF_ZONE,
|
||||
EVENT_HOMEASSISTANT_STARTED,
|
||||
SERVICE_RELOAD,
|
||||
|
@ -29,7 +30,7 @@ from homeassistant.core import (
|
|||
split_entity_id,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import condition, extract_domain_configs
|
||||
from homeassistant.helpers import condition, extract_domain_configs, template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
@ -104,6 +105,7 @@ PLATFORM_SCHEMA = vol.All(
|
|||
vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
|
||||
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
|
||||
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
|
||||
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
|
||||
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
|
||||
},
|
||||
SCRIPT_MODE_SINGLE,
|
||||
|
@ -239,6 +241,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
cond_func,
|
||||
action_script,
|
||||
initial_state,
|
||||
variables,
|
||||
):
|
||||
"""Initialize an automation entity."""
|
||||
self._id = automation_id
|
||||
|
@ -253,6 +256,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
self._logger = _LOGGER
|
||||
self._variables = variables
|
||||
self._variables_dynamic = template.is_complex(variables)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -329,6 +334,9 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
"""Startup with initial state or previous state."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
if self._variables_dynamic:
|
||||
template.attach(cast(HomeAssistant, self.hass), self._variables)
|
||||
|
||||
self._logger = logging.getLogger(
|
||||
f"{__name__}.{split_entity_id(self.entity_id)[1]}"
|
||||
)
|
||||
|
@ -378,11 +386,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
else:
|
||||
await self.async_disable()
|
||||
|
||||
async def async_trigger(self, variables, context=None, skip_condition=False):
|
||||
async def async_trigger(self, run_variables, context=None, skip_condition=False):
|
||||
"""Trigger automation.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if self._variables:
|
||||
if self._variables_dynamic:
|
||||
variables = template.render_complex(self._variables, run_variables)
|
||||
else:
|
||||
variables = dict(self._variables)
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
if run_variables:
|
||||
variables.update(run_variables)
|
||||
|
||||
if (
|
||||
not skip_condition
|
||||
and self._cond_func is not None
|
||||
|
@ -518,6 +537,9 @@ async def _async_process_config(hass, config, component):
|
|||
max_runs=config_block[CONF_MAX],
|
||||
max_exceeded=config_block[CONF_MAX_EXCEEDED],
|
||||
logger=_LOGGER,
|
||||
# We don't pass variables here
|
||||
# Automation will already render them to use them in the condition
|
||||
# and so will pass them on to the script.
|
||||
)
|
||||
|
||||
if CONF_CONDITION in config_block:
|
||||
|
@ -535,6 +557,7 @@ async def _async_process_config(hass, config, component):
|
|||
cond_func,
|
||||
action_script,
|
||||
initial_state,
|
||||
config_block.get(CONF_VARIABLES),
|
||||
)
|
||||
|
||||
entities.append(entity)
|
||||
|
|
|
@ -12,6 +12,7 @@ from homeassistant.const import (
|
|||
CONF_ICON,
|
||||
CONF_MODE,
|
||||
CONF_SEQUENCE,
|
||||
CONF_VARIABLES,
|
||||
SERVICE_RELOAD,
|
||||
SERVICE_TOGGLE,
|
||||
SERVICE_TURN_OFF,
|
||||
|
@ -59,6 +60,7 @@ SCRIPT_ENTRY_SCHEMA = make_script_schema(
|
|||
vol.Optional(CONF_ICON): cv.icon,
|
||||
vol.Required(CONF_SEQUENCE): cv.SCRIPT_SCHEMA,
|
||||
vol.Optional(CONF_DESCRIPTION, default=""): cv.string,
|
||||
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
|
||||
vol.Optional(CONF_FIELDS, default={}): {
|
||||
cv.string: {
|
||||
vol.Optional(CONF_DESCRIPTION): cv.string,
|
||||
|
@ -75,7 +77,7 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
|
||||
SCRIPT_SERVICE_SCHEMA = vol.Schema(dict)
|
||||
SCRIPT_TURN_ONOFF_SCHEMA = make_entity_service_schema(
|
||||
{vol.Optional(ATTR_VARIABLES): dict}
|
||||
{vol.Optional(ATTR_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA}
|
||||
)
|
||||
RELOAD_SERVICE_SCHEMA = vol.Schema({})
|
||||
|
||||
|
@ -263,6 +265,7 @@ class ScriptEntity(ToggleEntity):
|
|||
max_runs=cfg[CONF_MAX],
|
||||
max_exceeded=cfg[CONF_MAX_EXCEEDED],
|
||||
logger=logging.getLogger(f"{__name__}.{object_id}"),
|
||||
variables=cfg.get(CONF_VARIABLES),
|
||||
)
|
||||
self._changed = asyncio.Event()
|
||||
|
||||
|
|
|
@ -179,6 +179,7 @@ CONF_UNTIL = "until"
|
|||
CONF_URL = "url"
|
||||
CONF_USERNAME = "username"
|
||||
CONF_VALUE_TEMPLATE = "value_template"
|
||||
CONF_VARIABLES = "variables"
|
||||
CONF_VERIFY_SSL = "verify_ssl"
|
||||
CONF_WAIT_FOR_TRIGGER = "wait_for_trigger"
|
||||
CONF_WAIT_TEMPLATE = "wait_template"
|
||||
|
|
|
@ -863,6 +863,9 @@ def make_entity_service_schema(
|
|||
)
|
||||
|
||||
|
||||
SCRIPT_VARIABLES_SCHEMA = vol.Schema({str: template_complex})
|
||||
|
||||
|
||||
def script_action(value: Any) -> dict:
|
||||
"""Validate a script action."""
|
||||
if not isinstance(value, dict):
|
||||
|
|
|
@ -53,11 +53,7 @@ from homeassistant.const import (
|
|||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import SERVICE_CALL_LIMIT, Context, HomeAssistant, callback
|
||||
from homeassistant.helpers import (
|
||||
condition,
|
||||
config_validation as cv,
|
||||
template as template,
|
||||
)
|
||||
from homeassistant.helpers import condition, config_validation as cv, template
|
||||
from homeassistant.helpers.event import async_call_later, async_track_template
|
||||
from homeassistant.helpers.service import (
|
||||
CONF_SERVICE_DATA,
|
||||
|
@ -721,6 +717,7 @@ class Script:
|
|||
logger: Optional[logging.Logger] = None,
|
||||
log_exceptions: bool = True,
|
||||
top_level: bool = True,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the script."""
|
||||
all_scripts = hass.data.get(DATA_SCRIPTS)
|
||||
|
@ -759,6 +756,10 @@ class Script:
|
|||
self._choose_data: Dict[int, Dict[str, Any]] = {}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
self.variables = variables
|
||||
self._variables_dynamic = template.is_complex(variables)
|
||||
if self._variables_dynamic:
|
||||
template.attach(hass, variables)
|
||||
|
||||
def _set_logger(self, logger: Optional[logging.Logger] = None) -> None:
|
||||
if logger:
|
||||
|
@ -867,7 +868,7 @@ class Script:
|
|||
|
||||
async def async_run(
|
||||
self,
|
||||
variables: Optional[_VarsType] = None,
|
||||
run_variables: Optional[_VarsType] = None,
|
||||
context: Optional[Context] = None,
|
||||
started_action: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
|
@ -898,8 +899,19 @@ class Script:
|
|||
# are read-only, but more importantly, so as not to leak any variables created
|
||||
# during the run back to the caller.
|
||||
if self._top_level:
|
||||
variables = dict(variables) if variables is not None else {}
|
||||
if self.variables:
|
||||
if self._variables_dynamic:
|
||||
variables = template.render_complex(self.variables, run_variables)
|
||||
else:
|
||||
variables = dict(self.variables)
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
if run_variables:
|
||||
variables.update(run_variables)
|
||||
variables["context"] = context
|
||||
else:
|
||||
variables = cast(dict, run_variables)
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
cls = _ScriptRun
|
||||
|
|
|
@ -65,7 +65,7 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
|
|||
if isinstance(obj, list):
|
||||
for child in obj:
|
||||
attach(hass, child)
|
||||
elif isinstance(obj, dict):
|
||||
elif isinstance(obj, collections.abc.Mapping):
|
||||
for child_key, child_value in obj.items():
|
||||
attach(hass, child_key)
|
||||
attach(hass, child_value)
|
||||
|
@ -77,7 +77,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
|
|||
"""Recursive template creator helper function."""
|
||||
if isinstance(value, list):
|
||||
return [render_complex(item, variables) for item in value]
|
||||
if isinstance(value, dict):
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
return {
|
||||
render_complex(key, variables): render_complex(item, variables)
|
||||
for key, item in value.items()
|
||||
|
@ -88,6 +88,19 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
def is_complex(value: Any) -> bool:
|
||||
"""Test if data structure is a complex template."""
|
||||
if isinstance(value, Template):
|
||||
return True
|
||||
if isinstance(value, list):
|
||||
return any(is_complex(val) for val in value)
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
return any(is_complex(val) for val in value.keys()) or any(
|
||||
is_complex(val) for val in value.values()
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def is_template_string(maybe_template: str) -> bool:
|
||||
"""Check if the input is a Jinja2 template."""
|
||||
return _RE_JINJA_DELIMITERS.search(maybe_template) is not None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Typing Helpers for Home Assistant."""
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
import homeassistant.core
|
||||
|
||||
|
@ -12,7 +12,7 @@ HomeAssistantType = homeassistant.core.HomeAssistant
|
|||
ServiceCallType = homeassistant.core.ServiceCall
|
||||
ServiceDataType = Dict[str, Any]
|
||||
StateType = Union[None, str, int, float]
|
||||
TemplateVarsType = Optional[Dict[str, Any]]
|
||||
TemplateVarsType = Optional[Mapping[str, Any]]
|
||||
|
||||
# Custom type for recorder Queries
|
||||
QueryType = Any
|
||||
|
|
|
@ -1134,3 +1134,57 @@ async def test_logbook_humanify_automation_triggered_event(hass):
|
|||
assert event2["domain"] == "automation"
|
||||
assert event2["message"] == "has been triggered by source of trigger"
|
||||
assert event2["entity_id"] == "automation.bye"
|
||||
|
||||
|
||||
async def test_automation_variables(hass):
|
||||
"""Test automation variables."""
|
||||
calls = async_mock_service(hass, "test", "automation")
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
automation.DOMAIN,
|
||||
{
|
||||
automation.DOMAIN: [
|
||||
{
|
||||
"variables": {
|
||||
"test_var": "defined_in_config",
|
||||
"event_type": "{{ trigger.event.event_type }}",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
"event_type": "{{ event_type }}",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"variables": {
|
||||
"test_var": "defined_in_config",
|
||||
},
|
||||
"trigger": {"platform": "event", "event_type": "test_event_2"},
|
||||
"condition": {
|
||||
"condition": "template",
|
||||
"value_template": "{{ trigger.event.data.pass_condition }}",
|
||||
},
|
||||
"action": {
|
||||
"service": "test.automation",
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data["value"] == "defined_in_config"
|
||||
assert calls[0].data["event_type"] == "test_event"
|
||||
|
||||
hass.bus.async_fire("test_event_2")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
hass.bus.async_fire("test_event_2", {"pass_condition": True})
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
|
|
|
@ -23,7 +23,7 @@ from homeassistant.loader import bind_hass
|
|||
from homeassistant.setup import async_setup_component, setup_component
|
||||
|
||||
from tests.async_mock import Mock, patch
|
||||
from tests.common import get_test_home_assistant
|
||||
from tests.common import async_mock_service, get_test_home_assistant
|
||||
from tests.components.logbook.test_init import MockLazyEventPartialState
|
||||
|
||||
ENTITY_ID = "script.test"
|
||||
|
@ -615,3 +615,69 @@ async def test_concurrent_script(hass, concurrently):
|
|||
|
||||
assert not script.is_on(hass, "script.script1")
|
||||
assert not script.is_on(hass, "script.script2")
|
||||
|
||||
|
||||
async def test_script_variables(hass):
|
||||
"""Test defining scripts."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"script1": {
|
||||
"variables": {
|
||||
"test_var": "from_config",
|
||||
"templated_config_var": "{{ var_from_service | default('config-default') }}",
|
||||
},
|
||||
"sequence": [
|
||||
{
|
||||
"service": "test.script",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
"templated_config_var": "{{ templated_config_var }}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"script2": {
|
||||
"variables": {
|
||||
"test_var": "from_config",
|
||||
},
|
||||
"sequence": [
|
||||
{
|
||||
"service": "test.script",
|
||||
"data": {
|
||||
"value": "{{ test_var }}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
mock_calls = async_mock_service(hass, "test", "script")
|
||||
|
||||
await hass.services.async_call(
|
||||
"script", "script1", {"var_from_service": "hello"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(mock_calls) == 1
|
||||
assert mock_calls[0].data["value"] == "from_config"
|
||||
assert mock_calls[0].data["templated_config_var"] == "hello"
|
||||
|
||||
await hass.services.async_call(
|
||||
"script", "script1", {"test_var": "from_service"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(mock_calls) == 2
|
||||
assert mock_calls[1].data["value"] == "from_service"
|
||||
assert mock_calls[1].data["templated_config_var"] == "config-default"
|
||||
|
||||
# Call script with vars but no templates in it
|
||||
await hass.services.async_call(
|
||||
"script", "script2", {"test_var": "from_service"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(mock_calls) == 3
|
||||
assert mock_calls[2].data["value"] == "from_service"
|
||||
|
|
Loading…
Reference in New Issue