"""Models used by multiple MQTT modules.""" from __future__ import annotations from ast import literal_eval import asyncio from collections import deque from collections.abc import Callable from dataclasses import dataclass, field from enum import StrEnum import logging from typing import TYPE_CHECKING, Any, TypedDict from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME, Platform from homeassistant.core import CALLBACK_TYPE, callback from homeassistant.exceptions import ServiceValidationError, TemplateError from homeassistant.helpers import template from homeassistant.helpers.entity import Entity from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.typing import ( ConfigType, DiscoveryInfoType, TemplateVarsType, VolSchemaType, ) from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: from paho.mqtt.client import MQTTMessage from .client import MQTT, Subscription from .debug_info import TimestampedPublishMessage from .device_trigger import Trigger from .discovery import MQTTDiscoveryPayload from .tag import MQTTTagScanner from .const import DOMAIN, TEMPLATE_ERRORS class PayloadSentinel(StrEnum): """Sentinel for `async_render_with_possible_json_value`.""" NONE = "none" DEFAULT = "default" _LOGGER = logging.getLogger(__name__) ATTR_THIS = "this" type PublishPayloadType = str | bytes | int | float | None def convert_outgoing_mqtt_payload( payload: PublishPayloadType, ) -> PublishPayloadType: """Ensure correct raw MQTT payload is passed as bytes for publishing.""" if isinstance(payload, str) and payload.startswith(("b'", 'b"')): try: native_object = literal_eval(payload) except (ValueError, TypeError, SyntaxError, MemoryError): pass else: if isinstance(native_object, bytes): return native_object return payload @dataclass class PublishMessage: """MQTT Message for publishing.""" topic: str payload: PublishPayloadType qos: int retain: bool # eq=False so we use the id() of the object for comparison # since client will only generate one instance of this object # per messages/subscribed_topic. @dataclass(slots=True, frozen=True, eq=False) class ReceiveMessage: """MQTT Message received.""" topic: str payload: ReceivePayloadType qos: int retain: bool subscribed_topic: str timestamp: float type MessageCallbackType = Callable[[ReceiveMessage], None] class SubscriptionDebugInfo(TypedDict): """Class for holding subscription debug info.""" messages: deque[ReceiveMessage] count: int class EntityDebugInfo(TypedDict): """Class for holding entity based debug info.""" subscriptions: dict[str, SubscriptionDebugInfo] discovery_data: DiscoveryInfoType transmitted: dict[str, dict[str, deque[TimestampedPublishMessage]]] class TriggerDebugInfo(TypedDict): """Class for holding trigger based debug info.""" device_id: str discovery_data: DiscoveryInfoType class PendingDiscovered(TypedDict): """Pending discovered items.""" pending: deque[MQTTDiscoveryPayload] unsub: CALLBACK_TYPE class MqttOriginInfo(TypedDict, total=False): """Integration info of discovered entity.""" name: str manufacturer: str sw_version: str hw_version: str support_url: str class MqttCommandTemplateException(ServiceValidationError): """Handle MqttCommandTemplate exceptions.""" _message: str def __init__( self, *args: object, base_exception: Exception, command_template: str, value: PublishPayloadType, entity_id: str | None = None, ) -> None: """Initialize exception.""" super().__init__(base_exception, *args) value_log = str(value) self.translation_domain = DOMAIN self.translation_key = "command_template_error" self.translation_placeholders = { "error": str(base_exception), "entity_id": str(entity_id), "command_template": command_template, } entity_id_log = "" if entity_id is None else f" for entity '{entity_id}'" self._message = ( f"{type(base_exception).__name__}: {base_exception} rendering template{entity_id_log}" f", template: '{command_template}' and payload: {value_log}" ) def __str__(self) -> str: """Return exception message string.""" return self._message class MqttCommandTemplate: """Class for rendering MQTT payload with command templates.""" def __init__( self, command_template: template.Template | None, *, entity: Entity | None = None, ) -> None: """Instantiate a command template.""" self._template_state: template.TemplateStateFromEntityId | None = None self._command_template = command_template self._entity = entity @callback def async_render( self, value: PublishPayloadType = None, variables: TemplateVarsType = None, ) -> PublishPayloadType: """Render or convert the command template with given value or variables.""" if self._command_template is None: return value values: dict[str, Any] = {"value": value} if self._entity: values[ATTR_ENTITY_ID] = self._entity.entity_id values[ATTR_NAME] = self._entity.name if not self._template_state and self._command_template.hass is not None: self._template_state = template.TemplateStateFromEntityId( self._entity.hass, self._entity.entity_id ) values[ATTR_THIS] = self._template_state if variables is not None: values.update(variables) _LOGGER.debug( "Rendering outgoing payload with variables %s and %s", values, self._command_template, ) try: return convert_outgoing_mqtt_payload( self._command_template.async_render(values, parse_result=False) ) except TemplateError as exc: raise MqttCommandTemplateException( base_exception=exc, command_template=self._command_template.template, value=value, entity_id=self._entity.entity_id if self._entity is not None else None, ) from exc class MqttValueTemplateException(TemplateError): """Handle MqttValueTemplate exceptions.""" _message: str def __init__( self, *args: object, base_exception: Exception, value_template: str, default: ReceivePayloadType | PayloadSentinel, payload: ReceivePayloadType, entity_id: str | None = None, ) -> None: """Initialize exception.""" super().__init__(base_exception, *args) entity_id_log = "" if entity_id is None else f" for entity '{entity_id}'" default_log = str(default) default_payload_log = ( "" if default is PayloadSentinel.NONE else f", default value: {default_log}" ) payload_log = str(payload) self._message = ( f"{type(base_exception).__name__}: {base_exception} rendering template{entity_id_log}" f", template: '{value_template}'{default_payload_log} and payload: {payload_log}" ) def __str__(self) -> str: """Return exception message string.""" return self._message class MqttValueTemplate: """Class for rendering MQTT value template with possible json values.""" def __init__( self, value_template: template.Template | None, *, entity: Entity | None = None, config_attributes: TemplateVarsType = None, ) -> None: """Instantiate a value template.""" self._template_state: template.TemplateStateFromEntityId | None = None self._value_template = value_template self._config_attributes = config_attributes self._entity = entity @callback def async_render_with_possible_json_value( self, payload: ReceivePayloadType, default: ReceivePayloadType | PayloadSentinel = PayloadSentinel.NONE, variables: TemplateVarsType = None, ) -> ReceivePayloadType: """Render with possible json value or pass-though a received MQTT value.""" rendered_payload: ReceivePayloadType if self._value_template is None: return payload values: dict[str, Any] = {} if variables is not None: values.update(variables) if self._config_attributes is not None: values.update(self._config_attributes) if self._entity: values[ATTR_ENTITY_ID] = self._entity.entity_id values[ATTR_NAME] = self._entity.name if not self._template_state and self._value_template.hass: self._template_state = template.TemplateStateFromEntityId( self._value_template.hass, self._entity.entity_id ) values[ATTR_THIS] = self._template_state if default is PayloadSentinel.NONE: _LOGGER.debug( "Rendering incoming payload '%s' with variables %s and %s", payload, values, self._value_template, ) try: rendered_payload = ( self._value_template.async_render_with_possible_json_value( payload, variables=values ) ) except TEMPLATE_ERRORS as exc: raise MqttValueTemplateException( base_exception=exc, value_template=self._value_template.template, default=default, payload=payload, entity_id=self._entity.entity_id if self._entity else None, ) from exc return rendered_payload _LOGGER.debug( ( "Rendering incoming payload '%s' with variables %s with default value" " '%s' and %s" ), payload, values, default, self._value_template, ) try: rendered_payload = ( self._value_template.async_render_with_possible_json_value( payload, default, variables=values ) ) except TEMPLATE_ERRORS as exc: raise MqttValueTemplateException( base_exception=exc, value_template=self._value_template.template, default=default, payload=payload, entity_id=self._entity.entity_id if self._entity else None, ) from exc return rendered_payload class EntityTopicState: """Manage entity state write requests for subscribed topics.""" def __init__(self) -> None: """Register topic.""" self.subscribe_calls: dict[str, Entity] = {} @callback def process_write_state_requests(self, msg: MQTTMessage) -> None: """Process the write state requests.""" while self.subscribe_calls: entity_id, entity = self.subscribe_calls.popitem() try: entity.async_write_ha_state() except Exception: _LOGGER.exception( "Exception raised while updating state of %s, topic: " "'%s' with payload: %s", entity_id, msg.topic, msg.payload, ) @callback def write_state_request(self, entity: Entity) -> None: """Register write state request.""" self.subscribe_calls[entity.entity_id] = entity @dataclass class MqttData: """Keep the MQTT entry data.""" client: MQTT config: list[ConfigType] debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict) debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field( default_factory=dict ) device_triggers: dict[str, Trigger] = field(default_factory=dict) data_config_flow_lock: asyncio.Lock = field(default_factory=asyncio.Lock) discovery_already_discovered: set[tuple[str, str]] = field(default_factory=set) discovery_pending_discovered: dict[tuple[str, str], PendingDiscovered] = field( default_factory=dict ) discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( default_factory=dict ) discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list) integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict) last_discovery: float = 0.0 platforms_loaded: set[Platform | str] = field(default_factory=set) reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) reload_handlers: dict[str, CALLBACK_TYPE] = field(default_factory=dict) reload_schema: dict[str, VolSchemaType] = field(default_factory=dict) state_write_requests: EntityTopicState = field(default_factory=EntityTopicState) subscriptions_to_restore: set[Subscription] = field(default_factory=set) tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict) DATA_MQTT: HassKey[MqttData] = HassKey("mqtt") DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")