Type hint improvements (#28260)

* Add and improve core and config_entries type hints

* Complete and improve config_entries type hints

* More entity registry type hints

* Complete helpers.event type hints
pull/28310/head
Ville Skyttä 2019-10-28 22:36:26 +02:00 committed by Paulus Schoutsen
parent f7a64019b6
commit f88ead597a
9 changed files with 135 additions and 89 deletions

View File

@ -16,7 +16,7 @@ from homeassistant.const import (
STATE_ON,
STATE_UNAVAILABLE,
)
from homeassistant.core import State, callback
from homeassistant.core import CALLBACK_TYPE, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
@ -96,7 +96,7 @@ class LightGroup(light.Light):
self._effect_list: Optional[List[str]] = None
self._effect: Optional[str] = None
self._supported_features: int = 0
self._async_unsub_state_changed = None
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@ -108,6 +108,7 @@ class LightGroup(light.Light):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
assert self.hass is not None
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)

View File

@ -12,7 +12,7 @@ from homeassistant.const import (
STATE_ON,
STATE_UNAVAILABLE,
)
from homeassistant.core import State, callback
from homeassistant.core import CALLBACK_TYPE, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_state_change
@ -56,7 +56,7 @@ class LightSwitch(Light):
self._switch_entity_id = switch_entity_id
self._is_on = False
self._available = False
self._async_unsub_state_changed = None
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
@property
def name(self) -> str:
@ -113,6 +113,7 @@ class LightSwitch(Light):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
assert self.hass is not None
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._switch_entity_id, async_state_changed_listener
)

View File

