diff --git a/homeassistant/components/group/light.py b/homeassistant/components/group/light.py index 85804552494..2cd65028131 100644 --- a/homeassistant/components/group/light.py +++ b/homeassistant/components/group/light.py @@ -16,7 +16,7 @@ from homeassistant.const import ( STATE_ON, STATE_UNAVAILABLE, ) -from homeassistant.core import State, callback +from homeassistant.core import CALLBACK_TYPE, State, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_state_change from homeassistant.helpers.typing import ConfigType, HomeAssistantType @@ -96,7 +96,7 @@ class LightGroup(light.Light): self._effect_list: Optional[List[str]] = None self._effect: Optional[str] = None self._supported_features: int = 0 - self._async_unsub_state_changed = None + self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None async def async_added_to_hass(self) -> None: """Register callbacks.""" @@ -108,6 +108,7 @@ class LightGroup(light.Light): """Handle child updates.""" self.async_schedule_update_ha_state(True) + assert self.hass is not None self._async_unsub_state_changed = async_track_state_change( self.hass, self._entity_ids, async_state_changed_listener ) diff --git a/homeassistant/components/switch/light.py b/homeassistant/components/switch/light.py index b0abf957991..1bdc1d39083 100644 --- a/homeassistant/components/switch/light.py +++ b/homeassistant/components/switch/light.py @@ -12,7 +12,7 @@ from homeassistant.const import ( STATE_ON, STATE_UNAVAILABLE, ) -from homeassistant.core import State, callback +from homeassistant.core import CALLBACK_TYPE, State, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_state_change @@ -56,7 +56,7 @@ class LightSwitch(Light): self._switch_entity_id = switch_entity_id self._is_on = False self._available = False - self._async_unsub_state_changed = None + self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None @property def name(self) -> str: @@ -113,6 +113,7 @@ class LightSwitch(Light): """Handle child updates.""" self.async_schedule_update_ha_state(True) + assert self.hass is not None self._async_unsub_state_changed = async_track_state_change( self.hass, self._switch_entity_id, async_state_changed_listener ) diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index aee15d6c0ce..ae7c534adf8 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -3,7 +3,7 @@ import asyncio import logging import functools import uuid -from typing import Any, Callable, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set, cast import weakref import attr @@ -14,11 +14,11 @@ from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady from homeassistant.setup import async_setup_component, async_process_deps_reqs from homeassistant.util.decorator import Registry from homeassistant.helpers import entity_registry +from homeassistant.helpers.event import Event -# mypy: allow-untyped-defs, no-check-untyped-defs _LOGGER = logging.getLogger(__name__) -_UNDEF = object() +_UNDEF: dict = {} SOURCE_USER = "user" SOURCE_DISCOVERY = "discovery" @@ -205,7 +205,7 @@ class ConfigEntry: wait_time, ) - async def setup_again(now): + async def setup_again(now: Any) -> None: """Run setup again.""" self._async_cancel_retry_setup = None await self.async_setup(hass, integration=integration, tries=tries) @@ -357,7 +357,7 @@ class ConfigEntry: return lambda: self.update_listeners.remove(weak_listener) - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: """Return dictionary version of this entry.""" return { "entry_id": self.entry_id, @@ -418,7 +418,7 @@ class ConfigEntries: return list(self._entries) return [entry for entry in self._entries if entry.domain == domain] - async def async_remove(self, entry_id): + async def async_remove(self, entry_id: str) -> Dict[str, Any]: """Remove an entry.""" entry = self.async_get_entry(entry_id) @@ -529,8 +529,13 @@ class ConfigEntries: @callback def async_update_entry( - self, entry, *, data=_UNDEF, options=_UNDEF, system_options=_UNDEF - ): + self, + entry: ConfigEntry, + *, + data: dict = _UNDEF, + options: dict = _UNDEF, + system_options: dict = _UNDEF, + ) -> None: """Update a config entry.""" if data is not _UNDEF: entry.data = data @@ -547,7 +552,7 @@ class ConfigEntries: self._async_schedule_save() - async def async_forward_entry_setup(self, entry, domain): + async def async_forward_entry_setup(self, entry: ConfigEntry, domain: str) -> bool: """Forward the setup of an entry to a different component. By default an entry is setup with the component it belongs to. If that @@ -567,8 +572,9 @@ class ConfigEntries: integration = await loader.async_get_integration(self.hass, domain) await entry.async_setup(self.hass, integration=integration) + return True - async def async_forward_entry_unload(self, entry, domain): + async def async_forward_entry_unload(self, entry: ConfigEntry, domain: str) -> bool: """Forward the unloading of an entry to a different component.""" # It was never loaded. if domain not in self.hass.config.components: @@ -578,7 +584,9 @@ class ConfigEntries: return await entry.async_unload(self.hass, integration=integration) - async def _async_finish_flow(self, flow, result): + async def _async_finish_flow( + self, flow: "ConfigFlow", result: Dict[str, Any] + ) -> Dict[str, Any]: """Finish a config flow and add an entry.""" # Remove notification if no other discovery config entries in progress if not any( @@ -611,7 +619,9 @@ class ConfigEntries: result["result"] = entry return result - async def _async_create_flow(self, handler_key, *, context, data): + async def _async_create_flow( + self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any] + ) -> "ConfigFlow": """Create a flow for specified handler. Handler key is the domain of the component that we want to set up. @@ -654,7 +664,7 @@ class ConfigEntries: notification_id=DISCOVERY_NOTIFICATION_ID, ) - flow = handler() + flow = cast(ConfigFlow, handler()) flow.init_step = source return flow @@ -663,12 +673,12 @@ class ConfigEntries: self._store.async_delay_save(self._data_to_save, SAVE_DELAY) @callback - def _data_to_save(self): + def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]: """Return data to save.""" return {"entries": [entry.as_dict() for entry in self._entries]} -async def _old_conf_migrator(old_config): +async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]: """Migrate the pre-0.73 config format to the latest version.""" return {"entries": old_config} @@ -686,18 +696,20 @@ class ConfigFlow(data_entry_flow.FlowHandler): @staticmethod @callback - def async_get_options_flow(config_entry): + def async_get_options_flow(config_entry: ConfigEntry) -> "OptionsFlow": """Get the options flow for this handler.""" raise data_entry_flow.UnknownHandler @callback - def _async_current_entries(self): + def _async_current_entries(self) -> List[ConfigEntry]: """Return current entries.""" + assert self.hass is not None return self.hass.config_entries.async_entries(self.handler) @callback - def _async_in_progress(self): + def _async_in_progress(self) -> List[Dict]: """Return other in progress flows for current domain.""" + assert self.hass is not None return [ flw for flw in self.hass.config_entries.flow.async_progress() @@ -715,29 +727,33 @@ class OptionsFlowManager: hass, self._async_create_flow, self._async_finish_flow ) - async def _async_create_flow(self, entry_id, *, context, data): + async def _async_create_flow( + self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any] + ) -> Optional["OptionsFlow"]: """Create an options flow for a config entry. Entry_id and flow.handler is the same thing to map entry with flow. """ entry = self.hass.config_entries.async_get_entry(entry_id) if entry is None: - return + return None if entry.domain not in HANDLERS: raise data_entry_flow.UnknownHandler - flow = HANDLERS[entry.domain].async_get_options_flow(entry) + flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry)) return flow - async def _async_finish_flow(self, flow, result): + async def _async_finish_flow( + self, flow: "OptionsFlow", result: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: """Finish an options flow and update options for configuration entry. Flow.handler and entry_id is the same thing to map flow with entry. """ entry = self.hass.config_entries.async_get_entry(flow.handler) if entry is None: - return + return None self.hass.config_entries.async_update_entry(entry, options=result["data"]) result["result"] = True @@ -747,7 +763,7 @@ class OptionsFlowManager: class OptionsFlow(data_entry_flow.FlowHandler): """Base class for config option flows.""" - pass + handler: str @attr.s(slots=True) @@ -756,11 +772,11 @@ class SystemOptions: disable_new_entities = attr.ib(type=bool, default=False) - def update(self, *, disable_new_entities): + def update(self, *, disable_new_entities: bool) -> None: """Update properties.""" self.disable_new_entities = disable_new_entities - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: """Return dictionary version of this config entrys system options.""" return {"disable_new_entities": self.disable_new_entities} @@ -784,7 +800,7 @@ class EntityRegistryDisabledHandler: entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated ) - async def _handle_entry_updated(self, event): + async def _handle_entry_updated(self, event: Event) -> None: """Handle entity registry entry update.""" if ( event.data["action"] != "update" @@ -811,6 +827,7 @@ class EntityRegistryDisabledHandler: config_entry = self.hass.config_entries.async_get_entry( entity_entry.config_entry_id ) + assert config_entry is not None if config_entry.entry_id not in self.changed and await support_entry_unload( self.hass, config_entry.domain @@ -830,7 +847,7 @@ class EntityRegistryDisabledHandler: self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload ) - async def _handle_reload(self, _now): + async def _handle_reload(self, _now: Any) -> None: """Handle a reload.""" self._remove_call_later = None to_reload = self.changed diff --git a/homeassistant/core.py b/homeassistant/core.py index ec11b14edaa..01c5561d939 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -1283,7 +1283,7 @@ class Config: self.skip_pip: bool = False # List of loaded components - self.components: set = set() + self.components: Set[str] = set() # API (HTTP) server configuration, see components.http.ApiConfig self.api: Optional[Any] = None diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index c06c69d9213..58d8e4ea131 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -1,6 +1,6 @@ """Classes to help gather user submissions.""" import logging -from typing import Dict, Any, Callable, Hashable, List, Optional +from typing import Dict, Any, Callable, List, Optional import uuid import voluptuous as vol from .core import callback, HomeAssistant @@ -58,7 +58,7 @@ class FlowManager: ] async def async_init( - self, handler: Hashable, *, context: Optional[Dict] = None, data: Any = None + self, handler: str, *, context: Optional[Dict] = None, data: Any = None ) -> Any: """Start a configuration flow.""" if context is None: @@ -170,7 +170,7 @@ class FlowHandler: # Set by flow manager flow_id: str = None # type: ignore hass: Optional[HomeAssistant] = None - handler: Optional[Hashable] = None + handler: Optional[str] = None cur_step: Optional[Dict[str, str]] = None context: Dict diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index d3db8febcb2..87832f60739 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -399,7 +399,7 @@ class OAuth2Session: new_token = await self.implementation.async_refresh_token(token) - self.hass.config_entries.async_update_entry( # type: ignore + self.hass.config_entries.async_update_entry( self.config_entry, data={**self.config_entry.data, "token": new_token} ) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 00671e9c776..08f29a9fb3e 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -7,15 +7,15 @@ The Entity Registry will persist itself 10 seconds after a new entity is registered. Registering a new entity while a timer is in progress resets the timer. """ -from asyncio import Event +import asyncio from collections import OrderedDict from itertools import chain import logging -from typing import List, Optional, cast +from typing import Any, Dict, Iterable, List, Optional, cast import attr -from homeassistant.core import callback, split_entity_id, valid_entity_id +from homeassistant.core import Event, callback, split_entity_id, valid_entity_id from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.loader import bind_hass from homeassistant.util import ensure_unique_string, slugify @@ -24,8 +24,7 @@ from homeassistant.util.yaml import load_yaml from .typing import HomeAssistantType -# mypy: allow-untyped-calls, allow-untyped-defs -# mypy: no-check-untyped-defs, no-warn-return-any +# mypy: allow-untyped-defs, no-check-untyped-defs PATH_REGISTRY = "entity_registry.yaml" DATA_REGISTRY = "entity_registry" @@ -51,7 +50,7 @@ class RegistryEntry: platform = attr.ib(type=str) name = attr.ib(type=str, default=None) device_id = attr.ib(type=str, default=None) - config_entry_id = attr.ib(type=str, default=None) + config_entry_id: Optional[str] = attr.ib(default=None) disabled_by = attr.ib( type=Optional[str], default=None, @@ -68,12 +67,12 @@ class RegistryEntry: domain = attr.ib(type=str, init=False, repr=False) @domain.default - def _domain_default(self): + def _domain_default(self) -> str: """Compute domain value.""" return split_entity_id(self.entity_id)[0] @property - def disabled(self): + def disabled(self) -> bool: """Return if entry is disabled.""" return self.disabled_by is not None @@ -81,17 +80,17 @@ class RegistryEntry: class EntityRegistry: """Class to hold a registry of entities.""" - def __init__(self, hass): + def __init__(self, hass: HomeAssistantType): """Initialize the registry.""" self.hass = hass - self.entities = None + self.entities: Dict[str, RegistryEntry] self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self.hass.bus.async_listen( EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed ) @callback - def async_is_registered(self, entity_id): + def async_is_registered(self, entity_id: str) -> bool: """Check if an entity_id is currently registered.""" return entity_id in self.entities @@ -116,8 +115,11 @@ class EntityRegistry: @callback def async_generate_entity_id( - self, domain, suggested_object_id, known_object_ids=None - ): + self, + domain: str, + suggested_object_id: str, + known_object_ids: Optional[Iterable[str]] = None, + ) -> str: """Generate an entity ID that does not conflict. Conflicts checked against registered and currently existing entities. @@ -195,7 +197,7 @@ class EntityRegistry: return entity @callback - def async_remove(self, entity_id): + def async_remove(self, entity_id: str) -> None: """Remove an entity from registry.""" self.entities.pop(entity_id) self.hass.bus.async_fire( @@ -204,7 +206,7 @@ class EntityRegistry: self.async_schedule_save() @callback - def async_device_removed(self, event): + def async_device_removed(self, event: Event) -> None: """Handle the removal of a device. Remove entities from the registry that are associated to a device when @@ -309,7 +311,7 @@ class EntityRegistry: return new - async def async_load(self): + async def async_load(self) -> None: """Load the entity registry.""" data = await self.hass.helpers.storage.async_migrator( self.hass.config.path(PATH_REGISTRY), @@ -317,7 +319,7 @@ class EntityRegistry: old_conf_load_func=load_yaml, old_conf_migrate_func=_async_migrate, ) - entities = OrderedDict() + entities: Dict[str, RegistryEntry] = OrderedDict() if data is not None: for entity in data["entities"]: @@ -334,12 +336,12 @@ class EntityRegistry: self.entities = entities @callback - def async_schedule_save(self): + def async_schedule_save(self) -> None: """Schedule saving the entity registry.""" self._store.async_delay_save(self._data_to_save, SAVE_DELAY) @callback - def _data_to_save(self): + def _data_to_save(self) -> Dict[str, Any]: """Return data of entity registry to store in a file.""" data = {} @@ -359,7 +361,7 @@ class EntityRegistry: return data @callback - def async_clear_config_entry(self, config_entry): + def async_clear_config_entry(self, config_entry: str) -> None: """Clear config entry from registry entries.""" for entity_id in [ entity_id @@ -375,7 +377,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: reg_or_evt = hass.data.get(DATA_REGISTRY) if not reg_or_evt: - evt = hass.data[DATA_REGISTRY] = Event() + evt = hass.data[DATA_REGISTRY] = asyncio.Event() reg = EntityRegistry(hass) await reg.async_load() @@ -384,7 +386,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: evt.set() return reg - if isinstance(reg_or_evt, Event): + if isinstance(reg_or_evt, asyncio.Event): evt = reg_or_evt await evt.wait() return cast(EntityRegistry, hass.data.get(DATA_REGISTRY)) @@ -402,7 +404,7 @@ def async_entries_for_device( ] -async def _async_migrate(entities): +async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: """Migrate the YAML config file to storage helper format.""" return { "entities": [ diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index e819da9873a..715344a3969 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -1,13 +1,14 @@ """Helpers for listening to events.""" from datetime import datetime, timedelta import functools as ft -from typing import Any, Callable, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, Optional, Union, cast import attr from homeassistant.loader import bind_hass from homeassistant.helpers.sun import get_astral_event_next -from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event +from homeassistant.helpers.template import Template +from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State from homeassistant.const import ( ATTR_NOW, EVENT_STATE_CHANGED, @@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util from homeassistant.util.async_ import run_callback_threadsafe -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # PyLint does not like the use of threaded_listener_factory # pylint: disable=invalid-name -def threaded_listener_factory(async_factory): +def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE: """Convert an async event helper to a threaded one.""" @ft.wraps(async_factory) - def factory(*args, **kwargs): + def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE: """Call async event helper safely.""" hass = args[0] @@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory): hass.loop, ft.partial(async_factory, *args, **kwargs) ).result() - def remove(): + def remove() -> None: """Threadsafe removal.""" run_callback_threadsafe(hass.loop, async_remove).result() @@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory): @callback @bind_hass -def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None): +def async_track_state_change( + hass: HomeAssistant, + entity_ids: Union[str, Iterable[str]], + action: Callable[[str, State, State], None], + from_state: Union[None, str, Iterable[str]] = None, + to_state: Union[None, str, Iterable[str]] = None, +) -> CALLBACK_TYPE: """Track specific state changes. entity_ids, from_state and to_state can be string or list. @@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) @callback - def state_change_listener(event): + def state_change_listener(event: Event) -> None: """Handle specific state changes.""" - if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids: + if ( + entity_ids != MATCH_ALL + and cast(str, event.data.get("entity_id")) not in entity_ids + ): return old_state = event.data.get("old_state") @@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change) @callback @bind_hass -def async_track_template(hass, template, action, variables=None): +def async_track_template( + hass: HomeAssistant, + template: Template, + action: Callable[[str, State, State], None], + variables: Optional[Dict[str, Any]] = None, +) -> CALLBACK_TYPE: """Add a listener that track state changes with template condition.""" from . import condition @@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None): already_triggered = False @callback - def template_condition_listener(entity_id, from_s, to_s): + def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None: """Check if condition is correct and run action.""" nonlocal already_triggered template_result = condition.async_template(hass, template, variables) @@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template) @callback @bind_hass def async_track_same_state( - hass, period, action, async_check_same_func, entity_ids=MATCH_ALL -): + hass: HomeAssistant, + period: timedelta, + action: Callable[..., None], + async_check_same_func: Callable[[str, State, State], bool], + entity_ids: Union[str, Iterable[str]] = MATCH_ALL, +) -> CALLBACK_TYPE: """Track the state of entities for a period and run an action. If async_check_func is None it use the state of orig_value. Without entity_ids we track all state changes. """ - async_remove_state_for_cancel = None - async_remove_state_for_listener = None + async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None + async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None @callback - def clear_listener(): + def clear_listener() -> None: """Clear all unsub listener.""" nonlocal async_remove_state_for_cancel, async_remove_state_for_listener @@ -157,7 +175,7 @@ def async_track_same_state( async_remove_state_for_cancel = None @callback - def state_for_listener(now): + def state_for_listener(now: Any) -> None: """Fire on state changes after a delay and calls action.""" nonlocal async_remove_state_for_listener async_remove_state_for_listener = None @@ -165,7 +183,9 @@ def async_track_same_state( hass.async_run_job(action) @callback - def state_for_cancel_listener(entity, from_state, to_state): + def state_for_cancel_listener( + entity: str, from_state: State, to_state: State + ) -> None: """Fire on changes and cancel for listener if changed.""" if not async_check_same_func(entity, from_state, to_state): clear_listener() @@ -193,7 +213,7 @@ def async_track_point_in_time( utc_point_in_time = dt_util.as_utc(point_in_time) @callback - def utc_converter(utc_now): + def utc_converter(utc_now: datetime) -> None: """Convert passed in UTC now to local now.""" hass.async_run_job(action, dt_util.as_local(utc_now)) @@ -213,7 +233,7 @@ def async_track_point_in_utc_time( point_in_time = dt_util.as_utc(point_in_time) @callback - def point_in_time_listener(event): + def point_in_time_listener(event: Event) -> None: """Listen for matching time_changed events.""" now = event.data[ATTR_NOW] @@ -225,7 +245,7 @@ def async_track_point_in_utc_time( # available to execute this listener it might occur that the # listener gets lined up twice to be executed. This will make # sure the second time it does nothing. - point_in_time_listener.run = True + setattr(point_in_time_listener, "run", True) async_unsub() hass.async_run_job(action, now) @@ -260,12 +280,12 @@ def async_track_time_interval( """Add a listener that fires repetitively at every timedelta interval.""" remove = None - def next_interval(): + def next_interval() -> datetime: """Return the next interval.""" return dt_util.utcnow() + interval @callback - def interval_listener(now): + def interval_listener(now: datetime) -> None: """Handle elapsed intervals.""" nonlocal remove remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) @@ -273,7 +293,7 @@ def async_track_time_interval( remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) - def remove_listener(): + def remove_listener() -> None: """Remove interval listener.""" remove() @@ -387,7 +407,7 @@ def async_track_utc_time_change( if all(val is None for val in (hour, minute, second)): @callback - def time_change_listener(event): + def time_change_listener(event: Event) -> None: """Fire every time event that comes in.""" hass.async_run_job(action, event.data[ATTR_NOW]) diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index 1d9ca691451..aa17b2a1fba 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -7,7 +7,7 @@ import random import re from datetime import datetime from functools import wraps -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List, Optional, Union import jinja2 from jinja2 import contextfilter, contextfunction @@ -72,7 +72,9 @@ def render_complex(value, variables=None): return value.async_render(variables) -def extract_entities(template, variables=None): +def extract_entities( + template: Optional[str], variables: Optional[Dict[str, Any]] = None +) -> Union[str, List[str]]: """Extract all entities for state_changed listener from template string.""" if template is None or _RE_JINJA_DELIMITERS.search(template) is None: return [] @@ -86,6 +88,7 @@ def extract_entities(template, variables=None): for result in extraction: if ( result[0] == "trigger.entity_id" + and variables and "trigger" in variables and "entity_id" in variables["trigger"] ): @@ -163,7 +166,7 @@ class Template: if not isinstance(template, str): raise TypeError("Expected template to be a string") - self.template = template + self.template: str = template self._compiled_code = None self._compiled = None self.hass = hass @@ -187,7 +190,9 @@ class Template: except jinja2.exceptions.TemplateSyntaxError as err: raise TemplateError(err) - def extract_entities(self, variables=None): + def extract_entities( + self, variables: Dict[str, Any] = None + ) -> Union[str, List[str]]: """Extract all entities for state_changed listener.""" return extract_entities(self.template, variables)