From 0377338a81da8b1716253c8b019c44234930ea9a Mon Sep 17 00:00:00 2001 From: Fabian Heredia Montiel Date: Sun, 7 Aug 2016 18:26:35 -0500 Subject: [PATCH] Improvement typing (#2735) * Fix: Circular dependencies of internal files * Change: dt.date for Date and dt.datetime for DateTime * Use NewType if available * FIX: Wrong version test * Remove: Date and DateTime types due to error * Change to HomeAssistantType * General Improvement of Typing * Improve typing config_validation * Improve typing script * General Typing Improvements * Improve NewType check * Improve typing db_migrator * Improve util/__init__ typing * Improve helpers/location typing * Regroup imports and remove pylint: disable=ungrouped-imports * General typing improvements --- homeassistant/__main__.py | 2 +- homeassistant/bootstrap.py | 4 ++- homeassistant/config.py | 5 ++- homeassistant/core.py | 20 +++++++----- homeassistant/helpers/condition.py | 20 +++++++----- homeassistant/helpers/config_validation.py | 25 +++++++++------ homeassistant/helpers/entity.py | 8 ++--- homeassistant/helpers/event_decorators.py | 6 +++- homeassistant/helpers/location.py | 7 ++-- homeassistant/helpers/script.py | 16 +++++++--- homeassistant/helpers/service.py | 7 +++- homeassistant/helpers/typing.py | 37 +++++++++++++++++++--- homeassistant/remote.py | 11 ++++--- homeassistant/scripts/db_migrator.py | 14 +++++--- homeassistant/util/__init__.py | 12 ++++--- homeassistant/util/dt.py | 11 ++++--- 16 files changed, 139 insertions(+), 66 deletions(-) diff --git a/homeassistant/__main__.py b/homeassistant/__main__.py index fb1594d5b3f..39a18feb1f2 100644 --- a/homeassistant/__main__.py +++ b/homeassistant/__main__.py @@ -274,7 +274,7 @@ def try_to_restart() -> None: # thread left (which is us). Nothing we really do with it, but it might be # useful when debugging shutdown/restart issues. try: - nthreads = sum(thread.isAlive() and not thread.isDaemon() + nthreads = sum(thread.is_alive() and not thread.daemon for thread in threading.enumerate()) if nthreads > 1: sys.stderr.write( diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index c62fe9e7d6b..1f0ac9eb103 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -12,6 +12,8 @@ from typing import Any, Optional, Dict import voluptuous as vol +from homeassistant.helpers.typing import HomeAssistantType + import homeassistant.components as core_components from homeassistant.components import group, persistent_notification import homeassistant.config as conf_util @@ -216,7 +218,7 @@ def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str, # pylint: disable=too-many-branches, too-many-statements, too-many-arguments def from_config_dict(config: Dict[str, Any], - hass: Optional[core.HomeAssistant]=None, + hass: Optional[HomeAssistantType]=None, config_dir: Optional[str]=None, enable_log: bool=True, verbose: bool=False, diff --git a/homeassistant/config.py b/homeassistant/config.py index 53c43be7c17..a614f139e2f 100644 --- a/homeassistant/config.py +++ b/homeassistant/config.py @@ -4,6 +4,9 @@ import os import shutil from types import MappingProxyType +# pylint: disable=unused-import +from typing import Any, Tuple # NOQA + import voluptuous as vol from homeassistant.const import ( @@ -37,7 +40,7 @@ DEFAULT_CORE_CONFIG = ( CONF_UNIT_SYSTEM_IMPERIAL)), (CONF_TIME_ZONE, 'UTC', 'time_zone', 'Pick yours from here: http://en.wiki' 'pedia.org/wiki/List_of_tz_database_time_zones'), -) +) # type: Tuple[Tuple[str, Any, Any, str], ...] DEFAULT_CONFIG = """ # Show links to resources in log and frontend introduction: diff --git a/homeassistant/core.py b/homeassistant/core.py index 8ce381b072f..107008216da 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -14,9 +14,13 @@ import threading import time from types import MappingProxyType -from typing import Any, Callable +# pylint: disable=unused-import +from typing import Optional, Any, Callable # NOQA + import voluptuous as vol +from homeassistant.helpers.typing import UnitSystemType # NOQA + import homeassistant.util as util import homeassistant.util.dt as dt_util import homeassistant.util.location as location @@ -713,15 +717,15 @@ class Config(object): # pylint: disable=too-many-instance-attributes def __init__(self): """Initialize a new config object.""" - self.latitude = None - self.longitude = None - self.elevation = None - self.location_name = None - self.time_zone = None - self.units = METRIC_SYSTEM + self.latitude = None # type: Optional[float] + self.longitude = None # type: Optional[float] + self.elevation = None # type: Optional[int] + self.location_name = None # type: Optional[str] + self.time_zone = None # type: Optional[str] + self.units = METRIC_SYSTEM # type: UnitSystemType # If True, pip install is skipped for requirements on startup - self.skip_pip = False + self.skip_pip = False # type: bool # List of loaded components self.components = [] diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index e4335e2f2e4..791405f6da1 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -3,6 +3,8 @@ from datetime import timedelta import logging import sys +from homeassistant.helpers.typing import ConfigType, HomeAssistantType + from homeassistant.components import ( zone as zone_cmp, sun as sun_cmp) from homeassistant.const import ( @@ -21,7 +23,7 @@ FROM_CONFIG_FORMAT = '{}_from_config' _LOGGER = logging.getLogger(__name__) -def from_config(config, config_validation=True): +def from_config(config: ConfigType, config_validation: bool=True): """Turn a condition configuration into a method.""" factory = getattr( sys.modules[__name__], @@ -34,13 +36,14 @@ def from_config(config, config_validation=True): return factory(config, config_validation) -def and_from_config(config, config_validation=True): +def and_from_config(config: ConfigType, config_validation: bool=True): """Create multi condition matcher using 'AND'.""" if config_validation: config = cv.AND_CONDITION_SCHEMA(config) checks = [from_config(entry) for entry in config['conditions']] - def if_and_condition(hass, variables=None): + def if_and_condition(hass: HomeAssistantType, + variables=None) -> bool: """Test and condition.""" for check in checks: try: @@ -55,13 +58,14 @@ def and_from_config(config, config_validation=True): return if_and_condition -def or_from_config(config, config_validation=True): +def or_from_config(config: ConfigType, config_validation: bool=True): """Create multi condition matcher using 'OR'.""" if config_validation: config = cv.OR_CONDITION_SCHEMA(config) checks = [from_config(entry) for entry in config['conditions']] - def if_or_condition(hass, variables=None): + def if_or_condition(hass: HomeAssistantType, + variables=None) -> bool: """Test and condition.""" for check in checks: try: @@ -76,8 +80,8 @@ def or_from_config(config, config_validation=True): # pylint: disable=too-many-arguments -def numeric_state(hass, entity, below=None, above=None, value_template=None, - variables=None): +def numeric_state(hass: HomeAssistantType, entity, below=None, above=None, + value_template=None, variables=None): """Test a numeric state condition.""" if isinstance(entity, str): entity = hass.states.get(entity) @@ -93,7 +97,7 @@ def numeric_state(hass, entity, below=None, above=None, value_template=None, try: value = render(hass, value_template, variables) except TemplateError as ex: - _LOGGER.error(ex) + _LOGGER.error("Template error: %s", ex) return False try: diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 314d886db70..a9b965930ae 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -1,6 +1,8 @@ """Helpers for config validation using voluptuous.""" from datetime import timedelta +from typing import Any, Union, TypeVar, Callable, Sequence, List, Dict + import jinja2 import voluptuous as vol @@ -28,12 +30,15 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180), msg='invalid longitude') sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE)) +# typing typevar +T = TypeVar('T') + # Adapted from: # https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666 -def has_at_least_one_key(*keys): +def has_at_least_one_key(*keys: str) -> Callable: """Validator that at least one key exists.""" - def validate(obj): + def validate(obj: Dict) -> Dict: """Test keys exist in dict.""" if not isinstance(obj, dict): raise vol.Invalid('expected dictionary') @@ -46,7 +51,7 @@ def has_at_least_one_key(*keys): return validate -def boolean(value): +def boolean(value: Any) -> bool: """Validate and coerce a boolean value.""" if isinstance(value, str): value = value.lower() @@ -63,12 +68,12 @@ def isfile(value): return vol.IsFile('not a file')(value) -def ensure_list(value): +def ensure_list(value: Union[T, Sequence[T]]) -> List[T]: """Wrap value in list if it is not one.""" return value if isinstance(value, list) else [value] -def entity_id(value): +def entity_id(value: Any) -> str: """Validate Entity ID.""" value = string(value).lower() if valid_entity_id(value): @@ -76,7 +81,7 @@ def entity_id(value): raise vol.Invalid('Entity ID {} is an invalid entity id'.format(value)) -def entity_ids(value): +def entity_ids(value: Union[str, Sequence]) -> List[str]: """Validate Entity IDs.""" if value is None: raise vol.Invalid('Entity IDs can not be None') @@ -109,7 +114,7 @@ time_period_dict = vol.All( lambda value: timedelta(**value)) -def time_period_str(value): +def time_period_str(value: str) -> timedelta: """Validate and transform time offset.""" if isinstance(value, int): raise vol.Invalid('Make sure you wrap time values in quotes') @@ -182,7 +187,7 @@ def platform_validator(domain): return validator -def positive_timedelta(value): +def positive_timedelta(value: timedelta) -> timedelta: """Validate timedelta is positive.""" if value < timedelta(0): raise vol.Invalid('Time period should be positive') @@ -209,14 +214,14 @@ def slug(value): raise vol.Invalid('invalid slug {} (try {})'.format(value, slg)) -def string(value): +def string(value: Any) -> str: """Coerce value to string, except for None.""" if value is not None: return str(value) raise vol.Invalid('string value is None') -def temperature_unit(value): +def temperature_unit(value) -> str: """Validate and transform temperature unit.""" value = str(value).upper() if value == 'C': diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 4dac7f9f6d0..6e75e8ce59b 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -12,9 +12,7 @@ from homeassistant.const import ( from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.util import ensure_unique_string, slugify -# pylint: disable=using-constant-test,unused-import -if False: - from homeassistant.core import HomeAssistant # NOQA +from homeassistant.helpers.typing import HomeAssistantType # Entity attributes that we will overwrite _OVERWRITE = {} # type: Dict[str, Any] @@ -27,7 +25,7 @@ ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$") def generate_entity_id(entity_id_format: str, name: Optional[str], current_ids: Optional[List[str]]=None, - hass: 'Optional[HomeAssistant]'=None) -> str: + hass: Optional[HomeAssistantType]=None) -> str: """Generate a unique entity ID based on given entity IDs or used IDs.""" name = (name or DEVICE_DEFAULT_NAME).lower() if current_ids is None: @@ -153,7 +151,7 @@ class Entity(object): # are used to perform a very specific function. Overwriting these may # produce undesirable effects in the entity's operation. - hass = None # type: Optional[HomeAssistant] + hass = None # type: Optional[HomeAssistantType] def update_ha_state(self, force_refresh=False): """Update Home Assistant with current state of entity. diff --git a/homeassistant/helpers/event_decorators.py b/homeassistant/helpers/event_decorators.py index d4292f20a5f..29430f7055c 100644 --- a/homeassistant/helpers/event_decorators.py +++ b/homeassistant/helpers/event_decorators.py @@ -1,9 +1,13 @@ """Event Decorators for custom components.""" import functools +# pylint: disable=unused-import +from typing import Optional # NOQA +from homeassistant.helpers.typing import HomeAssistantType # NOQA + from homeassistant.helpers import event -HASS = None +HASS = None # type: Optional[HomeAssistantType] def track_state_change(entity_ids, from_state=None, to_state=None): diff --git a/homeassistant/helpers/location.py b/homeassistant/helpers/location.py index a3cdc348a24..c84d02dcb83 100644 --- a/homeassistant/helpers/location.py +++ b/homeassistant/helpers/location.py @@ -1,18 +1,21 @@ """Location helpers for Home Assistant.""" +from typing import Sequence + from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE from homeassistant.core import State from homeassistant.util import location as loc_util -def has_location(state): +def has_location(state: State) -> bool: """Test if state contains a valid location.""" return (isinstance(state, State) and isinstance(state.attributes.get(ATTR_LATITUDE), float) and isinstance(state.attributes.get(ATTR_LONGITUDE), float)) -def closest(latitude, longitude, states): +def closest(latitude: float, longitude: float, + states: Sequence[State]) -> State: """Return closest state to point.""" with_location = [state for state in states if has_location(state)] diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index bc1382ef982..132ffb30a82 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -3,8 +3,12 @@ import logging import threading from itertools import islice +from typing import Optional, Sequence + import voluptuous as vol +from homeassistant.helpers.typing import ConfigType, HomeAssistantType + import homeassistant.util.dt as date_util from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION from homeassistant.helpers.event import track_point_in_utc_time @@ -22,7 +26,8 @@ CONF_EVENT_DATA = "event_data" CONF_DELAY = "delay" -def call_from_config(hass, config, variables=None): +def call_from_config(hass: HomeAssistantType, config: ConfigType, + variables: Optional[Sequence]=None) -> None: """Call a script based on a config entry.""" Script(hass, config).run(variables) @@ -31,7 +36,8 @@ class Script(): """Representation of a script.""" # pylint: disable=too-many-instance-attributes - def __init__(self, hass, sequence, name=None, change_listener=None): + def __init__(self, hass: HomeAssistantType, sequence, name: str=None, + change_listener=None) -> None: """Initialize the script.""" self.hass = hass self.sequence = cv.SCRIPT_SCHEMA(sequence) @@ -45,11 +51,11 @@ class Script(): self._delay_listener = None @property - def is_running(self): + def is_running(self) -> bool: """Return true if script is on.""" return self._cur != -1 - def run(self, variables=None): + def run(self, variables: Optional[Sequence]=None) -> None: """Run script.""" with self._lock: if self._cur == -1: @@ -101,7 +107,7 @@ class Script(): if self._change_listener: self._change_listener() - def stop(self): + def stop(self) -> None: """Stop running script.""" with self._lock: if self._cur == -1: diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 95dce9516de..ffad2997a43 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -2,15 +2,20 @@ import functools import logging +# pylint: disable=unused-import +from typing import Optional # NOQA + import voluptuous as vol +from homeassistant.helpers.typing import HomeAssistantType # NOQA + from homeassistant.const import ATTR_ENTITY_ID from homeassistant.exceptions import TemplateError from homeassistant.helpers import template from homeassistant.loader import get_component import homeassistant.helpers.config_validation as cv -HASS = None +HASS = None # type: Optional[HomeAssistantType] CONF_SERVICE = 'service' CONF_SERVICE_TEMPLATE = 'service_template' diff --git a/homeassistant/helpers/typing.py b/homeassistant/helpers/typing.py index 67bbbbf9600..473e28995b2 100644 --- a/homeassistant/helpers/typing.py +++ b/homeassistant/helpers/typing.py @@ -1,12 +1,39 @@ """Typing Helpers for Home-Assistant.""" from typing import Dict, Any -import homeassistant.core +# NOTE: NewType added to typing in 3.5.2 in June, 2016; Since 3.5.2 includes +# security fixes everyone on 3.5 should upgrade "soon" +try: + from typing import NewType +except ImportError: + NewType = None + +# HACK: mypy/pytype will import, other interpreters will not; this is to avoid +# circular dependencies where the type is needed. +# All homeassistant types should be imported this way. +# Documentation +# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles +# pylint: disable=using-constant-test,unused-import +if False: + from homeassistant.core import HomeAssistant # NOQA + from homeassistant.helpers.unit_system import UnitSystem # NOQA +# ENDHACK # pylint: disable=invalid-name +if NewType: + ConfigType = NewType('ConfigType', Dict[str, Any]) + HomeAssistantType = NewType('HomeAssistantType', 'HomeAssistant') + UnitSystemType = NewType('UnitSystemType', 'UnitSystem') -ConfigType = Dict[str, Any] -HomeAssistantType = homeassistant.core.HomeAssistant + # Custom type for recorder Queries + QueryType = NewType('QueryType', Any) -# Custom type for recorder Queries -QueryType = Any +# Duplicates for 3.5.1 +# pylint: disable=invalid-name +else: + ConfigType = Dict[str, Any] # type: ignore + HomeAssistantType = 'HomeAssistant' # type: ignore + UnitSystemType = 'UnitSystemType' # type: ignore + + # Custom type for recorder Queries + QueryType = Any # type: ignore diff --git a/homeassistant/remote.py b/homeassistant/remote.py index 409d276caf5..8e62cdd044a 100644 --- a/homeassistant/remote.py +++ b/homeassistant/remote.py @@ -15,6 +15,8 @@ import time import threading import urllib.parse +from typing import Optional + import requests import homeassistant.bootstrap as bootstrap @@ -42,7 +44,7 @@ class APIStatus(enum.Enum): CANNOT_CONNECT = "cannot_connect" UNKNOWN = "unknown" - def __str__(self): + def __str__(self) -> str: """Return the state.""" return self.value @@ -51,7 +53,8 @@ class API(object): """Object to pass around Home Assistant API location and credentials.""" # pylint: disable=too-few-public-methods - def __init__(self, host, api_password=None, port=None, use_ssl=False): + def __init__(self, host: str, api_password: Optional[str]=None, + port: Optional[int]=None, use_ssl: bool=False) -> None: """Initalize the API.""" self.host = host self.port = port or SERVER_PORT @@ -68,7 +71,7 @@ class API(object): if api_password is not None: self._headers[HTTP_HEADER_HA_AUTH] = api_password - def validate_api(self, force_validate=False): + def validate_api(self, force_validate: bool=False) -> bool: """Test if we can communicate with the API.""" if self.status is None or force_validate: self.status = validate_api(self) @@ -100,7 +103,7 @@ class API(object): _LOGGER.exception(error) raise HomeAssistantError(error) - def __repr__(self): + def __repr__(self) -> str: """Return the representation of the API.""" return "API({}, {}, {})".format( self.host, self.api_password, self.port) diff --git a/homeassistant/scripts/db_migrator.py b/homeassistant/scripts/db_migrator.py index 3ce26014c05..2b38ef52cb2 100644 --- a/homeassistant/scripts/db_migrator.py +++ b/homeassistant/scripts/db_migrator.py @@ -4,6 +4,10 @@ import argparse import os.path import sqlite3 import sys + +from datetime import datetime +from typing import Optional + try: from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -16,7 +20,7 @@ import homeassistant.config as config_util import homeassistant.util.dt as dt_util -def ts_to_dt(timestamp): +def ts_to_dt(timestamp: Optional[float]) -> Optional[datetime]: """Turn a datetime into an integer for in the DB.""" if timestamp is None: return None @@ -26,8 +30,8 @@ def ts_to_dt(timestamp): # Based on code at # http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console # pylint: disable=too-many-arguments -def print_progress(iteration, total, prefix='', suffix='', decimals=2, - bar_length=68): +def print_progress(iteration: int, total: int, prefix: str='', suffix: str='', + decimals: int=2, bar_length: int=68) -> None: """Print progress bar. Call in a loop to create terminal progress bar @@ -49,7 +53,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=2, print("\n") -def run(args): +def run(args) -> int: """The actual script body.""" # pylint: disable=too-many-locals,invalid-name,too-many-statements parser = argparse.ArgumentParser( @@ -75,7 +79,7 @@ def run(args): args = parser.parse_args() - config_dir = os.path.join(os.getcwd(), args.config) + config_dir = os.path.join(os.getcwd(), args.config) # type: str # Test if configuration directory exists if not os.path.isdir(config_dir): diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index e0f856c7444..032588f6cba 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -12,21 +12,24 @@ import string from functools import wraps from types import MappingProxyType -from typing import Any, Sequence +from typing import Any, Optional, TypeVar, Callable, Sequence from .dt import as_local, utcnow +T = TypeVar('T') +U = TypeVar('U') + RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') RE_SLUGIFY = re.compile(r'[^a-z0-9_]+') -def sanitize_filename(filename): +def sanitize_filename(filename: str) -> str: r"""Sanitize a filename by removing .. / and \\.""" return RE_SANITIZE_FILENAME.sub("", filename) -def sanitize_path(path): +def sanitize_path(path: str) -> str: """Sanitize a path by removing ~ and ..""" return RE_SANITIZE_PATH.sub("", path) @@ -50,7 +53,8 @@ def repr_helper(inp: Any) -> str: return str(inp) -def convert(value, to_type, default=None): +def convert(value: T, to_type: Callable[[T], U], + default: Optional[U]=None) -> Optional[U]: """Convert value to to_type, returns default if fails.""" try: return default if value is None else to_type(value) diff --git a/homeassistant/util/dt.py b/homeassistant/util/dt.py index a5724ee90e1..282ddf9bb8c 100644 --- a/homeassistant/util/dt.py +++ b/homeassistant/util/dt.py @@ -8,7 +8,7 @@ from typing import Any, Union, Optional, Tuple # NOQA import pytz DATE_STR_FORMAT = "%Y-%m-%d" -UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC +UTC = DEFAULT_TIME_ZONE = pytz.utc # type: dt.tzinfo # Copyright (c) Django Software Foundation and individual contributors. @@ -93,11 +93,10 @@ def start_of_local_day(dt_or_d: Union[dt.date, dt.datetime]=None) -> dt.datetime: """Return local datetime object of start of day from date or datetime.""" if dt_or_d is None: - dt_or_d = now().date() + date = now().date() # type: dt.date elif isinstance(dt_or_d, dt.datetime): - dt_or_d = dt_or_d.date() - - return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(dt_or_d, dt.time())) + date = dt_or_d.date() + return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(date, dt.time())) # Copyright (c) Django Software Foundation and individual contributors. @@ -118,6 +117,8 @@ def parse_datetime(dt_str: str) -> dt.datetime: if kws['microsecond']: kws['microsecond'] = kws['microsecond'].ljust(6, '0') tzinfo_str = kws.pop('tzinfo') + + tzinfo = None # type: Optional[dt.tzinfo] if tzinfo_str == 'Z': tzinfo = UTC elif tzinfo_str is not None: