Type hint improvements (#33082)
parent
f04be61f6f
commit
267d98b5eb
|
@ -486,7 +486,7 @@ class Event:
|
|||
def __init__(
|
||||
self,
|
||||
event_type: str,
|
||||
data: Optional[Dict] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
time_fired: Optional[int] = None,
|
||||
context: Optional[Context] = None,
|
||||
|
@ -550,9 +550,7 @@ class EventBus:
|
|||
@property
|
||||
def listeners(self) -> Dict[str, int]:
|
||||
"""Return dictionary with events and the number of listeners."""
|
||||
return run_callback_threadsafe( # type: ignore
|
||||
self._hass.loop, self.async_listeners
|
||||
).result()
|
||||
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
|
||||
|
||||
def fire(
|
||||
self,
|
||||
|
@ -852,7 +850,7 @@ class StateMachine:
|
|||
future = run_callback_threadsafe(
|
||||
self._loop, self.async_entity_ids, domain_filter
|
||||
)
|
||||
return future.result() # type: ignore
|
||||
return future.result()
|
||||
|
||||
@callback
|
||||
def async_entity_ids(self, domain_filter: Optional[str] = None) -> List[str]:
|
||||
|
@ -873,9 +871,7 @@ class StateMachine:
|
|||
|
||||
def all(self) -> List[State]:
|
||||
"""Create a list of all states."""
|
||||
return run_callback_threadsafe( # type: ignore
|
||||
self._loop, self.async_all
|
||||
).result()
|
||||
return run_callback_threadsafe(self._loop, self.async_all).result()
|
||||
|
||||
@callback
|
||||
def async_all(self) -> List[State]:
|
||||
|
@ -905,7 +901,7 @@ class StateMachine:
|
|||
|
||||
Returns boolean to indicate if an entity was removed.
|
||||
"""
|
||||
return run_callback_threadsafe( # type: ignore
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_remove, entity_id
|
||||
).result()
|
||||
|
||||
|
@ -1064,9 +1060,7 @@ class ServiceRegistry:
|
|||
@property
|
||||
def services(self) -> Dict[str, Dict[str, Service]]:
|
||||
"""Return dictionary with per domain a list of available services."""
|
||||
return run_callback_threadsafe( # type: ignore
|
||||
self._hass.loop, self.async_services
|
||||
).result()
|
||||
return run_callback_threadsafe(self._hass.loop, self.async_services).result()
|
||||
|
||||
@callback
|
||||
def async_services(self) -> Dict[str, Dict[str, Service]]:
|
||||
|
|
|
@ -146,19 +146,16 @@ def numeric_state(
|
|||
variables: TemplateVarsType = None,
|
||||
) -> bool:
|
||||
"""Test a numeric state condition."""
|
||||
return cast(
|
||||
bool,
|
||||
run_callback_threadsafe(
|
||||
hass.loop,
|
||||
async_numeric_state,
|
||||
hass,
|
||||
entity,
|
||||
below,
|
||||
above,
|
||||
value_template,
|
||||
variables,
|
||||
).result(),
|
||||
)
|
||||
return run_callback_threadsafe(
|
||||
hass.loop,
|
||||
async_numeric_state,
|
||||
hass,
|
||||
entity,
|
||||
below,
|
||||
above,
|
||||
value_template,
|
||||
variables,
|
||||
).result()
|
||||
|
||||
|
||||
def async_numeric_state(
|
||||
|
@ -353,12 +350,9 @@ def template(
|
|||
hass: HomeAssistant, value_template: Template, variables: TemplateVarsType = None
|
||||
) -> bool:
|
||||
"""Test if template condition matches."""
|
||||
return cast(
|
||||
bool,
|
||||
run_callback_threadsafe(
|
||||
hass.loop, async_template, hass, value_template, variables
|
||||
).result(),
|
||||
)
|
||||
return run_callback_threadsafe(
|
||||
hass.loop, async_template, hass, value_template, variables
|
||||
).result()
|
||||
|
||||
|
||||
def async_template(
|
||||
|
|
|
@ -10,11 +10,10 @@ from typing import Any, Callable, Collection, Dict, Optional, Union
|
|||
from homeassistant import core, setup
|
||||
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.loader import DEPENDENCY_BLACKLIST, bind_hass
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
EVENT_LOAD_PLATFORM = "load_platform.{}"
|
||||
ATTR_PLATFORM = "platform"
|
||||
|
||||
|
@ -56,13 +55,29 @@ def async_listen(
|
|||
|
||||
|
||||
@bind_hass
|
||||
def discover(hass, service, discovered, component, hass_config):
|
||||
def discover(
|
||||
hass: core.HomeAssistant,
|
||||
service: str,
|
||||
discovered: DiscoveryInfoType,
|
||||
component: str,
|
||||
hass_config: ConfigType,
|
||||
) -> None:
|
||||
"""Fire discovery event. Can ensure a component is loaded."""
|
||||
hass.add_job(async_discover(hass, service, discovered, component, hass_config))
|
||||
hass.add_job(
|
||||
async_discover( # type: ignore
|
||||
hass, service, discovered, component, hass_config
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_discover(hass, service, discovered, component, hass_config):
|
||||
async def async_discover(
|
||||
hass: core.HomeAssistant,
|
||||
service: str,
|
||||
discovered: Optional[DiscoveryInfoType],
|
||||
component: Optional[str],
|
||||
hass_config: ConfigType,
|
||||
) -> None:
|
||||
"""Fire discovery event. Can ensure a component is loaded."""
|
||||
if component in DEPENDENCY_BLACKLIST:
|
||||
raise HomeAssistantError(f"Cannot discover the {component} component.")
|
||||
|
@ -70,7 +85,7 @@ async def async_discover(hass, service, discovered, component, hass_config):
|
|||
if component is not None and component not in hass.config.components:
|
||||
await setup.async_setup_component(hass, component, hass_config)
|
||||
|
||||
data = {ATTR_SERVICE: service}
|
||||
data: Dict[str, Any] = {ATTR_SERVICE: service}
|
||||
|
||||
if discovered is not None:
|
||||
data[ATTR_DISCOVERED] = discovered
|
||||
|
@ -117,7 +132,13 @@ def async_listen_platform(
|
|||
|
||||
|
||||
@bind_hass
|
||||
def load_platform(hass, component, platform, discovered, hass_config):
|
||||
def load_platform(
|
||||
hass: core.HomeAssistant,
|
||||
component: str,
|
||||
platform: str,
|
||||
discovered: DiscoveryInfoType,
|
||||
hass_config: ConfigType,
|
||||
) -> None:
|
||||
"""Load a component and platform dynamically.
|
||||
|
||||
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
|
||||
|
@ -129,12 +150,20 @@ def load_platform(hass, component, platform, discovered, hass_config):
|
|||
Use `listen_platform` to register a callback for these events.
|
||||
"""
|
||||
hass.add_job(
|
||||
async_load_platform(hass, component, platform, discovered, hass_config)
|
||||
async_load_platform( # type: ignore
|
||||
hass, component, platform, discovered, hass_config
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_load_platform(hass, component, platform, discovered, hass_config):
|
||||
async def async_load_platform(
|
||||
hass: core.HomeAssistant,
|
||||
component: str,
|
||||
platform: str,
|
||||
discovered: DiscoveryInfoType,
|
||||
hass_config: ConfigType,
|
||||
) -> None:
|
||||
"""Load a component and platform dynamically.
|
||||
|
||||
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
|
||||
|
@ -164,7 +193,7 @@ async def async_load_platform(hass, component, platform, discovered, hass_config
|
|||
if not setup_success:
|
||||
return
|
||||
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
|
||||
ATTR_PLATFORM: platform,
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ from homeassistant.helpers.entity_registry import (
|
|||
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs, no-warn-return-any
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SLOW_UPDATE_WARNING = 10
|
||||
|
|
|
@ -191,7 +191,7 @@ class EntityComponent:
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return await service.async_extract_entities( # type: ignore
|
||||
return await service.async_extract_entities(
|
||||
self.hass, self.entities, service_call, expand_group
|
||||
)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
from functools import partial, wraps
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -22,9 +22,10 @@ from homeassistant.exceptions import (
|
|||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.helpers import template, typing
|
||||
from homeassistant.helpers import template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
|
||||
from homeassistant.loader import async_get_integration, bind_hass
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
||||
|
@ -42,8 +43,12 @@ SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
|||
|
||||
@bind_hass
|
||||
def call_from_config(
|
||||
hass, config, blocking=False, variables=None, validate_config=True
|
||||
):
|
||||
hass: HomeAssistantType,
|
||||
config: ConfigType,
|
||||
blocking: bool = False,
|
||||
variables: TemplateVarsType = None,
|
||||
validate_config: bool = True,
|
||||
) -> None:
|
||||
"""Call a service based on a config hash."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
async_call_from_config(hass, config, blocking, variables, validate_config),
|
||||
|
@ -53,8 +58,13 @@ def call_from_config(
|
|||
|
||||
@bind_hass
|
||||
async def async_call_from_config(
|
||||
hass, config, blocking=False, variables=None, validate_config=True, context=None
|
||||
):
|
||||
hass: HomeAssistantType,
|
||||
config: ConfigType,
|
||||
blocking: bool = False,
|
||||
variables: TemplateVarsType = None,
|
||||
validate_config: bool = True,
|
||||
context: Optional[ha.Context] = None,
|
||||
) -> None:
|
||||
"""Call a service based on a config hash."""
|
||||
try:
|
||||
parms = async_prepare_call_from_config(hass, config, variables, validate_config)
|
||||
|
@ -68,7 +78,12 @@ async def async_call_from_config(
|
|||
|
||||
@ha.callback
|
||||
@bind_hass
|
||||
def async_prepare_call_from_config(hass, config, variables=None, validate_config=False):
|
||||
def async_prepare_call_from_config(
|
||||
hass: HomeAssistantType,
|
||||
config: ConfigType,
|
||||
variables: TemplateVarsType = None,
|
||||
validate_config: bool = False,
|
||||
) -> Tuple[str, str, Dict[str, Any]]:
|
||||
"""Prepare to call a service based on a config hash."""
|
||||
if validate_config:
|
||||
try:
|
||||
|
@ -113,7 +128,9 @@ def async_prepare_call_from_config(hass, config, variables=None, validate_config
|
|||
|
||||
|
||||
@bind_hass
|
||||
def extract_entity_ids(hass, service_call, expand_group=True):
|
||||
def extract_entity_ids(
|
||||
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||
) -> Set[str]:
|
||||
"""Extract a list of entity ids from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
|
@ -124,7 +141,12 @@ def extract_entity_ids(hass, service_call, expand_group=True):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_entities(hass, entities, service_call, expand_group=True):
|
||||
async def async_extract_entities(
|
||||
hass: HomeAssistantType,
|
||||
entities: Iterable[Entity],
|
||||
service_call: ha.ServiceCall,
|
||||
expand_group: bool = True,
|
||||
) -> List[Entity]:
|
||||
"""Extract a list of entity objects from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
|
@ -158,7 +180,9 @@ async def async_extract_entities(hass, entities, service_call, expand_group=True
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_entity_ids(hass, service_call, expand_group=True):
|
||||
async def async_extract_entity_ids(
|
||||
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
|
||||
) -> Set[str]:
|
||||
"""Extract a list of entity ids from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
|
@ -166,7 +190,7 @@ async def async_extract_entity_ids(hass, service_call, expand_group=True):
|
|||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
||||
|
||||
extracted = set()
|
||||
extracted: Set[str] = set()
|
||||
|
||||
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
|
||||
None,
|
||||
|
@ -226,7 +250,9 @@ async def _load_services_file(hass: HomeAssistantType, domain: str) -> JSON_TYPE
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_all_descriptions(hass):
|
||||
async def async_get_all_descriptions(
|
||||
hass: HomeAssistantType,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Return descriptions (i.e. user documentation) for all service calls."""
|
||||
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
format_cache_key = "{}.{}".format
|
||||
|
@ -253,7 +279,7 @@ async def async_get_all_descriptions(hass):
|
|||
loaded[domain] = content
|
||||
|
||||
# Build response
|
||||
descriptions = {}
|
||||
descriptions: Dict[str, Dict[str, Any]] = {}
|
||||
for domain in services:
|
||||
descriptions[domain] = {}
|
||||
|
||||
|
@ -281,7 +307,9 @@ async def async_get_all_descriptions(hass):
|
|||
|
||||
@ha.callback
|
||||
@bind_hass
|
||||
def async_set_service_schema(hass, domain, service, schema):
|
||||
def async_set_service_schema(
|
||||
hass: HomeAssistantType, domain: str, service: str, schema: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Register a description for a service."""
|
||||
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
|
||||
|
||||
|
@ -454,7 +482,7 @@ async def _handle_entity_call(hass, entity, func, data, context):
|
|||
@bind_hass
|
||||
@ha.callback
|
||||
def async_register_admin_service(
|
||||
hass: typing.HomeAssistantType,
|
||||
hass: HomeAssistantType,
|
||||
domain: str,
|
||||
service: str,
|
||||
service_func: Callable,
|
||||
|
|
|
@ -51,7 +51,7 @@ _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{")
|
|||
|
||||
|
||||
@bind_hass
|
||||
def attach(hass, obj):
|
||||
def attach(hass: HomeAssistantType, obj: Any) -> None:
|
||||
"""Recursively attach hass to all template instances in list and dict."""
|
||||
if isinstance(obj, list):
|
||||
for child in obj:
|
||||
|
@ -63,7 +63,7 @@ def attach(hass, obj):
|
|||
obj.hass = hass
|
||||
|
||||
|
||||
def render_complex(value, variables=None):
|
||||
def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
|
||||
"""Recursive template creator helper function."""
|
||||
if isinstance(value, list):
|
||||
return [render_complex(item, variables) for item in value]
|
||||
|
@ -307,11 +307,11 @@ class Template:
|
|||
and self.hass == other.hass
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""Hash code for template."""
|
||||
return hash(self.template)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Representation of Template."""
|
||||
return 'Template("' + self.template + '")'
|
||||
|
||||
|
@ -333,7 +333,7 @@ class AllStates:
|
|||
raise TemplateError(f"Invalid domain name '{name}'")
|
||||
return DomainStates(self._hass, name)
|
||||
|
||||
def _collect_all(self):
|
||||
def _collect_all(self) -> None:
|
||||
render_info = self._hass.data.get(_RENDER_INFO)
|
||||
if render_info is not None:
|
||||
# pylint: disable=protected-access
|
||||
|
@ -349,7 +349,7 @@ class AllStates:
|
|||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return number of states."""
|
||||
self._collect_all()
|
||||
return len(self._hass.states.async_entity_ids())
|
||||
|
@ -359,7 +359,7 @@ class AllStates:
|
|||
state = _get_state(self._hass, entity_id)
|
||||
return STATE_UNKNOWN if state is None else state.state
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Representation of All States."""
|
||||
return "<template AllStates>"
|
||||
|
||||
|
@ -455,19 +455,21 @@ class TemplateState(State):
|
|||
return f"<template {rep[1:]}"
|
||||
|
||||
|
||||
def _collect_state(hass, entity_id):
|
||||
def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
|
||||
entity_collect = hass.data.get(_RENDER_INFO)
|
||||
if entity_collect is not None:
|
||||
# pylint: disable=protected-access
|
||||
entity_collect._entities.append(entity_id)
|
||||
|
||||
|
||||
def _wrap_state(hass, state):
|
||||
def _wrap_state(
|
||||
hass: HomeAssistantType, state: Optional[State]
|
||||
) -> Optional[TemplateState]:
|
||||
"""Wrap a state."""
|
||||
return None if state is None else TemplateState(hass, state)
|
||||
|
||||
|
||||
def _get_state(hass, entity_id):
|
||||
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
|
||||
state = hass.states.get(entity_id)
|
||||
if state is None:
|
||||
# Only need to collect if none, if not none collect first actual
|
||||
|
@ -477,7 +479,9 @@ def _get_state(hass, entity_id):
|
|||
return _wrap_state(hass, state)
|
||||
|
||||
|
||||
def _resolve_state(hass, entity_id_or_state):
|
||||
def _resolve_state(
|
||||
hass: HomeAssistantType, entity_id_or_state: Any
|
||||
) -> Union[State, TemplateState, None]:
|
||||
"""Return state or entity_id if given."""
|
||||
if isinstance(entity_id_or_state, State):
|
||||
return entity_id_or_state
|
||||
|
|
|
@ -6,10 +6,12 @@ import functools
|
|||
import logging
|
||||
import threading
|
||||
from traceback import extract_stack
|
||||
from typing import Any, Callable, Coroutine
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
|
||||
"""Submit a coroutine object to a given event loop.
|
||||
|
@ -33,8 +35,8 @@ def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
|
|||
|
||||
|
||||
def run_callback_threadsafe(
|
||||
loop: AbstractEventLoop, callback: Callable, *args: Any
|
||||
) -> concurrent.futures.Future:
|
||||
loop: AbstractEventLoop, callback: Callable[..., T], *args: Any
|
||||
) -> "concurrent.futures.Future[T]":
|
||||
"""Submit a callback object to a given event loop.
|
||||
|
||||
Return a concurrent.futures.Future to access the result.
|
||||
|
|
Loading…
Reference in New Issue