"""Provides device automations for MQTT.""" from __future__ import annotations from collections.abc import Callable import logging from typing import cast import attr import voluptuous as vol from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_DEVICE, CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE, CONF_VALUE_TEMPLATE, ) from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType from . import debug_info, trigger as mqtt_trigger from .config import MQTT_BASE_SCHEMA from .const import ( ATTR_DISCOVERY_HASH, CONF_ENCODING, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN, ) from .discovery import MQTT_DISCOVERY_DONE from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, MqttDiscoveryDeviceUpdate, send_discovery_done, update_device, ) from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) CONF_AUTOMATION_TYPE = "automation_type" CONF_DISCOVERY_ID = "discovery_id" CONF_SUBTYPE = "subtype" DEFAULT_ENCODING = "utf-8" DEVICE = "device" MQTT_TRIGGER_BASE = { # Trigger when MQTT message is received CONF_PLATFORM: DEVICE, CONF_DOMAIN: DOMAIN, } TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): DEVICE, vol.Required(CONF_DOMAIN): DOMAIN, vol.Required(CONF_DEVICE_ID): str, vol.Required(CONF_DISCOVERY_ID): str, vol.Required(CONF_TYPE): cv.string, vol.Required(CONF_SUBTYPE): cv.string, } ) TRIGGER_DISCOVERY_SCHEMA = MQTT_BASE_SCHEMA.extend( { vol.Required(CONF_AUTOMATION_TYPE): str, vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, vol.Optional(CONF_PAYLOAD, default=None): vol.Any(None, cv.string), vol.Required(CONF_SUBTYPE): cv.string, vol.Required(CONF_TOPIC): cv.string, vol.Required(CONF_TYPE): cv.string, vol.Optional(CONF_VALUE_TEMPLATE, default=None): vol.Any(None, cv.string), }, extra=vol.REMOVE_EXTRA, ) LOG_NAME = "Device trigger" @attr.s(slots=True) class TriggerInstance: """Attached trigger settings.""" action: TriggerActionType = attr.ib() trigger_info: TriggerInfo = attr.ib() trigger: Trigger = attr.ib() remove: CALLBACK_TYPE | None = attr.ib(default=None) async def async_attach_trigger(self) -> None: """Attach MQTT trigger.""" mqtt_config = { CONF_PLATFORM: DOMAIN, CONF_TOPIC: self.trigger.topic, CONF_ENCODING: DEFAULT_ENCODING, CONF_QOS: self.trigger.qos, } if self.trigger.payload: mqtt_config[CONF_PAYLOAD] = self.trigger.payload if self.trigger.value_template: mqtt_config[CONF_VALUE_TEMPLATE] = self.trigger.value_template mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config) if self.remove: self.remove() self.remove = await mqtt_trigger.async_attach_trigger( self.trigger.hass, mqtt_config, self.action, self.trigger_info, ) @attr.s(slots=True) class Trigger: """Device trigger settings.""" device_id: str = attr.ib() discovery_data: dict | None = attr.ib() hass: HomeAssistant = attr.ib() payload: str | None = attr.ib() qos: int | None = attr.ib() subtype: str = attr.ib() topic: str | None = attr.ib() type: str = attr.ib() value_template: str | None = attr.ib() trigger_instances: list[TriggerInstance] = attr.ib(factory=list) async def add_trigger( self, action: TriggerActionType, trigger_info: TriggerInfo ) -> Callable: """Add MQTT trigger.""" instance = TriggerInstance(action, trigger_info, self) self.trigger_instances.append(instance) if self.topic is not None: # If we know about the trigger, subscribe to MQTT topic await instance.async_attach_trigger() @callback def async_remove() -> None: """Remove trigger.""" if instance not in self.trigger_instances: raise HomeAssistantError("Can't remove trigger twice") if instance.remove: instance.remove() self.trigger_instances.remove(instance) return async_remove async def update_trigger(self, config: ConfigType) -> None: """Update MQTT device trigger.""" self.type = config[CONF_TYPE] self.subtype = config[CONF_SUBTYPE] self.payload = config[CONF_PAYLOAD] self.qos = config[CONF_QOS] topic_changed = self.topic != config[CONF_TOPIC] self.topic = config[CONF_TOPIC] self.value_template = config[CONF_VALUE_TEMPLATE] # Unsubscribe+subscribe if this trigger is in use and topic has changed # If topic is same unsubscribe+subscribe will execute in the wrong order # because unsubscribe is done with help of async_create_task if topic_changed: for trig in self.trigger_instances: await trig.async_attach_trigger() def detach_trigger(self) -> None: """Remove MQTT device trigger.""" # Mark trigger as unknown self.topic = None # Unsubscribe if this trigger is in use for trig in self.trigger_instances: if trig.remove: trig.remove() trig.remove = None class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): """Setup a MQTT device trigger with auto discovery.""" def __init__( self, hass: HomeAssistant, config: ConfigType, device_id: str, discovery_data: dict, config_entry: ConfigEntry, ) -> None: """Initialize.""" self._config = config self._config_entry = config_entry self.device_id = device_id self.discovery_data = discovery_data self.hass = hass self._mqtt_data = get_mqtt_data(hass) MqttDiscoveryDeviceUpdate.__init__( self, hass, discovery_data, device_id, config_entry, LOG_NAME, ) async def async_setup(self) -> None: """Initialize the device trigger.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] if discovery_id not in self._mqtt_data.device_triggers: self._mqtt_data.device_triggers[discovery_id] = Trigger( hass=self.hass, device_id=self.device_id, discovery_data=self.discovery_data, type=self._config[CONF_TYPE], subtype=self._config[CONF_SUBTYPE], topic=self._config[CONF_TOPIC], payload=self._config[CONF_PAYLOAD], qos=self._config[CONF_QOS], value_template=self._config[CONF_VALUE_TEMPLATE], ) else: await self._mqtt_data.device_triggers[discovery_id].update_trigger( self._config ) debug_info.add_trigger_discovery_data( self.hass, discovery_hash, self.discovery_data, self.device_id ) async def async_update(self, discovery_data: dict) -> None: """Handle MQTT device trigger discovery updates.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] debug_info.update_trigger_discovery_data( self.hass, discovery_hash, discovery_data ) config = TRIGGER_DISCOVERY_SCHEMA(discovery_data) update_device(self.hass, self._config_entry, config) device_trigger: Trigger = self._mqtt_data.device_triggers[discovery_id] await device_trigger.update_trigger(config) async def async_tear_down(self) -> None: """Cleanup device trigger.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] if discovery_id in self._mqtt_data.device_triggers: _LOGGER.info("Removing trigger: %s", discovery_hash) trigger: Trigger = self._mqtt_data.device_triggers[discovery_id] trigger.detach_trigger() debug_info.remove_trigger_discovery_data(self.hass, discovery_hash) async def async_setup_trigger( hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict ) -> None: """Set up the MQTT device trigger.""" config = TRIGGER_DISCOVERY_SCHEMA(config) discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] if (device_id := update_device(hass, config_entry, config)) is None: async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) return mqtt_device_trigger = MqttDeviceTrigger( hass, config, device_id, discovery_data, config_entry ) await mqtt_device_trigger.async_setup() send_discovery_done(hass, discovery_data) async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: """Handle Mqtt removed from a device.""" mqtt_data = get_mqtt_data(hass) triggers = await async_get_triggers(hass, device_id) for trig in triggers: device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID]) if device_trigger: device_trigger.detach_trigger() discovery_data = cast(dict, device_trigger.discovery_data) discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] debug_info.remove_trigger_discovery_data(hass, discovery_hash) async def async_get_triggers( hass: HomeAssistant, device_id: str ) -> list[dict[str, str]]: """List device triggers for MQTT devices.""" mqtt_data = get_mqtt_data(hass) triggers: list[dict[str, str]] = [] if not mqtt_data.device_triggers: return triggers for discovery_id, trig in mqtt_data.device_triggers.items(): if trig.device_id != device_id or trig.topic is None: continue trigger = { **MQTT_TRIGGER_BASE, "device_id": device_id, "type": trig.type, "subtype": trig.subtype, "discovery_id": discovery_id, } triggers.append(trigger) return triggers async def async_attach_trigger( hass: HomeAssistant, config: ConfigType, action: TriggerActionType, trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" mqtt_data = get_mqtt_data(hass) device_id = config[CONF_DEVICE_ID] discovery_id = config[CONF_DISCOVERY_ID] if discovery_id not in mqtt_data.device_triggers: mqtt_data.device_triggers[discovery_id] = Trigger( hass=hass, device_id=device_id, discovery_data=None, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=None, payload=None, qos=None, value_template=None, ) return await mqtt_data.device_triggers[discovery_id].add_trigger( action, trigger_info )