Improvement typing (#2735)

* Fix: Circular dependencies of internal files

* Change: dt.date for Date and dt.datetime for DateTime

* Use NewType if available

* FIX: Wrong version test

* Remove: Date and DateTime types due to error

* Change to HomeAssistantType

* General Improvement of Typing

* Improve typing config_validation

* Improve typing script

* General Typing Improvements

* Improve NewType check

* Improve typing db_migrator

* Improve util/__init__ typing

* Improve helpers/location typing

* Regroup imports and remove pylint: disable=ungrouped-imports

* General typing improvements
pull/2745/head
Fabian Heredia Montiel 2016-08-07 18:26:35 -05:00 committed by Paulus Schoutsen
parent a3ca3e878b
commit 0377338a81
16 changed files with 139 additions and 66 deletions

View File

@ -274,7 +274,7 @@ def try_to_restart() -> None:
# thread left (which is us). Nothing we really do with it, but it might be
# useful when debugging shutdown/restart issues.
try:
nthreads = sum(thread.isAlive() and not thread.isDaemon()
nthreads = sum(thread.is_alive() and not thread.daemon
for thread in threading.enumerate())
if nthreads > 1:
sys.stderr.write(

View File

@ -12,6 +12,8 @@ from typing import Any, Optional, Dict
import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType
import homeassistant.components as core_components
from homeassistant.components import group, persistent_notification
import homeassistant.config as conf_util
@ -216,7 +218,7 @@ def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str,
# pylint: disable=too-many-branches, too-many-statements, too-many-arguments
def from_config_dict(config: Dict[str, Any],
hass: Optional[core.HomeAssistant]=None,
hass: Optional[HomeAssistantType]=None,
config_dir: Optional[str]=None,
enable_log: bool=True,
verbose: bool=False,

View File

@ -4,6 +4,9 @@ import os
import shutil
from types import MappingProxyType
# pylint: disable=unused-import
from typing import Any, Tuple # NOQA
import voluptuous as vol
from homeassistant.const import (
@ -37,7 +40,7 @@ DEFAULT_CORE_CONFIG = (
CONF_UNIT_SYSTEM_IMPERIAL)),
(CONF_TIME_ZONE, 'UTC', 'time_zone', 'Pick yours from here: http://en.wiki'
'pedia.org/wiki/List_of_tz_database_time_zones'),
)
) # type: Tuple[Tuple[str, Any, Any, str], ...]
DEFAULT_CONFIG = """
# Show links to resources in log and frontend
introduction:

View File

@ -14,9 +14,13 @@ import threading
import time
from types import MappingProxyType
from typing import Any, Callable
# pylint: disable=unused-import
from typing import Optional, Any, Callable # NOQA
import voluptuous as vol
from homeassistant.helpers.typing import UnitSystemType # NOQA
import homeassistant.util as util
import homeassistant.util.dt as dt_util
import homeassistant.util.location as location
@ -713,15 +717,15 @@ class Config(object):
# pylint: disable=too-many-instance-attributes
def __init__(self):
"""Initialize a new config object."""
self.latitude = None
self.longitude = None
self.elevation = None
self.location_name = None
self.time_zone = None
self.units = METRIC_SYSTEM
self.latitude = None # type: Optional[float]
self.longitude = None # type: Optional[float]
self.elevation = None # type: Optional[int]
self.location_name = None # type: Optional[str]
self.time_zone = None # type: Optional[str]
self.units = METRIC_SYSTEM # type: UnitSystemType
# If True, pip install is skipped for requirements on startup
self.skip_pip = False
self.skip_pip = False # type: bool
# List of loaded components
self.components = []

View File

@ -3,6 +3,8 @@ from datetime import timedelta
import logging
import sys
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.components import (
zone as zone_cmp, sun as sun_cmp)
from homeassistant.const import (
@ -21,7 +23,7 @@ FROM_CONFIG_FORMAT = '{}_from_config'
_LOGGER = logging.getLogger(__name__)
def from_config(config, config_validation=True):
def from_config(config: ConfigType, config_validation: bool=True):
"""Turn a condition configuration into a method."""
factory = getattr(
sys.modules[__name__],
@ -34,13 +36,14 @@ def from_config(config, config_validation=True):
return factory(config, config_validation)
def and_from_config(config, config_validation=True):
def and_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'AND'."""
if config_validation:
config = cv.AND_CONDITION_SCHEMA(config)
checks = [from_config(entry) for entry in config['conditions']]
def if_and_condition(hass, variables=None):
def if_and_condition(hass: HomeAssistantType,
variables=None) -> bool:
"""Test and condition."""
for check in checks:
try:
@ -55,13 +58,14 @@ def and_from_config(config, config_validation=True):
return if_and_condition
def or_from_config(config, config_validation=True):
def or_from_config(config: ConfigType, config_validation: bool=True):
"""Create multi condition matcher using 'OR'."""
if config_validation:
config = cv.OR_CONDITION_SCHEMA(config)
checks = [from_config(entry) for entry in config['conditions']]
def if_or_condition(hass, variables=None):
def if_or_condition(hass: HomeAssistantType,
variables=None) -> bool:
"""Test and condition."""
for check in checks:
try:
@ -76,8 +80,8 @@ def or_from_config(config, config_validation=True):
# pylint: disable=too-many-arguments
def numeric_state(hass, entity, below=None, above=None, value_template=None,
variables=None):
def numeric_state(hass: HomeAssistantType, entity, below=None, above=None,
value_template=None, variables=None):
"""Test a numeric state condition."""
if isinstance(entity, str):
entity = hass.states.get(entity)
@ -93,7 +97,7 @@ def numeric_state(hass, entity, below=None, above=None, value_template=None,
try:
value = render(hass, value_template, variables)
except TemplateError as ex:
_LOGGER.error(ex)
_LOGGER.error("Template error: %s", ex)
return False
try:

View File

@ -1,6 +1,8 @@
"""Helpers for config validation using voluptuous."""
from datetime import timedelta
from typing import Any, Union, TypeVar, Callable, Sequence, List, Dict
import jinja2
import voluptuous as vol
@ -28,12 +30,15 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180),
msg='invalid longitude')
sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE))
# typing typevar
T = TypeVar('T')
# Adapted from:
# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666
def has_at_least_one_key(*keys):
def has_at_least_one_key(*keys: str) -> Callable:
"""Validator that at least one key exists."""
def validate(obj):
def validate(obj: Dict) -> Dict:
"""Test keys exist in dict."""
if not isinstance(obj, dict):
raise vol.Invalid('expected dictionary')
@ -46,7 +51,7 @@ def has_at_least_one_key(*keys):
return validate
def boolean(value):
def boolean(value: Any) -> bool:
"""Validate and coerce a boolean value."""
if isinstance(value, str):
value = value.lower()
@ -63,12 +68,12 @@ def isfile(value):
return vol.IsFile('not a file')(value)
def ensure_list(value):
def ensure_list(value: Union[T, Sequence[T]]) -> List[T]:
"""Wrap value in list if it is not one."""
return value if isinstance(value, list) else [value]
def entity_id(value):
def entity_id(value: Any) -> str:
"""Validate Entity ID."""
value = string(value).lower()
if valid_entity_id(value):
@ -76,7 +81,7 @@ def entity_id(value):
raise vol.Invalid('Entity ID {} is an invalid entity id'.format(value))
def entity_ids(value):
def entity_ids(value: Union[str, Sequence]) -> List[str]:
"""Validate Entity IDs."""
if value is None:
raise vol.Invalid('Entity IDs can not be None')
@ -109,7 +114,7 @@ time_period_dict = vol.All(
lambda value: timedelta(**value))
def time_period_str(value):
def time_period_str(value: str) -> timedelta:
"""Validate and transform time offset."""
if isinstance(value, int):
raise vol.Invalid('Make sure you wrap time values in quotes')
@ -182,7 +187,7 @@ def platform_validator(domain):
return validator
def positive_timedelta(value):
def positive_timedelta(value: timedelta) -> timedelta:
"""Validate timedelta is positive."""
if value < timedelta(0):
raise vol.Invalid('Time period should be positive')
@ -209,14 +214,14 @@ def slug(value):
raise vol.Invalid('invalid slug {} (try {})'.format(value, slg))
def string(value):
def string(value: Any) -> str:
"""Coerce value to string, except for None."""
if value is not None:
return str(value)
raise vol.Invalid('string value is None')
def temperature_unit(value):
def temperature_unit(value) -> str:
"""Validate and transform temperature unit."""
value = str(value).upper()
if value == 'C':

