From d3be06906bec9ec61c01958893c02ca1b17b4099 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Tue, 13 Sep 2022 23:11:29 +0200 Subject: [PATCH] Improve type hints in script helpers (#78364) * Improve type hints in script helpers * Import CONF_SERVICE_DATA from homeassistant.const * Make data optional --- homeassistant/helpers/script.py | 41 +++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 54ae4f456ab..e472934fc76 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import asynccontextmanager, suppress from contextvars import ContextVar from copy import copy @@ -49,6 +49,7 @@ from homeassistant.const import ( CONF_SCENE, CONF_SEQUENCE, CONF_SERVICE, + CONF_SERVICE_DATA, CONF_STOP, CONF_TARGET, CONF_THEN, @@ -218,7 +219,9 @@ async def trace_action(hass, script_run, stop, variables): trace_stack_pop(trace_stack_cv) -def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA): +def make_script_schema( + schema: Mapping[Any, Any], default_script_mode: str, extra: int = vol.PREVENT_EXTRA +) -> vol.Schema: """Make a schema for a component that uses the script helper.""" return vol.Schema( { @@ -1109,7 +1112,9 @@ async def _async_stop_scripts_at_shutdown(hass, event): _VarsType = Union[dict[str, Any], MappingProxyType] -def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None: +def _referenced_extract_ids( + data: dict[str, Any] | None, key: str, found: set[str] +) -> None: """Extract referenced IDs.""" if not data: return @@ -1275,24 +1280,26 @@ class Script: return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED) @property - def referenced_areas(self): + def referenced_areas(self) -> set[str]: """Return a set of referenced areas.""" if self._referenced_areas is not None: return self._referenced_areas - self._referenced_areas: set[str] = set() + self._referenced_areas = set() Script._find_referenced_areas(self._referenced_areas, self.sequence) return self._referenced_areas @staticmethod - def _find_referenced_areas(referenced, sequence): + def _find_referenced_areas( + referenced: set[str], sequence: Sequence[dict[str, Any]] + ) -> None: for step in sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( step.get(CONF_TARGET), - step.get(service.CONF_SERVICE_DATA), + step.get(CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_AREA_ID, referenced) @@ -1313,24 +1320,26 @@ class Script: Script._find_referenced_areas(referenced, script[CONF_SEQUENCE]) @property - def referenced_devices(self): + def referenced_devices(self) -> set[str]: """Return a set of referenced devices.""" if self._referenced_devices is not None: return self._referenced_devices - self._referenced_devices: set[str] = set() + self._referenced_devices = set() Script._find_referenced_devices(self._referenced_devices, self.sequence) return self._referenced_devices @staticmethod - def _find_referenced_devices(referenced, sequence): + def _find_referenced_devices( + referenced: set[str], sequence: Sequence[dict[str, Any]] + ) -> None: for step in sequence: action = cv.determine_script_action(step) if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( step.get(CONF_TARGET), - step.get(service.CONF_SERVICE_DATA), + step.get(CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_DEVICE_ID, referenced) @@ -1361,17 +1370,19 @@ class Script: Script._find_referenced_devices(referenced, script[CONF_SEQUENCE]) @property - def referenced_entities(self): + def referenced_entities(self) -> set[str]: """Return a set of referenced entities.""" if self._referenced_entities is not None: return self._referenced_entities - self._referenced_entities: set[str] = set() + self._referenced_entities = set() Script._find_referenced_entities(self._referenced_entities, self.sequence) return self._referenced_entities @staticmethod - def _find_referenced_entities(referenced, sequence): + def _find_referenced_entities( + referenced: set[str], sequence: Sequence[dict[str, Any]] + ) -> None: for step in sequence: action = cv.determine_script_action(step) @@ -1379,7 +1390,7 @@ class Script: for data in ( step, step.get(CONF_TARGET), - step.get(service.CONF_SERVICE_DATA), + step.get(CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), ): _referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)