Type Hints - Core/Utils/Helpers Part 1 (#2592)

* Fix deprecated(moved) import

* Add util/dt typing

* Green on mypy util/dt

* Fix some errors

* First part of yping util/yaml

* Add more typing to util/yaml
pull/2603/head
Fabian Heredia Montiel 2016-07-23 13:07:08 -05:00 committed by Paulus Schoutsen
parent 34ca1dac7d
commit d4f78e8552
6 changed files with 96 additions and 71 deletions

View File

@ -1,5 +1,5 @@
"""Helper methods for various modules.""" """Helper methods for various modules."""
import collections from collections.abc import MutableSet
from itertools import chain from itertools import chain
import threading import threading
import queue import queue
@ -12,6 +12,8 @@ import string
from functools import wraps from functools import wraps
from types import MappingProxyType from types import MappingProxyType
from typing import Any
from .dt import as_local, utcnow from .dt import as_local, utcnow
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
@ -36,7 +38,7 @@ def slugify(text):
return RE_SLUGIFY.sub("", text) return RE_SLUGIFY.sub("", text)
def repr_helper(inp): def repr_helper(inp: Any) -> str:
"""Help creating a more readable string representation of objects.""" """Help creating a more readable string representation of objects."""
if isinstance(inp, (dict, MappingProxyType)): if isinstance(inp, (dict, MappingProxyType)):
return ", ".join( return ", ".join(
@ -128,7 +130,7 @@ class OrderedEnum(enum.Enum):
return NotImplemented return NotImplemented
class OrderedSet(collections.MutableSet): class OrderedSet(MutableSet):
"""Ordered set taken from http://code.activestate.com/recipes/576694/.""" """Ordered set taken from http://code.activestate.com/recipes/576694/."""
def __init__(self, iterable=None): def __init__(self, iterable=None):

View File

@ -1,7 +1,8 @@
"""Color util methods.""" """Color util methods."""
import logging import logging
import math import math
# pylint: disable=unused-import
from typing import Tuple
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -36,14 +37,14 @@ def color_name_to_rgb(color_name):
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy # http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
# License: Code is given as is. Use at your own risk and discretion. # License: Code is given as is. Use at your own risk and discretion.
# pylint: disable=invalid-name # pylint: disable=invalid-name
def color_RGB_to_xy(R, G, B): def color_RGB_to_xy(iR: int, iG: int, iB: int) -> Tuple[float, float, int]:
"""Convert from RGB color to XY color.""" """Convert from RGB color to XY color."""
if R + G + B == 0: if iR + iG + iB == 0:
return 0, 0, 0 return 0.0, 0.0, 0
R = R / 255 R = iR / 255
B = B / 255 B = iB / 255
G = G / 255 G = iG / 255
# Gamma correction # Gamma correction
R = pow((R + 0.055) / (1.0 + 0.055), R = pow((R + 0.055) / (1.0 + 0.055),
@ -72,9 +73,10 @@ def color_RGB_to_xy(R, G, B):
# taken from # taken from
# https://github.com/benknight/hue-python-rgb-converter/blob/master/rgb_cie.py # https://github.com/benknight/hue-python-rgb-converter/blob/master/rgb_cie.py
# Copyright (c) 2014 Benjamin Knight / MIT License. # Copyright (c) 2014 Benjamin Knight / MIT License.
def color_xy_brightness_to_RGB(vX, vY, brightness): def color_xy_brightness_to_RGB(vX: float, vY: float,
ibrightness: int) -> Tuple[int, int, int]:
"""Convert from XYZ to RGB.""" """Convert from XYZ to RGB."""
brightness /= 255. brightness = ibrightness / 255.
if brightness == 0: if brightness == 0:
return (0, 0, 0) return (0, 0, 0)
@ -106,17 +108,18 @@ def color_xy_brightness_to_RGB(vX, vY, brightness):
if max_component > 1: if max_component > 1:
r, g, b = map(lambda x: x / max_component, [r, g, b]) r, g, b = map(lambda x: x / max_component, [r, g, b])
r, g, b = map(lambda x: int(x * 255), [r, g, b]) ir, ig, ib = map(lambda x: int(x * 255), [r, g, b])
return (r, g, b) return (ir, ig, ib)
def _match_max_scale(input_colors, output_colors): def _match_max_scale(input_colors: Tuple[int, ...],
output_colors: Tuple[int, ...]) -> Tuple[int, ...]:
"""Match the maximum value of the output to the input.""" """Match the maximum value of the output to the input."""
max_in = max(input_colors) max_in = max(input_colors)
max_out = max(output_colors) max_out = max(output_colors)
if max_out == 0: if max_out == 0:
factor = 0 factor = 0.0
else: else:
factor = max_in / max_out factor = max_in / max_out
return tuple(int(round(i * factor)) for i in output_colors) return tuple(int(round(i * factor)) for i in output_colors)
@ -176,7 +179,8 @@ def color_temperature_to_rgb(color_temperature_kelvin):
return (red, green, blue) return (red, green, blue)
def _bound(color_component, minimum=0, maximum=255): def _bound(color_component: float, minimum: float=0,
maximum: float=255) -> float:
""" """
Bound the given color component value between the given min and max values. Bound the given color component value between the given min and max values.
@ -188,7 +192,7 @@ def _bound(color_component, minimum=0, maximum=255):
return min(color_component_out, maximum) return min(color_component_out, maximum)
def _get_red(temperature): def _get_red(temperature: float) -> float:
"""Get the red component of the temperature in RGB space.""" """Get the red component of the temperature in RGB space."""
if temperature <= 66: if temperature <= 66:
return 255 return 255
@ -196,7 +200,7 @@ def _get_red(temperature):
return _bound(tmp_red) return _bound(tmp_red)
def _get_green(temperature): def _get_green(temperature: float) -> float:
"""Get the green component of the given color temp in RGB space.""" """Get the green component of the given color temp in RGB space."""
if temperature <= 66: if temperature <= 66:
green = 99.4708025861 * math.log(temperature) - 161.1195681661 green = 99.4708025861 * math.log(temperature) - 161.1195681661
@ -205,13 +209,13 @@ def _get_green(temperature):
return _bound(green) return _bound(green)
def _get_blue(tmp_internal): def _get_blue(temperature: float) -> float:
"""Get the blue component of the given color temperature in RGB space.""" """Get the blue component of the given color temperature in RGB space."""
if tmp_internal >= 66: if temperature >= 66:
return 255 return 255
if tmp_internal <= 19: if temperature <= 19:
return 0 return 0
blue = 138.5177312231 * math.log(tmp_internal - 10) - 305.0447927307 blue = 138.5177312231 * math.log(temperature - 10) - 305.0447927307
return _bound(blue) return _bound(blue)

View File

@ -2,10 +2,13 @@
import datetime as dt import datetime as dt
import re import re
# pylint: disable=unused-import
from typing import Any, Union, Optional, Tuple # NOQA
import pytz import pytz
DATE_STR_FORMAT = "%Y-%m-%d" DATE_STR_FORMAT = "%Y-%m-%d"
UTC = DEFAULT_TIME_ZONE = pytz.utc UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC
# Copyright (c) Django Software Foundation and individual contributors. # Copyright (c) Django Software Foundation and individual contributors.
@ -19,16 +22,17 @@ DATETIME_RE = re.compile(
) )
def set_default_time_zone(time_zone): def set_default_time_zone(time_zone: dt.tzinfo) -> None:
"""Set a default time zone to be used when none is specified.""" """Set a default time zone to be used when none is specified."""
global DEFAULT_TIME_ZONE # pylint: disable=global-statement global DEFAULT_TIME_ZONE # pylint: disable=global-statement
# NOTE: Remove in the future in favour of typing
assert isinstance(time_zone, dt.tzinfo) assert isinstance(time_zone, dt.tzinfo)
DEFAULT_TIME_ZONE = time_zone DEFAULT_TIME_ZONE = time_zone
def get_time_zone(time_zone_str): def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]:
"""Get time zone from string. Return None if unable to determine.""" """Get time zone from string. Return None if unable to determine."""
try: try:
return pytz.timezone(time_zone_str) return pytz.timezone(time_zone_str)
@ -36,17 +40,17 @@ def get_time_zone(time_zone_str):
return None return None
def utcnow(): def utcnow() -> dt.datetime:
"""Get now in UTC time.""" """Get now in UTC time."""
return dt.datetime.now(UTC) return dt.datetime.now(UTC)
def now(time_zone=None): def now(time_zone: dt.tzinfo=None) -> dt.datetime:
"""Get now in specified time zone.""" """Get now in specified time zone."""
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE) return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
def as_utc(dattim): def as_utc(dattim: dt.datetime) -> dt.datetime:
"""Return a datetime as UTC time. """Return a datetime as UTC time.
Assumes datetime without tzinfo to be in the DEFAULT_TIME_ZONE. Assumes datetime without tzinfo to be in the DEFAULT_TIME_ZONE.
@ -70,7 +74,7 @@ def as_timestamp(dt_value):
return parsed_dt.timestamp() return parsed_dt.timestamp()
def as_local(dattim): def as_local(dattim: dt.datetime) -> dt.datetime:
"""Convert a UTC datetime object to local time zone.""" """Convert a UTC datetime object to local time zone."""
if dattim.tzinfo == DEFAULT_TIME_ZONE: if dattim.tzinfo == DEFAULT_TIME_ZONE:
return dattim return dattim
@ -80,12 +84,13 @@ def as_local(dattim):
return dattim.astimezone(DEFAULT_TIME_ZONE) return dattim.astimezone(DEFAULT_TIME_ZONE)
def utc_from_timestamp(timestamp): def utc_from_timestamp(timestamp: float) -> dt.datetime:
"""Return a UTC time from a timestamp.""" """Return a UTC time from a timestamp."""
return dt.datetime.utcfromtimestamp(timestamp).replace(tzinfo=UTC) return dt.datetime.utcfromtimestamp(timestamp).replace(tzinfo=UTC)
def start_of_local_day(dt_or_d=None): def start_of_local_day(dt_or_d:
Union[dt.date, dt.datetime]=None) -> dt.datetime:
"""Return local datetime object of start of day from date or datetime.""" """Return local datetime object of start of day from date or datetime."""
if dt_or_d is None: if dt_or_d is None:
dt_or_d = now().date() dt_or_d = now().date()
@ -98,7 +103,7 @@ def start_of_local_day(dt_or_d=None):
# Copyright (c) Django Software Foundation and individual contributors. # Copyright (c) Django Software Foundation and individual contributors.
# All rights reserved. # All rights reserved.
# https://github.com/django/django/blob/master/LICENSE # https://github.com/django/django/blob/master/LICENSE
def parse_datetime(dt_str): def parse_datetime(dt_str: str) -> dt.datetime:
"""Parse a string and return a datetime.datetime. """Parse a string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one, This function supports time zone offsets. When the input contains one,
@ -109,25 +114,27 @@ def parse_datetime(dt_str):
match = DATETIME_RE.match(dt_str) match = DATETIME_RE.match(dt_str)
if not match: if not match:
return None return None
kws = match.groupdict() kws = match.groupdict() # type: Dict[str, Any]
if kws['microsecond']: if kws['microsecond']:
kws['microsecond'] = kws['microsecond'].ljust(6, '0') kws['microsecond'] = kws['microsecond'].ljust(6, '0')
tzinfo = kws.pop('tzinfo') tzinfo_str = kws.pop('tzinfo')
if tzinfo == 'Z': if tzinfo_str == 'Z':
tzinfo = UTC tzinfo = UTC
elif tzinfo is not None: elif tzinfo_str is not None:
offset_mins = int(tzinfo[-2:]) if len(tzinfo) > 3 else 0 offset_mins = int(tzinfo_str[-2:]) if len(tzinfo_str) > 3 else 0
offset_hours = int(tzinfo[1:3]) offset_hours = int(tzinfo_str[1:3])
offset = dt.timedelta(hours=offset_hours, minutes=offset_mins) offset = dt.timedelta(hours=offset_hours, minutes=offset_mins)
if tzinfo[0] == '-': if tzinfo_str[0] == '-':
offset = -offset offset = -offset
tzinfo = dt.timezone(offset) tzinfo = dt.timezone(offset)
else:
tzinfo = None
kws = {k: int(v) for k, v in kws.items() if v is not None} kws = {k: int(v) for k, v in kws.items() if v is not None}
kws['tzinfo'] = tzinfo kws['tzinfo'] = tzinfo
return dt.datetime(**kws) return dt.datetime(**kws)
def parse_date(dt_str): def parse_date(dt_str: str) -> dt.date:
"""Convert a date string to a date object.""" """Convert a date string to a date object."""
try: try:
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date() return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
@ -154,7 +161,7 @@ def parse_time(time_str):
# Found in this gist: https://gist.github.com/zhangsen/1199964 # Found in this gist: https://gist.github.com/zhangsen/1199964
def get_age(date): def get_age(date: dt.datetime) -> str:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
""" """
Take a datetime and return its "age" as a string. Take a datetime and return its "age" as a string.
@ -164,14 +171,14 @@ def get_age(date):
be returned. be returned.
Make sure date is not in the future, or else it won't work. Make sure date is not in the future, or else it won't work.
""" """
def formatn(number, unit): def formatn(number: int, unit: str) -> str:
"""Add "unit" if it's plural.""" """Add "unit" if it's plural."""
if number == 1: if number == 1:
return "1 %s" % unit return "1 %s" % unit
elif number > 1: elif number > 1:
return "%d %ss" % (number, unit) return "%d %ss" % (number, unit)
def q_n_r(first, second): def q_n_r(first: int, second: int) -> Tuple[int, int]:
"""Return quotient and remaining.""" """Return quotient and remaining."""
return first // second, first % second return first // second, first % second
@ -196,7 +203,5 @@ def get_age(date):
minute, second = q_n_r(second, 60) minute, second = q_n_r(second, 60)
if minute > 0: if minute > 0:
return formatn(minute, 'minute') return formatn(minute, 'minute')
if second > 0:
return formatn(second, 'second')
return "0 second" return formatn(second, 'second') if second > 0 else "0 seconds"

View File

@ -3,7 +3,7 @@
import logging import logging
def fahrenheit_to_celcius(fahrenheit): def fahrenheit_to_celcius(fahrenheit: float) -> float:
"""**DEPRECATED** Convert a Fahrenheit temperature to Celsius.""" """**DEPRECATED** Convert a Fahrenheit temperature to Celsius."""
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
'fahrenheit_to_celcius is now fahrenheit_to_celsius ' 'fahrenheit_to_celcius is now fahrenheit_to_celsius '
@ -11,12 +11,12 @@ def fahrenheit_to_celcius(fahrenheit):
return fahrenheit_to_celsius(fahrenheit) return fahrenheit_to_celsius(fahrenheit)
def fahrenheit_to_celsius(fahrenheit): def fahrenheit_to_celsius(fahrenheit: float) -> float:
"""Convert a Fahrenheit temperature to Celsius.""" """Convert a Fahrenheit temperature to Celsius."""
return (fahrenheit - 32.0) / 1.8 return (fahrenheit - 32.0) / 1.8
def celcius_to_fahrenheit(celcius): def celcius_to_fahrenheit(celcius: float) -> float:
"""**DEPRECATED** Convert a Celsius temperature to Fahrenheit.""" """**DEPRECATED** Convert a Celsius temperature to Fahrenheit."""
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
'celcius_to_fahrenheit is now celsius_to_fahrenheit correcting ' 'celcius_to_fahrenheit is now celsius_to_fahrenheit correcting '
@ -24,6 +24,6 @@ def celcius_to_fahrenheit(celcius):
return celsius_to_fahrenheit(celcius) return celsius_to_fahrenheit(celcius)
def celsius_to_fahrenheit(celsius): def celsius_to_fahrenheit(celsius: float) -> float:
"""Convert a Celsius temperature to Fahrenheit.""" """Convert a Celsius temperature to Fahrenheit."""
return celsius * 1.8 + 32.0 return celsius * 1.8 + 32.0

View File

@ -2,6 +2,7 @@
import logging import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Union, List, Dict
import glob import glob
import yaml import yaml
@ -21,15 +22,16 @@ _SECRET_YAML = 'secrets.yaml'
class SafeLineLoader(yaml.SafeLoader): class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers.""" """Loader class that keeps track of line numbers."""
def compose_node(self, parent, index): def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node:
"""Annotate a node with the first line it was seen.""" """Annotate a node with the first line it was seen."""
last_line = self.line last_line = self.line # type: int
node = super(SafeLineLoader, self).compose_node(parent, index) node = super(SafeLineLoader,
self).compose_node(parent, index) # type: yaml.nodes.Node
node.__line__ = last_line + 1 node.__line__ = last_line + 1
return node return node
def load_yaml(fname): def load_yaml(fname: str) -> Union[List, Dict]:
"""Load a YAML file.""" """Load a YAML file."""
try: try:
with open(fname, encoding='utf-8') as conf_file: with open(fname, encoding='utf-8') as conf_file:
@ -41,7 +43,8 @@ def load_yaml(fname):
raise HomeAssistantError(exc) raise HomeAssistantError(exc)
def _include_yaml(loader, node): def _include_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node) -> Union[List, Dict]:
"""Load another YAML file and embeds it using the !include tag. """Load another YAML file and embeds it using the !include tag.
Example: Example:
@ -51,9 +54,10 @@ def _include_yaml(loader, node):
return load_yaml(fname) return load_yaml(fname)
def _include_dir_named_yaml(loader, node): def _include_dir_named_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load multiple files from directory as a dictionary.""" """Load multiple files from directory as a dictionary."""
mapping = OrderedDict() mapping = OrderedDict() # type: OrderedDict
files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml')
for fname in glob.glob(files): for fname in glob.glob(files):
filename = os.path.splitext(os.path.basename(fname))[0] filename = os.path.splitext(os.path.basename(fname))[0]
@ -61,9 +65,10 @@ def _include_dir_named_yaml(loader, node):
return mapping return mapping
def _include_dir_merge_named_yaml(loader, node): def _include_dir_merge_named_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load multiple files from directory as a merged dictionary.""" """Load multiple files from directory as a merged dictionary."""
mapping = OrderedDict() mapping = OrderedDict() # type: OrderedDict
files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml')
for fname in glob.glob(files): for fname in glob.glob(files):
if os.path.basename(fname) == _SECRET_YAML: if os.path.basename(fname) == _SECRET_YAML:
@ -74,17 +79,20 @@ def _include_dir_merge_named_yaml(loader, node):
return mapping return mapping
def _include_dir_list_yaml(loader, node): def _include_dir_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load multiple files from directory as a list.""" """Load multiple files from directory as a list."""
files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml')
return [load_yaml(f) for f in glob.glob(files) return [load_yaml(f) for f in glob.glob(files)
if os.path.basename(f) != _SECRET_YAML] if os.path.basename(f) != _SECRET_YAML]
def _include_dir_merge_list_yaml(loader, node): def _include_dir_merge_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load multiple files from directory as a merged list.""" """Load multiple files from directory as a merged list."""
files = os.path.join(os.path.dirname(loader.name), node.value, '*.yaml') files = os.path.join(os.path.dirname(loader.name),
merged_list = [] node.value, '*.yaml') # type: str
merged_list = [] # type: List
for fname in glob.glob(files): for fname in glob.glob(files):
if os.path.basename(fname) == _SECRET_YAML: if os.path.basename(fname) == _SECRET_YAML:
continue continue
@ -94,12 +102,13 @@ def _include_dir_merge_list_yaml(loader, node):
return merged_list return merged_list
def _ordered_dict(loader, node): def _ordered_dict(loader: SafeLineLoader,
node: yaml.nodes.MappingNode) -> OrderedDict:
"""Load YAML mappings into an ordered dictionary to preserve key order.""" """Load YAML mappings into an ordered dictionary to preserve key order."""
loader.flatten_mapping(node) loader.flatten_mapping(node)
nodes = loader.construct_pairs(node) nodes = loader.construct_pairs(node)
seen = {} seen = {} # type: Dict
min_line = None min_line = None
for (key, _), (node, _) in zip(nodes, node.value): for (key, _), (node, _) in zip(nodes, node.value):
line = getattr(node, '__line__', 'unknown') line = getattr(node, '__line__', 'unknown')
@ -116,12 +125,13 @@ def _ordered_dict(loader, node):
seen[key] = line seen[key] = line
processed = OrderedDict(nodes) processed = OrderedDict(nodes)
processed.__config_file__ = loader.name setattr(processed, '__config_file__', loader.name)
processed.__line__ = min_line setattr(processed, '__line__', min_line)
return processed return processed
def _env_var_yaml(loader, node): def _env_var_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load environment variables and embed it into the configuration YAML.""" """Load environment variables and embed it into the configuration YAML."""
if node.value in os.environ: if node.value in os.environ:
return os.environ[node.value] return os.environ[node.value]
@ -131,7 +141,8 @@ def _env_var_yaml(loader, node):
# pylint: disable=protected-access # pylint: disable=protected-access
def _secret_yaml(loader, node): def _secret_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
"""Load secrets and embed it into the configuration YAML.""" """Load secrets and embed it into the configuration YAML."""
# Create secret cache on loader and load secrets.yaml # Create secret cache on loader and load secrets.yaml
if not hasattr(loader, '_SECRET_CACHE'): if not hasattr(loader, '_SECRET_CACHE'):

View File

@ -137,7 +137,10 @@ class TestDateUtil(unittest.TestCase):
def test_get_age(self): def test_get_age(self):
"""Test get_age.""" """Test get_age."""
diff = dt_util.now() - timedelta(seconds=0) diff = dt_util.now() - timedelta(seconds=0)
self.assertEqual(dt_util.get_age(diff), "0 second") self.assertEqual(dt_util.get_age(diff), "0 seconds")
diff = dt_util.now() - timedelta(seconds=1)
self.assertEqual(dt_util.get_age(diff), "1 second")
diff = dt_util.now() - timedelta(seconds=30) diff = dt_util.now() - timedelta(seconds=30)
self.assertEqual(dt_util.get_age(diff), "30 seconds") self.assertEqual(dt_util.get_age(diff), "30 seconds")