Improvement typing core ()

* Add package typing

* Add util/location typing

* FIX: lint wrong order of imports

* Fix sometyping and add helpers/entity typing

* Mypy import trick

* Add asteroid to test requiremts to fix pylint issue

* Fix deprecated function isSet for is_set

* Add loader.py typing

* Improve typing bootstrap
pull/2629/merge
Fabian Heredia Montiel 2016-07-27 22:33:49 -05:00 committed by Paulus Schoutsen
parent 8c728d1b4e
commit ae97218582
9 changed files with 99 additions and 56 deletions

View File

@ -6,11 +6,12 @@ import os
import sys
from collections import defaultdict
from threading import RLock
from types import ModuleType
from typing import Any, Optional, Dict
import voluptuous as vol
import homeassistant.components as core_components
from homeassistant.components import group, persistent_notification
import homeassistant.config as conf_util
@ -32,7 +33,8 @@ ATTR_COMPONENT = 'component'
ERROR_LOG_FILENAME = 'home-assistant.log'
def setup_component(hass, domain, config=None):
def setup_component(hass: core.HomeAssistant, domain: str,
config: Optional[Dict]=None) -> bool:
"""Setup a component and all its dependencies."""
if domain in hass.config.components:
return True
@ -55,7 +57,8 @@ def setup_component(hass, domain, config=None):
return True
def _handle_requirements(hass, component, name):
def _handle_requirements(hass: core.HomeAssistant, component,
name: str) -> bool:
"""Install the requirements for a component."""
if hass.config.skip_pip or not hasattr(component, 'REQUIREMENTS'):
return True
@ -69,7 +72,7 @@ def _handle_requirements(hass, component, name):
return True
def _setup_component(hass, domain, config):
def _setup_component(hass: core.HomeAssistant, domain: str, config) -> bool:
"""Setup a component for Home Assistant."""
# pylint: disable=too-many-return-statements,too-many-branches
# pylint: disable=too-many-statements
@ -178,7 +181,8 @@ def _setup_component(hass, domain, config):
return True
def prepare_setup_platform(hass, config, domain, platform_name):
def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str,
platform_name: str) -> Optional[ModuleType]:
"""Load a platform and makes sure dependencies are setup."""
_ensure_loader_prepared(hass)
@ -309,7 +313,8 @@ def from_config_file(config_path: str,
skip_pip=skip_pip)
def enable_logging(hass, verbose=False, log_rotate_days=None):
def enable_logging(hass: core.HomeAssistant, verbose: bool=False,
log_rotate_days=None) -> None:
"""Setup the logging."""
logging.basicConfig(level=logging.INFO)
fmt = ("%(log_color)s%(asctime)s %(levelname)s (%(threadName)s) "
@ -360,12 +365,12 @@ def enable_logging(hass, verbose=False, log_rotate_days=None):
'Unable to setup error log %s (access denied)', err_log_path)
def _ensure_loader_prepared(hass):
def _ensure_loader_prepared(hass: core.HomeAssistant) -> None:
"""Ensure Home Assistant loader is prepared."""
if not loader.PREPARED:
loader.prepare(hass)
def _mount_local_lib_path(config_dir):
def _mount_local_lib_path(config_dir: str) -> None:
"""Add local library to Python Path."""
sys.path.insert(0, os.path.join(config_dir, 'deps'))

View File

