Add generic parameters to HassJob (#70973)

pull/72719/head
Marc Mueller 2022-05-30 09:22:37 +02:00 committed by GitHub
parent 6bc09741c7
commit b417ae72e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 22 deletions

View File

@ -37,6 +37,7 @@ from typing import (
)
from urllib.parse import urlparse
from typing_extensions import ParamSpec
import voluptuous as vol
import yarl
@ -98,6 +99,7 @@ block_async_io.enable()
_T = TypeVar("_T")
_R = TypeVar("_R")
_R_co = TypeVar("_R_co", covariant=True)
_P = ParamSpec("_P")
# Internal; not helpers.typing.UNDEFINED due to circular dependency
_UNDEF: dict[Any, Any] = {}
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
@ -182,7 +184,7 @@ class HassJobType(enum.Enum):
Executor = 3
class HassJob(Generic[_R_co]):
class HassJob(Generic[_P, _R_co]):
"""Represent a job to be run later.
We check the callable type in advance
@ -192,7 +194,7 @@ class HassJob(Generic[_R_co]):
__slots__ = ("job_type", "target")
def __init__(self, target: Callable[..., _R_co]) -> None:
def __init__(self, target: Callable[_P, _R_co]) -> None:
"""Create a job object."""
self.target = target
self.job_type = _get_hassjob_callable_job_type(target)
@ -416,20 +418,20 @@ class HomeAssistant:
@overload
@callback
def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any
) -> asyncio.Future[_R] | None:
...
@overload
@callback
def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None:
...
@callback
def async_add_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None:
"""Add a HassJob from within the event loop.
@ -512,20 +514,20 @@ class HomeAssistant:
@overload
@callback
def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R]], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R]], *args: Any
) -> asyncio.Future[_R] | None:
...
@overload
@callback
def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None:
...
@callback
def async_run_hass_job(
self, hassjob: HassJob[Coroutine[Any, Any, _R] | _R], *args: Any
self, hassjob: HassJob[..., Coroutine[Any, Any, _R] | _R], *args: Any
) -> asyncio.Future[_R] | None:
"""Run a HassJob from within the event loop.
@ -814,7 +816,7 @@ class Event:
class _FilterableJob(NamedTuple):
"""Event listener job to be executed with optional filter."""
job: HassJob[None | Awaitable[None]]
job: HassJob[[Event], None | Awaitable[None]]
event_filter: Callable[[Event], bool] | None
run_immediately: bool

View File

@ -258,7 +258,9 @@ def _async_track_state_change_event(
action: Callable[[Event], Any],
) -> CALLBACK_TYPE:
"""async_track_state_change_event without lowercasing."""
entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {})
entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_CHANGE_CALLBACKS, {}
)
if TRACK_STATE_CHANGE_LISTENER not in hass.data:
@ -319,10 +321,10 @@ def _async_remove_indexed_listeners(
data_key: str,
listener_key: str,
storage_keys: Iterable[str],
job: HassJob[Any],
job: HassJob[[Event], Any],
) -> None:
"""Remove a listener."""
callbacks = hass.data[data_key]
callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data[data_key]
for storage_key in storage_keys:
callbacks[storage_key].remove(job)
@ -347,7 +349,9 @@ def async_track_entity_registry_updated_event(
if not (entity_ids := _async_string_to_lower_list(entity_ids)):
return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {})
entity_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}
)
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
@ -401,7 +405,7 @@ def async_track_entity_registry_updated_event(
@callback
def _async_dispatch_domain_event(
hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[Any]]]
hass: HomeAssistant, event: Event, callbacks: dict[str, list[HassJob[[Event], Any]]]
) -> None:
domain = split_entity_id(event.data["entity_id"])[0]
@ -438,7 +442,9 @@ def _async_track_state_added_domain(
action: Callable[[Event], Any],
) -> CALLBACK_TYPE:
"""async_track_state_added_domain without lowercasing."""
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})
domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {}
)
if TRACK_STATE_ADDED_DOMAIN_LISTENER not in hass.data:
@ -490,7 +496,9 @@ def async_track_state_removed_domain(
if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {})
domain_callbacks: dict[str, list[HassJob[[Event], Any]]] = hass.data.setdefault(
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {}
)
if TRACK_STATE_REMOVED_DOMAIN_LISTENER not in hass.data:
@ -1249,7 +1257,7 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@bind_hass
def async_track_point_in_time(
hass: HomeAssistant,
action: HassJob[Awaitable[None] | None]
action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime,
) -> CALLBACK_TYPE:
@ -1271,7 +1279,7 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@bind_hass
def async_track_point_in_utc_time(
hass: HomeAssistant,
action: HassJob[Awaitable[None] | None]
action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None],
point_in_time: datetime,
) -> CALLBACK_TYPE:
@ -1284,7 +1292,7 @@ def async_track_point_in_utc_time(
cancel_callback: asyncio.TimerHandle | None = None
@callback
def run_action(job: HassJob[Awaitable[None] | None]) -> None:
def run_action(job: HassJob[[datetime], Awaitable[None] | None]) -> None:
"""Call the action."""
nonlocal cancel_callback
@ -1324,7 +1332,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
def async_call_later(
hass: HomeAssistant,
delay: float | timedelta,
action: HassJob[Awaitable[None] | None]
action: HassJob[[datetime], Awaitable[None] | None]
| Callable[[datetime], Awaitable[None] | None],
) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>."""
@ -1345,7 +1353,7 @@ def async_track_time_interval(
) -> CALLBACK_TYPE:
"""Add a listener that fires repetitively at every timedelta interval."""
remove: CALLBACK_TYPE
interval_listener_job: HassJob[None]
interval_listener_job: HassJob[[datetime], None]
job = HassJob(action)
@ -1382,7 +1390,7 @@ class SunListener:
"""Helper class to help listen to sun events."""
hass: HomeAssistant = attr.ib()
job: HassJob[Awaitable[None] | None] = attr.ib()
job: HassJob[[], Awaitable[None] | None] = attr.ib()
event: str = attr.ib()
offset: timedelta | None = attr.ib()
_unsub_sun: CALLBACK_TYPE | None = attr.ib(default=None)