Type hint improvements (#33082)

pull/34355/head
Ville Skyttä 2020-04-17 21:33:58 +03:00 committed by GitHub
parent f04be61f6f
commit 267d98b5eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 124 additions and 73 deletions

View File

@ -486,7 +486,7 @@ class Event:
def __init__(
self,
event_type: str,
data: Optional[Dict] = None,
data: Optional[Dict[str, Any]] = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None,
context: Optional[Context] = None,
@ -550,9 +550,7 @@ class EventBus:
@property
def listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners."""
return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_listeners
).result()
return run_callback_threadsafe(self._hass.loop, self.async_listeners).result()
def fire(
self,
@ -852,7 +850,7 @@ class StateMachine:
future = run_callback_threadsafe(
self._loop, self.async_entity_ids, domain_filter
)
return future.result() # type: ignore
return future.result()
@callback
def async_entity_ids(self, domain_filter: Optional[str] = None) -> List[str]:
@ -873,9 +871,7 @@ class StateMachine:
def all(self) -> List[State]:
"""Create a list of all states."""
return run_callback_threadsafe( # type: ignore
self._loop, self.async_all
).result()
return run_callback_threadsafe(self._loop, self.async_all).result()
@callback
def async_all(self) -> List[State]:
@ -905,7 +901,7 @@ class StateMachine:
Returns boolean to indicate if an entity was removed.
"""
return run_callback_threadsafe( # type: ignore
return run_callback_threadsafe(
self._loop, self.async_remove, entity_id
).result()
@ -1064,9 +1060,7 @@ class ServiceRegistry:
@property
def services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services."""
return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_services
).result()
return run_callback_threadsafe(self._hass.loop, self.async_services).result()
@callback
def async_services(self) -> Dict[str, Dict[str, Service]]:

View File

@ -146,19 +146,16 @@ def numeric_state(
variables: TemplateVarsType = None,
) -> bool:
"""Test a numeric state condition."""
return cast(
bool,
run_callback_threadsafe(
hass.loop,
async_numeric_state,
hass,
entity,
below,
above,
value_template,
variables,
).result(),
)
return run_callback_threadsafe(
hass.loop,
async_numeric_state,
hass,
entity,
below,
above,
value_template,
variables,
).result()
def async_numeric_state(
@ -353,12 +350,9 @@ def template(
hass: HomeAssistant, value_template: Template, variables: TemplateVarsType = None
) -> bool:
"""Test if template condition matches."""
return cast(
bool,
run_callback_threadsafe(
hass.loop, async_template, hass, value_template, variables
).result(),
)
return run_callback_threadsafe(
hass.loop, async_template, hass, value_template, variables
).result()
def async_template(

View File

@ -10,11 +10,10 @@ from typing import Any, Callable, Collection, Dict, Optional, Union
from homeassistant import core, setup
from homeassistant.const import ATTR_DISCOVERED, ATTR_SERVICE, EVENT_PLATFORM_DISCOVERED
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import DEPENDENCY_BLACKLIST, bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs
EVENT_LOAD_PLATFORM = "load_platform.{}"
ATTR_PLATFORM = "platform"
@ -56,13 +55,29 @@ def async_listen(
@bind_hass
def discover(hass, service, discovered, component, hass_config):
def discover(
hass: core.HomeAssistant,
service: str,
discovered: DiscoveryInfoType,
component: str,
hass_config: ConfigType,
) -> None:
"""Fire discovery event. Can ensure a component is loaded."""
hass.add_job(async_discover(hass, service, discovered, component, hass_config))
hass.add_job(
async_discover( # type: ignore
hass, service, discovered, component, hass_config
)
)
@bind_hass
async def async_discover(hass, service, discovered, component, hass_config):
async def async_discover(
hass: core.HomeAssistant,
service: str,
discovered: Optional[DiscoveryInfoType],
component: Optional[str],
hass_config: ConfigType,
) -> None:
"""Fire discovery event. Can ensure a component is loaded."""
if component in DEPENDENCY_BLACKLIST:
raise HomeAssistantError(f"Cannot discover the {component} component.")
@ -70,7 +85,7 @@ async def async_discover(hass, service, discovered, component, hass_config):
if component is not None and component not in hass.config.components:
await setup.async_setup_component(hass, component, hass_config)
data = {ATTR_SERVICE: service}
data: Dict[str, Any] = {ATTR_SERVICE: service}
if discovered is not None:
data[ATTR_DISCOVERED] = discovered
@ -117,7 +132,13 @@ def async_listen_platform(
@bind_hass
def load_platform(hass, component, platform, discovered, hass_config):
def load_platform(
hass: core.HomeAssistant,
component: str,
platform: str,
discovered: DiscoveryInfoType,
hass_config: ConfigType,
) -> None:
"""Load a component and platform dynamically.
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
@ -129,12 +150,20 @@ def load_platform(hass, component, platform, discovered, hass_config):
Use `listen_platform` to register a callback for these events.
"""
hass.add_job(
async_load_platform(hass, component, platform, discovered, hass_config)
async_load_platform( # type: ignore
hass, component, platform, discovered, hass_config
)
)
@bind_hass
async def async_load_platform(hass, component, platform, discovered, hass_config):
async def async_load_platform(
hass: core.HomeAssistant,
component: str,
platform: str,
discovered: DiscoveryInfoType,
hass_config: ConfigType,
) -> None:
"""Load a component and platform dynamically.
Target components will be loaded and an EVENT_PLATFORM_DISCOVERED will be
@ -164,7 +193,7 @@ async def async_load_platform(hass, component, platform, discovered, hass_config
if not setup_success:
return
data = {
data: Dict[str, Any] = {
ATTR_SERVICE: EVENT_LOAD_PLATFORM.format(component),
ATTR_PLATFORM: platform,
}

View File

@ -35,7 +35,7 @@ from homeassistant.helpers.entity_registry import (
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs, no-warn-return-any
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10

View File

@ -191,7 +191,7 @@ class EntityComponent:
This method must be run in the event loop.
"""
return await service.async_extract_entities( # type: ignore
return await service.async_extract_entities(
self.hass, self.entities, service_call, expand_group
)

View File

@ -2,7 +2,7 @@
import asyncio
from functools import partial, wraps
import logging
from typing import Callable
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
import voluptuous as vol
@ -22,9 +22,10 @@ from homeassistant.exceptions import (
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import template, typing
from homeassistant.helpers import template
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.loader import async_get_integration, bind_hass
from homeassistant.util.yaml import load_yaml
from homeassistant.util.yaml.loader import JSON_TYPE
@ -42,8 +43,12 @@ SERVICE_DESCRIPTION_CACHE = "service_description_cache"
@bind_hass
def call_from_config(
hass, config, blocking=False, variables=None, validate_config=True
):
hass: HomeAssistantType,
config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
) -> None:
"""Call a service based on a config hash."""
asyncio.run_coroutine_threadsafe(
async_call_from_config(hass, config, blocking, variables, validate_config),
@ -53,8 +58,13 @@ def call_from_config(
@bind_hass
async def async_call_from_config(
hass, config, blocking=False, variables=None, validate_config=True, context=None
):
hass: HomeAssistantType,
config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
context: Optional[ha.Context] = None,
) -> None:
"""Call a service based on a config hash."""
try:
parms = async_prepare_call_from_config(hass, config, variables, validate_config)
@ -68,7 +78,12 @@ async def async_call_from_config(
@ha.callback
@bind_hass
def async_prepare_call_from_config(hass, config, variables=None, validate_config=False):
def async_prepare_call_from_config(
hass: HomeAssistantType,
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
) -> Tuple[str, str, Dict[str, Any]]:
"""Prepare to call a service based on a config hash."""
if validate_config:
try:
@ -113,7 +128,9 @@ def async_prepare_call_from_config(hass, config, variables=None, validate_config
@bind_hass
def extract_entity_ids(hass, service_call, expand_group=True):
def extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
@ -124,7 +141,12 @@ def extract_entity_ids(hass, service_call, expand_group=True):
@bind_hass
async def async_extract_entities(hass, entities, service_call, expand_group=True):
async def async_extract_entities(
hass: HomeAssistantType,
entities: Iterable[Entity],
service_call: ha.ServiceCall,
expand_group: bool = True,
) -> List[Entity]:
"""Extract a list of entity objects from a service call.
Will convert group entity ids to the entity ids it represents.
@ -158,7 +180,9 @@ async def async_extract_entities(hass, entities, service_call, expand_group=True
@bind_hass
async def async_extract_entity_ids(hass, service_call, expand_group=True):
async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
) -> Set[str]:
"""Extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents.
@ -166,7 +190,7 @@ async def async_extract_entity_ids(hass, service_call, expand_group=True):
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
area_ids = service_call.data.get(ATTR_AREA_ID)
extracted = set()
extracted: Set[str] = set()
if entity_ids in (None, ENTITY_MATCH_NONE) and area_ids in (
None,
@ -226,7 +250,9 @@ async def _load_services_file(hass: HomeAssistantType, domain: str) -> JSON_TYPE
@bind_hass
async def async_get_all_descriptions(hass):
async def async_get_all_descriptions(
hass: HomeAssistantType,
) -> Dict[str, Dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
format_cache_key = "{}.{}".format
@ -253,7 +279,7 @@ async def async_get_all_descriptions(hass):
loaded[domain] = content
# Build response
descriptions = {}
descriptions: Dict[str, Dict[str, Any]] = {}
for domain in services:
descriptions[domain] = {}
@ -281,7 +307,9 @@ async def async_get_all_descriptions(hass):
@ha.callback
@bind_hass
def async_set_service_schema(hass, domain, service, schema):
def async_set_service_schema(
hass: HomeAssistantType, domain: str, service: str, schema: Dict[str, Any]
) -> None:
"""Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -454,7 +482,7 @@ async def _handle_entity_call(hass, entity, func, data, context):
@bind_hass
@ha.callback
def async_register_admin_service(
hass: typing.HomeAssistantType,
hass: HomeAssistantType,
domain: str,
service: str,
service_func: Callable,

View File

@ -51,7 +51,7 @@ _RE_JINJA_DELIMITERS = re.compile(r"\{%|\{\{")
@bind_hass
def attach(hass, obj):
def attach(hass: HomeAssistantType, obj: Any) -> None:
"""Recursively attach hass to all template instances in list and dict."""
if isinstance(obj, list):
for child in obj:
@ -63,7 +63,7 @@ def attach(hass, obj):
obj.hass = hass
def render_complex(value, variables=None):
def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
"""Recursive template creator helper function."""
if isinstance(value, list):
return [render_complex(item, variables) for item in value]
@ -307,11 +307,11 @@ class Template:
and self.hass == other.hass
)
def __hash__(self):
def __hash__(self) -> int:
"""Hash code for template."""
return hash(self.template)
def __repr__(self):
def __repr__(self) -> str:
"""Representation of Template."""
return 'Template("' + self.template + '")'
@ -333,7 +333,7 @@ class AllStates:
raise TemplateError(f"Invalid domain name '{name}'")
return DomainStates(self._hass, name)
def _collect_all(self):
def _collect_all(self) -> None:
render_info = self._hass.data.get(_RENDER_INFO)
if render_info is not None:
# pylint: disable=protected-access
@ -349,7 +349,7 @@ class AllStates:
)
)
def __len__(self):
def __len__(self) -> int:
"""Return number of states."""
self._collect_all()
return len(self._hass.states.async_entity_ids())
@ -359,7 +359,7 @@ class AllStates:
state = _get_state(self._hass, entity_id)
return STATE_UNKNOWN if state is None else state.state
def __repr__(self):
def __repr__(self) -> str:
"""Representation of All States."""
return "<template AllStates>"
@ -455,19 +455,21 @@ class TemplateState(State):
return f"<template {rep[1:]}"
def _collect_state(hass, entity_id):
def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
entity_collect = hass.data.get(_RENDER_INFO)
if entity_collect is not None:
# pylint: disable=protected-access
entity_collect._entities.append(entity_id)
def _wrap_state(hass, state):
def _wrap_state(
hass: HomeAssistantType, state: Optional[State]
) -> Optional[TemplateState]:
"""Wrap a state."""
return None if state is None else TemplateState(hass, state)
def _get_state(hass, entity_id):
def _get_state(hass: HomeAssistantType, entity_id: str) -> Optional[TemplateState]:
state = hass.states.get(entity_id)
if state is None:
# Only need to collect if none, if not none collect first actual
@ -477,7 +479,9 @@ def _get_state(hass, entity_id):
return _wrap_state(hass, state)
def _resolve_state(hass, entity_id_or_state):
def _resolve_state(
hass: HomeAssistantType, entity_id_or_state: Any
) -> Union[State, TemplateState, None]:
"""Return state or entity_id if given."""
if isinstance(entity_id_or_state, State):
return entity_id_or_state

View File

@ -6,10 +6,12 @@ import functools
import logging
import threading
from traceback import extract_stack
from typing import Any, Callable, Coroutine
from typing import Any, Callable, Coroutine, TypeVar
_LOGGER = logging.getLogger(__name__)
T = TypeVar("T")
def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
"""Submit a coroutine object to a given event loop.
@ -33,8 +35,8 @@ def fire_coroutine_threadsafe(coro: Coroutine, loop: AbstractEventLoop) -> None:
def run_callback_threadsafe(
loop: AbstractEventLoop, callback: Callable, *args: Any
) -> concurrent.futures.Future:
loop: AbstractEventLoop, callback: Callable[..., T], *args: Any
) -> "concurrent.futures.Future[T]":
"""Submit a callback object to a given event loop.
Return a concurrent.futures.Future to access the result.