@ -3,7 +3,7 @@ import asyncio
import logging
import functools
import uuid
from typing import Any, Callable, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional, Set, cast
import weakref
import attr
@ -14,11 +14,11 @@ from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry
from homeassistant.helpers import entity_registry
from homeassistant.helpers.event import Event
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
_UNDEF: dict = {}
SOURCE_USER = "user"
SOURCE_DISCOVERY = "discovery"
@ -205,7 +205,7 @@ class ConfigEntry:
wait_time,
)
async def setup_again(now):
async def setup_again(now: Any) -> None:
"""Run setup again."""
self._async_cancel_retry_setup = None
await self.async_setup(hass, integration=integration, tries=tries)
@ -357,7 +357,7 @@ class ConfigEntry:
return lambda: self.update_listeners.remove(weak_listener)
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this entry."""
return {
"entry_id": self.entry_id,
@ -418,7 +418,7 @@ class ConfigEntries:
return list(self._entries)
return [entry for entry in self._entries if entry.domain == domain]
async def async_remove(self, entry_id):
async def async_remove(self, entry_id: str) -> Dict[str, Any]:
"""Remove an entry."""
entry = self.async_get_entry(entry_id)
@ -529,8 +529,13 @@ class ConfigEntries:
@callback
def async_update_entry(
self, entry, *, data=_UNDEF, options=_UNDEF, system_options=_UNDEF
):
self,
entry: ConfigEntry,
*,
data: dict = _UNDEF,
options: dict = _UNDEF,
system_options: dict = _UNDEF,
) -> None:
"""Update a config entry."""
if data is not _UNDEF:
entry.data = data
@ -547,7 +552,7 @@ class ConfigEntries:
self._async_schedule_save()
async def async_forward_entry_setup(self, entry, domain):
async def async_forward_entry_setup(self, entry: ConfigEntry, domain: str) -> bool:
"""Forward the setup of an entry to a different component.
By default an entry is setup with the component it belongs to. If that
@ -567,8 +572,9 @@ class ConfigEntries:
integration = await loader.async_get_integration(self.hass, domain)
await entry.async_setup(self.hass, integration=integration)
return True
async def async_forward_entry_unload(self, entry, domain):
async def async_forward_entry_unload(self, entry: ConfigEntry, domain: str) -> bool:
"""Forward the unloading of an entry to a different component."""
# It was never loaded.
if domain not in self.hass.config.components:
@ -578,7 +584,9 @@ class ConfigEntries:
return await entry.async_unload(self.hass, integration=integration)
async def _async_finish_flow(self, flow, result):
async def _async_finish_flow(
self, flow: "ConfigFlow", result: Dict[str, Any]
) -> Dict[str, Any]:
"""Finish a config flow and add an entry."""
# Remove notification if no other discovery config entries in progress
if not any(
@ -611,7 +619,9 @@ class ConfigEntries:
result["result"] = entry
return result
async def _async_create_flow(self, handler_key, *, context, data):
async def _async_create_flow(
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> "ConfigFlow":
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -654,7 +664,7 @@ class ConfigEntries:
notification_id=DISCOVERY_NOTIFICATION_ID,
)
flow = handler()
flow = cast(ConfigFlow, handler())
flow.init_step = source
return flow
@ -663,12 +673,12 @@ class ConfigEntries:
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self):
def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
"""Return data to save."""
return {"entries": [entry.as_dict() for entry in self._entries]}
async def _old_conf_migrator(old_config):
async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
"""Migrate the pre-0.73 config format to the latest version."""
return {"entries": old_config}
@ -686,18 +696,20 @@ class ConfigFlow(data_entry_flow.FlowHandler):
@staticmethod
@callback
def async_get_options_flow(config_entry):
def async_get_options_flow(config_entry: ConfigEntry) -> "OptionsFlow":
"""Get the options flow for this handler."""
raise data_entry_flow.UnknownHandler
@callback
def _async_current_entries(self):
def _async_current_entries(self) -> List[ConfigEntry]:
"""Return current entries."""
assert self.hass is not None
return self.hass.config_entries.async_entries(self.handler)
@callback
def _async_in_progress(self):
def _async_in_progress(self) -> List[Dict]:
"""Return other in progress flows for current domain."""
assert self.hass is not None
return [
flw
for flw in self.hass.config_entries.flow.async_progress()
@ -715,29 +727,33 @@ class OptionsFlowManager:
hass, self._async_create_flow, self._async_finish_flow
)
async def _async_create_flow(self, entry_id, *, context, data):
async def _async_create_flow(
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
) -> Optional["OptionsFlow"]:
"""Create an options flow for a config entry.
Entry_id and flow.handler is the same thing to map entry with flow.
"""
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry is None:
return
return None
if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler
flow = HANDLERS[entry.domain].async_get_options_flow(entry)
flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
return flow
async def _async_finish_flow(self, flow, result):
async def _async_finish_flow(
self, flow: "OptionsFlow", result: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry.
"""
entry = self.hass.config_entries.async_get_entry(flow.handler)
if entry is None:
return
return None
self.hass.config_entries.async_update_entry(entry, options=result["data"])
result["result"] = True
@ -747,7 +763,7 @@ class OptionsFlowManager:
class OptionsFlow(data_entry_flow.FlowHandler):
"""Base class for config option flows."""
pass
handler: str
@attr.s(slots=True)
@ -756,11 +772,11 @@ class SystemOptions:
disable_new_entities = attr.ib(type=bool, default=False)
def update(self, *, disable_new_entities):
def update(self, *, disable_new_entities: bool) -> None:
"""Update properties."""
self.disable_new_entities = disable_new_entities
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
"""Return dictionary version of this config entrys system options."""
return {"disable_new_entities": self.disable_new_entities}
@ -784,7 +800,7 @@ class EntityRegistryDisabledHandler:
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
)
async def _handle_entry_updated(self, event):
async def _handle_entry_updated(self, event: Event) -> None:
"""Handle entity registry entry update."""
if (
event.data["action"] != "update"
@ -811,6 +827,7 @@ class EntityRegistryDisabledHandler:
config_entry = self.hass.config_entries.async_get_entry(
entity_entry.config_entry_id
)
assert config_entry is not None
if config_entry.entry_id not in self.changed and await support_entry_unload(
self.hass, config_entry.domain
@ -830,7 +847,7 @@ class EntityRegistryDisabledHandler:
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
)
async def _handle_reload(self, _now):
async def _handle_reload(self, _now: Any) -> None:
"""Handle a reload."""
self._remove_call_later = None
to_reload = self.changed

View File

@ -1283,7 +1283,7 @@ class Config:
self.skip_pip: bool = False
# List of loaded components
self.components: set = set()
self.components: Set[str] = set()
# API (HTTP) server configuration, see components.http.ApiConfig
self.api: Optional[Any] = None

View File

@ -1,6 +1,6 @@
"""Classes to help gather user submissions."""
import logging
from typing import Dict, Any, Callable, Hashable, List, Optional
from typing import Dict, Any, Callable, List, Optional
import uuid
import voluptuous as vol
from .core import callback, HomeAssistant
@ -58,7 +58,7 @@ class FlowManager:
]
async def async_init(
self, handler: Hashable, *, context: Optional[Dict] = None, data: Any = None
self, handler: str, *, context: Optional[Dict] = None, data: Any = None
) -> Any:
"""Start a configuration flow."""
if context is None:
@ -170,7 +170,7 @@ class FlowHandler:
# Set by flow manager
flow_id: str = None # type: ignore
hass: Optional[HomeAssistant] = None
handler: Optional[Hashable] = None
handler: Optional[str] = None
cur_step: Optional[Dict[str, str]] = None
context: Dict

View File

@ -399,7 +399,7 @@ class OAuth2Session:
new_token = await self.implementation.async_refresh_token(token)
self.hass.config_entries.async_update_entry( # type: ignore
self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, "token": new_token}
)

View File

@ -7,15 +7,15 @@ The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the
timer.
"""
from asyncio import Event
import asyncio
from collections import OrderedDict
from itertools import chain
import logging
from typing import List, Optional, cast
from typing import Any, Dict, Iterable, List, Optional, cast
import attr
from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.loader import bind_hass
from homeassistant.util import ensure_unique_string, slugify
@ -24,8 +24,7 @@ from homeassistant.util.yaml import load_yaml
from .typing import HomeAssistantType
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
# mypy: allow-untyped-defs, no-check-untyped-defs
PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry"
@ -51,7 +50,7 @@ class RegistryEntry:
platform = attr.ib(type=str)
name = attr.ib(type=str, default=None)
device_id = attr.ib(type=str, default=None)
config_entry_id = attr.ib(type=str, default=None)
config_entry_id: Optional[str] = attr.ib(default=None)
disabled_by = attr.ib(
type=Optional[str],
default=None,
@ -68,12 +67,12 @@ class RegistryEntry:
domain = attr.ib(type=str, init=False, repr=False)
@domain.default
def _domain_default(self):
def _domain_default(self) -> str:
"""Compute domain value."""
return split_entity_id(self.entity_id)[0]
@property
def disabled(self):
def disabled(self) -> bool:
"""Return if entry is disabled."""
return self.disabled_by is not None
@ -81,17 +80,17 @@ class RegistryEntry:
class EntityRegistry:
"""Class to hold a registry of entities."""
def __init__(self, hass):
def __init__(self, hass: HomeAssistantType):
"""Initialize the registry."""
self.hass = hass
self.entities = None
self.entities: Dict[str, RegistryEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
)
@callback
def async_is_registered(self, entity_id):
def async_is_registered(self, entity_id: str) -> bool:
"""Check if an entity_id is currently registered."""
return entity_id in self.entities
@ -116,8 +115,11 @@ class EntityRegistry:
@callback
def async_generate_entity_id(
self, domain, suggested_object_id, known_object_ids=None
):
self,
domain: str,
suggested_object_id: str,
known_object_ids: Optional[Iterable[str]] = None,
) -> str:
"""Generate an entity ID that does not conflict.
Conflicts checked against registered and currently existing entities.
@ -195,7 +197,7 @@ class EntityRegistry:
return entity
@callback
def async_remove(self, entity_id):
def async_remove(self, entity_id: str) -> None:
"""Remove an entity from registry."""
self.entities.pop(entity_id)
self.hass.bus.async_fire(
@ -204,7 +206,7 @@ class EntityRegistry:
self.async_schedule_save()
@callback
def async_device_removed(self, event):
def async_device_removed(self, event: Event) -> None:
"""Handle the removal of a device.
Remove entities from the registry that are associated to a device when
@ -309,7 +311,7 @@ class EntityRegistry:
return new
async def async_load(self):
async def async_load(self) -> None:
"""Load the entity registry."""
data = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_REGISTRY),
@ -317,7 +319,7 @@ class EntityRegistry:
old_conf_load_func=load_yaml,
old_conf_migrate_func=_async_migrate,
)
entities = OrderedDict()
entities: Dict[str, RegistryEntry] = OrderedDict()
if data is not None:
for entity in data["entities"]:
@ -334,12 +336,12 @@ class EntityRegistry:
self.entities = entities
@callback
def async_schedule_save(self):
def async_schedule_save(self) -> None:
"""Schedule saving the entity registry."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self):
def _data_to_save(self) -> Dict[str, Any]:
"""Return data of entity registry to store in a file."""
data = {}
@ -359,7 +361,7 @@ class EntityRegistry:
return data
@callback
def async_clear_config_entry(self, config_entry):
def async_clear_config_entry(self, config_entry: str) -> None:
"""Clear config entry from registry entries."""
for entity_id in [
entity_id
@ -375,7 +377,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
reg_or_evt = hass.data.get(DATA_REGISTRY)
if not reg_or_evt:
evt = hass.data[DATA_REGISTRY] = Event()
evt = hass.data[DATA_REGISTRY] = asyncio.Event()
reg = EntityRegistry(hass)
await reg.async_load()
@ -384,7 +386,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
evt.set()
return reg
if isinstance(reg_or_evt, Event):
if isinstance(reg_or_evt, asyncio.Event):
evt = reg_or_evt
await evt.wait()
return cast(EntityRegistry, hass.data.get(DATA_REGISTRY))
@ -402,7 +404,7 @@ def async_entries_for_device(
]
async def _async_migrate(entities):
async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
"""Migrate the YAML config file to storage helper format."""
return {
"entities": [

View File

@ -1,13 +1,14 @@
"""Helpers for listening to events."""
from datetime import datetime, timedelta
import functools as ft
from typing import Any, Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union, cast
import attr
from homeassistant.loader import bind_hass
from homeassistant.helpers.sun import get_astral_event_next
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event
from homeassistant.helpers.template import Template
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State
from homeassistant.const import (
ATTR_NOW,
EVENT_STATE_CHANGED,
@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util
from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# PyLint does not like the use of threaded_listener_factory
# pylint: disable=invalid-name
def threaded_listener_factory(async_factory):
def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
"""Convert an async event helper to a threaded one."""
@ft.wraps(async_factory)
def factory(*args, **kwargs):
def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE:
"""Call async event helper safely."""
hass = args[0]
@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory):
hass.loop, ft.partial(async_factory, *args, **kwargs)
).result()
def remove():
def remove() -> None:
"""Threadsafe removal."""
run_callback_threadsafe(hass.loop, async_remove).result()
@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory):
@callback
@bind_hass
def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None):
def async_track_state_change(
hass: HomeAssistant,
entity_ids: Union[str, Iterable[str]],
action: Callable[[str, State, State], None],
from_state: Union[None, str, Iterable[str]] = None,
to_state: Union[None, str, Iterable[str]] = None,
) -> CALLBACK_TYPE:
"""Track specific state changes.
entity_ids, from_state and to_state can be string or list.
@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
@callback
def state_change_listener(event):
def state_change_listener(event: Event) -> None:
"""Handle specific state changes."""
if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids:
if (
entity_ids != MATCH_ALL
and cast(str, event.data.get("entity_id")) not in entity_ids
):
return
old_state = event.data.get("old_state")
@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change)
@callback
@bind_hass
def async_track_template(hass, template, action, variables=None):
def async_track_template(
hass: HomeAssistant,
template: Template,
action: Callable[[str, State, State], None],
variables: Optional[Dict[str, Any]] = None,
) -> CALLBACK_TYPE:
"""Add a listener that track state changes with template condition."""
from . import condition
@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None):
already_triggered = False
@callback
def template_condition_listener(entity_id, from_s, to_s):
def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None:
"""Check if condition is correct and run action."""
nonlocal already_triggered
template_result = condition.async_template(hass, template, variables)
@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template)
@callback
@bind_hass
def async_track_same_state(
hass, period, action, async_check_same_func, entity_ids=MATCH_ALL
):
hass: HomeAssistant,
period: timedelta,
action: Callable[..., None],
async_check_same_func: Callable[[str, State, State], bool],
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
) -> CALLBACK_TYPE:
"""Track the state of entities for a period and run an action.
If async_check_func is None it use the state of orig_value.
Without entity_ids we track all state changes.
"""
async_remove_state_for_cancel = None
async_remove_state_for_listener = None
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
@callback
def clear_listener():
def clear_listener() -> None:
"""Clear all unsub listener."""
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
@ -157,7 +175,7 @@ def async_track_same_state(
async_remove_state_for_cancel = None
@callback
def state_for_listener(now):
def state_for_listener(now: Any) -> None:
"""Fire on state changes after a delay and calls action."""
nonlocal async_remove_state_for_listener
async_remove_state_for_listener = None
@ -165,7 +183,9 @@ def async_track_same_state(
hass.async_run_job(action)
@callback
def state_for_cancel_listener(entity, from_state, to_state):
def state_for_cancel_listener(
entity: str, from_state: State, to_state: State
) -> None:
"""Fire on changes and cancel for listener if changed."""
if not async_check_same_func(entity, from_state, to_state):
clear_listener()
@ -193,7 +213,7 @@ def async_track_point_in_time(
utc_point_in_time = dt_util.as_utc(point_in_time)
@callback
def utc_converter(utc_now):
def utc_converter(utc_now: datetime) -> None:
"""Convert passed in UTC now to local now."""
hass.async_run_job(action, dt_util.as_local(utc_now))
@ -213,7 +233,7 @@ def async_track_point_in_utc_time(
point_in_time = dt_util.as_utc(point_in_time)
@callback
def point_in_time_listener(event):
def point_in_time_listener(event: Event) -> None:
"""Listen for matching time_changed events."""
now = event.data[ATTR_NOW]
@ -225,7 +245,7 @@ def async_track_point_in_utc_time(
# available to execute this listener it might occur that the
# listener gets lined up twice to be executed. This will make
# sure the second time it does nothing.
point_in_time_listener.run = True
setattr(point_in_time_listener, "run", True)
async_unsub()
hass.async_run_job(action, now)
@ -260,12 +280,12 @@ def async_track_time_interval(
"""Add a listener that fires repetitively at every timedelta interval."""
remove = None
def next_interval():
def next_interval() -> datetime:
"""Return the next interval."""
return dt_util.utcnow() + interval
@callback
def interval_listener(now):
def interval_listener(now: datetime) -> None:
"""Handle elapsed intervals."""
nonlocal remove
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
@ -273,7 +293,7 @@ def async_track_time_interval(
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
def remove_listener():
def remove_listener() -> None:
"""Remove interval listener."""
remove()
@ -387,7 +407,7 @@ def async_track_utc_time_change(
if all(val is None for val in (hour, minute, second)):
@callback
def time_change_listener(event):
def time_change_listener(event: Event) -> None:
"""Fire every time event that comes in."""
hass.async_run_job(action, event.data[ATTR_NOW])

View File

@ -7,7 +7,7 @@ import random
import re
from datetime import datetime
from functools import wraps
from typing import Any, Iterable
from typing import Any, Dict, Iterable, List, Optional, Union
import jinja2
from jinja2 import contextfilter, contextfunction
@ -72,7 +72,9 @@ def render_complex(value, variables=None):
return value.async_render(variables)
def extract_entities(template, variables=None):
def extract_entities(
template: Optional[str], variables: Optional[Dict[str, Any]] = None
) -> Union[str, List[str]]:
"""Extract all entities for state_changed listener from template string."""
if template is None or _RE_JINJA_DELIMITERS.search(template) is None:
return []
@ -86,6 +88,7 @@ def extract_entities(template, variables=None):
for result in extraction:
if (
result[0] == "trigger.entity_id"
and variables
and "trigger" in variables
and "entity_id" in variables["trigger"]
):
@ -163,7 +166,7 @@ class Template:
if not isinstance(template, str):
raise TypeError("Expected template to be a string")
self.template = template
self.template: str = template
self._compiled_code = None
self._compiled = None
self.hass = hass
@ -187,7 +190,9 @@ class Template:
except jinja2.exceptions.TemplateSyntaxError as err:
raise TemplateError(err)
def extract_entities(self, variables=None):
def extract_entities(
self, variables: Dict[str, Any] = None
) -> Union[str, List[str]]:
"""Extract all entities for state_changed listener."""
return extract_entities(self.template, variables)