291 lines
9.2 KiB
Python
291 lines
9.2 KiB
Python
"""Models used by multiple MQTT modules."""
|
|
from __future__ import annotations
|
|
|
|
from ast import literal_eval
|
|
import asyncio
|
|
from collections import deque
|
|
from collections.abc import Callable, Coroutine
|
|
from dataclasses import dataclass, field
|
|
import datetime as dt
|
|
import logging
|
|
from typing import TYPE_CHECKING, Any, TypedDict, Union
|
|
|
|
import attr
|
|
|
|
from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME
|
|
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
|
from homeassistant.helpers import template
|
|
from homeassistant.helpers.entity import Entity
|
|
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
|
|
|
|
if TYPE_CHECKING:
|
|
from .client import MQTT, Subscription
|
|
from .debug_info import TimestampedPublishMessage
|
|
from .device_trigger import Trigger
|
|
from .discovery import MQTTDiscoveryPayload
|
|
from .tag import MQTTTagScanner
|
|
|
|
_SENTINEL = object()
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
ATTR_THIS = "this"
|
|
|
|
PublishPayloadType = Union[str, bytes, int, float, None]
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class PublishMessage:
|
|
"""MQTT Message."""
|
|
|
|
topic: str = attr.ib()
|
|
payload: PublishPayloadType = attr.ib()
|
|
qos: int = attr.ib()
|
|
retain: bool = attr.ib()
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class ReceiveMessage:
|
|
"""MQTT Message."""
|
|
|
|
topic: str = attr.ib()
|
|
payload: ReceivePayloadType = attr.ib()
|
|
qos: int = attr.ib()
|
|
retain: bool = attr.ib()
|
|
subscribed_topic: str = attr.ib(default=None)
|
|
timestamp: dt.datetime = attr.ib(default=None)
|
|
|
|
|
|
AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
|
|
MessageCallbackType = Callable[[ReceiveMessage], None]
|
|
|
|
|
|
class SubscriptionDebugInfo(TypedDict):
|
|
"""Class for holding subscription debug info."""
|
|
|
|
messages: deque[ReceiveMessage]
|
|
count: int
|
|
|
|
|
|
class EntityDebugInfo(TypedDict):
|
|
"""Class for holding entity based debug info."""
|
|
|
|
subscriptions: dict[str, SubscriptionDebugInfo]
|
|
discovery_data: DiscoveryInfoType
|
|
transmitted: dict[str, dict[str, deque[TimestampedPublishMessage]]]
|
|
|
|
|
|
class TriggerDebugInfo(TypedDict):
|
|
"""Class for holding trigger based debug info."""
|
|
|
|
device_id: str
|
|
discovery_data: DiscoveryInfoType
|
|
|
|
|
|
class PendingDiscovered(TypedDict):
|
|
"""Pending discovered items."""
|
|
|
|
pending: deque[MQTTDiscoveryPayload]
|
|
unsub: CALLBACK_TYPE
|
|
|
|
|
|
class MqttCommandTemplate:
|
|
"""Class for rendering MQTT payload with command templates."""
|
|
|
|
def __init__(
|
|
self,
|
|
command_template: template.Template | None,
|
|
*,
|
|
hass: HomeAssistant | None = None,
|
|
entity: Entity | None = None,
|
|
) -> None:
|
|
"""Instantiate a command template."""
|
|
self._template_state: template.TemplateStateFromEntityId | None = None
|
|
self._command_template = command_template
|
|
if command_template is None:
|
|
return
|
|
|
|
self._entity = entity
|
|
|
|
command_template.hass = hass
|
|
|
|
if entity:
|
|
command_template.hass = entity.hass
|
|
|
|
@callback
|
|
def async_render(
|
|
self,
|
|
value: PublishPayloadType = None,
|
|
variables: TemplateVarsType = None,
|
|
) -> PublishPayloadType:
|
|
"""Render or convert the command template with given value or variables."""
|
|
|
|
def _convert_outgoing_payload(
|
|
payload: PublishPayloadType,
|
|
) -> PublishPayloadType:
|
|
"""Ensure correct raw MQTT payload is passed as bytes for publishing."""
|
|
if isinstance(payload, str):
|
|
try:
|
|
native_object = literal_eval(payload)
|
|
if isinstance(native_object, bytes):
|
|
return native_object
|
|
|
|
except (ValueError, TypeError, SyntaxError, MemoryError):
|
|
pass
|
|
|
|
return payload
|
|
|
|
if self._command_template is None:
|
|
return value
|
|
|
|
values: dict[str, Any] = {"value": value}
|
|
if self._entity:
|
|
values[ATTR_ENTITY_ID] = self._entity.entity_id
|
|
values[ATTR_NAME] = self._entity.name
|
|
if not self._template_state:
|
|
self._template_state = template.TemplateStateFromEntityId(
|
|
self._command_template.hass, self._entity.entity_id
|
|
)
|
|
values[ATTR_THIS] = self._template_state
|
|
|
|
if variables is not None:
|
|
values.update(variables)
|
|
_LOGGER.debug(
|
|
"Rendering outgoing payload with variables %s and %s",
|
|
values,
|
|
self._command_template,
|
|
)
|
|
return _convert_outgoing_payload(
|
|
self._command_template.async_render(values, parse_result=False)
|
|
)
|
|
|
|
|
|
class MqttValueTemplate:
|
|
"""Class for rendering MQTT value template with possible json values."""
|
|
|
|
def __init__(
|
|
self,
|
|
value_template: template.Template | None,
|
|
*,
|
|
hass: HomeAssistant | None = None,
|
|
entity: Entity | None = None,
|
|
config_attributes: TemplateVarsType = None,
|
|
) -> None:
|
|
"""Instantiate a value template."""
|
|
self._template_state: template.TemplateStateFromEntityId | None = None
|
|
self._value_template = value_template
|
|
self._config_attributes = config_attributes
|
|
if value_template is None:
|
|
return
|
|
|
|
value_template.hass = hass
|
|
self._entity = entity
|
|
|
|
if entity:
|
|
value_template.hass = entity.hass
|
|
|
|
@callback
|
|
def async_render_with_possible_json_value(
|
|
self,
|
|
payload: ReceivePayloadType,
|
|
default: ReceivePayloadType | object = _SENTINEL,
|
|
variables: TemplateVarsType = None,
|
|
) -> ReceivePayloadType:
|
|
"""Render with possible json value or pass-though a received MQTT value."""
|
|
if self._value_template is None:
|
|
return payload
|
|
|
|
values: dict[str, Any] = {}
|
|
|
|
if variables is not None:
|
|
values.update(variables)
|
|
|
|
if self._config_attributes is not None:
|
|
values.update(self._config_attributes)
|
|
|
|
if self._entity:
|
|
values[ATTR_ENTITY_ID] = self._entity.entity_id
|
|
values[ATTR_NAME] = self._entity.name
|
|
if not self._template_state and self._value_template.hass:
|
|
self._template_state = template.TemplateStateFromEntityId(
|
|
self._value_template.hass, self._entity.entity_id
|
|
)
|
|
values[ATTR_THIS] = self._template_state
|
|
|
|
if default == _SENTINEL:
|
|
_LOGGER.debug(
|
|
"Rendering incoming payload '%s' with variables %s and %s",
|
|
payload,
|
|
values,
|
|
self._value_template,
|
|
)
|
|
return self._value_template.async_render_with_possible_json_value(
|
|
payload, variables=values
|
|
)
|
|
|
|
_LOGGER.debug(
|
|
"Rendering incoming payload '%s' with variables %s with default value '%s' and %s",
|
|
payload,
|
|
values,
|
|
default,
|
|
self._value_template,
|
|
)
|
|
return self._value_template.async_render_with_possible_json_value(
|
|
payload, default, variables=values
|
|
)
|
|
|
|
|
|
class EntityTopicState:
|
|
"""Manage entity state write requests for subscribed topics."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Register topic."""
|
|
self.subscribe_calls: dict[str, Entity] = {}
|
|
|
|
@callback
|
|
def process_write_state_requests(self) -> None:
|
|
"""Process the write state requests."""
|
|
while self.subscribe_calls:
|
|
_, entity = self.subscribe_calls.popitem()
|
|
entity.async_write_ha_state()
|
|
|
|
@callback
|
|
def write_state_request(self, entity: Entity) -> None:
|
|
"""Register write state request."""
|
|
self.subscribe_calls[entity.entity_id] = entity
|
|
|
|
|
|
@dataclass
|
|
class MqttData:
|
|
"""Keep the MQTT entry data."""
|
|
|
|
client: MQTT | None = None
|
|
config: ConfigType | None = None
|
|
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
|
|
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
|
|
default_factory=dict
|
|
)
|
|
device_triggers: dict[str, Trigger] = field(default_factory=dict)
|
|
data_config_flow_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
discovery_already_discovered: set[tuple[str, str]] = field(default_factory=set)
|
|
discovery_pending_discovered: dict[tuple[str, str], PendingDiscovered] = field(
|
|
default_factory=dict
|
|
)
|
|
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
|
|
default_factory=dict
|
|
)
|
|
discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list)
|
|
integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
|
|
last_discovery: float = 0.0
|
|
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
|
|
reload_entry: bool = False
|
|
reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field(
|
|
default_factory=dict
|
|
)
|
|
reload_needed: bool = False
|
|
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
|
|
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
|
|
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
|
|
updated_config: ConfigType = field(default_factory=dict)
|