View File

@ -12,9 +12,7 @@ from homeassistant.const import (
from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.helpers.typing import HomeAssistantType
# Entity attributes that we will overwrite
_OVERWRITE = {} # type: Dict[str, Any]
@ -27,7 +25,7 @@ ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
def generate_entity_id(entity_id_format: str, name: Optional[str],
current_ids: Optional[List[str]]=None,
hass: 'Optional[HomeAssistant]'=None) -> str:
hass: Optional[HomeAssistantType]=None) -> str:
"""Generate a unique entity ID based on given entity IDs or used IDs."""
name = (name or DEVICE_DEFAULT_NAME).lower()
if current_ids is None:
@ -153,7 +151,7 @@ class Entity(object):
# are used to perform a very specific function. Overwriting these may
# produce undesirable effects in the entity's operation.
hass = None # type: Optional[HomeAssistant]
hass = None # type: Optional[HomeAssistantType]
def update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity.

View File

@ -1,9 +1,13 @@
"""Event Decorators for custom components."""
import functools
# pylint: disable=unused-import
from typing import Optional # NOQA
from homeassistant.helpers.typing import HomeAssistantType # NOQA
from homeassistant.helpers import event
HASS = None
HASS = None # type: Optional[HomeAssistantType]
def track_state_change(entity_ids, from_state=None, to_state=None):