@ -158,14 +158,14 @@ class HomeAssistant(object):
except AttributeError:
pass
try:
while not request_shutdown.isSet():
while not request_shutdown.is_set():
time.sleep(1)
except KeyboardInterrupt:
pass
finally:
self.stop()
return RESTART_EXIT_CODE if request_restart.isSet() else 0
return RESTART_EXIT_CODE if request_restart.is_set() else 0
def stop(self) -> None:
"""Stop Home Assistant and shuts down all threads."""
@ -233,7 +233,7 @@ class Event(object):
class EventBus(object):
"""Allows firing of and listening for events."""
def __init__(self, pool: util.ThreadPool):
def __init__(self, pool: util.ThreadPool) -> None:
"""Initialize a new event bus."""
self._listeners = {}
self._lock = threading.Lock()
@ -792,7 +792,7 @@ def create_timer(hass, interval=TIMER_INTERVAL):
calc_now = dt_util.utcnow
while not stop_event.isSet():
while not stop_event.is_set():
now = calc_now()
# First check checks if we are not on a second matching the
@ -816,7 +816,7 @@ def create_timer(hass, interval=TIMER_INTERVAL):
last_fired_on_second = now.second
# Event might have been set while sleeping
if not stop_event.isSet():
if not stop_event.is_set():
try:
hass.bus.fire(EVENT_TIME_CHANGED, {ATTR_NOW: now})
except HomeAssistantError:

View File

@ -1,10 +1,20 @@
"""Helper methods for components within Home Assistant."""
import re
from typing import Any, Iterable, Tuple, List, Dict
from homeassistant.const import CONF_PLATFORM
# Typing Imports and TypeAlias
# pylint: disable=using-constant-test,unused-import
if False:
from logging import Logger # NOQA
def validate_config(config, items, logger):
# pylint: disable=invalid-name
ConfigType = Dict[str, Any]
def validate_config(config: ConfigType, items: Dict, logger: 'Logger') -> bool:
"""Validate if all items are available in the configuration.
config is the general dictionary with all the configurations.
@ -29,7 +39,8 @@ def validate_config(config, items, logger):
return not errors_found
def config_per_platform(config, domain):
def config_per_platform(config: ConfigType,
domain: str) -> Iterable[Tuple[Any, Any]]:
"""Generator to break a component config into different platforms.
For example, will find 'switch', 'switch 2', 'switch 3', .. etc
@ -48,7 +59,7 @@ def config_per_platform(config, domain):
yield platform, item
def extract_domain_configs(config, domain):
def extract_domain_configs(config: ConfigType, domain: str) -> List[str]:
"""Extract keys from config for given domain name."""
pattern = re.compile(r'^{}(| .+)$'.format(domain))
return [key for key in config.keys() if pattern.match(key)]

View File

