Improve type hints in script helpers (#78364)
* Improve type hints in script helpers * Import CONF_SERVICE_DATA from homeassistant.const * Make data optionalpull/78423/head
parent
4f963cfc64
commit
d3be06906b
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from contextvars import ContextVar
|
||||
from copy import copy
|
||||
|
@ -49,6 +49,7 @@ from homeassistant.const import (
|
|||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
CONF_SERVICE,
|
||||
CONF_SERVICE_DATA,
|
||||
CONF_STOP,
|
||||
CONF_TARGET,
|
||||
CONF_THEN,
|
||||
|
@ -218,7 +219,9 @@ async def trace_action(hass, script_run, stop, variables):
|
|||
trace_stack_pop(trace_stack_cv)
|
||||
|
||||
|
||||
def make_script_schema(schema, default_script_mode, extra=vol.PREVENT_EXTRA):
|
||||
def make_script_schema(
|
||||
schema: Mapping[Any, Any], default_script_mode: str, extra: int = vol.PREVENT_EXTRA
|
||||
) -> vol.Schema:
|
||||
"""Make a schema for a component that uses the script helper."""
|
||||
return vol.Schema(
|
||||
{
|
||||
|
@ -1109,7 +1112,9 @@ async def _async_stop_scripts_at_shutdown(hass, event):
|
|||
_VarsType = Union[dict[str, Any], MappingProxyType]
|
||||
|
||||
|
||||
def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> None:
|
||||
def _referenced_extract_ids(
|
||||
data: dict[str, Any] | None, key: str, found: set[str]
|
||||
) -> None:
|
||||
"""Extract referenced IDs."""
|
||||
if not data:
|
||||
return
|
||||
|
@ -1275,24 +1280,26 @@ class Script:
|
|||
return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED)
|
||||
|
||||
@property
|
||||
def referenced_areas(self):
|
||||
def referenced_areas(self) -> set[str]:
|
||||
"""Return a set of referenced areas."""
|
||||
if self._referenced_areas is not None:
|
||||
return self._referenced_areas
|
||||
|
||||
self._referenced_areas: set[str] = set()
|
||||
self._referenced_areas = set()
|
||||
Script._find_referenced_areas(self._referenced_areas, self.sequence)
|
||||
return self._referenced_areas
|
||||
|
||||
@staticmethod
|
||||
def _find_referenced_areas(referenced, sequence):
|
||||
def _find_referenced_areas(
|
||||
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||
) -> None:
|
||||
for step in sequence:
|
||||
action = cv.determine_script_action(step)
|
||||
|
||||
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
||||
for data in (
|
||||
step.get(CONF_TARGET),
|
||||
step.get(service.CONF_SERVICE_DATA),
|
||||
step.get(CONF_SERVICE_DATA),
|
||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||
):
|
||||
_referenced_extract_ids(data, ATTR_AREA_ID, referenced)
|
||||
|
@ -1313,24 +1320,26 @@ class Script:
|
|||
Script._find_referenced_areas(referenced, script[CONF_SEQUENCE])
|
||||
|
||||
@property
|
||||
def referenced_devices(self):
|
||||
def referenced_devices(self) -> set[str]:
|
||||
"""Return a set of referenced devices."""
|
||||
if self._referenced_devices is not None:
|
||||
return self._referenced_devices
|
||||
|
||||
self._referenced_devices: set[str] = set()
|
||||
self._referenced_devices = set()
|
||||
Script._find_referenced_devices(self._referenced_devices, self.sequence)
|
||||
return self._referenced_devices
|
||||
|
||||
@staticmethod
|
||||
def _find_referenced_devices(referenced, sequence):
|
||||
def _find_referenced_devices(
|
||||
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||
) -> None:
|
||||
for step in sequence:
|
||||
action = cv.determine_script_action(step)
|
||||
|
||||
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
|
||||
for data in (
|
||||
step.get(CONF_TARGET),
|
||||
step.get(service.CONF_SERVICE_DATA),
|
||||
step.get(CONF_SERVICE_DATA),
|
||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||
):
|
||||
_referenced_extract_ids(data, ATTR_DEVICE_ID, referenced)
|
||||
|
@ -1361,17 +1370,19 @@ class Script:
|
|||
Script._find_referenced_devices(referenced, script[CONF_SEQUENCE])
|
||||
|
||||
@property
|
||||
def referenced_entities(self):
|
||||
def referenced_entities(self) -> set[str]:
|
||||
"""Return a set of referenced entities."""
|
||||
if self._referenced_entities is not None:
|
||||
return self._referenced_entities
|
||||
|
||||
self._referenced_entities: set[str] = set()
|
||||
self._referenced_entities = set()
|
||||
Script._find_referenced_entities(self._referenced_entities, self.sequence)
|
||||
return self._referenced_entities
|
||||
|
||||
@staticmethod
|
||||
def _find_referenced_entities(referenced, sequence):
|
||||
def _find_referenced_entities(
|
||||
referenced: set[str], sequence: Sequence[dict[str, Any]]
|
||||
) -> None:
|
||||
for step in sequence:
|
||||
action = cv.determine_script_action(step)
|
||||
|
||||
|
@ -1379,7 +1390,7 @@ class Script:
|
|||
for data in (
|
||||
step,
|
||||
step.get(CONF_TARGET),
|
||||
step.get(service.CONF_SERVICE_DATA),
|
||||
step.get(CONF_SERVICE_DATA),
|
||||
step.get(service.CONF_SERVICE_DATA_TEMPLATE),
|
||||
):
|
||||
_referenced_extract_ids(data, ATTR_ENTITY_ID, referenced)
|
||||
|
|
Loading…
Reference in New Issue