"""This platform allows several switches to be grouped into one switch.""" from __future__ import annotations import logging from typing import Any import voluptuous as vol from homeassistant.components.switch import DOMAIN, PLATFORM_SCHEMA, SwitchEntity from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ENTITIES, CONF_NAME, CONF_UNIQUE_ID, SERVICE_TURN_OFF, SERVICE_TURN_ON, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN, ) from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import GroupEntity DEFAULT_NAME = "Switch Group" CONF_ALL = "all" # No limit on parallel updates to enable a group calling another group PARALLEL_UPDATES = 0 PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Required(CONF_ENTITIES): cv.entities_domain(DOMAIN), vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_ALL, default=False): cv.boolean, } ) _LOGGER = logging.getLogger(__name__) async def async_setup_platform( hass: HomeAssistant, config: ConfigType, async_add_entities: AddEntitiesCallback, discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up the Switch Group platform.""" async_add_entities( [ SwitchGroup( config.get(CONF_UNIQUE_ID), config[CONF_NAME], config[CONF_ENTITIES], config.get(CONF_ALL, False), ) ] ) async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Initialize Switch Group config entry.""" registry = er.async_get(hass) entities = er.async_validate_entity_ids( registry, config_entry.options[CONF_ENTITIES] ) async_add_entities( [ SwitchGroup( config_entry.entry_id, config_entry.title, entities, config_entry.options.get(CONF_ALL), ) ] ) class SwitchGroup(GroupEntity, SwitchEntity): """Representation of a switch group.""" _attr_available = False _attr_should_poll = False def __init__( self, unique_id: str | None, name: str, entity_ids: list[str], mode: bool | None, ) -> None: """Initialize a switch group.""" self._entity_ids = entity_ids self._attr_name = name self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entity_ids} self._attr_unique_id = unique_id self.mode = any if mode: self.mode = all async def async_added_to_hass(self) -> None: """Register callbacks.""" @callback def async_state_changed_listener(event: Event) -> None: """Handle child updates.""" self.async_set_context(event.context) self.async_defer_or_update_ha_state() self.async_on_remove( async_track_state_change_event( self.hass, self._entity_ids, async_state_changed_listener ) ) await super().async_added_to_hass() async def async_turn_on(self, **kwargs: Any) -> None: """Forward the turn_on command to all switches in the group.""" data = {ATTR_ENTITY_ID: self._entity_ids} _LOGGER.debug("Forwarded turn_on command: %s", data) await self.hass.services.async_call( DOMAIN, SERVICE_TURN_ON, data, blocking=True, context=self._context, ) async def async_turn_off(self, **kwargs: Any) -> None: """Forward the turn_off command to all switches in the group.""" data = {ATTR_ENTITY_ID: self._entity_ids} await self.hass.services.async_call( DOMAIN, SERVICE_TURN_OFF, data, blocking=True, context=self._context, ) @callback def async_update_group_state(self) -> None: """Query all members and determine the switch group state.""" states = [ state.state for entity_id in self._entity_ids if (state := self.hass.states.get(entity_id)) is not None ] valid_state = self.mode( state not in (STATE_UNKNOWN, STATE_UNAVAILABLE) for state in states ) if not valid_state: # Set as unknown if any / all member is unknown or unavailable self._attr_is_on = None else: # Set as ON if any / all member is ON self._attr_is_on = self.mode(state == STATE_ON for state in states) # Set group as unavailable if all members are unavailable or missing self._attr_available = any(state != STATE_UNAVAILABLE for state in states)