@ -2,6 +2,8 @@
import logging
import re
from typing import Any, Optional, List, Dict
from homeassistant.const import (
ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ICON,
ATTR_UNIT_OF_MEASUREMENT, DEVICE_DEFAULT_NAME, STATE_OFF, STATE_ON,
@ -10,8 +12,12 @@ from homeassistant.const import (
from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
# Entity attributes that we will overwrite
_OVERWRITE = {}
_OVERWRITE = {} # type: Dict[str, Any]
_LOGGER = logging.getLogger(__name__)
@ -19,7 +25,9 @@ _LOGGER = logging.getLogger(__name__)
ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
def generate_entity_id(entity_id_format, name, current_ids=None, hass=None):
def generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None,
hass: 'Optional[HomeAssistant]'=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower()
if current_ids is None:
@ -32,19 +40,19 @@ def generate_entity_id(entity_id_format, name, current_ids=None, hass=None):
entity_id_format.format(slugify(name)), current_ids)
def set_customize(customize):
def set_customize(customize: Dict[str, Any]) -> None:
"""Overwrite all current customize settings."""
global _OVERWRITE
_OVERWRITE = {key.lower(): val for key, val in customize.items()}
def split_entity_id(entity_id):
def split_entity_id(entity_id: str) -> List[str]:
"""Split a state entity_id into domain, object_id."""
return entity_id.split(".", 1)
def valid_entity_id(entity_id):
def valid_entity_id(entity_id: str) -> bool:
"""Test if an entity ID is a valid format."""
return ENTITY_ID_PATTERN.match(entity_id) is not None
@ -57,7 +65,7 @@ class Entity(object):
# The properties and methods here are safe to overwrite when inheriting
# this class. These may be used to customize the behavior of the entity.
@property
def should_poll(self):
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.
False if entity pushes its state to HA.
@ -65,17 +73,17 @@ class Entity(object):
return True
@property
def unique_id(self):
def unique_id(self) -> str:
"""Return an unique ID."""
return "{}.{}".format(self.__class__, id(self))
@property
def name(self):
def name(self) -> Optional[str]:
"""Return the name of the entity."""
return None
@property
def state(self):
def state(self) -> str:
"""Return the state of the entity."""
return STATE_UNKNOWN
@ -111,22 +119,22 @@ class Entity(object):
return None
@property
def hidden(self):
def hidden(self) -> bool:
"""Return True if the entity should be hidden from UIs."""
return False
@property
def available(self):
def available(self) -> bool:
"""Return True if entity is available."""
return True
@property
def assumed_state(self):
def assumed_state(self) -> bool:
"""Return True if unable to access real state of the entity."""
return False
@property
def force_update(self):
def force_update(self) -> bool:
"""Return True if state updates should be forced.
If True, a state change will be triggered anytime the state property is
@ -138,14 +146,14 @@ class Entity(object):
"""Retrieve latest state."""
pass
entity_id = None
entity_id = None # type: str
# DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation.
hass = None
hass = None # type: Optional[HomeAssistant]
def update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity.
@ -232,24 +240,24 @@ class ToggleEntity(Entity):
# pylint: disable=no-self-use
@property
def state(self):
def state(self) -> str:
"""Return the state."""
return STATE_ON if self.is_on else STATE_OFF
@property
def is_on(self):
def is_on(self) -> bool:
"""Return True if entity is on."""
raise NotImplementedError()
def turn_on(self, **kwargs):
def turn_on(self, **kwargs) -> None:
"""Turn the entity on."""
raise NotImplementedError()
def turn_off(self, **kwargs):
def turn_off(self, **kwargs) -> None:
"""Turn the entity off."""
raise NotImplementedError()
def toggle(self, **kwargs):
def toggle(self, **kwargs) -> None:
"""Toggle the entity off."""
if self.is_on:
self.turn_off(**kwargs)

View File

@ -16,21 +16,30 @@ import os
import pkgutil
import sys
from types import ModuleType
# pylint: disable=unused-import
from typing import Optional, Sequence, Set, Dict # NOQA
from homeassistant.const import PLATFORM_FORMAT
from homeassistant.util import OrderedSet
# Typing imports
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
PREPARED = False
# List of available components
AVAILABLE_COMPONENTS = []
AVAILABLE_COMPONENTS = [] # type: List[str]
# Dict of loaded components mapped name => module
_COMPONENT_CACHE = {}
_COMPONENT_CACHE = {} # type: Dict[str, ModuleType]
_LOGGER = logging.getLogger(__name__)
def prepare(hass):
def prepare(hass: 'HomeAssistant'):
"""Prepare the loading of components."""
global PREPARED # pylint: disable=global-statement
@ -71,19 +80,19 @@ def prepare(hass):
PREPARED = True
def set_component(comp_name, component):
def set_component(comp_name: str, component: ModuleType) -> None:
"""Set a component in the cache."""
_check_prepared()
_COMPONENT_CACHE[comp_name] = component
def get_platform(domain, platform):
def get_platform(domain: str, platform: str) -> Optional[ModuleType]:
"""Try to load specified platform."""
return get_component(PLATFORM_FORMAT.format(domain, platform))
def get_component(comp_name):
def get_component(comp_name) -> Optional[ModuleType]:
"""Try to load specified component.
Looks in config dir first, then built-in components.
@ -148,7 +157,7 @@ def get_component(comp_name):
return None
def load_order_components(components):
def load_order_components(components: Sequence[str]) -> OrderedSet:
"""Take in a list of components we want to load.
- filters out components we cannot load
@ -178,7 +187,7 @@ def load_order_components(components):
return load_order
def load_order_component(comp_name):
def load_order_component(comp_name: str) -> OrderedSet:
"""Return an OrderedSet of components in the correct order of loading.
Raises HomeAssistantError if a circular dependency is detected.
@ -187,7 +196,8 @@ def load_order_component(comp_name):
return _load_order_component(comp_name, OrderedSet(), set())
def _load_order_component(comp_name, load_order, loading):
def _load_order_component(comp_name: str, load_order: OrderedSet,
loading: Set) -> OrderedSet:
"""Recursive function to get load order of components."""
component = get_component(comp_name)
@ -224,7 +234,7 @@ def _load_order_component(comp_name, load_order, loading):
return load_order
def _check_prepared():
def _check_prepared() -> None:
"""Issue a warning if loader.prepare() has never been called."""
if not PREPARED:
_LOGGER.warning((

View File

@ -12,7 +12,7 @@ import string
from functools import wraps
from types import MappingProxyType
from typing import Any
from typing import Any, Sequence
from .dt import as_local, utcnow
@ -31,7 +31,7 @@ def sanitize_path(path):
return RE_SANITIZE_PATH.sub("", path)
def slugify(text):
def slugify(text: str) -> str:
"""Slugify a given text."""
text = text.lower().replace(" ", "_")
@ -59,17 +59,18 @@ def convert(value, to_type, default=None):
return default
def ensure_unique_string(preferred_string, current_strings):
def ensure_unique_string(preferred_string: str,
current_strings: Sequence[str]) -> str:
"""Return a string that is not present in current_strings.
If preferred string exists will append _2, _3, ..
"""
test_string = preferred_string
current_strings = set(current_strings)
current_strings_set = set(current_strings)
tries = 1
while test_string in current_strings:
while test_string in current_strings_set:
tries += 1
test_string = "{}_{}".format(preferred_string, tries)

View File

@ -5,8 +5,11 @@ detect_location_info and elevation are mocked by default during tests.
"""
import collections
import math
from typing import Any, Optional, Tuple, Dict
import requests
ELEVATION_URL = 'http://maps.googleapis.com/maps/api/elevation/json'
FREEGEO_API = 'https://freegeoip.io/json/'
IP_API = 'http://ip-api.com/json'
@ -81,7 +84,8 @@ def elevation(latitude, longitude):
# Source: https://github.com/maurycyp/vincenty
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
# pylint: disable=too-many-locals, invalid-name, unused-variable
def vincenty(point1, point2, miles=False):
def vincenty(point1: Tuple[float, float], point2: Tuple[float, float],
miles: bool=False) -> Optional[float]:
"""
Vincenty formula (inverse method) to calculate the distance.
@ -148,7 +152,7 @@ def vincenty(point1, point2, miles=False):
return round(s, 6)
def _get_freegeoip():
def _get_freegeoip() -> Optional[Dict[str, Any]]:
"""Query freegeoip.io for location data."""
try:
raw_info = requests.get(FREEGEO_API, timeout=5).json()
@ -169,7 +173,7 @@ def _get_freegeoip():
}
def _get_ip_api():
def _get_ip_api() -> Optional[Dict[str, Any]]:
"""Query ip-api.com for location data."""
try:
raw_info = requests.get(IP_API, timeout=5).json()

View File

@ -6,13 +6,16 @@ import sys
import threading
from urllib.parse import urlparse
from typing import Optional
import pkg_resources
_LOGGER = logging.getLogger(__name__)
INSTALL_LOCK = threading.Lock()
def install_package(package, upgrade=True, target=None):
def install_package(package: str, upgrade: bool=True,
target: Optional[str]=None) -> bool:
"""Install a package on PyPi. Accepts pip compatible package strings.
Return boolean if install successful.
@ -36,7 +39,7 @@ def install_package(package, upgrade=True, target=None):
return False
def check_package_exists(package, lib_dir):
def check_package_exists(package: str, lib_dir: str) -> bool:
"""Check if a package is installed globally or in lib_dir.
Returns True when the requirement is met.

View File

@ -1,5 +1,6 @@
flake8>=2.6.0
pylint>=1.5.6
astroid>=1.4.8
coveralls>=1.1
pytest>=2.9.2
pytest-cov>=2.2.1