Only work out job type once when setting up dispatcher (#116030)

pull/116059/head
J. Nick Koston 2024-04-23 22:24:36 +02:00 committed by GitHub
parent f1fa33483e
commit 8f1761343e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 13 deletions

View File

@ -7,7 +7,12 @@ from functools import partial
import logging
from typing import Any, TypeVarTuple, overload
from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.core import (
HassJob,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
)
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception
@ -161,9 +166,13 @@ def _generate_job(
signal: SignalType[*_Ts] | str, target: Callable[[*_Ts], Any] | Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target."""
job_type = get_hassjob_callable_job_type(target)
return HassJob(
catch_log_exception(target, partial(_format_err, signal, target)),
catch_log_exception(
target, partial(_format_err, signal, target), job_type=job_type
),
f"dispatcher {signal}",
job_type=job_type,
)

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from functools import partial, wraps
import inspect
@ -12,7 +11,12 @@ import queue
import traceback
from typing import Any, TypeVar, TypeVarTuple, cast, overload
from homeassistant.core import HomeAssistant, callback, is_callback
from homeassistant.core import (
HassJobType,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
)
_T = TypeVar("_T")
_Ts = TypeVarTuple("_Ts")
@ -129,34 +133,38 @@ def _callback_wrapper(
@overload
def catch_log_exception(
func: Callable[[*_Ts], Coroutine[Any, Any, Any]], format_err: Callable[[*_Ts], Any]
func: Callable[[*_Ts], Coroutine[Any, Any, Any]],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], Coroutine[Any, Any, None]]: ...
@overload
def catch_log_exception(
func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any]
func: Callable[[*_Ts], Any],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]: ...
def catch_log_exception(
func: Callable[[*_Ts], Any], format_err: Callable[[*_Ts], Any]
func: Callable[[*_Ts], Any],
format_err: Callable[[*_Ts], Any],
job_type: HassJobType | None = None,
) -> Callable[[*_Ts], None] | Callable[[*_Ts], Coroutine[Any, Any, None]]:
"""Decorate a function func to catch and log exceptions.
If func is a coroutine function, a coroutine function will be returned.
If func is a callback, a callback will be returned.
"""
# Check for partials to properly determine if coroutine function
check_func = func
while isinstance(check_func, partial):
check_func = check_func.func # type: ignore[unreachable] # false positive
if job_type is None:
job_type = get_hassjob_callable_job_type(func)
if asyncio.iscoroutinefunction(check_func):
if job_type is HassJobType.Coroutinefunction:
async_func = cast(Callable[[*_Ts], Coroutine[Any, Any, None]], func)
return wraps(async_func)(partial(_async_wrapper, async_func, format_err)) # type: ignore[return-value]
if is_callback(check_func):
if job_type is HassJobType.Callback:
return wraps(func)(partial(_callback_wrapper, func, format_err)) # type: ignore[return-value]
return wraps(func)(partial(_sync_wrapper, func, format_err)) # type: ignore[return-value]