Improve type hints in script helpers (#78364)

* Improve type hints in script helpers

* Import CONF_SERVICE_DATA from homeassistant.const

* Make data optional
pull/78423/head
epenet 2022-09-13 23:11:29 +02:00 committed by GitHub
parent 4f963cfc64
commit d3be06906b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 15 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from contextlib import asynccontextmanager, suppress
from contextvars import ContextVar
from copy import copy
@ -49,6 +49,7 @@ from homeassistant.const import (
CONF_SCENE,
CONF_SEQUENCE,
CONF_SERVICE,
CONF_SERVICE_DATA,
CONF_STOP,
CONF_TARGET,
CONF_THEN,
@ -218,7 +219,9 @@ async def trace_action(hass, script_run, stop, variables):
trace_stack_pop(trace_stack_cv)
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
def make_script_schema(
schema: Mapping[Any, Any], default_script_mode: str, extra: int = vol.PREVENT_EXTRA
) -> vol.Schema:
"""Make a schema for a component that uses the script helper."""
return vol.Schema(
{
@ -1109,7 +1112,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
_VarsType = Union[dict[str, Any], MappingProxyType]
def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None:
def _referenced_extract_ids(
data: dict[str, Any] | None, key: str, found: set[str]
) -> None:
"""Extract referenced IDs."""
if not data:
return
@ -1275,24 +1280,26 @@ class Script:
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
@property
def referenced_areas(self):
def referenced_areas(self) -> set[str]:
"""Return a set of referenced areas."""
if self._referenced_areas is not None:
return self._referenced_areas
self._referenced_areas: set[str] = set()
self._referenced_areas = set()
Script._find_referenced_areas(self._referenced_areas, self.sequence)
return self._referenced_areas
@staticmethod
def _find_referenced_areas(referenced, sequence):
def _find_referenced_areas(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in (
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
_referenced_extract_ids(data, ATTR_AREA_ID, referenced)
@ -1313,24 +1320,26 @@ class Script:
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
@property
def referenced_devices(self):
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""
if self._referenced_devices is not None:
return self._referenced_devices
self._referenced_devices: set[str] = set()
self._referenced_devices = set()
Script._find_referenced_devices(self._referenced_devices, self.sequence)
return self._referenced_devices
@staticmethod
def _find_referenced_devices(referenced, sequence):
def _find_referenced_devices(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
for data in (
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
@ -1361,17 +1370,19 @@ class Script:
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
@property
def referenced_entities(self):
def referenced_entities(self) -> set[str]:
"""Return a set of referenced entities."""
if self._referenced_entities is not None:
return self._referenced_entities
self._referenced_entities: set[str] = set()
self._referenced_entities = set()
Script._find_referenced_entities(self._referenced_entities, self.sequence)
return self._referenced_entities
@staticmethod
def _find_referenced_entities(referenced, sequence):
def _find_referenced_entities(
referenced: set[str], sequence: Sequence[dict[str, Any]]
) -> None:
for step in sequence:
action = cv.determine_script_action(step)
@ -1379,7 +1390,7 @@ class Script:
for data in (
step,
step.get(CONF_TARGET),
step.get(service.CONF_SERVICE_DATA),
step.get(CONF_SERVICE_DATA),
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
):
_referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)