Add support for multiple event triggers in automation (#43097)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/43130/head
Franck Nijhof 2020-11-12 11:58:28 +01:00 committed by GitHub
parent 673ac21de4
commit 6f326a7ea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 131 additions and 64 deletions

View File

@ -43,40 +43,35 @@ from homeassistant.helpers.script import (
ATTR_MODE,
CONF_MAX,
CONF_MAX_EXCEEDED,
SCRIPT_MODE_SINGLE,
Script,
make_script_schema,
)
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.trigger import async_initialize_triggers
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util.dt import parse_datetime
from .config import async_validate_config_item
from .const import (
CONF_ACTION,
CONF_CONDITION,
CONF_INITIAL_STATE,
CONF_TRIGGER,
DEFAULT_INITIAL_STATE,
DOMAIN,
LOGGER,
)
from .helpers import async_get_blueprints
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
DOMAIN = "automation"
ENTITY_ID_FORMAT = DOMAIN + ".{}"
DATA_BLUEPRINTS = "automation_blueprints"
CONF_DESCRIPTION = "description"
CONF_HIDE_ENTITY = "hide_entity"
CONF_CONDITION = "condition"
CONF_ACTION = "action"
CONF_TRIGGER = "trigger"
CONF_CONDITION_TYPE = "condition_type"
CONF_INITIAL_STATE = "initial_state"
CONF_SKIP_CONDITION = "skip_condition"
CONF_STOP_ACTIONS = "stop_actions"
CONF_BLUEPRINT = "blueprint"
CONF_INPUT = "input"
DEFAULT_INITIAL_STATE = True
DEFAULT_STOP_ACTIONS = True
EVENT_AUTOMATION_RELOADED = "automation_reloaded"
@ -87,38 +82,8 @@ ATTR_SOURCE = "source"
ATTR_VARIABLES = "variables"
SERVICE_TRIGGER = "trigger"
_LOGGER = logging.getLogger(__name__)
AutomationActionType = Callable[[HomeAssistant, TemplateVarsType], Awaitable[None]]
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
PLATFORM_SCHEMA = vol.All(
cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"),
make_script_schema(
{
# str on purpose
CONF_ID: str,
CONF_ALIAS: cv.string,
vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_INITIAL_STATE): cv.boolean,
vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
},
SCRIPT_MODE_SINGLE,
),
)
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: # type: ignore
"""Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, _LOGGER) # type: ignore
@bind_hass
def is_on(hass, entity_id):
@ -194,7 +159,7 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> List[str]:
async def async_setup(hass, config):
"""Set up the automation."""
hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
await _async_process_config(hass, config, component)
@ -263,7 +228,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._is_enabled = False
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
self._logger = _LOGGER
self._logger = LOGGER
self._variables: ScriptVariables = variables
@property
@ -536,10 +501,12 @@ async def _async_process_config(
try:
config_block = cast(
Dict[str, Any],
PLATFORM_SCHEMA(blueprint_inputs.async_substitute()),
await async_validate_config_item(
hass, blueprint_inputs.async_substitute()
),
)
except vol.Invalid as err:
_LOGGER.error(
LOGGER.error(
"Blueprint %s generated invalid automation with inputs %s: %s",
blueprint_inputs.blueprint.name,
blueprint_inputs.inputs,
@ -561,7 +528,7 @@ async def _async_process_config(
script_mode=config_block[CONF_MODE],
max_runs=config_block[CONF_MAX],
max_exceeded=config_block[CONF_MAX_EXCEEDED],
logger=_LOGGER,
logger=LOGGER,
# We don't pass variables here
# Automation will already render them to use them in the condition
# and so will pass them on to the script.
@ -600,7 +567,7 @@ async def _async_process_if(hass, config, p_config):
try:
checks.append(await condition.async_from_config(hass, if_config, False))
except HomeAssistantError as ex:
_LOGGER.warning("Invalid condition: %s", ex)
LOGGER.warning("Invalid condition: %s", ex)
return None
def if_action(variables=None):

View File

@ -8,25 +8,48 @@ from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig,
)
from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.const import CONF_ALIAS, CONF_ID, CONF_VARIABLES
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform
from homeassistant.helpers import config_per_platform, config_validation as cv, script
from homeassistant.helpers.condition import async_validate_condition_config
from homeassistant.helpers.script import async_validate_actions_config
from homeassistant.helpers.trigger import async_validate_trigger_config
from homeassistant.loader import IntegrationNotFound
from . import (
from .const import (
CONF_ACTION,
CONF_CONDITION,
CONF_DESCRIPTION,
CONF_HIDE_ENTITY,
CONF_INITIAL_STATE,
CONF_TRIGGER,
DOMAIN,
PLATFORM_SCHEMA,
async_get_blueprints,
)
from .helpers import async_get_blueprints
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
PLATFORM_SCHEMA = vol.All(
cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"),
script.make_script_schema(
{
# str on purpose
CONF_ID: str,
CONF_ALIAS: cv.string,
vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_INITIAL_STATE): cv.boolean,
vol.Optional(CONF_HIDE_ENTITY): cv.boolean,
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
},
script.SCRIPT_MODE_SINGLE,
),
)
async def async_validate_config_item(hass, config, full_config=None):
"""Validate config item."""
@ -48,7 +71,9 @@ async def async_validate_config_item(hass, config, full_config=None):
]
)
config[CONF_ACTION] = await async_validate_actions_config(hass, config[CONF_ACTION])
config[CONF_ACTION] = await script.async_validate_actions_config(
hass, config[CONF_ACTION]
)
return config

View File

@ -0,0 +1,19 @@
"""Constants for the automation integration."""
import logging
CONF_CONDITION = "condition"
CONF_ACTION = "action"
CONF_TRIGGER = "trigger"
DOMAIN = "automation"
CONF_DESCRIPTION = "description"
CONF_HIDE_ENTITY = "hide_entity"
CONF_CONDITION_TYPE = "condition_type"
CONF_INITIAL_STATE = "initial_state"
CONF_BLUEPRINT = "blueprint"
CONF_INPUT = "input"
DEFAULT_INITIAL_STATE = True
LOGGER = logging.getLogger(__package__)

View File

@ -0,0 +1,15 @@
"""Helpers for automation integration."""
from homeassistant.components import blueprint
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.singleton import singleton
from .const import DOMAIN, LOGGER
DATA_BLUEPRINTS = "automation_blueprints"
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: # type: ignore
"""Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER) # type: ignore

