Type hint improvements (#28260)
* Add and improve core and config_entries type hints * Complete and improve config_entries type hints * More entity registry type hints * Complete helpers.event type hintspull/28310/head
parent
f7a64019b6
commit
f88ead597a
|
@ -16,7 +16,7 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
STATE_UNAVAILABLE,
|
||||
)
|
||||
from homeassistant.core import State, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, State, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
|
@ -96,7 +96,7 @@ class LightGroup(light.Light):
|
|||
self._effect_list: Optional[List[str]] = None
|
||||
self._effect: Optional[str] = None
|
||||
self._supported_features: int = 0
|
||||
self._async_unsub_state_changed = None
|
||||
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register callbacks."""
|
||||
|
@ -108,6 +108,7 @@ class LightGroup(light.Light):
|
|||
"""Handle child updates."""
|
||||
self.async_schedule_update_ha_state(True)
|
||||
|
||||
assert self.hass is not None
|
||||
self._async_unsub_state_changed = async_track_state_change(
|
||||
self.hass, self._entity_ids, async_state_changed_listener
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@ from homeassistant.const import (
|
|||
STATE_ON,
|
||||
STATE_UNAVAILABLE,
|
||||
)
|
||||
from homeassistant.core import State, callback
|
||||
from homeassistant.core import CALLBACK_TYPE, State, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
|
@ -56,7 +56,7 @@ class LightSwitch(Light):
|
|||
self._switch_entity_id = switch_entity_id
|
||||
self._is_on = False
|
||||
self._available = False
|
||||
self._async_unsub_state_changed = None
|
||||
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -113,6 +113,7 @@ class LightSwitch(Light):
|
|||
"""Handle child updates."""
|
||||
self.async_schedule_update_ha_state(True)
|
||||
|
||||
assert self.hass is not None
|
||||
self._async_unsub_state_changed = async_track_state_change(
|
||||
self.hass, self._switch_entity_id, async_state_changed_listener
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
import logging
|
||||
import functools
|
||||
import uuid
|
||||
from typing import Any, Callable, List, Optional, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, cast
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
@ -14,11 +14,11 @@ from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
|
|||
from homeassistant.setup import async_setup_component, async_process_deps_reqs
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.helpers import entity_registry
|
||||
from homeassistant.helpers.event import Event
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_UNDEF = object()
|
||||
_UNDEF: dict = {}
|
||||
|
||||
SOURCE_USER = "user"
|
||||
SOURCE_DISCOVERY = "discovery"
|
||||
|
@ -205,7 +205,7 @@ class ConfigEntry:
|
|||
wait_time,
|
||||
)
|
||||
|
||||
async def setup_again(now):
|
||||
async def setup_again(now: Any) -> None:
|
||||
"""Run setup again."""
|
||||
self._async_cancel_retry_setup = None
|
||||
await self.async_setup(hass, integration=integration, tries=tries)
|
||||
|
@ -357,7 +357,7 @@ class ConfigEntry:
|
|||
|
||||
return lambda: self.update_listeners.remove(weak_listener)
|
||||
|
||||
def as_dict(self):
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Return dictionary version of this entry."""
|
||||
return {
|
||||
"entry_id": self.entry_id,
|
||||
|
@ -418,7 +418,7 @@ class ConfigEntries:
|
|||
return list(self._entries)
|
||||
return [entry for entry in self._entries if entry.domain == domain]
|
||||
|
||||
async def async_remove(self, entry_id):
|
||||
async def async_remove(self, entry_id: str) -> Dict[str, Any]:
|
||||
"""Remove an entry."""
|
||||
entry = self.async_get_entry(entry_id)
|
||||
|
||||
|
@ -529,8 +529,13 @@ class ConfigEntries:
|
|||
|
||||
@callback
|
||||
def async_update_entry(
|
||||
self, entry, *, data=_UNDEF, options=_UNDEF, system_options=_UNDEF
|
||||
):
|
||||
self,
|
||||
entry: ConfigEntry,
|
||||
*,
|
||||
data: dict = _UNDEF,
|
||||
options: dict = _UNDEF,
|
||||
system_options: dict = _UNDEF,
|
||||
) -> None:
|
||||
"""Update a config entry."""
|
||||
if data is not _UNDEF:
|
||||
entry.data = data
|
||||
|
@ -547,7 +552,7 @@ class ConfigEntries:
|
|||
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_forward_entry_setup(self, entry, domain):
|
||||
async def async_forward_entry_setup(self, entry: ConfigEntry, domain: str) -> bool:
|
||||
"""Forward the setup of an entry to a different component.
|
||||
|
||||
By default an entry is setup with the component it belongs to. If that
|
||||
|
@ -567,8 +572,9 @@ class ConfigEntries:
|
|||
integration = await loader.async_get_integration(self.hass, domain)
|
||||
|
||||
await entry.async_setup(self.hass, integration=integration)
|
||||
return True
|
||||
|
||||
async def async_forward_entry_unload(self, entry, domain):
|
||||
async def async_forward_entry_unload(self, entry: ConfigEntry, domain: str) -> bool:
|
||||
"""Forward the unloading of an entry to a different component."""
|
||||
# It was never loaded.
|
||||
if domain not in self.hass.config.components:
|
||||
|
@ -578,7 +584,9 @@ class ConfigEntries:
|
|||
|
||||
return await entry.async_unload(self.hass, integration=integration)
|
||||
|
||||
async def _async_finish_flow(self, flow, result):
|
||||
async def _async_finish_flow(
|
||||
self, flow: "ConfigFlow", result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Finish a config flow and add an entry."""
|
||||
# Remove notification if no other discovery config entries in progress
|
||||
if not any(
|
||||
|
@ -611,7 +619,9 @@ class ConfigEntries:
|
|||
result["result"] = entry
|
||||
return result
|
||||
|
||||
async def _async_create_flow(self, handler_key, *, context, data):
|
||||
async def _async_create_flow(
|
||||
self, handler_key: str, *, context: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> "ConfigFlow":
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
|
@ -654,7 +664,7 @@ class ConfigEntries:
|
|||
notification_id=DISCOVERY_NOTIFICATION_ID,
|
||||
)
|
||||
|
||||
flow = handler()
|
||||
flow = cast(ConfigFlow, handler())
|
||||
flow.init_step = source
|
||||
return flow
|
||||
|
||||
|
@ -663,12 +673,12 @@ class ConfigEntries:
|
|||
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self):
|
||||
def _data_to_save(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Return data to save."""
|
||||
return {"entries": [entry.as_dict() for entry in self._entries]}
|
||||
|
||||
|
||||
async def _old_conf_migrator(old_config):
|
||||
async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Migrate the pre-0.73 config format to the latest version."""
|
||||
return {"entries": old_config}
|
||||
|
||||
|
@ -686,18 +696,20 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
def async_get_options_flow(config_entry: ConfigEntry) -> "OptionsFlow":
|
||||
"""Get the options flow for this handler."""
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
@callback
|
||||
def _async_current_entries(self):
|
||||
def _async_current_entries(self) -> List[ConfigEntry]:
|
||||
"""Return current entries."""
|
||||
assert self.hass is not None
|
||||
return self.hass.config_entries.async_entries(self.handler)
|
||||
|
||||
@callback
|
||||
def _async_in_progress(self):
|
||||
def _async_in_progress(self) -> List[Dict]:
|
||||
"""Return other in progress flows for current domain."""
|
||||
assert self.hass is not None
|
||||
return [
|
||||
flw
|
||||
for flw in self.hass.config_entries.flow.async_progress()
|
||||
|
@ -715,29 +727,33 @@ class OptionsFlowManager:
|
|||
hass, self._async_create_flow, self._async_finish_flow
|
||||
)
|
||||
|
||||
async def _async_create_flow(self, entry_id, *, context, data):
|
||||
async def _async_create_flow(
|
||||
self, entry_id: str, *, context: Dict[str, Any], data: Dict[str, Any]
|
||||
) -> Optional["OptionsFlow"]:
|
||||
"""Create an options flow for a config entry.
|
||||
|
||||
Entry_id and flow.handler is the same thing to map entry with flow.
|
||||
"""
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry is None:
|
||||
return
|
||||
return None
|
||||
|
||||
if entry.domain not in HANDLERS:
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
flow = HANDLERS[entry.domain].async_get_options_flow(entry)
|
||||
flow = cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
|
||||
return flow
|
||||
|
||||
async def _async_finish_flow(self, flow, result):
|
||||
async def _async_finish_flow(
|
||||
self, flow: "OptionsFlow", result: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Finish an options flow and update options for configuration entry.
|
||||
|
||||
Flow.handler and entry_id is the same thing to map flow with entry.
|
||||
"""
|
||||
entry = self.hass.config_entries.async_get_entry(flow.handler)
|
||||
if entry is None:
|
||||
return
|
||||
return None
|
||||
self.hass.config_entries.async_update_entry(entry, options=result["data"])
|
||||
|
||||
result["result"] = True
|
||||
|
@ -747,7 +763,7 @@ class OptionsFlowManager:
|
|||
class OptionsFlow(data_entry_flow.FlowHandler):
|
||||
"""Base class for config option flows."""
|
||||
|
||||
pass
|
||||
handler: str
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -756,11 +772,11 @@ class SystemOptions:
|
|||
|
||||
disable_new_entities = attr.ib(type=bool, default=False)
|
||||
|
||||
def update(self, *, disable_new_entities):
|
||||
def update(self, *, disable_new_entities: bool) -> None:
|
||||
"""Update properties."""
|
||||
self.disable_new_entities = disable_new_entities
|
||||
|
||||
def as_dict(self):
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Return dictionary version of this config entrys system options."""
|
||||
return {"disable_new_entities": self.disable_new_entities}
|
||||
|
||||
|
@ -784,7 +800,7 @@ class EntityRegistryDisabledHandler:
|
|||
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
|
||||
)
|
||||
|
||||
async def _handle_entry_updated(self, event):
|
||||
async def _handle_entry_updated(self, event: Event) -> None:
|
||||
"""Handle entity registry entry update."""
|
||||
if (
|
||||
event.data["action"] != "update"
|
||||
|
@ -811,6 +827,7 @@ class EntityRegistryDisabledHandler:
|
|||
config_entry = self.hass.config_entries.async_get_entry(
|
||||
entity_entry.config_entry_id
|
||||
)
|
||||
assert config_entry is not None
|
||||
|
||||
if config_entry.entry_id not in self.changed and await support_entry_unload(
|
||||
self.hass, config_entry.domain
|
||||
|
@ -830,7 +847,7 @@ class EntityRegistryDisabledHandler:
|
|||
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
|
||||
)
|
||||
|
||||
async def _handle_reload(self, _now):
|
||||
async def _handle_reload(self, _now: Any) -> None:
|
||||
"""Handle a reload."""
|
||||
self._remove_call_later = None
|
||||
to_reload = self.changed
|
||||
|
|
|
@ -1283,7 +1283,7 @@ class Config:
|
|||
self.skip_pip: bool = False
|
||||
|
||||
# List of loaded components
|
||||
self.components: set = set()
|
||||
self.components: Set[str] = set()
|
||||
|
||||
# API (HTTP) server configuration, see components.http.ApiConfig
|
||||
self.api: Optional[Any] = None
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Classes to help gather user submissions."""
|
||||
import logging
|
||||
from typing import Dict, Any, Callable, Hashable, List, Optional
|
||||
from typing import Dict, Any, Callable, List, Optional
|
||||
import uuid
|
||||
import voluptuous as vol
|
||||
from .core import callback, HomeAssistant
|
||||
|
@ -58,7 +58,7 @@ class FlowManager:
|
|||
]
|
||||
|
||||
async def async_init(
|
||||
self, handler: Hashable, *, context: Optional[Dict] = None, data: Any = None
|
||||
self, handler: str, *, context: Optional[Dict] = None, data: Any = None
|
||||
) -> Any:
|
||||
"""Start a configuration flow."""
|
||||
if context is None:
|
||||
|
@ -170,7 +170,7 @@ class FlowHandler:
|
|||
# Set by flow manager
|
||||
flow_id: str = None # type: ignore
|
||||
hass: Optional[HomeAssistant] = None
|
||||
handler: Optional[Hashable] = None
|
||||
handler: Optional[str] = None
|
||||
cur_step: Optional[Dict[str, str]] = None
|
||||
context: Dict
|
||||
|
||||
|
|
|
@ -399,7 +399,7 @@ class OAuth2Session:
|
|||
|
||||
new_token = await self.implementation.async_refresh_token(token)
|
||||
|
||||
self.hass.config_entries.async_update_entry( # type: ignore
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data={**self.config_entry.data, "token": new_token}
|
||||
)
|
||||
|
||||
|
|
|
@ -7,15 +7,15 @@ The Entity Registry will persist itself 10 seconds after a new entity is
|
|||
registered. Registering a new entity while a timer is in progress resets the
|
||||
timer.
|
||||
"""
|
||||
from asyncio import Event
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, Dict, Iterable, List, Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
from homeassistant.core import callback, split_entity_id, valid_entity_id
|
||||
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
|
||||
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import ensure_unique_string, slugify
|
||||
|
@ -24,8 +24,7 @@ from homeassistant.util.yaml import load_yaml
|
|||
from .typing import HomeAssistantType
|
||||
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
PATH_REGISTRY = "entity_registry.yaml"
|
||||
DATA_REGISTRY = "entity_registry"
|
||||
|
@ -51,7 +50,7 @@ class RegistryEntry:
|
|||
platform = attr.ib(type=str)
|
||||
name = attr.ib(type=str, default=None)
|
||||
device_id = attr.ib(type=str, default=None)
|
||||
config_entry_id = attr.ib(type=str, default=None)
|
||||
config_entry_id: Optional[str] = attr.ib(default=None)
|
||||
disabled_by = attr.ib(
|
||||
type=Optional[str],
|
||||
default=None,
|
||||
|
@ -68,12 +67,12 @@ class RegistryEntry:
|
|||
domain = attr.ib(type=str, init=False, repr=False)
|
||||
|
||||
@domain.default
|
||||
def _domain_default(self):
|
||||
def _domain_default(self) -> str:
|
||||
"""Compute domain value."""
|
||||
return split_entity_id(self.entity_id)[0]
|
||||
|
||||
@property
|
||||
def disabled(self):
|
||||
def disabled(self) -> bool:
|
||||
"""Return if entry is disabled."""
|
||||
return self.disabled_by is not None
|
||||
|
||||
|
@ -81,17 +80,17 @@ class RegistryEntry:
|
|||
class EntityRegistry:
|
||||
"""Class to hold a registry of entities."""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistantType):
|
||||
"""Initialize the registry."""
|
||||
self.hass = hass
|
||||
self.entities = None
|
||||
self.entities: Dict[str, RegistryEntry]
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
self.hass.bus.async_listen(
|
||||
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_removed
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_is_registered(self, entity_id):
|
||||
def async_is_registered(self, entity_id: str) -> bool:
|
||||
"""Check if an entity_id is currently registered."""
|
||||
return entity_id in self.entities
|
||||
|
||||
|
@ -116,8 +115,11 @@ class EntityRegistry:
|
|||
|
||||
@callback
|
||||
def async_generate_entity_id(
|
||||
self, domain, suggested_object_id, known_object_ids=None
|
||||
):
|
||||
self,
|
||||
domain: str,
|
||||
suggested_object_id: str,
|
||||
known_object_ids: Optional[Iterable[str]] = None,
|
||||
) -> str:
|
||||
"""Generate an entity ID that does not conflict.
|
||||
|
||||
Conflicts checked against registered and currently existing entities.
|
||||
|
@ -195,7 +197,7 @@ class EntityRegistry:
|
|||
return entity
|
||||
|
||||
@callback
|
||||
def async_remove(self, entity_id):
|
||||
def async_remove(self, entity_id: str) -> None:
|
||||
"""Remove an entity from registry."""
|
||||
self.entities.pop(entity_id)
|
||||
self.hass.bus.async_fire(
|
||||
|
@ -204,7 +206,7 @@ class EntityRegistry:
|
|||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_device_removed(self, event):
|
||||
def async_device_removed(self, event: Event) -> None:
|
||||
"""Handle the removal of a device.
|
||||
|
||||
Remove entities from the registry that are associated to a device when
|
||||
|
@ -309,7 +311,7 @@ class EntityRegistry:
|
|||
|
||||
return new
|
||||
|
||||
async def async_load(self):
|
||||
async def async_load(self) -> None:
|
||||
"""Load the entity registry."""
|
||||
data = await self.hass.helpers.storage.async_migrator(
|
||||
self.hass.config.path(PATH_REGISTRY),
|
||||
|
@ -317,7 +319,7 @@ class EntityRegistry:
|
|||
old_conf_load_func=load_yaml,
|
||||
old_conf_migrate_func=_async_migrate,
|
||||
)
|
||||
entities = OrderedDict()
|
||||
entities: Dict[str, RegistryEntry] = OrderedDict()
|
||||
|
||||
if data is not None:
|
||||
for entity in data["entities"]:
|
||||
|
@ -334,12 +336,12 @@ class EntityRegistry:
|
|||
self.entities = entities
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self):
|
||||
def async_schedule_save(self) -> None:
|
||||
"""Schedule saving the entity registry."""
|
||||
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self):
|
||||
def _data_to_save(self) -> Dict[str, Any]:
|
||||
"""Return data of entity registry to store in a file."""
|
||||
data = {}
|
||||
|
||||
|
@ -359,7 +361,7 @@ class EntityRegistry:
|
|||
return data
|
||||
|
||||
@callback
|
||||
def async_clear_config_entry(self, config_entry):
|
||||
def async_clear_config_entry(self, config_entry: str) -> None:
|
||||
"""Clear config entry from registry entries."""
|
||||
for entity_id in [
|
||||
entity_id
|
||||
|
@ -375,7 +377,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
|
|||
reg_or_evt = hass.data.get(DATA_REGISTRY)
|
||||
|
||||
if not reg_or_evt:
|
||||
evt = hass.data[DATA_REGISTRY] = Event()
|
||||
evt = hass.data[DATA_REGISTRY] = asyncio.Event()
|
||||
|
||||
reg = EntityRegistry(hass)
|
||||
await reg.async_load()
|
||||
|
@ -384,7 +386,7 @@ async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
|
|||
evt.set()
|
||||
return reg
|
||||
|
||||
if isinstance(reg_or_evt, Event):
|
||||
if isinstance(reg_or_evt, asyncio.Event):
|
||||
evt = reg_or_evt
|
||||
await evt.wait()
|
||||
return cast(EntityRegistry, hass.data.get(DATA_REGISTRY))
|
||||
|
@ -402,7 +404,7 @@ def async_entries_for_device(
|
|||
]
|
||||
|
||||
|
||||
async def _async_migrate(entities):
|
||||
async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Migrate the YAML config file to storage helper format."""
|
||||
return {
|
||||
"entities": [
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
"""Helpers for listening to events."""
|
||||
from datetime import datetime, timedelta
|
||||
import functools as ft
|
||||
from typing import Any, Callable, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.helpers.sun import get_astral_event_next
|
||||
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.core import HomeAssistant, callback, CALLBACK_TYPE, Event, State
|
||||
from homeassistant.const import (
|
||||
ATTR_NOW,
|
||||
EVENT_STATE_CHANGED,
|
||||
|
@ -21,16 +22,15 @@ from homeassistant.util import dt as dt_util
|
|||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
# PyLint does not like the use of threaded_listener_factory
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
def threaded_listener_factory(async_factory):
|
||||
def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
|
||||
"""Convert an async event helper to a threaded one."""
|
||||
|
||||
@ft.wraps(async_factory)
|
||||
def factory(*args, **kwargs):
|
||||
def factory(*args: Any, **kwargs: Any) -> CALLBACK_TYPE:
|
||||
"""Call async event helper safely."""
|
||||
hass = args[0]
|
||||
|
||||
|
@ -41,7 +41,7 @@ def threaded_listener_factory(async_factory):
|
|||
hass.loop, ft.partial(async_factory, *args, **kwargs)
|
||||
).result()
|
||||
|
||||
def remove():
|
||||
def remove() -> None:
|
||||
"""Threadsafe removal."""
|
||||
run_callback_threadsafe(hass.loop, async_remove).result()
|
||||
|
||||
|
@ -52,7 +52,13 @@ def threaded_listener_factory(async_factory):
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None):
|
||||
def async_track_state_change(
|
||||
hass: HomeAssistant,
|
||||
entity_ids: Union[str, Iterable[str]],
|
||||
action: Callable[[str, State, State], None],
|
||||
from_state: Union[None, str, Iterable[str]] = None,
|
||||
to_state: Union[None, str, Iterable[str]] = None,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track specific state changes.
|
||||
|
||||
entity_ids, from_state and to_state can be string or list.
|
||||
|
@ -74,9 +80,12 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, to_state
|
|||
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
|
||||
|
||||
@callback
|
||||
def state_change_listener(event):
|
||||
def state_change_listener(event: Event) -> None:
|
||||
"""Handle specific state changes."""
|
||||
if entity_ids != MATCH_ALL and event.data.get("entity_id") not in entity_ids:
|
||||
if (
|
||||
entity_ids != MATCH_ALL
|
||||
and cast(str, event.data.get("entity_id")) not in entity_ids
|
||||
):
|
||||
return
|
||||
|
||||
old_state = event.data.get("old_state")
|
||||
|
@ -103,7 +112,12 @@ track_state_change = threaded_listener_factory(async_track_state_change)
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_track_template(hass, template, action, variables=None):
|
||||
def async_track_template(
|
||||
hass: HomeAssistant,
|
||||
template: Template,
|
||||
action: Callable[[str, State, State], None],
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Add a listener that track state changes with template condition."""
|
||||
from . import condition
|
||||
|
||||
|
@ -111,7 +125,7 @@ def async_track_template(hass, template, action, variables=None):
|
|||
already_triggered = False
|
||||
|
||||
@callback
|
||||
def template_condition_listener(entity_id, from_s, to_s):
|
||||
def template_condition_listener(entity_id: str, from_s: State, to_s: State) -> None:
|
||||
"""Check if condition is correct and run action."""
|
||||
nonlocal already_triggered
|
||||
template_result = condition.async_template(hass, template, variables)
|
||||
|
@ -134,18 +148,22 @@ track_template = threaded_listener_factory(async_track_template)
|
|||
@callback
|
||||
@bind_hass
|
||||
def async_track_same_state(
|
||||
hass, period, action, async_check_same_func, entity_ids=MATCH_ALL
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
period: timedelta,
|
||||
action: Callable[..., None],
|
||||
async_check_same_func: Callable[[str, State, State], bool],
|
||||
entity_ids: Union[str, Iterable[str]] = MATCH_ALL,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track the state of entities for a period and run an action.
|
||||
|
||||
If async_check_func is None it use the state of orig_value.
|
||||
Without entity_ids we track all state changes.
|
||||
"""
|
||||
async_remove_state_for_cancel = None
|
||||
async_remove_state_for_listener = None
|
||||
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
|
||||
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
@callback
|
||||
def clear_listener():
|
||||
def clear_listener() -> None:
|
||||
"""Clear all unsub listener."""
|
||||
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
|
||||
|
||||
|
@ -157,7 +175,7 @@ def async_track_same_state(
|
|||
async_remove_state_for_cancel = None
|
||||
|
||||
@callback
|
||||
def state_for_listener(now):
|
||||
def state_for_listener(now: Any) -> None:
|
||||
"""Fire on state changes after a delay and calls action."""
|
||||
nonlocal async_remove_state_for_listener
|
||||
async_remove_state_for_listener = None
|
||||
|
@ -165,7 +183,9 @@ def async_track_same_state(
|
|||
hass.async_run_job(action)
|
||||
|
||||
@callback
|
||||
def state_for_cancel_listener(entity, from_state, to_state):
|
||||
def state_for_cancel_listener(
|
||||
entity: str, from_state: State, to_state: State
|
||||
) -> None:
|
||||
"""Fire on changes and cancel for listener if changed."""
|
||||
if not async_check_same_func(entity, from_state, to_state):
|
||||
clear_listener()
|
||||
|
@ -193,7 +213,7 @@ def async_track_point_in_time(
|
|||
utc_point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@callback
|
||||
def utc_converter(utc_now):
|
||||
def utc_converter(utc_now: datetime) -> None:
|
||||
"""Convert passed in UTC now to local now."""
|
||||
hass.async_run_job(action, dt_util.as_local(utc_now))
|
||||
|
||||
|
@ -213,7 +233,7 @@ def async_track_point_in_utc_time(
|
|||
point_in_time = dt_util.as_utc(point_in_time)
|
||||
|
||||
@callback
|
||||
def point_in_time_listener(event):
|
||||
def point_in_time_listener(event: Event) -> None:
|
||||
"""Listen for matching time_changed events."""
|
||||
now = event.data[ATTR_NOW]
|
||||
|
||||
|
@ -225,7 +245,7 @@ def async_track_point_in_utc_time(
|
|||
# available to execute this listener it might occur that the
|
||||
# listener gets lined up twice to be executed. This will make
|
||||
# sure the second time it does nothing.
|
||||
point_in_time_listener.run = True
|
||||
setattr(point_in_time_listener, "run", True)
|
||||
async_unsub()
|
||||
|
||||
hass.async_run_job(action, now)
|
||||
|
@ -260,12 +280,12 @@ def async_track_time_interval(
|
|||
"""Add a listener that fires repetitively at every timedelta interval."""
|
||||
remove = None
|
||||
|
||||
def next_interval():
|
||||
def next_interval() -> datetime:
|
||||
"""Return the next interval."""
|
||||
return dt_util.utcnow() + interval
|
||||
|
||||
@callback
|
||||
def interval_listener(now):
|
||||
def interval_listener(now: datetime) -> None:
|
||||
"""Handle elapsed intervals."""
|
||||
nonlocal remove
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
|
@ -273,7 +293,7 @@ def async_track_time_interval(
|
|||
|
||||
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
|
||||
|
||||
def remove_listener():
|
||||
def remove_listener() -> None:
|
||||
"""Remove interval listener."""
|
||||
remove()
|
||||
|
||||
|
@ -387,7 +407,7 @@ def async_track_utc_time_change(
|
|||
if all(val is None for val in (hour, minute, second)):
|
||||
|
||||
@callback
|
||||
def time_change_listener(event):
|
||||
def time_change_listener(event: Event) -> None:
|
||||
"""Fire every time event that comes in."""
|
||||
hass.async_run_job(action, event.data[ATTR_NOW])
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import random
|
|||
import re
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import jinja2
|
||||
from jinja2 import contextfilter, contextfunction
|
||||
|
@ -72,7 +72,9 @@ def render_complex(value, variables=None):
|
|||
return value.async_render(variables)
|
||||
|
||||
|
||||
def extract_entities(template, variables=None):
|
||||
def extract_entities(
|
||||
template: Optional[str], variables: Optional[Dict[str, Any]] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Extract all entities for state_changed listener from template string."""
|
||||
if template is None or _RE_JINJA_DELIMITERS.search(template) is None:
|
||||
return []
|
||||
|
@ -86,6 +88,7 @@ def extract_entities(template, variables=None):
|
|||
for result in extraction:
|
||||
if (
|
||||
result[0] == "trigger.entity_id"
|
||||
and variables
|
||||
and "trigger" in variables
|
||||
and "entity_id" in variables["trigger"]
|
||||
):
|
||||
|
@ -163,7 +166,7 @@ class Template:
|
|||
if not isinstance(template, str):
|
||||
raise TypeError("Expected template to be a string")
|
||||
|
||||
self.template = template
|
||||
self.template: str = template
|
||||
self._compiled_code = None
|
||||
self._compiled = None
|
||||
self.hass = hass
|
||||
|
@ -187,7 +190,9 @@ class Template:
|
|||
except jinja2.exceptions.TemplateSyntaxError as err:
|
||||
raise TemplateError(err)
|
||||
|
||||
def extract_entities(self, variables=None):
|
||||
def extract_entities(
|
||||
self, variables: Dict[str, Any] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Extract all entities for state_changed listener."""
|
||||
return extract_entities(self.template, variables)
|
||||
|
||||
|
|
Loading…
Reference in New Issue