Allow string formatting for dispatcher SignalType (#114174)
parent
dd43947ca0
commit
eb81a4204e
|
@ -18,6 +18,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.data_entry_flow import FlowResultType
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
SignalTypeFormat,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
|
@ -79,10 +80,14 @@ SUPPORTED_COMPONENTS = {
|
|||
"water_heater",
|
||||
}
|
||||
|
||||
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
|
||||
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
|
||||
MQTT_DISCOVERY_UPDATED: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat(
|
||||
"mqtt_discovery_updated_{}"
|
||||
)
|
||||
MQTT_DISCOVERY_NEW: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat(
|
||||
"mqtt_discovery_new_{}_{}"
|
||||
)
|
||||
MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component"
|
||||
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
|
||||
MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat("mqtt_discovery_done_{}")
|
||||
|
||||
TOPIC_BASE = "~"
|
||||
|
||||
|
|
|
@ -15,10 +15,16 @@ from homeassistant import core, setup
|
|||
from homeassistant.const import Platform
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .dispatcher import async_dispatcher_connect, async_dispatcher_send
|
||||
from .dispatcher import (
|
||||
SignalTypeFormat,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from .typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
SIGNAL_PLATFORM_DISCOVERED = "discovery.platform_discovered_{}"
|
||||
SIGNAL_PLATFORM_DISCOVERED: SignalTypeFormat[DiscoveryDict] = SignalTypeFormat(
|
||||
"discovery.platform_discovered_{}"
|
||||
)
|
||||
EVENT_LOAD_PLATFORM = "load_platform.{}"
|
||||
ATTR_PLATFORM = "platform"
|
||||
ATTR_DISCOVERED = "discovered"
|
||||
|
|
|
@ -20,8 +20,8 @@ DATA_DISPATCHER = "dispatcher"
|
|||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SignalType(Generic[*_Ts]):
|
||||
"""Generic string class for signal to improve typing."""
|
||||
class _SignalTypeBase(Generic[*_Ts]):
|
||||
"""Generic base class for SignalType."""
|
||||
|
||||
name: str
|
||||
|
||||
|
@ -40,6 +40,20 @@ class SignalType(Generic[*_Ts]):
|
|||
return False
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class SignalType(_SignalTypeBase[*_Ts]):
|
||||
"""Generic string class for signal to improve typing."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class SignalTypeFormat(_SignalTypeBase[*_Ts]):
|
||||
"""Generic string class for signal. Requires call to 'format' before use."""
|
||||
|
||||
def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]:
|
||||
"""Format name and return new SignalType instance."""
|
||||
return SignalType(self.name.format(*args, **kwargs))
|
||||
|
||||
|
||||
_DispatcherDataType = dict[
|
||||
SignalType[*_Ts] | str,
|
||||
dict[
|
||||
|
|
|
@ -76,6 +76,7 @@ from homeassistant.helpers import (
|
|||
translation,
|
||||
)
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
SignalType,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
|
@ -1497,7 +1498,9 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
|
|||
|
||||
|
||||
@callback
|
||||
def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]:
|
||||
def async_mock_signal(
|
||||
hass: HomeAssistant, signal: SignalType[Any] | str
|
||||
) -> list[tuple[Any]]:
|
||||
"""Catch all dispatches to a signal."""
|
||||
calls = []
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
SignalType,
|
||||
SignalTypeFormat,
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
|
@ -58,6 +59,27 @@ async def test_signal_type(hass: HomeAssistant) -> None:
|
|||
assert calls == [("Hello", 2), ("World", 3), ("x", 4)]
|
||||
|
||||
|
||||
async def test_signal_type_format(hass: HomeAssistant) -> None:
|
||||
"""Test dispatcher with SignalType and format."""
|
||||
signal: SignalTypeFormat[str, int] = SignalTypeFormat("test-{}")
|
||||
calls: list[tuple[str, int]] = []
|
||||
|
||||
def test_funct(data1: str, data2: int) -> None:
|
||||
calls.append((data1, data2))
|
||||
|
||||
async_dispatcher_connect(hass, signal.format("unique-id"), test_funct)
|
||||
async_dispatcher_send(hass, signal.format("unique-id"), "Hello", 2)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == [("Hello", 2)]
|
||||
|
||||
# Test compatibility with string keys
|
||||
async_dispatcher_send(hass, "test-{}".format("unique-id"), "x", 4)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert calls == [("Hello", 2), ("x", 4)]
|
||||
|
||||
|
||||
async def test_simple_function_unsub(hass: HomeAssistant) -> None:
|
||||
"""Test simple function (executor) and unsub."""
|
||||
calls1 = []
|
||||
|
|
Loading…
Reference in New Issue