View File

@ -1,18 +1,21 @@
"""Location helpers for Home Assistant."""
from typing import Sequence
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State
from homeassistant.util import location as loc_util
def has_location(state):
def has_location(state: State) -> bool:
"""Test if state contains a valid location."""
return (isinstance(state, State) and
isinstance(state.attributes.get(ATTR_LATITUDE), float) and
isinstance(state.attributes.get(ATTR_LONGITUDE), float))
def closest(latitude, longitude, states):
def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State:
"""Return closest state to point."""
with_location = [state for state in states if has_location(state)]

View File

@ -3,8 +3,12 @@ import logging
import threading
from itertools import islice
from typing import Optional, Sequence
import voluptuous as vol
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
import homeassistant.util.dt as date_util
from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION
from homeassistant.helpers.event import track_point_in_utc_time
@ -22,7 +26,8 @@ CONF_EVENT_DATA = "event_data"
CONF_DELAY = "delay"
def call_from_config(hass, config, variables=None):
def call_from_config(hass: HomeAssistantType, config: ConfigType,
variables: Optional[Sequence]=None) -> None:
"""Call a script based on a config entry."""
Script(hass, config).run(variables)
@ -31,7 +36,8 @@ class Script():
"""Representation of a script."""
# pylint: disable=too-many-instance-attributes
def __init__(self, hass, sequence, name=None, change_listener=None):
def __init__(self, hass: HomeAssistantType, sequence, name: str=None,
change_listener=None) -> None:
"""Initialize the script."""
self.hass = hass
self.sequence = cv.SCRIPT_SCHEMA(sequence)
@ -45,11 +51,11 @@ class Script():
self._delay_listener = None
@property
def is_running(self):
def is_running(self) -> bool:
"""Return true if script is on."""
return self._cur != -1
def run(self, variables=None):
def run(self, variables: Optional[Sequence]=None) -> None:
"""Run script."""
with self._lock:
if self._cur == -1:
@ -101,7 +107,7 @@ class Script():
if self._change_listener:
self._change_listener()
def stop(self):
def stop(self) -> None:
"""Stop running script."""
with self._lock:
if self._cur == -1:

