Improvement typing (#2735)
* Fix: Circular dependencies of internal files * Change: dt.date for Date and dt.datetime for DateTime * Use NewType if available * FIX: Wrong version test * Remove: Date and DateTime types due to error * Change to HomeAssistantType * General Improvement of Typing * Improve typing config_validation * Improve typing script * General Typing Improvements * Improve NewType check * Improve typing db_migrator * Improve util/__init__ typing * Improve helpers/location typing * Regroup imports and remove pylint: disable=ungrouped-imports * General typing improvementspull/2745/head
parent
a3ca3e878b
commit
0377338a81
|
@ -274,7 +274,7 @@ def try_to_restart() -> None:
|
|||
# thread left (which is us). Nothing we really do with it, but it might be
|
||||
# useful when debugging shutdown/restart issues.
|
||||
try:
|
||||
nthreads = sum(thread.isAlive() and not thread.isDaemon()
|
||||
nthreads = sum(thread.is_alive() and not thread.daemon
|
||||
for thread in threading.enumerate())
|
||||
if nthreads > 1:
|
||||
sys.stderr.write(
|
||||
|
|
|
@ -12,6 +12,8 @@ from typing import Any, Optional, Dict
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
import homeassistant.components as core_components
|
||||
from homeassistant.components import group, persistent_notification
|
||||
import homeassistant.config as conf_util
|
||||
|
@ -216,7 +218,7 @@ def prepare_setup_platform(hass: core.HomeAssistant, config, domain: str,
|
|||
|
||||
# pylint: disable=too-many-branches, too-many-statements, too-many-arguments
|
||||
def from_config_dict(config: Dict[str, Any],
|
||||
hass: Optional[core.HomeAssistant]=None,
|
||||
hass: Optional[HomeAssistantType]=None,
|
||||
config_dir: Optional[str]=None,
|
||||
enable_log: bool=True,
|
||||
verbose: bool=False,
|
||||
|
|
|
@ -4,6 +4,9 @@ import os
|
|||
import shutil
|
||||
from types import MappingProxyType
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from typing import Any, Tuple # NOQA
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
|
@ -37,7 +40,7 @@ DEFAULT_CORE_CONFIG = (
|
|||
CONF_UNIT_SYSTEM_IMPERIAL)),
|
||||
(CONF_TIME_ZONE, 'UTC', 'time_zone', 'Pick yours from here: http://en.wiki'
|
||||
'pedia.org/wiki/List_of_tz_database_time_zones'),
|
||||
)
|
||||
) # type: Tuple[Tuple[str, Any, Any, str], ...]
|
||||
DEFAULT_CONFIG = """
|
||||
# Show links to resources in log and frontend
|
||||
introduction:
|
||||
|
|
|
@ -14,9 +14,13 @@ import threading
|
|||
import time
|
||||
from types import MappingProxyType
|
||||
|
||||
from typing import Any, Callable
|
||||
# pylint: disable=unused-import
|
||||
from typing import Optional, Any, Callable # NOQA
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers.typing import UnitSystemType # NOQA
|
||||
|
||||
import homeassistant.util as util
|
||||
import homeassistant.util.dt as dt_util
|
||||
import homeassistant.util.location as location
|
||||
|
@ -713,15 +717,15 @@ class Config(object):
|
|||
# pylint: disable=too-many-instance-attributes
|
||||
def __init__(self):
|
||||
"""Initialize a new config object."""
|
||||
self.latitude = None
|
||||
self.longitude = None
|
||||
self.elevation = None
|
||||
self.location_name = None
|
||||
self.time_zone = None
|
||||
self.units = METRIC_SYSTEM
|
||||
self.latitude = None # type: Optional[float]
|
||||
self.longitude = None # type: Optional[float]
|
||||
self.elevation = None # type: Optional[int]
|
||||
self.location_name = None # type: Optional[str]
|
||||
self.time_zone = None # type: Optional[str]
|
||||
self.units = METRIC_SYSTEM # type: UnitSystemType
|
||||
|
||||
# If True, pip install is skipped for requirements on startup
|
||||
self.skip_pip = False
|
||||
self.skip_pip = False # type: bool
|
||||
|
||||
# List of loaded components
|
||||
self.components = []
|
||||
|
|
|
@ -3,6 +3,8 @@ from datetime import timedelta
|
|||
import logging
|
||||
import sys
|
||||
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
|
||||
from homeassistant.components import (
|
||||
zone as zone_cmp, sun as sun_cmp)
|
||||
from homeassistant.const import (
|
||||
|
@ -21,7 +23,7 @@ FROM_CONFIG_FORMAT = '{}_from_config'
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def from_config(config, config_validation=True):
|
||||
def from_config(config: ConfigType, config_validation: bool=True):
|
||||
"""Turn a condition configuration into a method."""
|
||||
factory = getattr(
|
||||
sys.modules[__name__],
|
||||
|
@ -34,13 +36,14 @@ def from_config(config, config_validation=True):
|
|||
return factory(config, config_validation)
|
||||
|
||||
|
||||
def and_from_config(config, config_validation=True):
|
||||
def and_from_config(config: ConfigType, config_validation: bool=True):
|
||||
"""Create multi condition matcher using 'AND'."""
|
||||
if config_validation:
|
||||
config = cv.AND_CONDITION_SCHEMA(config)
|
||||
checks = [from_config(entry) for entry in config['conditions']]
|
||||
|
||||
def if_and_condition(hass, variables=None):
|
||||
def if_and_condition(hass: HomeAssistantType,
|
||||
variables=None) -> bool:
|
||||
"""Test and condition."""
|
||||
for check in checks:
|
||||
try:
|
||||
|
@ -55,13 +58,14 @@ def and_from_config(config, config_validation=True):
|
|||
return if_and_condition
|
||||
|
||||
|
||||
def or_from_config(config, config_validation=True):
|
||||
def or_from_config(config: ConfigType, config_validation: bool=True):
|
||||
"""Create multi condition matcher using 'OR'."""
|
||||
if config_validation:
|
||||
config = cv.OR_CONDITION_SCHEMA(config)
|
||||
checks = [from_config(entry) for entry in config['conditions']]
|
||||
|
||||
def if_or_condition(hass, variables=None):
|
||||
def if_or_condition(hass: HomeAssistantType,
|
||||
variables=None) -> bool:
|
||||
"""Test and condition."""
|
||||
for check in checks:
|
||||
try:
|
||||
|
@ -76,8 +80,8 @@ def or_from_config(config, config_validation=True):
|
|||
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def numeric_state(hass, entity, below=None, above=None, value_template=None,
|
||||
variables=None):
|
||||
def numeric_state(hass: HomeAssistantType, entity, below=None, above=None,
|
||||
value_template=None, variables=None):
|
||||
"""Test a numeric state condition."""
|
||||
if isinstance(entity, str):
|
||||
entity = hass.states.get(entity)
|
||||
|
@ -93,7 +97,7 @@ def numeric_state(hass, entity, below=None, above=None, value_template=None,
|
|||
try:
|
||||
value = render(hass, value_template, variables)
|
||||
except TemplateError as ex:
|
||||
_LOGGER.error(ex)
|
||||
_LOGGER.error("Template error: %s", ex)
|
||||
return False
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Helpers for config validation using voluptuous."""
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Any, Union, TypeVar, Callable, Sequence, List, Dict
|
||||
|
||||
import jinja2
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -28,12 +30,15 @@ longitude = vol.All(vol.Coerce(float), vol.Range(min=-180, max=180),
|
|||
msg='invalid longitude')
|
||||
sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE))
|
||||
|
||||
# typing typevar
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/alecthomas/voluptuous/issues/115#issuecomment-144464666
|
||||
def has_at_least_one_key(*keys):
|
||||
def has_at_least_one_key(*keys: str) -> Callable:
|
||||
"""Validator that at least one key exists."""
|
||||
def validate(obj):
|
||||
def validate(obj: Dict) -> Dict:
|
||||
"""Test keys exist in dict."""
|
||||
if not isinstance(obj, dict):
|
||||
raise vol.Invalid('expected dictionary')
|
||||
|
@ -46,7 +51,7 @@ def has_at_least_one_key(*keys):
|
|||
return validate
|
||||
|
||||
|
||||
def boolean(value):
|
||||
def boolean(value: Any) -> bool:
|
||||
"""Validate and coerce a boolean value."""
|
||||
if isinstance(value, str):
|
||||
value = value.lower()
|
||||
|
@ -63,12 +68,12 @@ def isfile(value):
|
|||
return vol.IsFile('not a file')(value)
|
||||
|
||||
|
||||
def ensure_list(value):
|
||||
def ensure_list(value: Union[T, Sequence[T]]) -> List[T]:
|
||||
"""Wrap value in list if it is not one."""
|
||||
return value if isinstance(value, list) else [value]
|
||||
|
||||
|
||||
def entity_id(value):
|
||||
def entity_id(value: Any) -> str:
|
||||
"""Validate Entity ID."""
|
||||
value = string(value).lower()
|
||||
if valid_entity_id(value):
|
||||
|
@ -76,7 +81,7 @@ def entity_id(value):
|
|||
raise vol.Invalid('Entity ID {} is an invalid entity id'.format(value))
|
||||
|
||||
|
||||
def entity_ids(value):
|
||||
def entity_ids(value: Union[str, Sequence]) -> List[str]:
|
||||
"""Validate Entity IDs."""
|
||||
if value is None:
|
||||
raise vol.Invalid('Entity IDs can not be None')
|
||||
|
@ -109,7 +114,7 @@ time_period_dict = vol.All(
|
|||
lambda value: timedelta(**value))
|
||||
|
||||
|
||||
def time_period_str(value):
|
||||
def time_period_str(value: str) -> timedelta:
|
||||
"""Validate and transform time offset."""
|
||||
if isinstance(value, int):
|
||||
raise vol.Invalid('Make sure you wrap time values in quotes')
|
||||
|
@ -182,7 +187,7 @@ def platform_validator(domain):
|
|||
return validator
|
||||
|
||||
|
||||
def positive_timedelta(value):
|
||||
def positive_timedelta(value: timedelta) -> timedelta:
|
||||
"""Validate timedelta is positive."""
|
||||
if value < timedelta(0):
|
||||
raise vol.Invalid('Time period should be positive')
|
||||
|
@ -209,14 +214,14 @@ def slug(value):
|
|||
raise vol.Invalid('invalid slug {} (try {})'.format(value, slg))
|
||||
|
||||
|
||||
def string(value):
|
||||
def string(value: Any) -> str:
|
||||
"""Coerce value to string, except for None."""
|
||||
if value is not None:
|
||||
return str(value)
|
||||
raise vol.Invalid('string value is None')
|
||||
|
||||
|
||||
def temperature_unit(value):
|
||||
def temperature_unit(value) -> str:
|
||||
"""Validate and transform temperature unit."""
|
||||
value = str(value).upper()
|
||||
if value == 'C':
|
||||
|
|
|
@ -12,9 +12,7 @@ from homeassistant.const import (
|
|||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
from homeassistant.util import ensure_unique_string, slugify
|
||||
|
||||
# pylint: disable=using-constant-test,unused-import
|
||||
if False:
|
||||
from homeassistant.core import HomeAssistant # NOQA
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
# Entity attributes that we will overwrite
|
||||
_OVERWRITE = {} # type: Dict[str, Any]
|
||||
|
@ -27,7 +25,7 @@ ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
|
|||
|
||||
def generate_entity_id(entity_id_format: str, name: Optional[str],
|
||||
current_ids: Optional[List[str]]=None,
|
||||
hass: 'Optional[HomeAssistant]'=None) -> str:
|
||||
hass: Optional[HomeAssistantType]=None) -> str:
|
||||
"""Generate a unique entity ID based on given entity IDs or used IDs."""
|
||||
name = (name or DEVICE_DEFAULT_NAME).lower()
|
||||
if current_ids is None:
|
||||
|
@ -153,7 +151,7 @@ class Entity(object):
|
|||
# are used to perform a very specific function. Overwriting these may
|
||||
# produce undesirable effects in the entity's operation.
|
||||
|
||||
hass = None # type: Optional[HomeAssistant]
|
||||
hass = None # type: Optional[HomeAssistantType]
|
||||
|
||||
def update_ha_state(self, force_refresh=False):
|
||||
"""Update Home Assistant with current state of entity.
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
"""Event Decorators for custom components."""
|
||||
import functools
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from typing import Optional # NOQA
|
||||
from homeassistant.helpers.typing import HomeAssistantType # NOQA
|
||||
|
||||
from homeassistant.helpers import event
|
||||
|
||||
HASS = None
|
||||
HASS = None # type: Optional[HomeAssistantType]
|
||||
|
||||
|
||||
def track_state_change(entity_ids, from_state=None, to_state=None):
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
"""Location helpers for Home Assistant."""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
|
||||
from homeassistant.core import State
|
||||
from homeassistant.util import location as loc_util
|
||||
|
||||
|
||||
def has_location(state):
|
||||
def has_location(state: State) -> bool:
|
||||
"""Test if state contains a valid location."""
|
||||
return (isinstance(state, State) and
|
||||
isinstance(state.attributes.get(ATTR_LATITUDE), float) and
|
||||
isinstance(state.attributes.get(ATTR_LONGITUDE), float))
|
||||
|
||||
|
||||
def closest(latitude, longitude, states):
|
||||
def closest(latitude: float, longitude: float,
|
||||
states: Sequence[State]) -> State:
|
||||
"""Return closest state to point."""
|
||||
with_location = [state for state in states if has_location(state)]
|
||||
|
||||
|
|
|
@ -3,8 +3,12 @@ import logging
|
|||
import threading
|
||||
from itertools import islice
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
|
||||
import homeassistant.util.dt as date_util
|
||||
from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION
|
||||
from homeassistant.helpers.event import track_point_in_utc_time
|
||||
|
@ -22,7 +26,8 @@ CONF_EVENT_DATA = "event_data"
|
|||
CONF_DELAY = "delay"
|
||||
|
||||
|
||||
def call_from_config(hass, config, variables=None):
|
||||
def call_from_config(hass: HomeAssistantType, config: ConfigType,
|
||||
variables: Optional[Sequence]=None) -> None:
|
||||
"""Call a script based on a config entry."""
|
||||
Script(hass, config).run(variables)
|
||||
|
||||
|
@ -31,7 +36,8 @@ class Script():
|
|||
"""Representation of a script."""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
def __init__(self, hass, sequence, name=None, change_listener=None):
|
||||
def __init__(self, hass: HomeAssistantType, sequence, name: str=None,
|
||||
change_listener=None) -> None:
|
||||
"""Initialize the script."""
|
||||
self.hass = hass
|
||||
self.sequence = cv.SCRIPT_SCHEMA(sequence)
|
||||
|
@ -45,11 +51,11 @@ class Script():
|
|||
self._delay_listener = None
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
"""Return true if script is on."""
|
||||
return self._cur != -1
|
||||
|
||||
def run(self, variables=None):
|
||||
def run(self, variables: Optional[Sequence]=None) -> None:
|
||||
"""Run script."""
|
||||
with self._lock:
|
||||
if self._cur == -1:
|
||||
|
@ -101,7 +107,7 @@ class Script():
|
|||
if self._change_listener:
|
||||
self._change_listener()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
"""Stop running script."""
|
||||
with self._lock:
|
||||
if self._cur == -1:
|
||||
|
|
|
@ -2,15 +2,20 @@
|
|||
import functools
|
||||
import logging
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from typing import Optional # NOQA
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers.typing import HomeAssistantType # NOQA
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import template
|
||||
from homeassistant.loader import get_component
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
HASS = None
|
||||
HASS = None # type: Optional[HomeAssistantType]
|
||||
|
||||
CONF_SERVICE = 'service'
|
||||
CONF_SERVICE_TEMPLATE = 'service_template'
|
||||
|
|
|
@ -1,12 +1,39 @@
|
|||
"""Typing Helpers for Home-Assistant."""
|
||||
from typing import Dict, Any
|
||||
|
||||
import homeassistant.core
|
||||
# NOTE: NewType added to typing in 3.5.2 in June, 2016; Since 3.5.2 includes
|
||||
# security fixes everyone on 3.5 should upgrade "soon"
|
||||
try:
|
||||
from typing import NewType
|
||||
except ImportError:
|
||||
NewType = None
|
||||
|
||||
# HACK: mypy/pytype will import, other interpreters will not; this is to avoid
|
||||
# circular dependencies where the type is needed.
|
||||
# All homeassistant types should be imported this way.
|
||||
# Documentation
|
||||
# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
|
||||
# pylint: disable=using-constant-test,unused-import
|
||||
if False:
|
||||
from homeassistant.core import HomeAssistant # NOQA
|
||||
from homeassistant.helpers.unit_system import UnitSystem # NOQA
|
||||
# ENDHACK
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
if NewType:
|
||||
ConfigType = NewType('ConfigType', Dict[str, Any])
|
||||
HomeAssistantType = NewType('HomeAssistantType', 'HomeAssistant')
|
||||
UnitSystemType = NewType('UnitSystemType', 'UnitSystem')
|
||||
|
||||
ConfigType = Dict[str, Any]
|
||||
HomeAssistantType = homeassistant.core.HomeAssistant
|
||||
# Custom type for recorder Queries
|
||||
QueryType = NewType('QueryType', Any)
|
||||
|
||||
# Custom type for recorder Queries
|
||||
QueryType = Any
|
||||
# Duplicates for 3.5.1
|
||||
# pylint: disable=invalid-name
|
||||
else:
|
||||
ConfigType = Dict[str, Any] # type: ignore
|
||||
HomeAssistantType = 'HomeAssistant' # type: ignore
|
||||
UnitSystemType = 'UnitSystemType' # type: ignore
|
||||
|
||||
# Custom type for recorder Queries
|
||||
QueryType = Any # type: ignore
|
||||
|
|
|
@ -15,6 +15,8 @@ import time
|
|||
import threading
|
||||
import urllib.parse
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
import homeassistant.bootstrap as bootstrap
|
||||
|
@ -42,7 +44,7 @@ class APIStatus(enum.Enum):
|
|||
CANNOT_CONNECT = "cannot_connect"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return the state."""
|
||||
return self.value
|
||||
|
||||
|
@ -51,7 +53,8 @@ class API(object):
|
|||
"""Object to pass around Home Assistant API location and credentials."""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
def __init__(self, host, api_password=None, port=None, use_ssl=False):
|
||||
def __init__(self, host: str, api_password: Optional[str]=None,
|
||||
port: Optional[int]=None, use_ssl: bool=False) -> None:
|
||||
"""Initalize the API."""
|
||||
self.host = host
|
||||
self.port = port or SERVER_PORT
|
||||
|
@ -68,7 +71,7 @@ class API(object):
|
|||
if api_password is not None:
|
||||
self._headers[HTTP_HEADER_HA_AUTH] = api_password
|
||||
|
||||
def validate_api(self, force_validate=False):
|
||||
def validate_api(self, force_validate: bool=False) -> bool:
|
||||
"""Test if we can communicate with the API."""
|
||||
if self.status is None or force_validate:
|
||||
self.status = validate_api(self)
|
||||
|
@ -100,7 +103,7 @@ class API(object):
|
|||
_LOGGER.exception(error)
|
||||
raise HomeAssistantError(error)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the API."""
|
||||
return "API({}, {}, {})".format(
|
||||
self.host, self.api_password, self.port)
|
||||
|
|
|
@ -4,6 +4,10 @@ import argparse
|
|||
import os.path
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
@ -16,7 +20,7 @@ import homeassistant.config as config_util
|
|||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
||||
def ts_to_dt(timestamp):
|
||||
def ts_to_dt(timestamp: Optional[float]) -> Optional[datetime]:
|
||||
"""Turn a datetime into an integer for in the DB."""
|
||||
if timestamp is None:
|
||||
return None
|
||||
|
@ -26,8 +30,8 @@ def ts_to_dt(timestamp):
|
|||
# Based on code at
|
||||
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
|
||||
# pylint: disable=too-many-arguments
|
||||
def print_progress(iteration, total, prefix='', suffix='', decimals=2,
|
||||
bar_length=68):
|
||||
def print_progress(iteration: int, total: int, prefix: str='', suffix: str='',
|
||||
decimals: int=2, bar_length: int=68) -> None:
|
||||
"""Print progress bar.
|
||||
|
||||
Call in a loop to create terminal progress bar
|
||||
|
@ -49,7 +53,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=2,
|
|||
print("\n")
|
||||
|
||||
|
||||
def run(args):
|
||||
def run(args) -> int:
|
||||
"""The actual script body."""
|
||||
# pylint: disable=too-many-locals,invalid-name,too-many-statements
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -75,7 +79,7 @@ def run(args):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
config_dir = os.path.join(os.getcwd(), args.config)
|
||||
config_dir = os.path.join(os.getcwd(), args.config) # type: str
|
||||
|
||||
# Test if configuration directory exists
|
||||
if not os.path.isdir(config_dir):
|
||||
|
|
|
@ -12,21 +12,24 @@ import string
|
|||
from functools import wraps
|
||||
from types import MappingProxyType
|
||||
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, Optional, TypeVar, Callable, Sequence
|
||||
|
||||
from .dt import as_local, utcnow
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
|
||||
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
|
||||
RE_SLUGIFY = re.compile(r'[^a-z0-9_]+')
|
||||
|
||||
|
||||
def sanitize_filename(filename):
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
r"""Sanitize a filename by removing .. / and \\."""
|
||||
return RE_SANITIZE_FILENAME.sub("", filename)
|
||||
|
||||
|
||||
def sanitize_path(path):
|
||||
def sanitize_path(path: str) -> str:
|
||||
"""Sanitize a path by removing ~ and .."""
|
||||
return RE_SANITIZE_PATH.sub("", path)
|
||||
|
||||
|
@ -50,7 +53,8 @@ def repr_helper(inp: Any) -> str:
|
|||
return str(inp)
|
||||
|
||||
|
||||
def convert(value, to_type, default=None):
|
||||
def convert(value: T, to_type: Callable[[T], U],
|
||||
default: Optional[U]=None) -> Optional[U]:
|
||||
"""Convert value to to_type, returns default if fails."""
|
||||
try:
|
||||
return default if value is None else to_type(value)
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Union, Optional, Tuple # NOQA
|
|||
import pytz
|
||||
|
||||
DATE_STR_FORMAT = "%Y-%m-%d"
|
||||
UTC = DEFAULT_TIME_ZONE = pytz.utc # type: pytz.UTC
|
||||
UTC = DEFAULT_TIME_ZONE = pytz.utc # type: dt.tzinfo
|
||||
|
||||
|
||||
# Copyright (c) Django Software Foundation and individual contributors.
|
||||
|
@ -93,11 +93,10 @@ def start_of_local_day(dt_or_d:
|
|||
Union[dt.date, dt.datetime]=None) -> dt.datetime:
|
||||
"""Return local datetime object of start of day from date or datetime."""
|
||||
if dt_or_d is None:
|
||||
dt_or_d = now().date()
|
||||
date = now().date() # type: dt.date
|
||||
elif isinstance(dt_or_d, dt.datetime):
|
||||
dt_or_d = dt_or_d.date()
|
||||
|
||||
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(dt_or_d, dt.time()))
|
||||
date = dt_or_d.date()
|
||||
return DEFAULT_TIME_ZONE.localize(dt.datetime.combine(date, dt.time()))
|
||||
|
||||
|
||||
# Copyright (c) Django Software Foundation and individual contributors.
|
||||
|
@ -118,6 +117,8 @@ def parse_datetime(dt_str: str) -> dt.datetime:
|
|||
if kws['microsecond']:
|
||||
kws['microsecond'] = kws['microsecond'].ljust(6, '0')
|
||||
tzinfo_str = kws.pop('tzinfo')
|
||||
|
||||
tzinfo = None # type: Optional[dt.tzinfo]
|
||||
if tzinfo_str == 'Z':
|
||||
tzinfo = UTC
|
||||
elif tzinfo_str is not None:
|
||||
|
|
Loading…
Reference in New Issue