core/homeassistant/util/__init__.py

287 lines
8.9 KiB
Python
Raw Normal View History

2016-03-07 22:20:48 +00:00
"""Helper methods for various modules."""
2021-03-17 20:46:07 +00:00
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
2014-04-15 06:48:00 +00:00
import enum
from functools import wraps
import random
import re
import socket
import string
import threading
2016-02-10 07:27:01 +00:00
from types import MappingProxyType
2021-03-17 20:46:07 +00:00
from typing import Any, Callable, Coroutine, Iterable, KeysView, TypeVar
import slugify as unicode_slug
from ..helpers.deprecation import deprecated_function
2016-04-16 07:55:35 +00:00
from .dt import as_local, utcnow
2019-07-30 23:59:12 +00:00
T = TypeVar("T")
U = TypeVar("U") # pylint: disable=invalid-name
ENUM_T = TypeVar("ENUM_T", bound=enum.Enum) # pylint: disable=invalid-name
2019-07-30 23:59:12 +00:00
RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
2013-11-11 00:46:48 +00:00
def raise_if_invalid_filename(filename: str) -> None:
"""
Check if a filename is valid.
Raises a ValueError if the filename is invalid.
"""
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
raise ValueError(f"{filename} is not a safe filename")
def raise_if_invalid_path(path: str) -> None:
"""
Check if a path is valid.
Raises a ValueError if the path is invalid.
"""
if RE_SANITIZE_PATH.sub("", path) != path:
raise ValueError(f"{path} is not a safe path")
@deprecated_function(replacement="raise_if_invalid_filename")
def sanitize_filename(filename: str) -> str:
"""Check if a filename is safe.
Only to be used to compare to original filename to check if changed.
If result changed, the given path is not safe and should not be used,
raise an error.
DEPRECATED.
"""
# Backwards compatible fix for misuse of method
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
return ""
return filename
@deprecated_function(replacement="raise_if_invalid_path")
def sanitize_path(path: str) -> str:
"""Check if a path is safe.
Only to be used to compare to original path to check if changed.
If result changed, the given path is not safe and should not be used,
raise an error.
DEPRECATED.
"""
# Backwards compatible fix for misuse of method
if RE_SANITIZE_PATH.sub("", path) != path:
return ""
return path
2014-10-22 06:52:24 +00:00
def slugify(text: str, *, separator: str = "_") -> str:
2016-03-07 22:20:48 +00:00
"""Slugify a given text."""
if text == "":
return ""
slug = unicode_slug.slugify(text, separator=separator)
return "unknown" if slug == "" else slug
def repr_helper(inp: Any) -> str:
2016-03-07 22:20:48 +00:00
"""Help creating a more readable string representation of objects."""
2016-02-10 07:27:01 +00:00
if isinstance(inp, (dict, MappingProxyType)):
return ", ".join(
2020-04-07 21:14:28 +00:00
f"{repr_helper(key)}={repr_helper(item)}" for key, item in inp.items()
2019-07-30 23:59:12 +00:00
)
if isinstance(inp, datetime):
2016-04-16 07:55:35 +00:00
return as_local(inp).isoformat()
return str(inp)
2019-07-30 23:59:12 +00:00
def convert(
2021-03-17 20:46:07 +00:00
value: T | None, to_type: Callable[[T], U], default: U | None = None
) -> U | None:
2016-03-07 22:20:48 +00:00
"""Convert value to to_type, returns default if fails."""
try:
2014-03-26 07:08:50 +00:00
return default if value is None else to_type(value)
except (ValueError, TypeError):
# If value could not be converted
return default
2019-07-30 23:59:12 +00:00
def ensure_unique_string(
2021-03-17 20:46:07 +00:00
preferred_string: str, current_strings: Iterable[str] | KeysView[str]
2019-07-30 23:59:12 +00:00
) -> str:
2016-03-07 22:20:48 +00:00
"""Return a string that is not present in current_strings.
If preferred string exists will append _2, _3, ..
"""
test_string = preferred_string
current_strings_set = set(current_strings)
tries = 1
while test_string in current_strings_set:
tries += 1
test_string = f"{preferred_string}_{tries}"
return test_string
# Taken from: http://stackoverflow.com/a/11735897
def get_local_ip() -> str:
2016-03-07 22:20:48 +00:00
"""Try to determine the local IP address of the machine."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Use Google Public DNS server to determine own IP
2019-07-30 23:59:12 +00:00
sock.connect(("8.8.8.8", 80))
return sock.getsockname()[0] # type: ignore
except OSError:
try:
return socket.gethostbyname(socket.gethostname())
except socket.gaierror:
2019-07-30 23:59:12 +00:00
return "127.0.0.1"
2015-08-03 15:05:33 +00:00
finally:
sock.close()
# Taken from http://stackoverflow.com/a/23728630
def get_random_string(length: int = 10) -> str:
2016-03-07 22:20:48 +00:00
"""Return a random string with letters and digits."""
generator = random.SystemRandom()
source_chars = string.ascii_letters + string.digits
2019-07-30 23:59:12 +00:00
return "".join(generator.choice(source_chars) for _ in range(length))
2014-04-15 06:48:00 +00:00
class OrderedEnum(enum.Enum):
2016-03-07 22:20:48 +00:00
"""Taken from Python 3.4.0 docs."""
2014-04-15 06:48:00 +00:00
# https://github.com/PyCQA/pylint/issues/2306
# pylint: disable=comparison-with-callable
def __ge__(self, other: ENUM_T) -> bool:
2016-03-07 22:20:48 +00:00
"""Return the greater than element."""
2014-04-15 06:48:00 +00:00
if self.__class__ is other.__class__:
return bool(self.value >= other.value)
2014-04-15 06:48:00 +00:00
return NotImplemented
def __gt__(self, other: ENUM_T) -> bool:
2016-03-07 22:20:48 +00:00
"""Return the greater element."""
2014-04-15 06:48:00 +00:00
if self.__class__ is other.__class__:
return bool(self.value > other.value)
2014-04-15 06:48:00 +00:00
return NotImplemented
def __le__(self, other: ENUM_T) -> bool:
2016-03-07 22:20:48 +00:00
"""Return the lower than element."""
2014-04-15 06:48:00 +00:00
if self.__class__ is other.__class__:
return bool(self.value <= other.value)
2014-04-15 06:48:00 +00:00
return NotImplemented
def __lt__(self, other: ENUM_T) -> bool:
2016-03-07 22:20:48 +00:00
"""Return the lower element."""
2014-04-15 06:48:00 +00:00
if self.__class__ is other.__class__:
return bool(self.value < other.value)
2014-04-15 06:48:00 +00:00
return NotImplemented
class Throttle:
2016-03-07 22:20:48 +00:00
"""A class for throttling the execution of tasks.
This method decorator adds a cooldown to a method to prevent it from being
called more then 1 time within the timedelta interval `min_time` after it
returned its result.
Calling a method a second time during the interval will return None.
Pass keyword argument `no_throttle=True` to the wrapped method to make
the call not throttled.
Decorator takes in an optional second timedelta interval to throttle the
'no_throttle' calls.
Adds a datetime attribute `last_call` to the method.
"""
2019-07-30 23:59:12 +00:00
def __init__(
2021-03-17 20:46:07 +00:00
self, min_time: timedelta, limit_no_throttle: timedelta | None = None
2019-07-30 23:59:12 +00:00
) -> None:
2016-03-07 22:20:48 +00:00
"""Initialize the throttle."""
self.min_time = min_time
self.limit_no_throttle = limit_no_throttle
def __call__(self, method: Callable) -> Callable:
2016-03-07 22:20:48 +00:00
"""Caller for the throttle."""
# Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method):
2019-07-30 23:59:12 +00:00
async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled."""
return None
2019-07-30 23:59:12 +00:00
else:
2019-07-30 23:59:12 +00:00
def throttled_value() -> None: # type: ignore
"""Stand-in function for when real func is being throttled."""
return None
if self.limit_no_throttle is not None:
method = Throttle(self.limit_no_throttle)(method)
2015-10-11 17:42:42 +00:00
# Different methods that can be passed in:
# - a function
# - an unbound function on a class
# - a method (bound function on a class)
# We want to be able to differentiate between function and unbound
# methods (which are considered functions).
# All methods have the classname in their qualname separated by a '.'
2015-10-09 06:49:55 +00:00
# Functions have a '.' in their qualname if defined inline, but will
# be prefixed by '.<locals>.' so we strip that out.
2019-07-30 23:59:12 +00:00
is_func = (
not hasattr(method, "__self__")
and "." not in method.__qualname__.split(".<locals>.")[-1]
)
2015-10-09 06:49:55 +00:00
@wraps(method)
2021-03-17 20:46:07 +00:00
def wrapper(*args: Any, **kwargs: Any) -> Callable | Coroutine:
"""Wrap that allows wrapped to be called only once per min_time.
2016-03-07 22:20:48 +00:00
If we cannot acquire the lock, it is running so return None.
"""
2019-07-30 23:59:12 +00:00
if hasattr(method, "__self__"):
host = getattr(method, "__self__")
2015-10-11 17:42:42 +00:00
elif is_func:
host = wrapper
else:
host = args[0] if args else wrapper
# pylint: disable=protected-access # to _throttle
2019-07-30 23:59:12 +00:00
if not hasattr(host, "_throttle"):
host._throttle = {}
2015-10-09 06:49:55 +00:00
if id(self) not in host._throttle:
host._throttle[id(self)] = [threading.Lock(), None]
throttle = host._throttle[id(self)]
# pylint: enable=protected-access
if not throttle[0].acquire(False):
return throttled_value()
2015-09-13 05:56:49 +00:00
2015-10-09 06:49:55 +00:00
# Check if method is never called or no_throttle is given
2019-07-30 23:59:12 +00:00
force = kwargs.pop("no_throttle", False) or not throttle[1]
2015-09-13 05:56:49 +00:00
2015-10-09 06:49:55 +00:00
try:
if force or utcnow() - throttle[1] > self.min_time:
2015-09-13 05:56:49 +00:00
result = method(*args, **kwargs)
throttle[1] = utcnow()
return result # type: ignore
return throttled_value()
2015-09-13 05:56:49 +00:00
finally:
throttle[0].release()
return wrapper