Improve typing [util.logging] (#70894)
parent
b4a0345b38
commit
9a3908d21d
|
@ -26,6 +26,7 @@ homeassistant.util.async_
|
|||
homeassistant.util.color
|
||||
homeassistant.util.decorator
|
||||
homeassistant.util.location
|
||||
homeassistant.util.logging
|
||||
homeassistant.util.process
|
||||
homeassistant.util.unit_system
|
||||
|
||||
|
|
|
@ -2,18 +2,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
import logging
|
||||
import logging.handlers
|
||||
import queue
|
||||
import traceback
|
||||
from typing import Any, cast, overload
|
||||
from typing import Any, TypeVar, cast, overload
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
|
||||
from homeassistant.core import HomeAssistant, callback, is_callback
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class HideSensitiveDataFilter(logging.Filter):
|
||||
"""Filter API password calls."""
|
||||
|
@ -115,22 +117,24 @@ def log_exception(format_err: Callable[..., Any], *args: Any) -> None:
|
|||
|
||||
|
||||
@overload
|
||||
def catch_log_exception( # type: ignore[misc]
|
||||
func: Callable[..., Awaitable[Any]], format_err: Callable[..., Any], *args: Any
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
"""Overload for Callables that return an Awaitable."""
|
||||
def catch_log_exception(
|
||||
func: Callable[..., Coroutine[Any, Any, Any]],
|
||||
format_err: Callable[..., Any],
|
||||
*args: Any,
|
||||
) -> Callable[..., Coroutine[Any, Any, None]]:
|
||||
"""Overload for Callables that return a Coroutine."""
|
||||
|
||||
|
||||
@overload
|
||||
def catch_log_exception(
|
||||
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
|
||||
) -> Callable[..., None]:
|
||||
) -> Callable[..., None | Coroutine[Any, Any, None]]:
|
||||
"""Overload for Callables that return Any."""
|
||||
|
||||
|
||||
def catch_log_exception(
|
||||
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
|
||||
) -> Callable[..., None] | Callable[..., Awaitable[None]]:
|
||||
) -> Callable[..., None | Coroutine[Any, Any, None]]:
|
||||
"""Decorate a callback to catch and log exceptions."""
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
|
@ -138,9 +142,9 @@ def catch_log_exception(
|
|||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func
|
||||
|
||||
wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]]
|
||||
wrapper_func: Callable[..., None | Coroutine[Any, Any, None]]
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
async_func = cast(Callable[..., Awaitable[None]], func)
|
||||
async_func = cast(Callable[..., Coroutine[Any, Any, None]], func)
|
||||
|
||||
@wraps(async_func)
|
||||
async def async_wrapper(*args: Any) -> None:
|
||||
|
@ -170,11 +174,11 @@ def catch_log_exception(
|
|||
|
||||
|
||||
def catch_log_coro_exception(
|
||||
target: Coroutine[Any, Any, Any], format_err: Callable[..., Any], *args: Any
|
||||
) -> Coroutine[Any, Any, Any]:
|
||||
target: Coroutine[Any, Any, _T], format_err: Callable[..., Any], *args: Any
|
||||
) -> Coroutine[Any, Any, _T | None]:
|
||||
"""Decorate a coroutine to catch and log exceptions."""
|
||||
|
||||
async def coro_wrapper(*args: Any) -> Any:
|
||||
async def coro_wrapper(*args: Any) -> _T | None:
|
||||
"""Catch and log exception."""
|
||||
try:
|
||||
return await target
|
||||
|
@ -182,10 +186,12 @@ def catch_log_coro_exception(
|
|||
log_exception(format_err, *args)
|
||||
return None
|
||||
|
||||
return coro_wrapper()
|
||||
return coro_wrapper(*args)
|
||||
|
||||
|
||||
def async_create_catching_coro(target: Coroutine) -> Coroutine:
|
||||
def async_create_catching_coro(
|
||||
target: Coroutine[Any, Any, _T]
|
||||
) -> Coroutine[Any, Any, _T | None]:
|
||||
"""Wrap a coroutine to catch and log exceptions.
|
||||
|
||||
The exception will be logged together with a stacktrace of where the
|
||||
|
@ -196,7 +202,7 @@ def async_create_catching_coro(target: Coroutine) -> Coroutine:
|
|||
trace = traceback.extract_stack()
|
||||
wrapped_target = catch_log_coro_exception(
|
||||
target,
|
||||
lambda *args: "Exception in {} called from\n {}".format(
|
||||
lambda: "Exception in {} called from\n {}".format(
|
||||
target.__name__,
|
||||
"".join(traceback.format_list(trace[:-1])),
|
||||
),
|
||||
|
|
Loading…
Reference in New Issue