From 2a3c94bad04386c5b59a45cab5232fbc647897e2 Mon Sep 17 00:00:00 2001 From: "David F. Mulcahey" Date: Wed, 25 Mar 2020 07:23:54 -0400 Subject: [PATCH] Add group entity support to ZHA (#33196) * split entity into base and entity * add initial light group support * add dispatching of groups to light * added zha group object * add group event listener * add and remove group members * get group by name * fix rebase * fix rebase * use group_id for unique_id * get entities from registry * use group name * update entity domain * update zha storage to handle groups * dispatch group entities * update light group * fix group remove and dispatch light group entities * allow picking the domain for group entities * beginning - auto determine entity domain * move methods to helpers so they can be shared * fix rebase * remove double init groups... again * cleanup startup * use asyncio create task * group entity discovery * add logging and fix group name * add logging and update group after probe if needed * test add group via gateway * add method to get group entity ids * update storage * test get group by name * update storage on remove * test group with single member * add light group tests * test some light group logic * type hints * fix tests and cleanup * revert init changes except for create task * remove group entity domain changing for now * add missing import * tricky code saving * review comments * clean up class defs * cleanup * fix rebase because I cant read * make pylint happy --- homeassistant/components/zha/__init__.py | 8 +- homeassistant/components/zha/core/const.py | 2 + homeassistant/components/zha/core/device.py | 2 +- .../components/zha/core/discovery.py | 100 ++++++ homeassistant/components/zha/core/gateway.py | 123 ++++--- homeassistant/components/zha/core/group.py | 80 ++++- homeassistant/components/zha/core/helpers.py | 42 ++- .../components/zha/core/registries.py | 20 ++ homeassistant/components/zha/core/store.py | 117 +++++- homeassistant/components/zha/core/typing.py | 4 + homeassistant/components/zha/entity.py | 153 ++++---- homeassistant/components/zha/light.py | 339 +++++++++++++----- tests/components/zha/common.py | 39 +- tests/components/zha/conftest.py | 3 +- tests/components/zha/test_gateway.py | 140 +++++++- tests/components/zha/test_light.py | 169 ++++++++- 16 files changed, 1084 insertions(+), 257 deletions(-) diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 5fef586d5cf..2af35e8fb92 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -130,7 +130,7 @@ async def async_setup_entry(hass, config_entry): await zha_data[DATA_ZHA_GATEWAY].async_update_device_storage() hass.bus.async_listen_once(ha_const.EVENT_HOMEASSISTANT_STOP, async_zha_shutdown) - asyncio.create_task(async_load_entities(hass, config_entry)) + asyncio.create_task(async_load_entities(hass)) return True @@ -150,11 +150,9 @@ async def async_unload_entry(hass, config_entry): return True -async def async_load_entities( - hass: HomeAssistantType, config_entry: config_entries.ConfigEntry -) -> None: +async def async_load_entities(hass: HomeAssistantType) -> None: """Load entities after integration was setup.""" - await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_prepare_entities() + await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_initialize_devices_and_entities() to_setup = hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED] results = await asyncio.gather(*to_setup, return_exceptions=True) for res in results: diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index c2813c464e5..2eb567ab4c4 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -23,6 +23,7 @@ ATTR_COMMAND_TYPE = "command_type" ATTR_DEVICE_IEEE = "device_ieee" ATTR_DEVICE_TYPE = "device_type" ATTR_ENDPOINT_ID = "endpoint_id" +ATTR_ENTITY_DOMAIN = "entity_domain" ATTR_IEEE = "ieee" ATTR_LAST_SEEN = "last_seen" ATTR_LEVEL = "level" @@ -207,6 +208,7 @@ SIGNAL_REMOVE = "remove" SIGNAL_SET_LEVEL = "set_level" SIGNAL_STATE_ATTR = "update_state_attribute" SIGNAL_UPDATE_DEVICE = "{}_zha_update_device" +SIGNAL_REMOVE_GROUP = "remove_group" UNKNOWN = "unknown" UNKNOWN_MANUFACTURER = "unk_manufacturer" diff --git a/homeassistant/components/zha/core/device.py b/homeassistant/components/zha/core/device.py index 47b564f1767..0a7278cb5d5 100644 --- a/homeassistant/components/zha/core/device.py +++ b/homeassistant/components/zha/core/device.py @@ -373,7 +373,7 @@ class ZHADevice(LogMixin): self.debug("started configuration") await self._channels.async_configure() self.debug("completed configuration") - entry = self.gateway.zha_storage.async_create_or_update(self) + entry = self.gateway.zha_storage.async_create_or_update_device(self) self.debug("stored in registry: %s", entry) if self._channels.identify_ch is not None: diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 5f8f6b593f8..7202fd869fa 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -1,10 +1,12 @@ """Device discovery functions for Zigbee Home Automation.""" +from collections import Counter import logging from typing import Callable, List, Tuple from homeassistant import const as ha_const from homeassistant.core import callback +from homeassistant.helpers.entity_registry import async_entries_for_device from homeassistant.helpers.typing import HomeAssistantType from . import const as zha_const, registries as zha_regs, typing as zha_typing @@ -157,4 +159,102 @@ class ProbeEndpoint: self._device_configs.update(overrides) +class GroupProbe: + """Determine the appropriate component for a group.""" + + def __init__(self): + """Initialize instance.""" + self._hass = None + + def initialize(self, hass: HomeAssistantType) -> None: + """Initialize the group probe.""" + self._hass = hass + + @callback + def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None: + """Process a group and create any entities that are needed.""" + # only create a group entity if there are 2 or more members in a group + if len(group.members) < 2: + _LOGGER.debug( + "Group: %s:0x%04x has less than 2 members - skipping entity discovery", + group.name, + group.group_id, + ) + return + + if group.entity_domain is None: + _LOGGER.debug( + "Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery", + group.name, + group.group_id, + ) + group.entity_domain = GroupProbe.determine_default_entity_domain( + self._hass, group + ) + + if group.entity_domain is None: + return + + _LOGGER.debug( + "Group: %s:0x%04x has an entity domain of: %s after discovery", + group.name, + group.group_id, + group.entity_domain, + ) + + zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain) + if entity_class is None: + return + + self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append( + ( + entity_class, + ( + group.domain_entity_ids, + f"{group.entity_domain}_group_{group.group_id}", + group.group_id, + zha_gateway.coordinator_zha_device, + ), + ) + ) + + @staticmethod + def determine_default_entity_domain( + hass: HomeAssistantType, group: zha_typing.ZhaGroupType + ): + """Determine the default entity domain for this group.""" + if len(group.members) < 2: + _LOGGER.debug( + "Group: %s:0x%04x has less than 2 members so cannot default an entity domain", + group.name, + group.group_id, + ) + return None + + zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY] + all_domain_occurrences = [] + for device in group.members: + entities = async_entries_for_device( + zha_gateway.ha_entity_registry, device.device_id + ) + all_domain_occurrences.extend( + [ + entity.domain + for entity in entities + if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS + ] + ) + counts = Counter(all_domain_occurrences) + domain = counts.most_common(1)[0][0] + _LOGGER.debug( + "The default entity domain is: %s for group: %s:0x%04x", + domain, + group.name, + group.group_id, + ) + return domain + + PROBE = ProbeEndpoint() +GROUP_PROBE = GroupProbe() diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 78b5f939cae..9d5bf609ed2 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -6,6 +6,7 @@ import itertools import logging import os import traceback +from typing import List, Optional from serial import SerialException import zigpy.device as zigpy_dev @@ -52,6 +53,7 @@ from .const import ( DOMAIN, SIGNAL_ADD_ENTITIES, SIGNAL_REMOVE, + SIGNAL_REMOVE_GROUP, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZHA_GW_MSG, @@ -75,6 +77,7 @@ from .group import ZHAGroup from .patches import apply_application_controller_patch from .registries import RADIO_TYPES from .store import async_get_registry +from .typing import ZhaDeviceType, ZhaGroupType, ZigpyEndpointType, ZigpyGroupType _LOGGER = logging.getLogger(__name__) @@ -93,6 +96,7 @@ class ZHAGateway: self._config = config self._devices = {} self._groups = {} + self.coordinator_zha_device = None self._device_registry = collections.defaultdict(list) self.zha_storage = None self.ha_device_registry = None @@ -110,6 +114,7 @@ class ZHAGateway: async def async_initialize(self): """Initialize controller and connect radio.""" discovery.PROBE.initialize(self._hass) + discovery.GROUP_PROBE.initialize(self._hass) self.zha_storage = await async_get_registry(self._hass) self.ha_device_registry = await get_dev_reg(self._hass) @@ -156,17 +161,29 @@ class ZHAGateway: self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str( self.application_controller.ieee ) - await self.async_load_devices() - self._initialize_groups() + self.async_load_devices() + self.async_load_groups() - async def async_load_devices(self) -> None: + @callback + def async_load_devices(self) -> None: """Restore ZHA devices from zigpy application state.""" zigpy_devices = self.application_controller.devices.values() for zigpy_device in zigpy_devices: - self._async_get_or_create_device(zigpy_device, restored=True) + zha_device = self._async_get_or_create_device(zigpy_device, restored=True) + if zha_device.nwk == 0x0000: + self.coordinator_zha_device = zha_device - async def async_prepare_entities(self) -> None: - """Prepare entities by initializing device channels.""" + @callback + def async_load_groups(self) -> None: + """Initialize ZHA groups.""" + for group_id in self.application_controller.groups: + group = self.application_controller.groups[group_id] + zha_group = self._async_get_or_create_group(group) + # we can do this here because the entities are in the entity registry tied to the devices + discovery.GROUP_PROBE.discover_group_entities(zha_group) + + async def async_initialize_devices_and_entities(self) -> None: + """Initialize devices and load entities.""" semaphore = asyncio.Semaphore(2) async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool): @@ -231,35 +248,44 @@ class ZHAGateway: """Handle device leaving the network.""" self.async_update_device(device, False) - def group_member_removed(self, zigpy_group, endpoint): + def group_member_removed( + self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType + ) -> None: """Handle zigpy group member removed event.""" # need to handle endpoint correctly on groups zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_member_removed - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) - def group_member_added(self, zigpy_group, endpoint): + def group_member_added( + self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType + ) -> None: """Handle zigpy group member added event.""" # need to handle endpoint correctly on groups zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_member_added - endpoint: %s", endpoint) self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) - def group_added(self, zigpy_group): + def group_added(self, zigpy_group: ZigpyGroupType) -> None: """Handle zigpy group added event.""" zha_group = self._async_get_or_create_group(zigpy_group) zha_group.info("group_added") # need to dispatch for entity creation here self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) - def group_removed(self, zigpy_group): + def group_removed(self, zigpy_group: ZigpyGroupType) -> None: """Handle zigpy group added event.""" self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) zha_group = self._groups.pop(zigpy_group.group_id, None) zha_group.info("group_removed") + async_dispatcher_send( + self._hass, f"{SIGNAL_REMOVE_GROUP}_{zigpy_group.group_id}" + ) - def _send_group_gateway_message(self, zigpy_group, gateway_message_type): - """Send the gareway event for a zigpy group event.""" + def _send_group_gateway_message( + self, zigpy_group: ZigpyGroupType, gateway_message_type: str + ) -> None: + """Send the gateway event for a zigpy group event.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is not None: async_dispatcher_send( @@ -306,12 +332,12 @@ class ZHAGateway: """Return ZHADevice for given ieee.""" return self._devices.get(ieee) - def get_group(self, group_id): + def get_group(self, group_id: str) -> Optional[ZhaGroupType]: """Return Group for given group id.""" return self.groups.get(group_id) @callback - def async_get_group_by_name(self, group_name): + def async_get_group_by_name(self, group_name: str) -> Optional[ZhaGroupType]: """Get ZHA group by name.""" for group in self.groups.values(): if group.name == group_name: @@ -390,12 +416,6 @@ class ZHAGateway: logging.getLogger(logger_name).removeHandler(self._log_relay_handler) self.debug_enabled = False - def _initialize_groups(self): - """Initialize ZHA groups.""" - for group_id in self.application_controller.groups: - group = self.application_controller.groups[group_id] - self._async_get_or_create_group(group) - @callback def _async_get_or_create_device( self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False @@ -414,17 +434,19 @@ class ZHAGateway: model=zha_device.model, ) zha_device.set_device_id(device_registry_device.id) - entry = self.zha_storage.async_get_or_create(zha_device) + entry = self.zha_storage.async_get_or_create_device(zha_device) zha_device.async_update_last_seen(entry.last_seen) return zha_device @callback - def _async_get_or_create_group(self, zigpy_group): + def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType: """Get or create a ZHA group.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is None: zha_group = ZHAGroup(self._hass, self, zigpy_group) self._groups[zigpy_group.group_id] = zha_group + group_entry = self.zha_storage.async_get_or_create_group(zha_group) + zha_group.entity_domain = group_entry.entity_domain return zha_group @callback @@ -446,7 +468,9 @@ class ZHAGateway: async def async_update_device_storage(self): """Update the devices in the store.""" for device in self.devices.values(): - self.zha_storage.async_update(device) + self.zha_storage.async_update_device(device) + for group in self.groups.values(): + self.zha_storage.async_update_group(group) await self.zha_storage.async_save() async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType): @@ -494,25 +518,6 @@ class ZHAGateway: zha_device.update_available(True) async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) - # only public for testing - async def async_device_restored(self, device: zha_typing.ZigpyDeviceType): - """Add an existing device to the ZHA zigbee network when ZHA first starts.""" - zha_device = self._async_get_or_create_device(device, restored=True) - - if zha_device.is_mains_powered: - # the device isn't a battery powered device so we should be able - # to update it now - _LOGGER.debug( - "attempting to request fresh state for device - %s:%s %s with power source %s", - zha_device.nwk, - zha_device.ieee, - zha_device.name, - zha_device.power_source, - ) - await zha_device.async_initialize(from_cache=False) - else: - await zha_device.async_initialize(from_cache=True) - async def _async_device_rejoined(self, zha_device): _LOGGER.debug( "skipping discovery for previously discovered device - %s:%s", @@ -524,7 +529,9 @@ class ZHAGateway: # will cause async_init to fire so don't explicitly call it zha_device.update_available(True) - async def async_create_zigpy_group(self, name, members): + async def async_create_zigpy_group( + self, name: str, members: List[ZhaDeviceType] + ) -> ZhaGroupType: """Create a new Zigpy Zigbee group.""" # we start with one to fill any gaps from a user removing existing groups group_id = 1 @@ -537,24 +544,40 @@ class ZHAGateway: if members is not None: tasks = [] for ieee in members: + _LOGGER.debug( + "Adding member with IEEE: %s to group: %s:0x%04x", + ieee, + name, + group_id, + ) tasks.append(self.devices[ieee].async_add_to_group(group_id)) await asyncio.gather(*tasks) - return self.groups.get(group_id) + zha_group = self.groups.get(group_id) + _LOGGER.debug( + "Probing group: %s:0x%04x for entity discovery", + zha_group.name, + zha_group.group_id, + ) + discovery.GROUP_PROBE.discover_group_entities(zha_group) + if zha_group.entity_domain is not None: + self.zha_storage.async_update_group(zha_group) + async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES) + return zha_group - async def async_remove_zigpy_group(self, group_id): + async def async_remove_zigpy_group(self, group_id: int) -> None: """Remove a Zigbee group from Zigpy.""" group = self.groups.get(group_id) + if not group: + _LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id) + return if group and group.members: tasks = [] for member in group.members: tasks.append(member.async_remove_from_group(group_id)) if tasks: await asyncio.gather(*tasks) - else: - # we have members but none are tracked by ZHA for whatever reason - self.application_controller.groups.pop(group_id) - else: - self.application_controller.groups.pop(group_id) + self.application_controller.groups.pop(group_id) + self.zha_storage.async_delete_group(group) async def shutdown(self): """Stop ZHA Controller Application.""" diff --git a/homeassistant/components/zha/core/group.py b/homeassistant/components/zha/core/group.py index ca2cc0ff1d3..e6b2dee0625 100644 --- a/homeassistant/components/zha/core/group.py +++ b/homeassistant/components/zha/core/group.py @@ -1,10 +1,16 @@ """Group for Zigbee Home Automation.""" import asyncio import logging +from typing import Any, Dict, List, Optional + +from zigpy.types.named import EUI64 from homeassistant.core import callback +from homeassistant.helpers.entity_registry import async_entries_for_device +from homeassistant.helpers.typing import HomeAssistantType from .helpers import LogMixin +from .typing import ZhaDeviceType, ZhaGatewayType, ZigpyEndpointType, ZigpyGroupType _LOGGER = logging.getLogger(__name__) @@ -12,29 +18,45 @@ _LOGGER = logging.getLogger(__name__) class ZHAGroup(LogMixin): """ZHA Zigbee group object.""" - def __init__(self, hass, zha_gateway, zigpy_group): + def __init__( + self, + hass: HomeAssistantType, + zha_gateway: ZhaGatewayType, + zigpy_group: ZigpyGroupType, + ): """Initialize the group.""" - self.hass = hass - self._zigpy_group = zigpy_group - self._zha_gateway = zha_gateway + self.hass: HomeAssistantType = hass + self._zigpy_group: ZigpyGroupType = zigpy_group + self._zha_gateway: ZhaGatewayType = zha_gateway + self._entity_domain: str = None @property - def name(self): + def name(self) -> str: """Return group name.""" return self._zigpy_group.name @property - def group_id(self): + def group_id(self) -> int: """Return group name.""" return self._zigpy_group.group_id @property - def endpoint(self): + def endpoint(self) -> ZigpyEndpointType: """Return the endpoint for this group.""" return self._zigpy_group.endpoint @property - def members(self): + def entity_domain(self) -> Optional[str]: + """Return the domain that will be used for the entity representing this group.""" + return self._entity_domain + + @entity_domain.setter + def entity_domain(self, domain: Optional[str]) -> None: + """Set the domain that will be used for the entity representing this group.""" + self._entity_domain = domain + + @property + def members(self) -> List[ZhaDeviceType]: """Return the ZHA devices that are members of this group.""" return [ self._zha_gateway.devices.get(member_ieee[0]) @@ -42,7 +64,7 @@ class ZHAGroup(LogMixin): if member_ieee[0] in self._zha_gateway.devices ] - async def async_add_members(self, member_ieee_addresses): + async def async_add_members(self, member_ieee_addresses: List[EUI64]) -> None: """Add members to this group.""" if len(member_ieee_addresses) > 1: tasks = [] @@ -56,7 +78,7 @@ class ZHAGroup(LogMixin): member_ieee_addresses[0] ].async_add_to_group(self.group_id) - async def async_remove_members(self, member_ieee_addresses): + async def async_remove_members(self, member_ieee_addresses: List[EUI64]) -> None: """Remove members from this group.""" if len(member_ieee_addresses) > 1: tasks = [] @@ -72,18 +94,50 @@ class ZHAGroup(LogMixin): member_ieee_addresses[0] ].async_remove_from_group(self.group_id) + @property + def member_entity_ids(self) -> List[str]: + """Return the ZHA entity ids for all entities for the members of this group.""" + all_entity_ids: List[str] = [] + for device in self.members: + entities = async_entries_for_device( + self._zha_gateway.ha_entity_registry, device.device_id + ) + for entity in entities: + all_entity_ids.append(entity.entity_id) + return all_entity_ids + + @property + def domain_entity_ids(self) -> List[str]: + """Return entity ids from the entity domain for this group.""" + if self.entity_domain is None: + return + domain_entity_ids: List[str] = [] + for device in self.members: + entities = async_entries_for_device( + self._zha_gateway.ha_entity_registry, device.device_id + ) + domain_entity_ids.extend( + [ + entity.entity_id + for entity in entities + if entity.domain == self.entity_domain + ] + ) + return domain_entity_ids + @callback - def async_get_info(self): + def async_get_info(self) -> Dict[str, Any]: """Get ZHA group info.""" - group_info = {} + group_info: Dict[str, Any] = {} group_info["group_id"] = self.group_id + group_info["entity_domain"] = self.entity_domain group_info["name"] = self.name group_info["members"] = [ zha_device.async_get_info() for zha_device in self.members ] return group_info - def log(self, level, msg, *args): + def log(self, level: int, msg: str, *args): """Log a message.""" msg = f"[%s](%s): {msg}" args = (self.name, self.group_id) + args diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py index ab4c7ae540c..4441ac90717 100644 --- a/homeassistant/components/zha/core/helpers.py +++ b/homeassistant/components/zha/core/helpers.py @@ -1,10 +1,11 @@ """Helpers for Zigbee Home Automation.""" import collections import logging +from typing import Any, Callable, Iterator, List, Optional import zigpy.types -from homeassistant.core import callback +from homeassistant.core import State, callback from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, DATA_ZHA, DATA_ZHA_GATEWAY from .registries import BINDABLE_CLUSTERS @@ -85,6 +86,45 @@ async def async_get_zha_device(hass, device_id): return zha_gateway.devices[ieee] +def find_state_attributes(states: List[State], key: str) -> Iterator[Any]: + """Find attributes with matching key from states.""" + for state in states: + value = state.attributes.get(key) + if value is not None: + yield value + + +def mean_int(*args): + """Return the mean of the supplied values.""" + return int(sum(args) / len(args)) + + +def mean_tuple(*args): + """Return the mean values along the columns of the supplied values.""" + return tuple(sum(l) / len(l) for l in zip(*args)) + + +def reduce_attribute( + states: List[State], + key: str, + default: Optional[Any] = None, + reduce: Callable[..., Any] = mean_int, +) -> Any: + """Find the first attribute matching key from states. + + If none are found, return default. + """ + attrs = list(find_state_attributes(states, key)) + + if not attrs: + return default + + if len(attrs) == 1: + return attrs[0] + + return reduce(*attrs) + + class LogMixin: """Log helper.""" diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 0a1a81df5ff..34ae32c01c8 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -32,6 +32,8 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType from .decorators import CALLABLE_T, DictRegistry, SetRegistry from .typing import ChannelType +GROUP_ENTITY_DOMAINS = [LIGHT] + SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02 SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000 SMARTTHINGS_HUMIDITY_CLUSTER = 0xFC45 @@ -275,6 +277,9 @@ RegistryDictType = Dict[ ] # pylint: disable=invalid-name +GroupRegistryDictType = Dict[str, CALLABLE_T] # pylint: disable=invalid-name + + class ZHAEntityRegistry: """Channel to ZHA Entity mapping.""" @@ -282,6 +287,7 @@ class ZHAEntityRegistry: """Initialize Registry instance.""" self._strict_registry: RegistryDictType = collections.defaultdict(dict) self._loose_registry: RegistryDictType = collections.defaultdict(dict) + self._group_registry: GroupRegistryDictType = {} def get_entity( self, @@ -300,6 +306,10 @@ class ZHAEntityRegistry: return default, [] + def get_group_entity(self, component: str) -> CALLABLE_T: + """Match a ZHA group to a ZHA Entity class.""" + return self._group_registry.get(component) + def strict_match( self, component: str, @@ -350,5 +360,15 @@ class ZHAEntityRegistry: return decorator + def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]: + """Decorate a group match rule.""" + + def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T: + """Register a group match rule.""" + self._group_registry[component] = zha_ent + return zha_ent + + return decorator + ZHA_ENTITIES = ZHAEntityRegistry() diff --git a/homeassistant/components/zha/core/store.py b/homeassistant/components/zha/core/store.py index 46fef76b656..0cd9e045cb6 100644 --- a/homeassistant/components/zha/core/store.py +++ b/homeassistant/components/zha/core/store.py @@ -10,6 +10,8 @@ from homeassistant.core import callback from homeassistant.helpers.typing import HomeAssistantType from homeassistant.loader import bind_hass +from .typing import ZhaDeviceType, ZhaGroupType + _LOGGER = logging.getLogger(__name__) DATA_REGISTRY = "zha_storage" @@ -28,52 +30,96 @@ class ZhaDeviceEntry: last_seen = attr.ib(type=float, default=None) -class ZhaDeviceStorage: +@attr.s(slots=True, frozen=True) +class ZhaGroupEntry: + """Zha Group storage Entry.""" + + name = attr.ib(type=str, default=None) + group_id = attr.ib(type=int, default=None) + entity_domain = attr.ib(type=float, default=None) + + +class ZhaStorage: """Class to hold a registry of zha devices.""" def __init__(self, hass: HomeAssistantType) -> None: """Initialize the zha device storage.""" - self.hass = hass + self.hass: HomeAssistantType = hass self.devices: MutableMapping[str, ZhaDeviceEntry] = {} + self.groups: MutableMapping[str, ZhaGroupEntry] = {} self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback - def async_create(self, device) -> ZhaDeviceEntry: + def async_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Create a new ZhaDeviceEntry.""" - device_entry = ZhaDeviceEntry( + device_entry: ZhaDeviceEntry = ZhaDeviceEntry( name=device.name, ieee=str(device.ieee), last_seen=device.last_seen ) self.devices[device_entry.ieee] = device_entry - return self.async_update(device) + return self.async_update_device(device) @callback - def async_get_or_create(self, device) -> ZhaDeviceEntry: + def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry: + """Create a new ZhaGroupEntry.""" + group_entry: ZhaGroupEntry = ZhaGroupEntry( + name=group.name, + group_id=str(group.group_id), + entity_domain=group.entity_domain, + ) + self.groups[str(group.group_id)] = group_entry + return self.async_update_group(group) + + @callback + def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Create a new ZhaDeviceEntry.""" - ieee_str = str(device.ieee) + ieee_str: str = str(device.ieee) if ieee_str in self.devices: return self.devices[ieee_str] - return self.async_create(device) + return self.async_create_device(device) @callback - def async_create_or_update(self, device) -> ZhaDeviceEntry: + def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry: + """Create a new ZhaGroupEntry.""" + group_id: str = str(group.group_id) + if group_id in self.groups: + return self.groups[group_id] + return self.async_create_group(group) + + @callback + def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Create or update a ZhaDeviceEntry.""" if str(device.ieee) in self.devices: - return self.async_update(device) - return self.async_create(device) + return self.async_update_device(device) + return self.async_create_device(device) @callback - def async_delete(self, device) -> None: + def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry: + """Create or update a ZhaGroupEntry.""" + if str(group.group_id) in self.groups: + return self.async_update_group(group) + return self.async_create_group(group) + + @callback + def async_delete_device(self, device: ZhaDeviceType) -> None: """Delete ZhaDeviceEntry.""" - ieee_str = str(device.ieee) + ieee_str: str = str(device.ieee) if ieee_str in self.devices: del self.devices[ieee_str] self.async_schedule_save() @callback - def async_update(self, device) -> ZhaDeviceEntry: + def async_delete_group(self, group: ZhaGroupType) -> None: + """Delete ZhaGroupEntry.""" + group_id: str = str(group.group_id) + if group_id in self.groups: + del self.groups[group_id] + self.async_schedule_save() + + @callback + def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry: """Update name of ZhaDeviceEntry.""" - ieee_str = str(device.ieee) + ieee_str: str = str(device.ieee) old = self.devices[ieee_str] changes = {} @@ -83,11 +129,25 @@ class ZhaDeviceStorage: self.async_schedule_save() return new + @callback + def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry: + """Update name of ZhaGroupEntry.""" + group_id: str = str(group.group_id) + old = self.groups[group_id] + + changes = {} + changes["entity_domain"] = group.entity_domain + + new = self.groups[group_id] = attr.evolve(old, **changes) + self.async_schedule_save() + return new + async def async_load(self) -> None: """Load the registry of zha device entries.""" data = await self._store.async_load() devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict() + groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict() if data is not None: for device in data["devices"]: @@ -97,7 +157,18 @@ class ZhaDeviceStorage: last_seen=device["last_seen"] if "last_seen" in device else None, ) + if "groups" in data: + for group in data["groups"]: + groups[group["group_id"]] = ZhaGroupEntry( + name=group["name"], + group_id=group["group_id"], + entity_domain=group["entity_domain"] + if "entity_domain" in group + else None, + ) + self.devices = devices + self.groups = groups @callback def async_schedule_save(self) -> None: @@ -118,21 +189,29 @@ class ZhaDeviceStorage: for entry in self.devices.values() ] + data["groups"] = [ + { + "name": entry.name, + "group_id": entry.group_id, + "entity_domain": entry.entity_domain, + } + for entry in self.groups.values() + ] return data @bind_hass -async def async_get_registry(hass: HomeAssistantType) -> ZhaDeviceStorage: +async def async_get_registry(hass: HomeAssistantType) -> ZhaStorage: """Return zha device storage instance.""" task = hass.data.get(DATA_REGISTRY) if task is None: - async def _load_reg() -> ZhaDeviceStorage: - registry = ZhaDeviceStorage(hass) + async def _load_reg() -> ZhaStorage: + registry = ZhaStorage(hass) await registry.async_load() return registry task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg()) - return cast(ZhaDeviceStorage, await task) + return cast(ZhaStorage, await task) diff --git a/homeassistant/components/zha/core/typing.py b/homeassistant/components/zha/core/typing.py index a1cbc9f0fef..a4619d0596e 100644 --- a/homeassistant/components/zha/core/typing.py +++ b/homeassistant/components/zha/core/typing.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar import zigpy.device import zigpy.endpoint +import zigpy.group import zigpy.zcl import zigpy.zdo @@ -17,9 +18,11 @@ ZDOChannelType = "ZDOChannel" ZhaDeviceType = "ZHADevice" ZhaEntityType = "ZHAEntity" ZhaGatewayType = "ZHAGateway" +ZhaGroupType = "ZHAGroupType" ZigpyClusterType = zigpy.zcl.Cluster ZigpyDeviceType = zigpy.device.Device ZigpyEndpointType = zigpy.endpoint.Endpoint +ZigpyGroupType = zigpy.group.Group ZigpyZdoType = zigpy.zdo.ZDO if TYPE_CHECKING: @@ -38,3 +41,4 @@ if TYPE_CHECKING: ZhaDeviceType = homeassistant.components.zha.core.device.ZHADevice ZhaEntityType = homeassistant.components.zha.entity.ZhaEntity ZhaGatewayType = homeassistant.components.zha.core.gateway.ZHAGateway + ZhaGroupType = homeassistant.components.zha.core.group.ZHAGroup diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 4dd3fea016d..63ed3a6edc7 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -3,6 +3,7 @@ import asyncio import logging import time +from typing import Any, Awaitable, Dict, List from homeassistant.core import callback from homeassistant.helpers import entity @@ -20,6 +21,7 @@ from .core.const import ( SIGNAL_REMOVE, ) from .core.helpers import LogMixin +from .core.typing import CALLABLE_T, ChannelsType, ChannelType, ZhaDeviceType _LOGGER = logging.getLogger(__name__) @@ -27,30 +29,24 @@ ENTITY_SUFFIX = "entity_suffix" RESTART_GRACE_PERIOD = 7200 # 2 hours -class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): +class BaseZhaEntity(RestoreEntity, LogMixin, entity.Entity): """A base class for ZHA entities.""" - def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs): + def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs): """Init ZHA entity.""" - self._force_update = False - self._should_poll = False - self._unique_id = unique_id - ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]]) - ch_names = [ch.cluster.ep_attribute for ch in channels] - ch_names = ", ".join(sorted(ch_names)) - self._name = f"{zha_device.name} {ieeetail} {ch_names}" - self._state = None - self._device_state_attributes = {} - self._zha_device = zha_device - self.cluster_channels = {} - self._available = False - self._unsubs = [] - self.remove_future = None - for channel in channels: - self.cluster_channels[channel.name] = channel + self._name: str = "" + self._force_update: bool = False + self._should_poll: bool = False + self._unique_id: str = unique_id + self._state: Any = None + self._device_state_attributes: Dict[str, Any] = {} + self._zha_device: ZhaDeviceType = zha_device + self._available: bool = False + self._unsubs: List[CALLABLE_T] = [] + self.remove_future: Awaitable[None] = None @property - def name(self): + def name(self) -> str: """Return Entity's default name.""" return self._name @@ -60,12 +56,12 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): return self._unique_id @property - def zha_device(self): + def zha_device(self) -> ZhaDeviceType: """Return the zha device this entity is attached to.""" return self._zha_device @property - def device_state_attributes(self): + def device_state_attributes(self) -> Dict[str, Any]: """Return device specific state attributes.""" return self._device_state_attributes @@ -80,7 +76,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): return self._should_poll @property - def device_info(self): + def device_info(self) -> Dict[str, Any]: """Return a device description for device registry.""" zha_device_info = self._zha_device.device_info ieee = zha_device_info["ieee"] @@ -94,31 +90,94 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): } @property - def available(self): + def available(self) -> bool: """Return entity availability.""" return self._available @callback - def async_set_available(self, available): + def async_set_available(self, available: bool) -> None: """Set entity availability.""" self._available = available self.async_write_ha_state() @callback - def async_update_state_attribute(self, key, value): + def async_update_state_attribute(self, key: str, value: Any) -> None: """Update a single device state attribute.""" self._device_state_attributes.update({key: value}) self.async_write_ha_state() @callback - def async_set_state(self, attr_id, attr_name, value): + def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None: """Set the entity state.""" pass - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Run when about to be added to hass.""" await super().async_added_to_hass() self.remove_future = asyncio.Future() + await self.async_accept_signal( + None, + "{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)), + self.async_remove, + signal_override=True, + ) + + async def async_will_remove_from_hass(self) -> None: + """Disconnect entity object when removed.""" + for unsub in self._unsubs[:]: + unsub() + self._unsubs.remove(unsub) + self.zha_device.gateway.remove_entity_reference(self) + self.remove_future.set_result(True) + + @callback + def async_restore_last_state(self, last_state) -> None: + """Restore previous state.""" + pass + + async def async_accept_signal( + self, channel: ChannelType, signal: str, func: CALLABLE_T, signal_override=False + ): + """Accept a signal from a channel.""" + unsub = None + if signal_override: + unsub = async_dispatcher_connect(self.hass, signal, func) + else: + unsub = async_dispatcher_connect( + self.hass, f"{channel.unique_id}_{signal}", func + ) + self._unsubs.append(unsub) + + def log(self, level: int, msg: str, *args): + """Log a message.""" + msg = f"%s: {msg}" + args = (self.entity_id,) + args + _LOGGER.log(level, msg, *args) + + +class ZhaEntity(BaseZhaEntity): + """A base class for non group ZHA entities.""" + + def __init__( + self, + unique_id: str, + zha_device: ZhaDeviceType, + channels: ChannelsType, + **kwargs, + ): + """Init ZHA entity.""" + super().__init__(unique_id, zha_device, **kwargs) + ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]]) + ch_names = [ch.cluster.ep_attribute for ch in channels] + ch_names = ", ".join(sorted(ch_names)) + self._name: str = f"{zha_device.name} {ieeetail} {ch_names}" + self.cluster_channels: Dict[str, ChannelType] = {} + for channel in channels: + self.cluster_channels[channel.name] = channel + + async def async_added_to_hass(self) -> None: + """Run when about to be added to hass.""" + await super().async_added_to_hass() await self.async_check_recently_seen() await self.async_accept_signal( None, @@ -126,12 +185,6 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): self.async_set_available, signal_override=True, ) - await self.async_accept_signal( - None, - "{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)), - self.async_remove, - signal_override=True, - ) self._zha_device.gateway.register_entity_reference( self._zha_device.ieee, self.entity_id, @@ -141,7 +194,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): self.remove_future, ) - async def async_check_recently_seen(self): + async def async_check_recently_seen(self) -> None: """Check if the device was seen within the last 2 hours.""" last_state = await self.async_get_last_state() if ( @@ -155,38 +208,8 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity): self.async_restore_last_state(last_state) self._zha_device.set_available(True) - async def async_will_remove_from_hass(self) -> None: - """Disconnect entity object when removed.""" - for unsub in self._unsubs[:]: - unsub() - self._unsubs.remove(unsub) - self.zha_device.gateway.remove_entity_reference(self) - self.remove_future.set_result(True) - - @callback - def async_restore_last_state(self, last_state): - """Restore previous state.""" - pass - - async def async_update(self): + async def async_update(self) -> None: """Retrieve latest state.""" for channel in self.cluster_channels.values(): if hasattr(channel, "async_update"): await channel.async_update() - - async def async_accept_signal(self, channel, signal, func, signal_override=False): - """Accept a signal from a channel.""" - unsub = None - if signal_override: - unsub = async_dispatcher_connect(self.hass, signal, func) - else: - unsub = async_dispatcher_connect( - self.hass, f"{channel.unique_id}_{signal}", func - ) - self._unsubs.append(unsub) - - def log(self, level, msg, *args): - """Log a message.""" - msg = f"%s: {msg}" - args = (self.entity_id,) + args - _LOGGER.log(level, msg, *args) diff --git a/homeassistant/components/zha/light.py b/homeassistant/components/zha/light.py index bf3a457ff68..2192ec1a909 100644 --- a/homeassistant/components/zha/light.py +++ b/homeassistant/components/zha/light.py @@ -1,19 +1,44 @@ """Lights on Zigbee Home Automation networks.""" +from collections import Counter from datetime import timedelta import functools +import itertools import logging import random +from typing import Any, Dict, List, Optional, Tuple +from zigpy.zcl.clusters.general import Identify, LevelControl, OnOff +from zigpy.zcl.clusters.lighting import Color from zigpy.zcl.foundation import Status from homeassistant.components import light -from homeassistant.const import STATE_ON -from homeassistant.core import callback +from homeassistant.components.light import ( + ATTR_BRIGHTNESS, + ATTR_COLOR_TEMP, + ATTR_EFFECT, + ATTR_EFFECT_LIST, + ATTR_HS_COLOR, + ATTR_MAX_MIREDS, + ATTR_MIN_MIREDS, + ATTR_WHITE_VALUE, + SUPPORT_BRIGHTNESS, + SUPPORT_COLOR, + SUPPORT_COLOR_TEMP, + SUPPORT_EFFECT, + SUPPORT_FLASH, + SUPPORT_TRANSITION, + SUPPORT_WHITE_VALUE, +) +from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE +from homeassistant.core import CALLBACK_TYPE, State, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.event import ( + async_track_state_change, + async_track_time_interval, +) import homeassistant.util.color as color_util -from .core import discovery +from .core import discovery, helpers from .core.const import ( CHANNEL_COLOR, CHANNEL_LEVEL, @@ -25,11 +50,12 @@ from .core.const import ( EFFECT_DEFAULT_VARIANT, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED, + SIGNAL_REMOVE_GROUP, SIGNAL_SET_LEVEL, ) from .core.registries import ZHA_ENTITIES from .core.typing import ZhaDeviceType -from .entity import ZhaEntity +from .entity import BaseZhaEntity, ZhaEntity _LOGGER = logging.getLogger(__name__) @@ -46,8 +72,19 @@ FLASH_EFFECTS = {light.FLASH_SHORT: EFFECT_BLINK, light.FLASH_LONG: EFFECT_BREAT UNSUPPORTED_ATTRIBUTE = 0x86 STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, light.DOMAIN) +GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, light.DOMAIN) PARALLEL_UPDATES = 0 +SUPPORT_GROUP_LIGHT = ( + SUPPORT_BRIGHTNESS + | SUPPORT_COLOR_TEMP + | SUPPORT_EFFECT + | SUPPORT_FLASH + | SUPPORT_COLOR + | SUPPORT_TRANSITION + | SUPPORT_WHITE_VALUE +) + async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Zigbee Home Automation light from config entry.""" @@ -63,48 +100,35 @@ async def async_setup_entry(hass, config_entry, async_add_entities): hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub) -@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL}) -class Light(ZhaEntity, light.Light): - """Representation of a ZHA or ZLL light.""" +class BaseLight(BaseZhaEntity, light.Light): + """Operations common to all light entities.""" - _REFRESH_INTERVAL = (45, 75) + def __init__(self, *args, **kwargs): + """Initialize the light.""" + super().__init__(*args, **kwargs) + self._is_on: bool = False + self._available: bool = False + self._brightness: Optional[int] = None + self._off_brightness: Optional[int] = None + self._hs_color: Optional[Tuple[float, float]] = None + self._color_temp: Optional[int] = None + self._min_mireds: Optional[int] = 154 + self._max_mireds: Optional[int] = 500 + self._white_value: Optional[int] = None + self._effect_list: Optional[List[str]] = None + self._effect: Optional[str] = None + self._supported_features: int = 0 + self._state: bool = False + self._on_off_channel = None + self._level_channel = None + self._color_channel = None + self._identify_channel = None - def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs): - """Initialize the ZHA light.""" - super().__init__(unique_id, zha_device, channels, **kwargs) - self._supported_features = 0 - self._color_temp = None - self._hs_color = None - self._brightness = None - self._off_brightness = None - self._effect_list = [] - self._effect = None - self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF) - self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL) - self._color_channel = self.cluster_channels.get(CHANNEL_COLOR) - self._identify_channel = self.zha_device.channels.identify_ch - self._cancel_refresh_handle = None - - if self._level_channel: - self._supported_features |= light.SUPPORT_BRIGHTNESS - self._supported_features |= light.SUPPORT_TRANSITION - self._brightness = 0 - - if self._color_channel: - color_capabilities = self._color_channel.get_color_capabilities() - if color_capabilities & CAPABILITIES_COLOR_TEMP: - self._supported_features |= light.SUPPORT_COLOR_TEMP - - if color_capabilities & CAPABILITIES_COLOR_XY: - self._supported_features |= light.SUPPORT_COLOR - self._hs_color = (0, 0) - - if color_capabilities & CAPABILITIES_COLOR_LOOP: - self._supported_features |= light.SUPPORT_EFFECT - self._effect_list.append(light.EFFECT_COLORLOOP) - - if self._identify_channel: - self._supported_features |= light.SUPPORT_FLASH + @property + def device_state_attributes(self) -> Dict[str, Any]: + """Return state attributes.""" + attributes = {"off_brightness": self._off_brightness} + return attributes @property def is_on(self) -> bool: @@ -118,12 +142,6 @@ class Light(ZhaEntity, light.Light): """Return the brightness of this light.""" return self._brightness - @property - def device_state_attributes(self): - """Return state attributes.""" - attributes = {"off_brightness": self._off_brightness} - return attributes - def set_level(self, value): """Set the brightness of this light between 0..254. @@ -160,49 +178,6 @@ class Light(ZhaEntity, light.Light): """Flag supported features.""" return self._supported_features - @callback - def async_set_state(self, attr_id, attr_name, value): - """Set the state.""" - self._state = bool(value) - if value: - self._off_brightness = None - self.async_write_ha_state() - - async def async_added_to_hass(self): - """Run when about to be added to hass.""" - await super().async_added_to_hass() - await self.async_accept_signal( - self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state - ) - if self._level_channel: - await self.async_accept_signal( - self._level_channel, SIGNAL_SET_LEVEL, self.set_level - ) - refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL]) - self._cancel_refresh_handle = async_track_time_interval( - self.hass, self._refresh, timedelta(seconds=refresh_interval) - ) - - async def async_will_remove_from_hass(self) -> None: - """Disconnect entity object when removed.""" - self._cancel_refresh_handle() - await super().async_will_remove_from_hass() - - @callback - def async_restore_last_state(self, last_state): - """Restore previous state.""" - self._state = last_state.state == STATE_ON - if "brightness" in last_state.attributes: - self._brightness = last_state.attributes["brightness"] - if "off_brightness" in last_state.attributes: - self._off_brightness = last_state.attributes["off_brightness"] - if "color_temp" in last_state.attributes: - self._color_temp = last_state.attributes["color_temp"] - if "hs_color" in last_state.attributes: - self._hs_color = last_state.attributes["hs_color"] - if "effect" in last_state.attributes: - self._effect = last_state.attributes["effect"] - async def async_turn_on(self, **kwargs): """Turn the entity on.""" transition = kwargs.get(light.ATTR_TRANSITION) @@ -331,6 +306,86 @@ class Light(ZhaEntity, light.Light): self.async_write_ha_state() + +@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL}) +class Light(ZhaEntity, BaseLight): + """Representation of a ZHA or ZLL light.""" + + _REFRESH_INTERVAL = (45, 75) + + def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs): + """Initialize the ZHA light.""" + super().__init__(unique_id, zha_device, channels, **kwargs) + self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF) + self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL) + self._color_channel = self.cluster_channels.get(CHANNEL_COLOR) + self._identify_channel = self.zha_device.channels.identify_ch + self._cancel_refresh_handle = None + + if self._level_channel: + self._supported_features |= light.SUPPORT_BRIGHTNESS + self._supported_features |= light.SUPPORT_TRANSITION + self._brightness = 0 + + if self._color_channel: + color_capabilities = self._color_channel.get_color_capabilities() + if color_capabilities & CAPABILITIES_COLOR_TEMP: + self._supported_features |= light.SUPPORT_COLOR_TEMP + + if color_capabilities & CAPABILITIES_COLOR_XY: + self._supported_features |= light.SUPPORT_COLOR + self._hs_color = (0, 0) + + if color_capabilities & CAPABILITIES_COLOR_LOOP: + self._supported_features |= light.SUPPORT_EFFECT + self._effect_list.append(light.EFFECT_COLORLOOP) + + if self._identify_channel: + self._supported_features |= light.SUPPORT_FLASH + + @callback + def async_set_state(self, attr_id, attr_name, value): + """Set the state.""" + self._state = bool(value) + if value: + self._off_brightness = None + self.async_write_ha_state() + + async def async_added_to_hass(self): + """Run when about to be added to hass.""" + await super().async_added_to_hass() + await self.async_accept_signal( + self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state + ) + if self._level_channel: + await self.async_accept_signal( + self._level_channel, SIGNAL_SET_LEVEL, self.set_level + ) + refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL]) + self._cancel_refresh_handle = async_track_time_interval( + self.hass, self._refresh, timedelta(seconds=refresh_interval) + ) + + async def async_will_remove_from_hass(self) -> None: + """Disconnect entity object when removed.""" + self._cancel_refresh_handle() + await super().async_will_remove_from_hass() + + @callback + def async_restore_last_state(self, last_state): + """Restore previous state.""" + self._state = last_state.state == STATE_ON + if "brightness" in last_state.attributes: + self._brightness = last_state.attributes["brightness"] + if "off_brightness" in last_state.attributes: + self._off_brightness = last_state.attributes["off_brightness"] + if "color_temp" in last_state.attributes: + self._color_temp = last_state.attributes["color_temp"] + if "hs_color" in last_state.attributes: + self._hs_color = last_state.attributes["hs_color"] + if "effect" in last_state.attributes: + self._effect = last_state.attributes["effect"] + async def async_update(self): """Attempt to retrieve on off state from the light.""" await super().async_update() @@ -410,3 +465,99 @@ class HueLight(Light): """Representation of a HUE light which does not report attributes.""" _REFRESH_INTERVAL = (3, 5) + + +@GROUP_MATCH() +class LightGroup(BaseLight): + """Representation of a light group.""" + + def __init__( + self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs + ) -> None: + """Initialize a light group.""" + super().__init__(unique_id, zha_device, **kwargs) + self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}" + self._group_id: int = group_id + self._entity_ids: List[str] = entity_ids + group = self.zha_device.gateway.get_group(self._group_id) + self._on_off_channel = group.endpoint[OnOff.cluster_id] + self._level_channel = group.endpoint[LevelControl.cluster_id] + self._color_channel = group.endpoint[Color.cluster_id] + self._identify_channel = group.endpoint[Identify.cluster_id] + self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None + + async def async_added_to_hass(self) -> None: + """Register callbacks.""" + await super().async_added_to_hass() + await self.async_accept_signal( + None, + f"{SIGNAL_REMOVE_GROUP}_{self._group_id}", + self.async_remove, + signal_override=True, + ) + + @callback + def async_state_changed_listener( + entity_id: str, old_state: State, new_state: State + ): + """Handle child updates.""" + self.async_schedule_update_ha_state(True) + + self._async_unsub_state_changed = async_track_state_change( + self.hass, self._entity_ids, async_state_changed_listener + ) + await self.async_update() + + async def async_will_remove_from_hass(self) -> None: + """Handle removal from Home Assistant.""" + await super().async_will_remove_from_hass() + if self._async_unsub_state_changed is not None: + self._async_unsub_state_changed() + self._async_unsub_state_changed = None + + async def async_update(self) -> None: + """Query all members and determine the light group state.""" + all_states = [self.hass.states.get(x) for x in self._entity_ids] + states: List[State] = list(filter(None, all_states)) + on_states = [state for state in states if state.state == STATE_ON] + + self._is_on = len(on_states) > 0 + self._available = any(state.state != STATE_UNAVAILABLE for state in states) + + self._brightness = helpers.reduce_attribute(on_states, ATTR_BRIGHTNESS) + + self._hs_color = helpers.reduce_attribute( + on_states, ATTR_HS_COLOR, reduce=helpers.mean_tuple + ) + + self._white_value = helpers.reduce_attribute(on_states, ATTR_WHITE_VALUE) + + self._color_temp = helpers.reduce_attribute(on_states, ATTR_COLOR_TEMP) + self._min_mireds = helpers.reduce_attribute( + states, ATTR_MIN_MIREDS, default=154, reduce=min + ) + self._max_mireds = helpers.reduce_attribute( + states, ATTR_MAX_MIREDS, default=500, reduce=max + ) + + self._effect_list = None + all_effect_lists = list(helpers.find_state_attributes(states, ATTR_EFFECT_LIST)) + if all_effect_lists: + # Merge all effects from all effect_lists with a union merge. + self._effect_list = list(set().union(*all_effect_lists)) + + self._effect = None + all_effects = list(helpers.find_state_attributes(on_states, ATTR_EFFECT)) + if all_effects: + # Report the most common effect. + effects_count = Counter(itertools.chain(all_effects)) + self._effect = effects_count.most_common(1)[0][0] + + self._supported_features = 0 + for support in helpers.find_state_attributes(states, ATTR_SUPPORTED_FEATURES): + # Merge supported features by emulating support for every feature + # we find. + self._supported_features |= support + # Bitwise-and the supported features with the GroupedLight's features + # so that we don't break in the future when a new feature is added. + self._supported_features &= SUPPORT_GROUP_LIGHT diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index 3753136d59d..9c57b57419a 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -3,6 +3,8 @@ import time from unittest.mock import Mock from asynctest import CoroutineMock +from zigpy.device import Device as zigpy_dev +from zigpy.endpoint import Endpoint as zigpy_ep import zigpy.profiles.zha import zigpy.types import zigpy.zcl @@ -24,6 +26,7 @@ class FakeEndpoint: self.in_clusters = {} self.out_clusters = {} self._cluster_attr = {} + self.member_of = {} self.status = 1 self.manufacturer = manufacturer self.model = model @@ -45,6 +48,19 @@ class FakeEndpoint: patch_cluster(cluster) self.out_clusters[cluster_id] = cluster + @property + def __class__(self): + """Fake being Zigpy endpoint.""" + return zigpy_ep + + @property + def unique_id(self): + """Return the unique id for the endpoint.""" + return self.device.ieee, self.endpoint_id + + +FakeEndpoint.add_to_group = zigpy_ep.add_to_group + def patch_cluster(cluster): """Patch a cluster for testing.""" @@ -56,17 +72,19 @@ def patch_cluster(cluster): cluster.read_attributes_raw = Mock() cluster.unbind = CoroutineMock(return_value=[0]) cluster.write_attributes = CoroutineMock(return_value=[0]) + if cluster.cluster_id == 4: + cluster.add = CoroutineMock(return_value=[0]) class FakeDevice: """Fake device for mocking zigpy.""" - def __init__(self, app, ieee, manufacturer, model, node_desc=None): + def __init__(self, app, ieee, manufacturer, model, node_desc=None, nwk=0xB79C): """Init fake device.""" self._application = app self.application = app self.ieee = zigpy.types.EUI64.convert(ieee) - self.nwk = 0xB79C + self.nwk = nwk self.zdo = Mock() self.endpoints = {0: self.zdo} self.lqi = 255 @@ -78,13 +96,15 @@ class FakeDevice: self.manufacturer = manufacturer self.model = model self.node_desc = zigpy.zdo.types.NodeDescriptor() - self.add_to_group = CoroutineMock() self.remove_from_group = CoroutineMock() if node_desc is None: node_desc = b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00" self.node_desc = zigpy.zdo.types.NodeDescriptor.deserialize(node_desc)[0] +FakeDevice.add_to_group = zigpy_dev.add_to_group + + def get_zha_gateway(hass): """Return ZHA gateway from hass.data.""" try: @@ -137,6 +157,19 @@ async def find_entity_id(domain, zha_device, hass): return None +def async_find_group_entity_id(hass, domain, group): + """Find the group entity id under test.""" + entity_id = ( + f"{domain}.{group.name.lower().replace(' ','_')}_group_0x{group.group_id:04x}" + ) + + entity_ids = hass.states.async_entity_ids(domain) + + if entity_id in entity_ids: + return entity_id + return None + + async def async_enable_traffic(hass, zha_devices): """Allow traffic to flow through the gateway and the zha device.""" for zha_device in zha_devices: diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index e6056428db6..b83db53533c 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -110,10 +110,11 @@ def zigpy_device_mock(zigpy_app_controller): manufacturer="FakeManufacturer", model="FakeModel", node_descriptor=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", + nwk=0xB79C, ): """Make a fake device using the specified cluster classes.""" device = FakeDevice( - zigpy_app_controller, ieee, manufacturer, model, node_descriptor + zigpy_app_controller, ieee, manufacturer, model, node_descriptor, nwk=nwk ) for epid, ep in endpoints.items(): endpoint = FakeEndpoint(manufacturer, model, epid) diff --git a/tests/components/zha/test_gateway.py b/tests/components/zha/test_gateway.py index 74aed6f5872..3bb98522814 100644 --- a/tests/components/zha/test_gateway.py +++ b/tests/components/zha/test_gateway.py @@ -1,8 +1,18 @@ """Test ZHA Gateway.""" -import pytest -import zigpy.zcl.clusters.general as general +import logging -from .common import async_enable_traffic, get_zha_gateway +import pytest +import zigpy.profiles.zha as zha +import zigpy.zcl.clusters.general as general +import zigpy.zcl.clusters.lighting as lighting + +from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN + +from .common import async_enable_traffic, async_find_group_entity_id, get_zha_gateway + +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" +_LOGGER = logging.getLogger(__name__) @pytest.fixture @@ -15,7 +25,7 @@ def zigpy_dev_basic(zigpy_device_mock): "out_clusters": [], "device_type": 0, } - }, + } ) @@ -27,6 +37,74 @@ async def zha_dev_basic(hass, zha_device_restored, zigpy_dev_basic): return zha_device +@pytest.fixture +async def coordinator(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee="00:15:8d:00:02:32:4f:32", + nwk=0x0000, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_light_1(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + ], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_light_2(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + ], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic): """Device leaving the network should become unavailable.""" @@ -37,3 +115,57 @@ async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic): get_zha_gateway(hass).device_left(zigpy_dev_basic) assert zha_dev_basic.available is False + + +async def test_gateway_group_methods(hass, device_light_1, device_light_2, coordinator): + """Test creating a group with 2 members.""" + zha_gateway = get_zha_gateway(hass) + assert zha_gateway is not None + zha_gateway.coordinator_zha_device = coordinator + coordinator._zha_gateway = zha_gateway + device_light_1._zha_gateway = zha_gateway + device_light_2._zha_gateway = zha_gateway + member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee] + + # test creating a group with 2 members + zha_group = await zha_gateway.async_create_zigpy_group( + "Test Group", member_ieee_addresses + ) + await hass.async_block_till_done() + + assert zha_group is not None + assert zha_group.entity_domain == LIGHT_DOMAIN + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.ieee in member_ieee_addresses + + entity_id = async_find_group_entity_id(hass, LIGHT_DOMAIN, zha_group) + assert hass.states.get(entity_id) is not None + + # test get group by name + assert zha_group == zha_gateway.async_get_group_by_name(zha_group.name) + + # test removing a group + await zha_gateway.async_remove_zigpy_group(zha_group.group_id) + await hass.async_block_till_done() + + # we shouldn't have the group anymore + assert zha_gateway.async_get_group_by_name(zha_group.name) is None + + # the group entity should be cleaned up + assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN) + + # test creating a group with 1 member + zha_group = await zha_gateway.async_create_zigpy_group( + "Test Group", [device_light_1.ieee] + ) + await hass.async_block_till_done() + + assert zha_group is not None + assert zha_group.entity_domain is None + assert len(zha_group.members) == 1 + for member in zha_group.members: + assert member.ieee in [device_light_1.ieee] + + # the group entity should not have been cleaned up + assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN) diff --git a/tests/components/zha/test_light.py b/tests/components/zha/test_light.py index f27bd329bdb..c6bafa45aea 100644 --- a/tests/components/zha/test_light.py +++ b/tests/components/zha/test_light.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, call, sentinel from asynctest import CoroutineMock, patch import pytest -import zigpy.profiles.zha +import zigpy.profiles.zha as zha import zigpy.types import zigpy.zcl.clusters.general as general import zigpy.zcl.clusters.lighting as lighting @@ -17,8 +17,10 @@ import homeassistant.util.dt as dt_util from .common import ( async_enable_traffic, + async_find_group_entity_id, async_test_rejoin, find_entity_id, + get_zha_gateway, send_attributes_report, ) @@ -26,6 +28,8 @@ from tests.common import async_fire_time_changed ON = 1 OFF = 0 +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" LIGHT_ON_OFF = { 1: { @@ -66,6 +70,76 @@ LIGHT_COLOR = { } +@pytest.fixture +async def coordinator(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee="00:15:8d:00:02:32:4f:32", + nwk=0x0000, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_light_1(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + general.Identify.cluster_id, + ], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + +@pytest.fixture +async def device_light_2(hass, zigpy_device_mock, zha_device_joined): + """Test zha light platform.""" + + zigpy_device = zigpy_device_mock( + { + 1: { + "in_clusters": [ + general.OnOff.cluster_id, + general.LevelControl.cluster_id, + lighting.Color.cluster_id, + general.Groups.cluster_id, + general.Identify.cluster_id, + ], + "out_clusters": [], + "device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT, + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await zha_device_joined(zigpy_device) + zha_device.set_available(True) + return zha_device + + @patch("zigpy.zcl.clusters.general.OnOff.read_attributes", new=MagicMock()) async def test_light_refresh(hass, zigpy_device_mock, zha_device_joined_restored): """Test zha light platform refresh.""" @@ -337,3 +411,96 @@ async def async_test_flash_from_hass(hass, cluster, entity_id, flash): manufacturer=None, tsn=None, ) + + +async def async_test_zha_group_light_entity( + hass, device_light_1, device_light_2, coordinator +): + """Test the light entity for a ZHA group.""" + zha_gateway = get_zha_gateway(hass) + assert zha_gateway is not None + zha_gateway.coordinator_zha_device = coordinator + coordinator._zha_gateway = zha_gateway + device_light_1._zha_gateway = zha_gateway + device_light_2._zha_gateway = zha_gateway + member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee] + + # test creating a group with 2 members + zha_group = await zha_gateway.async_create_zigpy_group( + "Test Group", member_ieee_addresses + ) + await hass.async_block_till_done() + + assert zha_group is not None + assert zha_group.entity_domain == DOMAIN + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.ieee in member_ieee_addresses + + entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group) + assert hass.states.get(entity_id) is not None + + group_cluster_on_off = zha_group.endpoint[general.OnOff.cluster_id] + group_cluster_level = zha_group.endpoint[general.LevelControl.cluster_id] + group_cluster_identify = zha_group.endpoint[general.Identify.cluster_id] + + dev1_cluster_on_off = device_light_1.endpoints[1].on_off + dev2_cluster_on_off = device_light_2.endpoints[1].on_off + + # test that the lights were created and that they are unavailable + assert hass.states.get(entity_id).state == STATE_UNAVAILABLE + + # allow traffic to flow through the gateway and device + await async_enable_traffic(hass, zha_group.members) + + # test that the lights were created and are off + assert hass.states.get(entity_id).state == STATE_OFF + + # test turning the lights on and off from the light + await async_test_on_off_from_light(hass, group_cluster_on_off, entity_id) + + # test turning the lights on and off from the HA + await async_test_on_off_from_hass(hass, group_cluster_on_off, entity_id) + + # test short flashing the lights from the HA + await async_test_flash_from_hass( + hass, group_cluster_identify, entity_id, FLASH_SHORT + ) + + # test turning the lights on and off from the HA + await async_test_level_on_off_from_hass( + hass, group_cluster_on_off, group_cluster_level, entity_id + ) + + # test getting a brightness change from the network + await async_test_on_from_light(hass, group_cluster_on_off, entity_id) + await async_test_dimmer_from_light( + hass, group_cluster_level, entity_id, 150, STATE_ON + ) + + # test long flashing the lights from the HA + await async_test_flash_from_hass( + hass, group_cluster_identify, entity_id, FLASH_LONG + ) + + # test some of the group logic to make sure we key off states correctly + await dev1_cluster_on_off.on() + await dev2_cluster_on_off.on() + + # test that group light is on + assert hass.states.get(entity_id).state == STATE_ON + + await dev1_cluster_on_off.off() + + # test that group light is still on + assert hass.states.get(entity_id).state == STATE_ON + + await dev2_cluster_on_off.off() + + # test that group light is now off + assert hass.states.get(entity_id).state == STATE_OFF + + await dev1_cluster_on_off.on() + + # test that group light is now back on + assert hass.states.get(entity_id).state == STATE_ON