Improve type hints in automation ()

* Improve type hints in automation

* Apply suggestion

* Apply suggestion

* Apply suggestion

* Add Protocol for IfAction

* Use ConfigType for IfAction

* Rename variable
pull/78457/head
epenet 2022-09-14 13:04:09 +02:00 committed by GitHub
parent b7e9fcb9fe
commit 5e338d2166
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 80 deletions
homeassistant/components/automation

View File

@ -1,9 +1,9 @@
"""Allow to set up simple automation rules via the config file."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Mapping
import logging
from typing import Any, cast
from typing import Any, Protocol, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -31,9 +31,12 @@ from homeassistant.const import (
STATE_ON,
)
from homeassistant.core import (
CALLBACK_TYPE,
Context,
CoreState,
Event,
HomeAssistant,
ServiceCall,
callback,
split_entity_id,
valid_entity_id,
@ -99,9 +102,6 @@ from .const import (
from .helpers import async_get_blueprints
from .trace import trace_automation
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -120,6 +120,15 @@ SERVICE_TRIGGER = "trigger"
_LOGGER = logging.getLogger(__name__)
class IfAction(Protocol):
"""Define the format of if_action."""
config: list[ConfigType]
def __call__(self, variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions."""
# AutomationActionType, AutomationTriggerData,
# and AutomationTriggerInfo are deprecated as of 2022.9.
AutomationActionType = TriggerActionType
@ -128,7 +137,7 @@ AutomationTriggerInfo = TriggerInfo
@bind_hass
def is_on(hass, entity_id):
def is_on(hass: HomeAssistant, entity_id: str) -> bool:
"""
Return true if specified automation entity_id is on.
@ -143,12 +152,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if entity_id in automation_entity.referenced_entities
if entity_id in cast(AutomationEntity, automation_entity).referenced_entities
]
@ -158,12 +167,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(automation_entity.referenced_entities)
return list(cast(AutomationEntity, automation_entity).referenced_entities)
@callback
@ -172,12 +181,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if device_id in automation_entity.referenced_devices
if device_id in cast(AutomationEntity, automation_entity).referenced_devices
]
@ -187,12 +196,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(automation_entity.referenced_devices)
return list(cast(AutomationEntity, automation_entity).referenced_devices)
@callback
@ -201,12 +210,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if area_id in automation_entity.referenced_areas
if area_id in cast(AutomationEntity, automation_entity).referenced_areas
]
@ -216,12 +225,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
component: EntityComponent = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(automation_entity.referenced_areas)
return list(cast(AutomationEntity, automation_entity).referenced_areas)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -238,7 +247,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if not await _async_process_config(hass, config, component):
await async_get_blueprints(hass).async_populate()
async def trigger_service_handler(entity, service_call):
async def trigger_service_handler(
entity: AutomationEntity, service_call: ServiceCall
) -> None:
"""Handle forced automation trigger, e.g. from frontend."""
await entity.async_trigger(
{**service_call.data[ATTR_VARIABLES], "trigger": {"platform": None}},
@ -262,7 +273,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_turn_off",
)
async def reload_service_handler(service_call):
async def reload_service_handler(service_call: ServiceCall) -> None:
"""Remove all automations and load new ones from config."""
if (conf := await component.async_prepare_reload()) is None:
return
@ -290,22 +301,22 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
def __init__(
self,
automation_id,
name,
trigger_config,
cond_func,
action_script,
initial_state,
variables,
trigger_variables,
raw_config,
blueprint_inputs,
trace_config,
):
automation_id: str | None,
name: str,
trigger_config: list[ConfigType],
cond_func: IfAction | None,
action_script: Script,
initial_state: bool | None,
variables: ScriptVariables | None,
trigger_variables: ScriptVariables | None,
raw_config: ConfigType | None,
blueprint_inputs: ConfigType | None,
trace_config: ConfigType,
) -> None:
"""Initialize an automation entity."""
self._attr_name = name
self._trigger_config = trigger_config
self._async_detach_triggers = None
self._async_detach_triggers: CALLBACK_TYPE | None = None
self._cond_func = cond_func
self.action_script = action_script
self.action_script.change_listener = self.async_write_ha_state
@ -314,15 +325,15 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_entities: set[str] | None = None
self._referenced_devices: set[str] | None = None
self._logger = LOGGER
self._variables: ScriptVariables = variables
self._trigger_variables: ScriptVariables = trigger_variables
self._variables = variables
self._trigger_variables = trigger_variables
self._raw_config = raw_config
self._blueprint_inputs = blueprint_inputs
self._trace_config = trace_config
self._attr_unique_id = automation_id
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the entity state attributes."""
attrs = {
ATTR_LAST_TRIGGERED: self.action_script.last_triggered,
@ -341,12 +352,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
return self._async_detach_triggers is not None or self._is_enabled
@property
def referenced_areas(self):
def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas."""
return self.action_script.referenced_areas
@property
def referenced_devices(self):
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""
if self._referenced_devices is not None:
return self._referenced_devices
@ -364,7 +375,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
return referenced
@property
def referenced_entities(self):
def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities."""
if self._referenced_entities is not None:
return self._referenced_entities
@ -513,7 +524,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
event_data[ATTR_SOURCE] = variables["trigger"]["description"]
@callback
def started_action():
def started_action() -> None:
self.hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
)
@ -555,12 +566,12 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._logger.exception("While executing automation %s", self.entity_id)
automation_trace.set_error(err)
async def async_will_remove_from_hass(self):
async def async_will_remove_from_hass(self) -> None:
"""Remove listeners when removing automation from Home Assistant."""
await super().async_will_remove_from_hass()
await self.async_disable()
async def async_enable(self):
async def async_enable(self) -> None:
"""Enable this automation entity.
This method is a coroutine.
@ -576,7 +587,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self.async_write_ha_state()
return
async def async_enable_automation(event):
async def async_enable_automation(event: Event) -> None:
"""Start automation on startup."""
# Don't do anything if no longer enabled or already attached
if not self._is_enabled or self._async_detach_triggers is not None:
@ -589,7 +600,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
)
self.async_write_ha_state()
async def async_disable(self, stop_actions=DEFAULT_STOP_ACTIONS):
async def async_disable(self, stop_actions: bool = DEFAULT_STOP_ACTIONS) -> None:
"""Disable the automation entity."""
if not self._is_enabled and not self.action_script.runs:
return
@ -610,7 +621,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
) -> Callable[[], None] | None:
"""Set up the triggers."""
def log_cb(level, msg, **kwargs):
def log_cb(level: int, msg: str, **kwargs: Any) -> None:
self._logger.log(level, "%s %s", msg, self.name, **kwargs)
this = None
@ -650,7 +661,7 @@ async def _async_process_config(
Returns if blueprints were used.
"""
entities = []
entities: list[AutomationEntity] = []
blueprints_used = False
for config_key in extract_domain_configs(config, DOMAIN):
@ -681,10 +692,10 @@ async def _async_process_config(
else:
raw_config = cast(AutomationConfig, config_block).raw_config
automation_id = config_block.get(CONF_ID)
name = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}"
automation_id: str | None = config_block.get(CONF_ID)
name: str = config_block.get(CONF_ALIAS) or f"{config_key} {list_no}"
initial_state = config_block.get(CONF_INITIAL_STATE)
initial_state: bool | None = config_block.get(CONF_INITIAL_STATE)
action_script = Script(
hass,
@ -743,11 +754,13 @@ async def _async_process_config(
return blueprints_used
async def _async_process_if(hass, name, config, p_config):
async def _async_process_if(
hass: HomeAssistant, name: str, config: dict[str, Any], p_config: dict[str, Any]
) -> IfAction | None:
"""Process if checks."""
if_configs = p_config[CONF_CONDITION]
checks = []
checks: list[condition.ConditionCheckerType] = []
for if_config in if_configs:
try:
checks.append(await condition.async_from_config(hass, if_config))
@ -755,9 +768,9 @@ async def _async_process_if(hass, name, config, p_config):
LOGGER.warning("Invalid condition: %s", ex)
return None
def if_action(variables=None):
def if_action(variables: Mapping[str, Any] | None = None) -> bool:
"""AND all conditions."""
errors = []
errors: list[ConditionErrorIndex] = []
for index, check in enumerate(checks):
try:
with trace_path(["condition", str(index)]):
@ -780,9 +793,10 @@ async def _async_process_if(hass, name, config, p_config):
return True
if_action.config = if_configs
result: IfAction = if_action # type: ignore[assignment]
result.config = if_configs
return if_action
return result
@callback
@ -800,7 +814,7 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]]
if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf:
return trigger_conf[CONF_DEVICE_ID]
return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return]
return []
@ -809,13 +823,13 @@ def _trigger_extract_devices(trigger_conf: dict) -> list[str]:
def _trigger_extract_entities(trigger_conf: dict) -> list[str]:
"""Extract entities from a trigger config."""
if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"):
return trigger_conf[CONF_ENTITY_ID]
return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return]
if trigger_conf[CONF_PLATFORM] == "calendar":
return [trigger_conf[CONF_ENTITY_ID]]
if trigger_conf[CONF_PLATFORM] == "zone":
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]]
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return]
if trigger_conf[CONF_PLATFORM] == "geo_location":
return [trigger_conf[CONF_ZONE]]