View File

@ -2,15 +2,20 @@
import functools
import logging
# pylint: disable=unused-import
from typing import Optional # NOQA
import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType # NOQA
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import template
from homeassistant.loader import get_component
import homeassistant.helpers.config_validation as cv
HASS = None
HASS = None # type: Optional[HomeAssistantType]
CONF_SERVICE = 'service'
CONF_SERVICE_TEMPLATE = 'service_template'

View File

@ -1,12 +1,39 @@
"""Typing Helpers for Home-Assistant."""
from typing import Dict, Any
import homeassistant.core
# NOTE: NewType added to typing in 3.5.2 in June, 2016; Since 3.5.2 includes
# security fixes everyone on 3.5 should upgrade "soon"
try:
from typing import NewType
except ImportError:
NewType = None
# HACK: mypy/pytype will import, other interpreters will not; this is to avoid
# circular dependencies where the type is needed.
# All homeassistant types should be imported this way.
# Documentation
# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
# pylint: disable=using-constant-test,unused-import
if False:
from homeassistant.core import HomeAssistant # NOQA
from homeassistant.helpers.unit_system import UnitSystem # NOQA
# ENDHACK
# pylint: disable=invalid-name
if NewType:
ConfigType = NewType('ConfigType', Dict[str, Any])
HomeAssistantType = NewType('HomeAssistantType', 'HomeAssistant')
UnitSystemType = NewType('UnitSystemType', 'UnitSystem')
ConfigType = Dict[str, Any]
HomeAssistantType = homeassistant.core.HomeAssistant
# Custom type for recorder Queries
QueryType = NewType('QueryType', Any)
# Custom type for recorder Queries
QueryType = Any
# Duplicates for 3.5.1
# pylint: disable=invalid-name
else:
ConfigType = Dict[str, Any] # type: ignore
HomeAssistantType = 'HomeAssistant' # type: ignore
UnitSystemType = 'UnitSystemType' # type: ignore
# Custom type for recorder Queries
QueryType = Any # type: ignore

View File