View File

@ -2,8 +2,11 @@
from collections import OrderedDict
import uuid
from homeassistant.components.automation import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.automation.config import async_validate_config_item
from homeassistant.components.automation.config import (
DOMAIN,
PLATFORM_SCHEMA,
async_validate_config_item,
)
from homeassistant.config import AUTOMATION_CONFIG_PATH
from homeassistant.const import CONF_ID, SERVICE_RELOAD
from homeassistant.helpers import config_validation as cv, entity_registry

View File

@ -14,7 +14,7 @@ CONF_EVENT_CONTEXT = "context"
TRIGGER_SCHEMA = vol.Schema(
{
vol.Required(CONF_PLATFORM): "event",
vol.Required(CONF_EVENT_TYPE): cv.string,
vol.Required(CONF_EVENT_TYPE): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_EVENT_DATA): dict,
vol.Optional(CONF_EVENT_CONTEXT): dict,
}
@ -32,7 +32,8 @@ async def async_attach_trigger(
hass, config, action, automation_info, *, platform_type="event"
):
"""Listen for events based on configuration."""
event_type = config.get(CONF_EVENT_TYPE)
event_types = config.get(CONF_EVENT_TYPE)
removes = []
event_data_schema = None
if config.get(CONF_EVENT_DATA):
@ -82,4 +83,14 @@ async def async_attach_trigger(
event.context,
)
return hass.bus.async_listen(event_type, handle_event)
removes = [
hass.bus.async_listen(event_type, handle_event) for event_type in event_types
]
@callback
def remove_listen_events():
"""Remove event listeners."""
for remove in removes:
remove()
return remove_listen_events

View File

@ -59,6 +59,33 @@ async def test_if_fires_on_event(hass, calls):
assert len(calls) == 1
async def test_if_fires_on_multiple_events(hass, calls):
"""Test the firing of events."""
context = Context()
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "event",
"event_type": ["test_event", "test2_event"],
},
"action": {"service": "test.automation"},
}
},
)
hass.bus.async_fire("test_event", context=context)
await hass.async_block_till_done()
hass.bus.async_fire("test2_event", context=context)
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].context.parent_id == context.id
assert calls[1].context.parent_id == context.id
async def test_if_fires_on_event_extra_data(hass, calls, context_with_user):
"""Test the firing of events still matches with event data and context."""
assert await async_setup_component(