View File

@ -1,6 +1,9 @@
"""Config validation helper for the automation integration."""
from __future__ import annotations
import asyncio
from contextlib import suppress
from typing import Any
import voluptuous as vol
@ -17,10 +20,12 @@ from homeassistant.const import (
CONF_ID,
CONF_VARIABLES,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, config_validation as cv, script
from homeassistant.helpers.condition import async_validate_conditions_config
from homeassistant.helpers.trigger import async_validate_trigger_config
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound
from .const import (
@ -34,9 +39,6 @@ from .const import (
)
from .helpers import async_get_blueprints
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
PACKAGE_MERGE_HINT = "list"
_CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
@ -63,7 +65,11 @@ PLATFORM_SCHEMA = vol.All(
)
async def async_validate_config_item(hass, config, full_config=None):
async def async_validate_config_item(
hass: HomeAssistant,
config: ConfigType,
full_config: ConfigType | None = None,
) -> blueprint.BlueprintInputs | dict[str, Any]:
"""Validate config item."""
if blueprint.is_blueprint_instance_config(config):
blueprints = async_get_blueprints(hass)
@ -90,17 +96,21 @@ async def async_validate_config_item(hass, config, full_config=None):
class AutomationConfig(dict):
"""Dummy class to allow adding attributes."""
raw_config = None
raw_config: dict[str, Any] | None = None
async def _try_async_validate_config_item(hass, config, full_config=None):
async def _try_async_validate_config_item(
hass: HomeAssistant,
config: dict[str, Any],
full_config: dict[str, Any] | None = None,
) -> AutomationConfig | blueprint.BlueprintInputs | None:
"""Validate config item."""
raw_config = None
with suppress(ValueError):
raw_config = dict(config)
try:
config = await async_validate_config_item(hass, config, full_config)
validated_config = await async_validate_config_item(hass, config, full_config)
except (
vol.Invalid,
HomeAssistantError,
@ -110,15 +120,15 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
async_log_exception(ex, DOMAIN, full_config or config, hass)
return None
if isinstance(config, blueprint.BlueprintInputs):
return config
if isinstance(validated_config, blueprint.BlueprintInputs):
return validated_config
config = AutomationConfig(config)
config.raw_config = raw_config
return config
automation_config = AutomationConfig(validated_config)
automation_config.raw_config = raw_config
return automation_config
async def async_validate_config(hass, config):
async def async_validate_config(hass: HomeAssistant, config: ConfigType) -> ConfigType:
"""Validate config."""
automations = list(
filter(

View File

@ -1,6 +1,7 @@
"""Trace support for automation."""
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
@ -9,13 +10,11 @@ from homeassistant.components.trace import (
ActionTrace,
async_store_trace,
)
from homeassistant.core import Context
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
class AutomationTrace(ActionTrace):
"""Container for automation trace."""
@ -24,9 +23,9 @@ class AutomationTrace(ActionTrace):
def __init__(
self,
item_id: str,
config: dict[str, Any],
blueprint_inputs: dict[str, Any],
item_id: str | None,
config: ConfigType | None,
blueprint_inputs: ConfigType | None,
context: Context,
) -> None:
"""Container for automation trace."""
@ -49,8 +48,13 @@ class AutomationTrace(ActionTrace):
@contextmanager
def trace_automation(
hass, automation_id, config, blueprint_inputs, context, trace_config
):
hass: HomeAssistant,
automation_id: str | None,
config: ConfigType | None,
blueprint_inputs: ConfigType | None,
context: Context,
trace_config: ConfigType,
) -> Generator[AutomationTrace, None, None]:
"""Trace action execution of automation with automation_id."""
trace = AutomationTrace(automation_id, config, blueprint_inputs, context)
async_store_trace(hass, trace, trace_config[CONF_STORED_TRACES])