@ -15,6 +15,8 @@ import time
import threading
import urllib.parse
from typing import Optional
import requests
import homeassistant.bootstrap as bootstrap
@ -42,7 +44,7 @@ class APIStatus(enum.Enum):
CANNOT_CONNECT = "cannot_connect"
UNKNOWN = "unknown"
def __str__(self):
def __str__(self) -> str:
"""Return the state."""
return self.value
@ -51,7 +53,8 @@ class API(object):
"""Object to pass around Home Assistant API location and credentials."""
# pylint: disable=too-few-public-methods
def __init__(self, host, api_password=None, port=None, use_ssl=False):
def __init__(self, host: str, api_password: Optional[str]=None,
port: Optional[int]=None, use_ssl: bool=False) -> None:
"""Initalize the API."""
self.host = host
self.port = port or SERVER_PORT
@ -68,7 +71,7 @@ class API(object):
if api_password is not None:
self._headers[HTTP_HEADER_HA_AUTH] = api_password
def validate_api(self, force_validate=False):
def validate_api(self, force_validate: bool=False) -> bool:
"""Test if we can communicate with the API."""
if self.status is None or force_validate:
self.status = validate_api(self)
@ -100,7 +103,7 @@ class API(object):
_LOGGER.exception(error)
raise HomeAssistantError(error)
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation of the API."""
return "API({}, {}, {})".format(
self.host, self.api_password, self.port)

View File

@ -4,6 +4,10 @@ import argparse
import os.path
import sqlite3
import sys
from datetime import datetime
from typing import Optional
try:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@ -16,7 +20,7 @@ import homeassistant.config as config_util
import homeassistant.util.dt as dt_util
def ts_to_dt(timestamp):
def ts_to_dt(timestamp: Optional[float]) -> Optional[datetime]:
"""Turn a datetime into an integer for in the DB."""
if timestamp is None:
return None
@ -26,8 +30,8 @@ def ts_to_dt(timestamp):
# Based on code at
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
# pylint: disable=too-many-arguments
def print_progress(iteration, total, prefix='', suffix='', decimals=2,
bar_length=68):
def print_progress(iteration: int, total: int, prefix: str='', suffix: str='',
decimals: int=2, bar_length: int=68) -> None:
"""Print progress bar.
Call in a loop to create terminal progress bar
@ -49,7 +53,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=2,
print("\n")
def run(args):
def run(args) -> int:
"""The actual script body."""
# pylint: disable=too-many-locals,invalid-name,too-many-statements
parser = argparse.ArgumentParser(
@ -75,7 +79,7 @@ def run(args):
args = parser.parse_args()
config_dir = os.path.join(os.getcwd(), args.config)
config_dir = os.path.join(os.getcwd(), args.config) # type: str
# Test if configuration directory exists
if not os.path.isdir(config_dir):

View File

@ -12,21 +12,24 @@ import string
from functools import wraps
from types import MappingProxyType
from typing import Any, Sequence
from typing import Any, Optional, TypeVar, Callable, Sequence
from .dt import as_local, utcnow
T = TypeVar('T')
U = TypeVar('U')
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
RE_SLUGIFY = re.compile(r'[^a-z0-9_]+')
def sanitize_filename(filename):
def sanitize_filename(filename: str) -> str:
r"""Sanitize a filename by removing .. / and \\."""
return RE_SANITIZE_FILENAME.sub("", filename)
def sanitize_path(path):
def sanitize_path(path: str) -> str:
"""Sanitize a path by removing ~ and .."""
return RE_SANITIZE_PATH.sub("", path)
@ -50,7 +53,8 @@ def repr_helper(inp: Any) -> str:
return str(inp)
def convert(value, to_type, default=None):
def convert(value: T, to_type: Callable[[T], U],
default: Optional[U]=None) -> Optional[U]:
"""Convert value to to_type, returns default if fails."""
try:
return default if value is None else to_type(value)

View File

@ -8,7 +8,7 @@ from typing import Any, Union, Optional, Tuple # NOQA
import pytz
DATE_STR_FORMAT = "%Y-%m-%d"
UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC
UTC = DEFAULT_TIME_ZONE = pytz.utc # type: dt.tzinfo
# Copyright (c) Django Software Foundation and individual contributors.
@ -93,11 +93,10 @@ def start_of_local_day(dt_or_d:
Union[dt.date, dt.datetime]=None) -> dt.datetime:
"""Return local datetime object of start of day from date or datetime."""
if dt_or_d is None:
dt_or_d = now().date()
date = now().date() # type: dt.date
elif isinstance(dt_or_d, dt.datetime):
dt_or_d = dt_or_d.date()
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(dt_or_d, dt.time()))
date = dt_or_d.date()
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(date, dt.time()))
# Copyright (c) Django Software Foundation and individual contributors.
@ -118,6 +117,8 @@ def parse_datetime(dt_str: str) -> dt.datetime:
if kws['microsecond']:
kws['microsecond'] = kws['microsecond'].ljust(6, '0')
tzinfo_str = kws.pop('tzinfo')
tzinfo = None # type: Optional[dt.tzinfo]
if tzinfo_str == 'Z':
tzinfo = UTC
elif tzinfo_str is not None: