Improve type hints in automation (#78368)
* Improve type hints in automation * Apply suggestion * Apply suggestion * Apply suggestion * Add Protocol for IfAction * Use ConfigType for IfAction * Rename variablepull/78457/head
parent
b7e9fcb9fe
commit
5e338d2166
homeassistant/components/automation
|
@ -1,9 +1,9 @@
|
|||
"""Allow to set up simple automation rules via the config file."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
@ -31,9 +31,12 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Context,
|
||||
CoreState,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
callback,
|
||||
split_entity_id,
|
||||
valid_entity_id,
|
||||
|
@ -99,9 +102,6 @@ from .const import (
|
|||
from .helpers import async_get_blueprints
|
||||
from .trace import trace_automation
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
||||
ENTITY_ID_FORMAT = DOMAIN + ".{}"
|
||||
|
||||
|
||||
|
@ -120,6 +120,15 @@ SERVICE_TRIGGER = "trigger"
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IfAction(Protocol):
|
||||
"""Define the format of if_action."""
|
||||
|
||||
config: list[ConfigType]
|
||||
|
||||
def __call__(self, variables: Mapping[str, Any] | None = None) -> bool:
|
||||
"""AND all conditions."""
|
||||
|
||||
|
||||
# AutomationActionType, AutomationTriggerData,
|
||||
# and AutomationTriggerInfo are deprecated as of 2022.9.
|
||||
AutomationActionType = TriggerActionType
|
||||
|
@ -128,7 +137,7 @@ AutomationTriggerInfo = TriggerInfo
|
|||
|
||||
|
||||
@bind_hass
|
||||
def is_on(hass, entity_id):
|
||||
def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
||||
"""
|
||||
Return true if specified automation entity_id is on.
|
||||
|
||||
|
@ -143,12 +152,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if entity_id in automation_entity.referenced_entities
|
||||
if entity_id in cast(AutomationEntity, automation_entity).referenced_entities
|
||||
]
|
||||
|
||||
|
||||
|
@ -158,12 +167,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(automation_entity.referenced_entities)
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_entities)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -172,12 +181,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if device_id in automation_entity.referenced_devices
|
||||
if device_id in cast(AutomationEntity, automation_entity).referenced_devices
|
||||
]
|
||||
|
||||
|
||||
|
@ -187,12 +196,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(automation_entity.referenced_devices)
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_devices)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -201,12 +210,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
return [
|
||||
automation_entity.entity_id
|
||||
for automation_entity in component.entities
|
||||
if area_id in automation_entity.referenced_areas
|
||||
if area_id in cast(AutomationEntity, automation_entity).referenced_areas
|
||||
]
|
||||
|
||||
|
||||
|
@ -216,12 +225,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
|
|||
if DOMAIN not in hass.data:
|
||||
return []
|
||||
|
||||
component = hass.data[DOMAIN]
|
||||
component: EntityComponent = hass.data[DOMAIN]
|
||||
|
||||
if (automation_entity := component.get_entity(entity_id)) is None:
|
||||
return []
|
||||
|
||||
return list(automation_entity.referenced_areas)
|
||||
return list(cast(AutomationEntity, automation_entity).referenced_areas)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
@ -238,7 +247,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
if not await _async_process_config(hass, config, component):
|
||||
await async_get_blueprints(hass).async_populate()
|
||||
|
||||
async def trigger_service_handler(entity, service_call):
|
||||
async def trigger_service_handler(
|
||||
entity: AutomationEntity, service_call: ServiceCall
|
||||
) -> None:
|
||||
"""Handle forced automation trigger, e.g. from frontend."""
|
||||
await entity.async_trigger(
|
||||
{**service_call.data[ATTR_VARIABLES], "trigger": {"platform": None}},
|
||||
|
@ -262,7 +273,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
"async_turn_off",
|
||||
)
|
||||
|
||||
async def reload_service_handler(service_call):
|
||||
async def reload_service_handler(service_call: ServiceCall) -> None:
|
||||
"""Remove all automations and load new ones from config."""
|
||||
if (conf := await component.async_prepare_reload()) is None:
|
||||
return
|
||||
|
@ -290,22 +301,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
automation_id,
|
||||
name,
|
||||
trigger_config,
|
||||
cond_func,
|
||||
action_script,
|
||||
initial_state,
|
||||
variables,
|
||||
trigger_variables,
|
||||
raw_config,
|
||||
blueprint_inputs,
|
||||
trace_config,
|
||||
):
|
||||
automation_id: str | None,
|
||||
name: str,
|
||||
trigger_config: list[ConfigType],
|
||||
cond_func: IfAction | None,
|
||||
action_script: Script,
|
||||
initial_state: bool | None,
|
||||
variables: ScriptVariables | None,
|
||||
trigger_variables: ScriptVariables | None,
|
||||
raw_config: ConfigType | None,
|
||||
blueprint_inputs: ConfigType | None,
|
||||
trace_config: ConfigType,
|
||||
) -> None:
|
||||
"""Initialize an automation entity."""
|
||||
self._attr_name = name
|
||||
self._trigger_config = trigger_config
|
||||
self._async_detach_triggers = None
|
||||
self._async_detach_triggers: CALLBACK_TYPE | None = None
|
||||
self._cond_func = cond_func
|
||||
self.action_script = action_script
|
||||
self.action_script.change_listener = self.async_write_ha_state
|
||||
|
@ -314,15 +325,15 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
self._referenced_entities: set[str] | None = None
|
||||
self._referenced_devices: set[str] | None = None
|
||||
self._logger = LOGGER
|
||||
self._variables: ScriptVariables = variables
|
||||
self._trigger_variables: ScriptVariables = trigger_variables
|
||||
self._variables = variables
|
||||
self._trigger_variables = trigger_variables
|
||||
self._raw_config = raw_config
|
||||
self._blueprint_inputs = blueprint_inputs
|
||||
self._trace_config = trace_config
|
||||
self._attr_unique_id = automation_id
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self):
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return the entity state attributes."""
|
||||
attrs = {
|
||||
ATTR_LAST_TRIGGERED: self.action_script.last_triggered,
|
||||
|
@ -341,12 +352,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
return self._async_detach_triggers is not None or self._is_enabled
|
||||
|
||||
@property
|
||||
def referenced_areas(self):
|
||||
def referenced_areas(self) -> set[str]:
|
||||
"""Return a set of referenced areas."""
|
||||
return self.action_script.referenced_areas
|
||||
|
||||
@property
|
||||
def referenced_devices(self):
|
||||
def referenced_devices(self) -> set[str]:
|
||||
"""Return a set of referenced devices."""
|
||||
if self._referenced_devices is not None:
|
||||
return self._referenced_devices
|
||||
|
@ -364,7 +375,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
return referenced
|
||||
|
||||
@property
|
||||
def referenced_entities(self):
|
||||
def referenced_entities(self) -> set[str]:
|
||||
"""Return a set of referenced entities."""
|
||||
if self._referenced_entities is not None:
|
||||
return self._referenced_entities
|
||||
|
@ -513,7 +524,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
event_data[ATTR_SOURCE] = variables["trigger"]["description"]
|
||||
|
||||
@callback
|
||||
def started_action():
|
||||
def started_action() -> None:
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
|
||||
)
|
||||
|
@ -555,12 +566,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
self._logger.exception("While executing automation %s", self.entity_id)
|
||||
automation_trace.set_error(err)
|
||||
|
||||
async def async_will_remove_from_hass(self):
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Remove listeners when removing automation from Home Assistant."""
|
||||
await super().async_will_remove_from_hass()
|
||||
await self.async_disable()
|
||||
|
||||
async def async_enable(self):
|
||||
async def async_enable(self) -> None:
|
||||
"""Enable this automation entity.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -576,7 +587,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
async def async_enable_automation(event):
|
||||
async def async_enable_automation(event: Event) -> None:
|
||||
"""Start automation on startup."""
|
||||
# Don't do anything if no longer enabled or already attached
|
||||
if not self._is_enabled or self._async_detach_triggers is not None:
|
||||
|
@ -589,7 +600,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
)
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_disable(self, stop_actions=DEFAULT_STOP_ACTIONS):
|
||||
async def async_disable(self, stop_actions: bool = DEFAULT_STOP_ACTIONS) -> None:
|
||||
"""Disable the automation entity."""
|
||||
if not self._is_enabled and not self.action_script.runs:
|
||||
return
|
||||
|
@ -610,7 +621,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
) -> Callable[[], None] | None:
|
||||
"""Set up the triggers."""
|
||||
|
||||
def log_cb(level, msg, **kwargs):
|
||||
def log_cb(level: int, msg: str, **kwargs: Any) -> None:
|
||||
self._logger.log(level, "%s %s", msg, self.name, **kwargs)
|
||||
|
||||
this = None
|
||||
|
@ -650,7 +661,7 @@ async def _async_process_config(
|
|||
|
||||
Returns if blueprints were used.
|
||||
"""
|
||||
entities = []
|
||||
entities: list[AutomationEntity] = []
|
||||
blueprints_used = False
|
||||
|
||||
for config_key in extract_domain_configs(config, DOMAIN):
|
||||
|
@ -681,10 +692,10 @@ async def _async_process_config(
|
|||
else:
|
||||
raw_config = cast(AutomationConfig, config_block).raw_config
|
||||
|
||||
automation_id = config_block.get(CONF_ID)
|
||||
name = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}"
|
||||
automation_id: str | None = config_block.get(CONF_ID)
|
||||
name: str = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}"
|
||||
|
||||
initial_state = config_block.get(CONF_INITIAL_STATE)
|
||||
initial_state: bool | None = config_block.get(CONF_INITIAL_STATE)
|
||||
|
||||
action_script = Script(
|
||||
hass,
|
||||
|
@ -743,11 +754,13 @@ async def _async_process_config(
|
|||
return blueprints_used
|
||||
|
||||
|
||||
async def _async_process_if(hass, name, config, p_config):
|
||||
async def _async_process_if(
|
||||
hass: HomeAssistant, name: str, config: dict[str, Any], p_config: dict[str, Any]
|
||||
) -> IfAction | None:
|
||||
"""Process if checks."""
|
||||
if_configs = p_config[CONF_CONDITION]
|
||||
|
||||
checks = []
|
||||
checks: list[condition.ConditionCheckerType] = []
|
||||
for if_config in if_configs:
|
||||
try:
|
||||
checks.append(await condition.async_from_config(hass, if_config))
|
||||
|
@ -755,9 +768,9 @@ async def _async_process_if(hass, name, config, p_config):
|
|||
LOGGER.warning("Invalid condition: %s", ex)
|
||||
return None
|
||||
|
||||
def if_action(variables=None):
|
||||
def if_action(variables: Mapping[str, Any] | None = None) -> bool:
|
||||
"""AND all conditions."""
|
||||
errors = []
|
||||
errors: list[ConditionErrorIndex] = []
|
||||
for index, check in enumerate(checks):
|
||||
try:
|
||||
with trace_path(["condition", str(index)]):
|
||||
|
@ -780,9 +793,10 @@ async def _async_process_if(hass, name, config, p_config):
|
|||
|
||||
return True
|
||||
|
||||
if_action.config = if_configs
|
||||
result: IfAction = if_action # type: ignore[assignment]
|
||||
result.config = if_configs
|
||||
|
||||
return if_action
|
||||
return result
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -800,7 +814,7 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
|
|||
return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]]
|
||||
|
||||
if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf:
|
||||
return trigger_conf[CONF_DEVICE_ID]
|
||||
return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return]
|
||||
|
||||
return []
|
||||
|
||||
|
@ -809,13 +823,13 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
|
|||
def _trigger_extract_entities(trigger_conf: dict) -> list[str]:
|
||||
"""Extract entities from a trigger config."""
|
||||
if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"):
|
||||
return trigger_conf[CONF_ENTITY_ID]
|
||||
return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return]
|
||||
|
||||
if trigger_conf[CONF_PLATFORM] == "calendar":
|
||||
return [trigger_conf[CONF_ENTITY_ID]]
|
||||
|
||||
if trigger_conf[CONF_PLATFORM] == "zone":
|
||||
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]]
|
||||
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return]
|
||||
|
||||
if trigger_conf[CONF_PLATFORM] == "geo_location":
|
||||
return [trigger_conf[CONF_ZONE]]
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Config validation helper for the automation integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -17,10 +20,12 @@ from homeassistant.const import (
|
|||
CONF_ID,
|
||||
CONF_VARIABLES,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform, config_validation as cv, script
|
||||
from homeassistant.helpers.condition import async_validate_conditions_config
|
||||
from homeassistant.helpers.trigger import async_validate_trigger_config
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import IntegrationNotFound
|
||||
|
||||
from .const import (
|
||||
|
@ -34,9 +39,6 @@ from .const import (
|
|||
)
|
||||
from .helpers import async_get_blueprints
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
||||
PACKAGE_MERGE_HINT = "list"
|
||||
|
||||
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
|
||||
|
@ -63,7 +65,11 @@ PLATFORM_SCHEMA = vol.All(
|
|||
)
|
||||
|
||||
|
||||
async def async_validate_config_item(hass, config, full_config=None):
|
||||
async def async_validate_config_item(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
full_config: ConfigType | None = None,
|
||||
) -> blueprint.BlueprintInputs | dict[str, Any]:
|
||||
"""Validate config item."""
|
||||
if blueprint.is_blueprint_instance_config(config):
|
||||
blueprints = async_get_blueprints(hass)
|
||||
|
@ -90,17 +96,21 @@ async def async_validate_config_item(hass, config, full_config=None):
|
|||
class AutomationConfig(dict):
|
||||
"""Dummy class to allow adding attributes."""
|
||||
|
||||
raw_config = None
|
||||
raw_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
async def _try_async_validate_config_item(hass, config, full_config=None):
|
||||
async def _try_async_validate_config_item(
|
||||
hass: HomeAssistant,
|
||||
config: dict[str, Any],
|
||||
full_config: dict[str, Any] | None = None,
|
||||
) -> AutomationConfig | blueprint.BlueprintInputs | None:
|
||||
"""Validate config item."""
|
||||
raw_config = None
|
||||
with suppress(ValueError):
|
||||
raw_config = dict(config)
|
||||
|
||||
try:
|
||||
config = await async_validate_config_item(hass, config, full_config)
|
||||
validated_config = await async_validate_config_item(hass, config, full_config)
|
||||
except (
|
||||
vol.Invalid,
|
||||
HomeAssistantError,
|
||||
|
@ -110,15 +120,15 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
|
|||
async_log_exception(ex, DOMAIN, full_config or config, hass)
|
||||
return None
|
||||
|
||||
if isinstance(config, blueprint.BlueprintInputs):
|
||||
return config
|
||||
if isinstance(validated_config, blueprint.BlueprintInputs):
|
||||
return validated_config
|
||||
|
||||
config = AutomationConfig(config)
|
||||
config.raw_config = raw_config
|
||||
return config
|
||||
automation_config = AutomationConfig(validated_config)
|
||||
automation_config.raw_config = raw_config
|
||||
return automation_config
|
||||
|
||||
|
||||
async def async_validate_config(hass, config):
|
||||
async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType:
|
||||
"""Validate config."""
|
||||
automations = list(
|
||||
filter(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Trace support for automation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
@ -9,13 +10,11 @@ from homeassistant.components.trace import (
|
|||
ActionTrace,
|
||||
async_store_trace,
|
||||
)
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
||||
|
||||
class AutomationTrace(ActionTrace):
|
||||
"""Container for automation trace."""
|
||||
|
@ -24,9 +23,9 @@ class AutomationTrace(ActionTrace):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
item_id: str,
|
||||
config: dict[str, Any],
|
||||
blueprint_inputs: dict[str, Any],
|
||||
item_id: str | None,
|
||||
config: ConfigType | None,
|
||||
blueprint_inputs: ConfigType | None,
|
||||
context: Context,
|
||||
) -> None:
|
||||
"""Container for automation trace."""
|
||||
|
@ -49,8 +48,13 @@ class AutomationTrace(ActionTrace):
|
|||
|
||||
@contextmanager
|
||||
def trace_automation(
|
||||
hass, automation_id, config, blueprint_inputs, context, trace_config
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
automation_id: str | None,
|
||||
config: ConfigType | None,
|
||||
blueprint_inputs: ConfigType | None,
|
||||
context: Context,
|
||||
trace_config: ConfigType,
|
||||
) -> Generator[AutomationTrace, None, None]:
|
||||
"""Trace action execution of automation with automation_id."""
|
||||
trace = AutomationTrace(automation_id, config, blueprint_inputs, context)
|
||||
async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])
|
||||
|
|
Loading…
Reference in New Issue