Add group entity support to ZHA (#33196)
* split entity into base and entity * add initial light group support * add dispatching of groups to light * added zha group object * add group event listener * add and remove group members * get group by name * fix rebase * fix rebase * use group_id for unique_id * get entities from registry * use group name * update entity domain * update zha storage to handle groups * dispatch group entities * update light group * fix group remove and dispatch light group entities * allow picking the domain for group entities * beginning - auto determine entity domain * move methods to helpers so they can be shared * fix rebase * remove double init groups... again * cleanup startup * use asyncio create task * group entity discovery * add logging and fix group name * add logging and update group after probe if needed * test add group via gateway * add method to get group entity ids * update storage * test get group by name * update storage on remove * test group with single member * add light group tests * test some light group logic * type hints * fix tests and cleanup * revert init changes except for create task * remove group entity domain changing for now * add missing import * tricky code saving * review comments * clean up class defs * cleanup * fix rebase because I cant read * make pylint happypull/33246/head
parent
3ee05ad4bb
commit
2a3c94bad0
|
@ -130,7 +130,7 @@ async def async_setup_entry(hass, config_entry):
|
|||
await zha_data[DATA_ZHA_GATEWAY].async_update_device_storage()
|
||||
|
||||
hass.bus.async_listen_once(ha_const.EVENT_HOMEASSISTANT_STOP, async_zha_shutdown)
|
||||
asyncio.create_task(async_load_entities(hass, config_entry))
|
||||
asyncio.create_task(async_load_entities(hass))
|
||||
return True
|
||||
|
||||
|
||||
|
@ -150,11 +150,9 @@ async def async_unload_entry(hass, config_entry):
|
|||
return True
|
||||
|
||||
|
||||
async def async_load_entities(
|
||||
hass: HomeAssistantType, config_entry: config_entries.ConfigEntry
|
||||
) -> None:
|
||||
async def async_load_entities(hass: HomeAssistantType) -> None:
|
||||
"""Load entities after integration was setup."""
|
||||
await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_prepare_entities()
|
||||
await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_initialize_devices_and_entities()
|
||||
to_setup = hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED]
|
||||
results = await asyncio.gather(*to_setup, return_exceptions=True)
|
||||
for res in results:
|
||||
|
|
|
@ -23,6 +23,7 @@ ATTR_COMMAND_TYPE = "command_type"
|
|||
ATTR_DEVICE_IEEE = "device_ieee"
|
||||
ATTR_DEVICE_TYPE = "device_type"
|
||||
ATTR_ENDPOINT_ID = "endpoint_id"
|
||||
ATTR_ENTITY_DOMAIN = "entity_domain"
|
||||
ATTR_IEEE = "ieee"
|
||||
ATTR_LAST_SEEN = "last_seen"
|
||||
ATTR_LEVEL = "level"
|
||||
|
@ -207,6 +208,7 @@ SIGNAL_REMOVE = "remove"
|
|||
SIGNAL_SET_LEVEL = "set_level"
|
||||
SIGNAL_STATE_ATTR = "update_state_attribute"
|
||||
SIGNAL_UPDATE_DEVICE = "{}_zha_update_device"
|
||||
SIGNAL_REMOVE_GROUP = "remove_group"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
UNKNOWN_MANUFACTURER = "unk_manufacturer"
|
||||
|
|
|
@ -373,7 +373,7 @@ class ZHADevice(LogMixin):
|
|||
self.debug("started configuration")
|
||||
await self._channels.async_configure()
|
||||
self.debug("completed configuration")
|
||||
entry = self.gateway.zha_storage.async_create_or_update(self)
|
||||
entry = self.gateway.zha_storage.async_create_or_update_device(self)
|
||||
self.debug("stored in registry: %s", entry)
|
||||
|
||||
if self._channels.identify_ch is not None:
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
"""Device discovery functions for Zigbee Home Automation."""
|
||||
|
||||
from collections import Counter
|
||||
import logging
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from homeassistant import const as ha_const
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
from . import const as zha_const, registries as zha_regs, typing as zha_typing
|
||||
|
@ -157,4 +159,102 @@ class ProbeEndpoint:
|
|||
self._device_configs.update(overrides)
|
||||
|
||||
|
||||
class GroupProbe:
|
||||
"""Determine the appropriate component for a group."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize instance."""
|
||||
self._hass = None
|
||||
|
||||
def initialize(self, hass: HomeAssistantType) -> None:
|
||||
"""Initialize the group probe."""
|
||||
self._hass = hass
|
||||
|
||||
@callback
|
||||
def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None:
|
||||
"""Process a group and create any entities that are needed."""
|
||||
# only create a group entity if there are 2 or more members in a group
|
||||
if len(group.members) < 2:
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has less than 2 members - skipping entity discovery",
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
return
|
||||
|
||||
if group.entity_domain is None:
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery",
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
group.entity_domain = GroupProbe.determine_default_entity_domain(
|
||||
self._hass, group
|
||||
)
|
||||
|
||||
if group.entity_domain is None:
|
||||
return
|
||||
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has an entity domain of: %s after discovery",
|
||||
group.name,
|
||||
group.group_id,
|
||||
group.entity_domain,
|
||||
)
|
||||
|
||||
zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
|
||||
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain)
|
||||
if entity_class is None:
|
||||
return
|
||||
|
||||
self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append(
|
||||
(
|
||||
entity_class,
|
||||
(
|
||||
group.domain_entity_ids,
|
||||
f"{group.entity_domain}_group_{group.group_id}",
|
||||
group.group_id,
|
||||
zha_gateway.coordinator_zha_device,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def determine_default_entity_domain(
|
||||
hass: HomeAssistantType, group: zha_typing.ZhaGroupType
|
||||
):
|
||||
"""Determine the default entity domain for this group."""
|
||||
if len(group.members) < 2:
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has less than 2 members so cannot default an entity domain",
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
return None
|
||||
|
||||
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
|
||||
all_domain_occurrences = []
|
||||
for device in group.members:
|
||||
entities = async_entries_for_device(
|
||||
zha_gateway.ha_entity_registry, device.device_id
|
||||
)
|
||||
all_domain_occurrences.extend(
|
||||
[
|
||||
entity.domain
|
||||
for entity in entities
|
||||
if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS
|
||||
]
|
||||
)
|
||||
counts = Counter(all_domain_occurrences)
|
||||
domain = counts.most_common(1)[0][0]
|
||||
_LOGGER.debug(
|
||||
"The default entity domain is: %s for group: %s:0x%04x",
|
||||
domain,
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
return domain
|
||||
|
||||
|
||||
PROBE = ProbeEndpoint()
|
||||
GROUP_PROBE = GroupProbe()
|
||||
|
|
|
@ -6,6 +6,7 @@ import itertools
|
|||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
from serial import SerialException
|
||||
import zigpy.device as zigpy_dev
|
||||
|
@ -52,6 +53,7 @@ from .const import (
|
|||
DOMAIN,
|
||||
SIGNAL_ADD_ENTITIES,
|
||||
SIGNAL_REMOVE,
|
||||
SIGNAL_REMOVE_GROUP,
|
||||
UNKNOWN_MANUFACTURER,
|
||||
UNKNOWN_MODEL,
|
||||
ZHA_GW_MSG,
|
||||
|
@ -75,6 +77,7 @@ from .group import ZHAGroup
|
|||
from .patches import apply_application_controller_patch
|
||||
from .registries import RADIO_TYPES
|
||||
from .store import async_get_registry
|
||||
from .typing import ZhaDeviceType, ZhaGroupType, ZigpyEndpointType, ZigpyGroupType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -93,6 +96,7 @@ class ZHAGateway:
|
|||
self._config = config
|
||||
self._devices = {}
|
||||
self._groups = {}
|
||||
self.coordinator_zha_device = None
|
||||
self._device_registry = collections.defaultdict(list)
|
||||
self.zha_storage = None
|
||||
self.ha_device_registry = None
|
||||
|
@ -110,6 +114,7 @@ class ZHAGateway:
|
|||
async def async_initialize(self):
|
||||
"""Initialize controller and connect radio."""
|
||||
discovery.PROBE.initialize(self._hass)
|
||||
discovery.GROUP_PROBE.initialize(self._hass)
|
||||
|
||||
self.zha_storage = await async_get_registry(self._hass)
|
||||
self.ha_device_registry = await get_dev_reg(self._hass)
|
||||
|
@ -156,17 +161,29 @@ class ZHAGateway:
|
|||
self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(
|
||||
self.application_controller.ieee
|
||||
)
|
||||
await self.async_load_devices()
|
||||
self._initialize_groups()
|
||||
self.async_load_devices()
|
||||
self.async_load_groups()
|
||||
|
||||
async def async_load_devices(self) -> None:
|
||||
@callback
|
||||
def async_load_devices(self) -> None:
|
||||
"""Restore ZHA devices from zigpy application state."""
|
||||
zigpy_devices = self.application_controller.devices.values()
|
||||
for zigpy_device in zigpy_devices:
|
||||
self._async_get_or_create_device(zigpy_device, restored=True)
|
||||
zha_device = self._async_get_or_create_device(zigpy_device, restored=True)
|
||||
if zha_device.nwk == 0x0000:
|
||||
self.coordinator_zha_device = zha_device
|
||||
|
||||
async def async_prepare_entities(self) -> None:
|
||||
"""Prepare entities by initializing device channels."""
|
||||
@callback
|
||||
def async_load_groups(self) -> None:
|
||||
"""Initialize ZHA groups."""
|
||||
for group_id in self.application_controller.groups:
|
||||
group = self.application_controller.groups[group_id]
|
||||
zha_group = self._async_get_or_create_group(group)
|
||||
# we can do this here because the entities are in the entity registry tied to the devices
|
||||
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
||||
|
||||
async def async_initialize_devices_and_entities(self) -> None:
|
||||
"""Initialize devices and load entities."""
|
||||
semaphore = asyncio.Semaphore(2)
|
||||
|
||||
async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool):
|
||||
|
@ -231,35 +248,44 @@ class ZHAGateway:
|
|||
"""Handle device leaving the network."""
|
||||
self.async_update_device(device, False)
|
||||
|
||||
def group_member_removed(self, zigpy_group, endpoint):
|
||||
def group_member_removed(
|
||||
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
|
||||
) -> None:
|
||||
"""Handle zigpy group member removed event."""
|
||||
# need to handle endpoint correctly on groups
|
||||
zha_group = self._async_get_or_create_group(zigpy_group)
|
||||
zha_group.info("group_member_removed - endpoint: %s", endpoint)
|
||||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED)
|
||||
|
||||
def group_member_added(self, zigpy_group, endpoint):
|
||||
def group_member_added(
|
||||
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
|
||||
) -> None:
|
||||
"""Handle zigpy group member added event."""
|
||||
# need to handle endpoint correctly on groups
|
||||
zha_group = self._async_get_or_create_group(zigpy_group)
|
||||
zha_group.info("group_member_added - endpoint: %s", endpoint)
|
||||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED)
|
||||
|
||||
def group_added(self, zigpy_group):
|
||||
def group_added(self, zigpy_group: ZigpyGroupType) -> None:
|
||||
"""Handle zigpy group added event."""
|
||||
zha_group = self._async_get_or_create_group(zigpy_group)
|
||||
zha_group.info("group_added")
|
||||
# need to dispatch for entity creation here
|
||||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
|
||||
|
||||
def group_removed(self, zigpy_group):
|
||||
def group_removed(self, zigpy_group: ZigpyGroupType) -> None:
|
||||
"""Handle zigpy group added event."""
|
||||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
|
||||
zha_group = self._groups.pop(zigpy_group.group_id, None)
|
||||
zha_group.info("group_removed")
|
||||
async_dispatcher_send(
|
||||
self._hass, f"{SIGNAL_REMOVE_GROUP}_{zigpy_group.group_id}"
|
||||
)
|
||||
|
||||
def _send_group_gateway_message(self, zigpy_group, gateway_message_type):
|
||||
"""Send the gareway event for a zigpy group event."""
|
||||
def _send_group_gateway_message(
|
||||
self, zigpy_group: ZigpyGroupType, gateway_message_type: str
|
||||
) -> None:
|
||||
"""Send the gateway event for a zigpy group event."""
|
||||
zha_group = self._groups.get(zigpy_group.group_id)
|
||||
if zha_group is not None:
|
||||
async_dispatcher_send(
|
||||
|
@ -306,12 +332,12 @@ class ZHAGateway:
|
|||
"""Return ZHADevice for given ieee."""
|
||||
return self._devices.get(ieee)
|
||||
|
||||
def get_group(self, group_id):
|
||||
def get_group(self, group_id: str) -> Optional[ZhaGroupType]:
|
||||
"""Return Group for given group id."""
|
||||
return self.groups.get(group_id)
|
||||
|
||||
@callback
|
||||
def async_get_group_by_name(self, group_name):
|
||||
def async_get_group_by_name(self, group_name: str) -> Optional[ZhaGroupType]:
|
||||
"""Get ZHA group by name."""
|
||||
for group in self.groups.values():
|
||||
if group.name == group_name:
|
||||
|
@ -390,12 +416,6 @@ class ZHAGateway:
|
|||
logging.getLogger(logger_name).removeHandler(self._log_relay_handler)
|
||||
self.debug_enabled = False
|
||||
|
||||
def _initialize_groups(self):
|
||||
"""Initialize ZHA groups."""
|
||||
for group_id in self.application_controller.groups:
|
||||
group = self.application_controller.groups[group_id]
|
||||
self._async_get_or_create_group(group)
|
||||
|
||||
@callback
|
||||
def _async_get_or_create_device(
|
||||
self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False
|
||||
|
@ -414,17 +434,19 @@ class ZHAGateway:
|
|||
model=zha_device.model,
|
||||
)
|
||||
zha_device.set_device_id(device_registry_device.id)
|
||||
entry = self.zha_storage.async_get_or_create(zha_device)
|
||||
entry = self.zha_storage.async_get_or_create_device(zha_device)
|
||||
zha_device.async_update_last_seen(entry.last_seen)
|
||||
return zha_device
|
||||
|
||||
@callback
|
||||
def _async_get_or_create_group(self, zigpy_group):
|
||||
def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType:
|
||||
"""Get or create a ZHA group."""
|
||||
zha_group = self._groups.get(zigpy_group.group_id)
|
||||
if zha_group is None:
|
||||
zha_group = ZHAGroup(self._hass, self, zigpy_group)
|
||||
self._groups[zigpy_group.group_id] = zha_group
|
||||
group_entry = self.zha_storage.async_get_or_create_group(zha_group)
|
||||
zha_group.entity_domain = group_entry.entity_domain
|
||||
return zha_group
|
||||
|
||||
@callback
|
||||
|
@ -446,7 +468,9 @@ class ZHAGateway:
|
|||
async def async_update_device_storage(self):
|
||||
"""Update the devices in the store."""
|
||||
for device in self.devices.values():
|
||||
self.zha_storage.async_update(device)
|
||||
self.zha_storage.async_update_device(device)
|
||||
for group in self.groups.values():
|
||||
self.zha_storage.async_update_group(group)
|
||||
await self.zha_storage.async_save()
|
||||
|
||||
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType):
|
||||
|
@ -494,25 +518,6 @@ class ZHAGateway:
|
|||
zha_device.update_available(True)
|
||||
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
|
||||
|
||||
# only public for testing
|
||||
async def async_device_restored(self, device: zha_typing.ZigpyDeviceType):
|
||||
"""Add an existing device to the ZHA zigbee network when ZHA first starts."""
|
||||
zha_device = self._async_get_or_create_device(device, restored=True)
|
||||
|
||||
if zha_device.is_mains_powered:
|
||||
# the device isn't a battery powered device so we should be able
|
||||
# to update it now
|
||||
_LOGGER.debug(
|
||||
"attempting to request fresh state for device - %s:%s %s with power source %s",
|
||||
zha_device.nwk,
|
||||
zha_device.ieee,
|
||||
zha_device.name,
|
||||
zha_device.power_source,
|
||||
)
|
||||
await zha_device.async_initialize(from_cache=False)
|
||||
else:
|
||||
await zha_device.async_initialize(from_cache=True)
|
||||
|
||||
async def _async_device_rejoined(self, zha_device):
|
||||
_LOGGER.debug(
|
||||
"skipping discovery for previously discovered device - %s:%s",
|
||||
|
@ -524,7 +529,9 @@ class ZHAGateway:
|
|||
# will cause async_init to fire so don't explicitly call it
|
||||
zha_device.update_available(True)
|
||||
|
||||
async def async_create_zigpy_group(self, name, members):
|
||||
async def async_create_zigpy_group(
|
||||
self, name: str, members: List[ZhaDeviceType]
|
||||
) -> ZhaGroupType:
|
||||
"""Create a new Zigpy Zigbee group."""
|
||||
# we start with one to fill any gaps from a user removing existing groups
|
||||
group_id = 1
|
||||
|
@ -537,24 +544,40 @@ class ZHAGateway:
|
|||
if members is not None:
|
||||
tasks = []
|
||||
for ieee in members:
|
||||
_LOGGER.debug(
|
||||
"Adding member with IEEE: %s to group: %s:0x%04x",
|
||||
ieee,
|
||||
name,
|
||||
group_id,
|
||||
)
|
||||
tasks.append(self.devices[ieee].async_add_to_group(group_id))
|
||||
await asyncio.gather(*tasks)
|
||||
return self.groups.get(group_id)
|
||||
zha_group = self.groups.get(group_id)
|
||||
_LOGGER.debug(
|
||||
"Probing group: %s:0x%04x for entity discovery",
|
||||
zha_group.name,
|
||||
zha_group.group_id,
|
||||
)
|
||||
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
||||
if zha_group.entity_domain is not None:
|
||||
self.zha_storage.async_update_group(zha_group)
|
||||
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
|
||||
return zha_group
|
||||
|
||||
async def async_remove_zigpy_group(self, group_id):
|
||||
async def async_remove_zigpy_group(self, group_id: int) -> None:
|
||||
"""Remove a Zigbee group from Zigpy."""
|
||||
group = self.groups.get(group_id)
|
||||
if not group:
|
||||
_LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id)
|
||||
return
|
||||
if group and group.members:
|
||||
tasks = []
|
||||
for member in group.members:
|
||||
tasks.append(member.async_remove_from_group(group_id))
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
else:
|
||||
# we have members but none are tracked by ZHA for whatever reason
|
||||
self.application_controller.groups.pop(group_id)
|
||||
else:
|
||||
self.application_controller.groups.pop(group_id)
|
||||
self.application_controller.groups.pop(group_id)
|
||||
self.zha_storage.async_delete_group(group)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Stop ZHA Controller Application."""
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
"""Group for Zigbee Home Automation."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from zigpy.types.named import EUI64
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
from .helpers import LogMixin
|
||||
from .typing import ZhaDeviceType, ZhaGatewayType, ZigpyEndpointType, ZigpyGroupType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -12,29 +18,45 @@ _LOGGER = logging.getLogger(__name__)
|
|||
class ZHAGroup(LogMixin):
|
||||
"""ZHA Zigbee group object."""
|
||||
|
||||
def __init__(self, hass, zha_gateway, zigpy_group):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistantType,
|
||||
zha_gateway: ZhaGatewayType,
|
||||
zigpy_group: ZigpyGroupType,
|
||||
):
|
||||
"""Initialize the group."""
|
||||
self.hass = hass
|
||||
self._zigpy_group = zigpy_group
|
||||
self._zha_gateway = zha_gateway
|
||||
self.hass: HomeAssistantType = hass
|
||||
self._zigpy_group: ZigpyGroupType = zigpy_group
|
||||
self._zha_gateway: ZhaGatewayType = zha_gateway
|
||||
self._entity_domain: str = None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Return group name."""
|
||||
return self._zigpy_group.name
|
||||
|
||||
@property
|
||||
def group_id(self):
|
||||
def group_id(self) -> int:
|
||||
"""Return group name."""
|
||||
return self._zigpy_group.group_id
|
||||
|
||||
@property
|
||||
def endpoint(self):
|
||||
def endpoint(self) -> ZigpyEndpointType:
|
||||
"""Return the endpoint for this group."""
|
||||
return self._zigpy_group.endpoint
|
||||
|
||||
@property
|
||||
def members(self):
|
||||
def entity_domain(self) -> Optional[str]:
|
||||
"""Return the domain that will be used for the entity representing this group."""
|
||||
return self._entity_domain
|
||||
|
||||
@entity_domain.setter
|
||||
def entity_domain(self, domain: Optional[str]) -> None:
|
||||
"""Set the domain that will be used for the entity representing this group."""
|
||||
self._entity_domain = domain
|
||||
|
||||
@property
|
||||
def members(self) -> List[ZhaDeviceType]:
|
||||
"""Return the ZHA devices that are members of this group."""
|
||||
return [
|
||||
self._zha_gateway.devices.get(member_ieee[0])
|
||||
|
@ -42,7 +64,7 @@ class ZHAGroup(LogMixin):
|
|||
if member_ieee[0] in self._zha_gateway.devices
|
||||
]
|
||||
|
||||
async def async_add_members(self, member_ieee_addresses):
|
||||
async def async_add_members(self, member_ieee_addresses: List[EUI64]) -> None:
|
||||
"""Add members to this group."""
|
||||
if len(member_ieee_addresses) > 1:
|
||||
tasks = []
|
||||
|
@ -56,7 +78,7 @@ class ZHAGroup(LogMixin):
|
|||
member_ieee_addresses[0]
|
||||
].async_add_to_group(self.group_id)
|
||||
|
||||
async def async_remove_members(self, member_ieee_addresses):
|
||||
async def async_remove_members(self, member_ieee_addresses: List[EUI64]) -> None:
|
||||
"""Remove members from this group."""
|
||||
if len(member_ieee_addresses) > 1:
|
||||
tasks = []
|
||||
|
@ -72,18 +94,50 @@ class ZHAGroup(LogMixin):
|
|||
member_ieee_addresses[0]
|
||||
].async_remove_from_group(self.group_id)
|
||||
|
||||
@property
|
||||
def member_entity_ids(self) -> List[str]:
|
||||
"""Return the ZHA entity ids for all entities for the members of this group."""
|
||||
all_entity_ids: List[str] = []
|
||||
for device in self.members:
|
||||
entities = async_entries_for_device(
|
||||
self._zha_gateway.ha_entity_registry, device.device_id
|
||||
)
|
||||
for entity in entities:
|
||||
all_entity_ids.append(entity.entity_id)
|
||||
return all_entity_ids
|
||||
|
||||
@property
|
||||
def domain_entity_ids(self) -> List[str]:
|
||||
"""Return entity ids from the entity domain for this group."""
|
||||
if self.entity_domain is None:
|
||||
return
|
||||
domain_entity_ids: List[str] = []
|
||||
for device in self.members:
|
||||
entities = async_entries_for_device(
|
||||
self._zha_gateway.ha_entity_registry, device.device_id
|
||||
)
|
||||
domain_entity_ids.extend(
|
||||
[
|
||||
entity.entity_id
|
||||
for entity in entities
|
||||
if entity.domain == self.entity_domain
|
||||
]
|
||||
)
|
||||
return domain_entity_ids
|
||||
|
||||
@callback
|
||||
def async_get_info(self):
|
||||
def async_get_info(self) -> Dict[str, Any]:
|
||||
"""Get ZHA group info."""
|
||||
group_info = {}
|
||||
group_info: Dict[str, Any] = {}
|
||||
group_info["group_id"] = self.group_id
|
||||
group_info["entity_domain"] = self.entity_domain
|
||||
group_info["name"] = self.name
|
||||
group_info["members"] = [
|
||||
zha_device.async_get_info() for zha_device in self.members
|
||||
]
|
||||
return group_info
|
||||
|
||||
def log(self, level, msg, *args):
|
||||
def log(self, level: int, msg: str, *args):
|
||||
"""Log a message."""
|
||||
msg = f"[%s](%s): {msg}"
|
||||
args = (self.name, self.group_id) + args
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Helpers for Zigbee Home Automation."""
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any, Callable, Iterator, List, Optional
|
||||
|
||||
import zigpy.types
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import State, callback
|
||||
|
||||
from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, DATA_ZHA, DATA_ZHA_GATEWAY
|
||||
from .registries import BINDABLE_CLUSTERS
|
||||
|
@ -85,6 +86,45 @@ async def async_get_zha_device(hass, device_id):
|
|||
return zha_gateway.devices[ieee]
|
||||
|
||||
|
||||
def find_state_attributes(states: List[State], key: str) -> Iterator[Any]:
|
||||
"""Find attributes with matching key from states."""
|
||||
for state in states:
|
||||
value = state.attributes.get(key)
|
||||
if value is not None:
|
||||
yield value
|
||||
|
||||
|
||||
def mean_int(*args):
|
||||
"""Return the mean of the supplied values."""
|
||||
return int(sum(args) / len(args))
|
||||
|
||||
|
||||
def mean_tuple(*args):
|
||||
"""Return the mean values along the columns of the supplied values."""
|
||||
return tuple(sum(l) / len(l) for l in zip(*args))
|
||||
|
||||
|
||||
def reduce_attribute(
|
||||
states: List[State],
|
||||
key: str,
|
||||
default: Optional[Any] = None,
|
||||
reduce: Callable[..., Any] = mean_int,
|
||||
) -> Any:
|
||||
"""Find the first attribute matching key from states.
|
||||
|
||||
If none are found, return default.
|
||||
"""
|
||||
attrs = list(find_state_attributes(states, key))
|
||||
|
||||
if not attrs:
|
||||
return default
|
||||
|
||||
if len(attrs) == 1:
|
||||
return attrs[0]
|
||||
|
||||
return reduce(*attrs)
|
||||
|
||||
|
||||
class LogMixin:
|
||||
"""Log helper."""
|
||||
|
||||
|
|
|
@ -32,6 +32,8 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType
|
|||
from .decorators import CALLABLE_T, DictRegistry, SetRegistry
|
||||
from .typing import ChannelType
|
||||
|
||||
GROUP_ENTITY_DOMAINS = [LIGHT]
|
||||
|
||||
SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02
|
||||
SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000
|
||||
SMARTTHINGS_HUMIDITY_CLUSTER = 0xFC45
|
||||
|
@ -275,6 +277,9 @@ RegistryDictType = Dict[
|
|||
] # pylint: disable=invalid-name
|
||||
|
||||
|
||||
GroupRegistryDictType = Dict[str, CALLABLE_T] # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ZHAEntityRegistry:
|
||||
"""Channel to ZHA Entity mapping."""
|
||||
|
||||
|
@ -282,6 +287,7 @@ class ZHAEntityRegistry:
|
|||
"""Initialize Registry instance."""
|
||||
self._strict_registry: RegistryDictType = collections.defaultdict(dict)
|
||||
self._loose_registry: RegistryDictType = collections.defaultdict(dict)
|
||||
self._group_registry: GroupRegistryDictType = {}
|
||||
|
||||
def get_entity(
|
||||
self,
|
||||
|
@ -300,6 +306,10 @@ class ZHAEntityRegistry:
|
|||
|
||||
return default, []
|
||||
|
||||
def get_group_entity(self, component: str) -> CALLABLE_T:
|
||||
"""Match a ZHA group to a ZHA Entity class."""
|
||||
return self._group_registry.get(component)
|
||||
|
||||
def strict_match(
|
||||
self,
|
||||
component: str,
|
||||
|
@ -350,5 +360,15 @@ class ZHAEntityRegistry:
|
|||
|
||||
return decorator
|
||||
|
||||
def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
|
||||
"""Decorate a group match rule."""
|
||||
|
||||
def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:
|
||||
"""Register a group match rule."""
|
||||
self._group_registry[component] = zha_ent
|
||||
return zha_ent
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
ZHA_ENTITIES = ZHAEntityRegistry()
|
||||
|
|
|
@ -10,6 +10,8 @@ from homeassistant.core import callback
|
|||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .typing import ZhaDeviceType, ZhaGroupType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_REGISTRY = "zha_storage"
|
||||
|
@ -28,52 +30,96 @@ class ZhaDeviceEntry:
|
|||
last_seen = attr.ib(type=float, default=None)
|
||||
|
||||
|
||||
class ZhaDeviceStorage:
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class ZhaGroupEntry:
|
||||
"""Zha Group storage Entry."""
|
||||
|
||||
name = attr.ib(type=str, default=None)
|
||||
group_id = attr.ib(type=int, default=None)
|
||||
entity_domain = attr.ib(type=float, default=None)
|
||||
|
||||
|
||||
class ZhaStorage:
|
||||
"""Class to hold a registry of zha devices."""
|
||||
|
||||
def __init__(self, hass: HomeAssistantType) -> None:
|
||||
"""Initialize the zha device storage."""
|
||||
self.hass = hass
|
||||
self.hass: HomeAssistantType = hass
|
||||
self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
|
||||
self.groups: MutableMapping[str, ZhaGroupEntry] = {}
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_create(self, device) -> ZhaDeviceEntry:
|
||||
def async_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Create a new ZhaDeviceEntry."""
|
||||
device_entry = ZhaDeviceEntry(
|
||||
device_entry: ZhaDeviceEntry = ZhaDeviceEntry(
|
||||
name=device.name, ieee=str(device.ieee), last_seen=device.last_seen
|
||||
)
|
||||
self.devices[device_entry.ieee] = device_entry
|
||||
|
||||
return self.async_update(device)
|
||||
return self.async_update_device(device)
|
||||
|
||||
@callback
|
||||
def async_get_or_create(self, device) -> ZhaDeviceEntry:
|
||||
def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create a new ZhaGroupEntry."""
|
||||
group_entry: ZhaGroupEntry = ZhaGroupEntry(
|
||||
name=group.name,
|
||||
group_id=str(group.group_id),
|
||||
entity_domain=group.entity_domain,
|
||||
)
|
||||
self.groups[str(group.group_id)] = group_entry
|
||||
return self.async_update_group(group)
|
||||
|
||||
@callback
|
||||
def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Create a new ZhaDeviceEntry."""
|
||||
ieee_str = str(device.ieee)
|
||||
ieee_str: str = str(device.ieee)
|
||||
if ieee_str in self.devices:
|
||||
return self.devices[ieee_str]
|
||||
return self.async_create(device)
|
||||
return self.async_create_device(device)
|
||||
|
||||
@callback
|
||||
def async_create_or_update(self, device) -> ZhaDeviceEntry:
|
||||
def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create a new ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
if group_id in self.groups:
|
||||
return self.groups[group_id]
|
||||
return self.async_create_group(group)
|
||||
|
||||
@callback
|
||||
def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Create or update a ZhaDeviceEntry."""
|
||||
if str(device.ieee) in self.devices:
|
||||
return self.async_update(device)
|
||||
return self.async_create(device)
|
||||
return self.async_update_device(device)
|
||||
return self.async_create_device(device)
|
||||
|
||||
@callback
|
||||
def async_delete(self, device) -> None:
|
||||
def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create or update a ZhaGroupEntry."""
|
||||
if str(group.group_id) in self.groups:
|
||||
return self.async_update_group(group)
|
||||
return self.async_create_group(group)
|
||||
|
||||
@callback
|
||||
def async_delete_device(self, device: ZhaDeviceType) -> None:
|
||||
"""Delete ZhaDeviceEntry."""
|
||||
ieee_str = str(device.ieee)
|
||||
ieee_str: str = str(device.ieee)
|
||||
if ieee_str in self.devices:
|
||||
del self.devices[ieee_str]
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_update(self, device) -> ZhaDeviceEntry:
|
||||
def async_delete_group(self, group: ZhaGroupType) -> None:
|
||||
"""Delete ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
if group_id in self.groups:
|
||||
del self.groups[group_id]
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Update name of ZhaDeviceEntry."""
|
||||
ieee_str = str(device.ieee)
|
||||
ieee_str: str = str(device.ieee)
|
||||
old = self.devices[ieee_str]
|
||||
|
||||
changes = {}
|
||||
|
@ -83,11 +129,25 @@ class ZhaDeviceStorage:
|
|||
self.async_schedule_save()
|
||||
return new
|
||||
|
||||
@callback
|
||||
def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Update name of ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
old = self.groups[group_id]
|
||||
|
||||
changes = {}
|
||||
changes["entity_domain"] = group.entity_domain
|
||||
|
||||
new = self.groups[group_id] = attr.evolve(old, **changes)
|
||||
self.async_schedule_save()
|
||||
return new
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the registry of zha device entries."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict()
|
||||
groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict()
|
||||
|
||||
if data is not None:
|
||||
for device in data["devices"]:
|
||||
|
@ -97,7 +157,18 @@ class ZhaDeviceStorage:
|
|||
last_seen=device["last_seen"] if "last_seen" in device else None,
|
||||
)
|
||||
|
||||
if "groups" in data:
|
||||
for group in data["groups"]:
|
||||
groups[group["group_id"]] = ZhaGroupEntry(
|
||||
name=group["name"],
|
||||
group_id=group["group_id"],
|
||||
entity_domain=group["entity_domain"]
|
||||
if "entity_domain" in group
|
||||
else None,
|
||||
)
|
||||
|
||||
self.devices = devices
|
||||
self.groups = groups
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
|
@ -118,21 +189,29 @@ class ZhaDeviceStorage:
|
|||
for entry in self.devices.values()
|
||||
]
|
||||
|
||||
data["groups"] = [
|
||||
{
|
||||
"name": entry.name,
|
||||
"group_id": entry.group_id,
|
||||
"entity_domain": entry.entity_domain,
|
||||
}
|
||||
for entry in self.groups.values()
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_registry(hass: HomeAssistantType) -> ZhaDeviceStorage:
|
||||
async def async_get_registry(hass: HomeAssistantType) -> ZhaStorage:
|
||||
"""Return zha device storage instance."""
|
||||
task = hass.data.get(DATA_REGISTRY)
|
||||
|
||||
if task is None:
|
||||
|
||||
async def _load_reg() -> ZhaDeviceStorage:
|
||||
registry = ZhaDeviceStorage(hass)
|
||||
async def _load_reg() -> ZhaStorage:
|
||||
registry = ZhaStorage(hass)
|
||||
await registry.async_load()
|
||||
return registry
|
||||
|
||||
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
|
||||
|
||||
return cast(ZhaDeviceStorage, await task)
|
||||
return cast(ZhaStorage, await task)
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar
|
|||
|
||||
import zigpy.device
|
||||
import zigpy.endpoint
|
||||
import zigpy.group
|
||||
import zigpy.zcl
|
||||
import zigpy.zdo
|
||||
|
||||
|
@ -17,9 +18,11 @@ ZDOChannelType = "ZDOChannel"
|
|||
ZhaDeviceType = "ZHADevice"
|
||||
ZhaEntityType = "ZHAEntity"
|
||||
ZhaGatewayType = "ZHAGateway"
|
||||
ZhaGroupType = "ZHAGroupType"
|
||||
ZigpyClusterType = zigpy.zcl.Cluster
|
||||
ZigpyDeviceType = zigpy.device.Device
|
||||
ZigpyEndpointType = zigpy.endpoint.Endpoint
|
||||
ZigpyGroupType = zigpy.group.Group
|
||||
ZigpyZdoType = zigpy.zdo.ZDO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -38,3 +41,4 @@ if TYPE_CHECKING:
|
|||
ZhaDeviceType = homeassistant.components.zha.core.device.ZHADevice
|
||||
ZhaEntityType = homeassistant.components.zha.entity.ZhaEntity
|
||||
ZhaGatewayType = homeassistant.components.zha.core.gateway.ZHAGateway
|
||||
ZhaGroupType = homeassistant.components.zha.core.group.ZHAGroup
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Dict, List
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import entity
|
||||
|
@ -20,6 +21,7 @@ from .core.const import (
|
|||
SIGNAL_REMOVE,
|
||||
)
|
||||
from .core.helpers import LogMixin
|
||||
from .core.typing import CALLABLE_T, ChannelsType, ChannelType, ZhaDeviceType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,30 +29,24 @@ ENTITY_SUFFIX = "entity_suffix"
|
|||
RESTART_GRACE_PERIOD = 7200 # 2 hours
|
||||
|
||||
|
||||
class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
||||
class BaseZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
||||
"""A base class for ZHA entities."""
|
||||
|
||||
def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs):
|
||||
def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs):
|
||||
"""Init ZHA entity."""
|
||||
self._force_update = False
|
||||
self._should_poll = False
|
||||
self._unique_id = unique_id
|
||||
ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]])
|
||||
ch_names = [ch.cluster.ep_attribute for ch in channels]
|
||||
ch_names = ", ".join(sorted(ch_names))
|
||||
self._name = f"{zha_device.name} {ieeetail} {ch_names}"
|
||||
self._state = None
|
||||
self._device_state_attributes = {}
|
||||
self._zha_device = zha_device
|
||||
self.cluster_channels = {}
|
||||
self._available = False
|
||||
self._unsubs = []
|
||||
self.remove_future = None
|
||||
for channel in channels:
|
||||
self.cluster_channels[channel.name] = channel
|
||||
self._name: str = ""
|
||||
self._force_update: bool = False
|
||||
self._should_poll: bool = False
|
||||
self._unique_id: str = unique_id
|
||||
self._state: Any = None
|
||||
self._device_state_attributes: Dict[str, Any] = {}
|
||||
self._zha_device: ZhaDeviceType = zha_device
|
||||
self._available: bool = False
|
||||
self._unsubs: List[CALLABLE_T] = []
|
||||
self.remove_future: Awaitable[None] = None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Return Entity's default name."""
|
||||
return self._name
|
||||
|
||||
|
@ -60,12 +56,12 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
return self._unique_id
|
||||
|
||||
@property
|
||||
def zha_device(self):
|
||||
def zha_device(self) -> ZhaDeviceType:
|
||||
"""Return the zha device this entity is attached to."""
|
||||
return self._zha_device
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
def device_state_attributes(self) -> Dict[str, Any]:
|
||||
"""Return device specific state attributes."""
|
||||
return self._device_state_attributes
|
||||
|
||||
|
@ -80,7 +76,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
return self._should_poll
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
def device_info(self) -> Dict[str, Any]:
|
||||
"""Return a device description for device registry."""
|
||||
zha_device_info = self._zha_device.device_info
|
||||
ieee = zha_device_info["ieee"]
|
||||
|
@ -94,31 +90,94 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
}
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
def available(self) -> bool:
|
||||
"""Return entity availability."""
|
||||
return self._available
|
||||
|
||||
@callback
|
||||
def async_set_available(self, available):
|
||||
def async_set_available(self, available: bool) -> None:
|
||||
"""Set entity availability."""
|
||||
self._available = available
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def async_update_state_attribute(self, key, value):
|
||||
def async_update_state_attribute(self, key: str, value: Any) -> None:
|
||||
"""Update a single device state attribute."""
|
||||
self._device_state_attributes.update({key: value})
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def async_set_state(self, attr_id, attr_name, value):
|
||||
def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None:
|
||||
"""Set the entity state."""
|
||||
pass
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
self.remove_future = asyncio.Future()
|
||||
await self.async_accept_signal(
|
||||
None,
|
||||
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
|
||||
self.async_remove,
|
||||
signal_override=True,
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Disconnect entity object when removed."""
|
||||
for unsub in self._unsubs[:]:
|
||||
unsub()
|
||||
self._unsubs.remove(unsub)
|
||||
self.zha_device.gateway.remove_entity_reference(self)
|
||||
self.remove_future.set_result(True)
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state) -> None:
|
||||
"""Restore previous state."""
|
||||
pass
|
||||
|
||||
async def async_accept_signal(
|
||||
self, channel: ChannelType, signal: str, func: CALLABLE_T, signal_override=False
|
||||
):
|
||||
"""Accept a signal from a channel."""
|
||||
unsub = None
|
||||
if signal_override:
|
||||
unsub = async_dispatcher_connect(self.hass, signal, func)
|
||||
else:
|
||||
unsub = async_dispatcher_connect(
|
||||
self.hass, f"{channel.unique_id}_{signal}", func
|
||||
)
|
||||
self._unsubs.append(unsub)
|
||||
|
||||
def log(self, level: int, msg: str, *args):
|
||||
"""Log a message."""
|
||||
msg = f"%s: {msg}"
|
||||
args = (self.entity_id,) + args
|
||||
_LOGGER.log(level, msg, *args)
|
||||
|
||||
|
||||
class ZhaEntity(BaseZhaEntity):
|
||||
"""A base class for non group ZHA entities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unique_id: str,
|
||||
zha_device: ZhaDeviceType,
|
||||
channels: ChannelsType,
|
||||
**kwargs,
|
||||
):
|
||||
"""Init ZHA entity."""
|
||||
super().__init__(unique_id, zha_device, **kwargs)
|
||||
ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]])
|
||||
ch_names = [ch.cluster.ep_attribute for ch in channels]
|
||||
ch_names = ", ".join(sorted(ch_names))
|
||||
self._name: str = f"{zha_device.name} {ieeetail} {ch_names}"
|
||||
self.cluster_channels: Dict[str, ChannelType] = {}
|
||||
for channel in channels:
|
||||
self.cluster_channels[channel.name] = channel
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_check_recently_seen()
|
||||
await self.async_accept_signal(
|
||||
None,
|
||||
|
@ -126,12 +185,6 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
self.async_set_available,
|
||||
signal_override=True,
|
||||
)
|
||||
await self.async_accept_signal(
|
||||
None,
|
||||
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
|
||||
self.async_remove,
|
||||
signal_override=True,
|
||||
)
|
||||
self._zha_device.gateway.register_entity_reference(
|
||||
self._zha_device.ieee,
|
||||
self.entity_id,
|
||||
|
@ -141,7 +194,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
self.remove_future,
|
||||
)
|
||||
|
||||
async def async_check_recently_seen(self):
|
||||
async def async_check_recently_seen(self) -> None:
|
||||
"""Check if the device was seen within the last 2 hours."""
|
||||
last_state = await self.async_get_last_state()
|
||||
if (
|
||||
|
@ -155,38 +208,8 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
self.async_restore_last_state(last_state)
|
||||
self._zha_device.set_available(True)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Disconnect entity object when removed."""
|
||||
for unsub in self._unsubs[:]:
|
||||
unsub()
|
||||
self._unsubs.remove(unsub)
|
||||
self.zha_device.gateway.remove_entity_reference(self)
|
||||
self.remove_future.set_result(True)
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state):
|
||||
"""Restore previous state."""
|
||||
pass
|
||||
|
||||
async def async_update(self):
|
||||
async def async_update(self) -> None:
|
||||
"""Retrieve latest state."""
|
||||
for channel in self.cluster_channels.values():
|
||||
if hasattr(channel, "async_update"):
|
||||
await channel.async_update()
|
||||
|
||||
async def async_accept_signal(self, channel, signal, func, signal_override=False):
|
||||
"""Accept a signal from a channel."""
|
||||
unsub = None
|
||||
if signal_override:
|
||||
unsub = async_dispatcher_connect(self.hass, signal, func)
|
||||
else:
|
||||
unsub = async_dispatcher_connect(
|
||||
self.hass, f"{channel.unique_id}_{signal}", func
|
||||
)
|
||||
self._unsubs.append(unsub)
|
||||
|
||||
def log(self, level, msg, *args):
|
||||
"""Log a message."""
|
||||
msg = f"%s: {msg}"
|
||||
args = (self.entity_id,) + args
|
||||
_LOGGER.log(level, msg, *args)
|
||||
|
|
|
@ -1,19 +1,44 @@
|
|||
"""Lights on Zigbee Home Automation networks."""
|
||||
from collections import Counter
|
||||
from datetime import timedelta
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from zigpy.zcl.clusters.general import Identify, LevelControl, OnOff
|
||||
from zigpy.zcl.clusters.lighting import Color
|
||||
from zigpy.zcl.foundation import Status
|
||||
|
||||
from homeassistant.components import light
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.light import (
|
||||
ATTR_BRIGHTNESS,
|
||||
ATTR_COLOR_TEMP,
|
||||
ATTR_EFFECT,
|
||||
ATTR_EFFECT_LIST,
|
||||
ATTR_HS_COLOR,
|
||||
ATTR_MAX_MIREDS,
|
||||
ATTR_MIN_MIREDS,
|
||||
ATTR_WHITE_VALUE,
|
||||
SUPPORT_BRIGHTNESS,
|
||||
SUPPORT_COLOR,
|
||||
SUPPORT_COLOR_TEMP,
|
||||
SUPPORT_EFFECT,
|
||||
SUPPORT_FLASH,
|
||||
SUPPORT_TRANSITION,
|
||||
SUPPORT_WHITE_VALUE,
|
||||
)
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE
|
||||
from homeassistant.core import CALLBACK_TYPE, State, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_state_change,
|
||||
async_track_time_interval,
|
||||
)
|
||||
import homeassistant.util.color as color_util
|
||||
|
||||
from .core import discovery
|
||||
from .core import discovery, helpers
|
||||
from .core.const import (
|
||||
CHANNEL_COLOR,
|
||||
CHANNEL_LEVEL,
|
||||
|
@ -25,11 +50,12 @@ from .core.const import (
|
|||
EFFECT_DEFAULT_VARIANT,
|
||||
SIGNAL_ADD_ENTITIES,
|
||||
SIGNAL_ATTR_UPDATED,
|
||||
SIGNAL_REMOVE_GROUP,
|
||||
SIGNAL_SET_LEVEL,
|
||||
)
|
||||
from .core.registries import ZHA_ENTITIES
|
||||
from .core.typing import ZhaDeviceType
|
||||
from .entity import ZhaEntity
|
||||
from .entity import BaseZhaEntity, ZhaEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,8 +72,19 @@ FLASH_EFFECTS = {light.FLASH_SHORT: EFFECT_BLINK, light.FLASH_LONG: EFFECT_BREAT
|
|||
|
||||
UNSUPPORTED_ATTRIBUTE = 0x86
|
||||
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, light.DOMAIN)
|
||||
GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, light.DOMAIN)
|
||||
PARALLEL_UPDATES = 0
|
||||
|
||||
SUPPORT_GROUP_LIGHT = (
|
||||
SUPPORT_BRIGHTNESS
|
||||
| SUPPORT_COLOR_TEMP
|
||||
| SUPPORT_EFFECT
|
||||
| SUPPORT_FLASH
|
||||
| SUPPORT_COLOR
|
||||
| SUPPORT_TRANSITION
|
||||
| SUPPORT_WHITE_VALUE
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Set up the Zigbee Home Automation light from config entry."""
|
||||
|
@ -63,48 +100,35 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
|
|||
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
|
||||
|
||||
|
||||
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL})
|
||||
class Light(ZhaEntity, light.Light):
|
||||
"""Representation of a ZHA or ZLL light."""
|
||||
class BaseLight(BaseZhaEntity, light.Light):
|
||||
"""Operations common to all light entities."""
|
||||
|
||||
_REFRESH_INTERVAL = (45, 75)
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the light."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._is_on: bool = False
|
||||
self._available: bool = False
|
||||
self._brightness: Optional[int] = None
|
||||
self._off_brightness: Optional[int] = None
|
||||
self._hs_color: Optional[Tuple[float, float]] = None
|
||||
self._color_temp: Optional[int] = None
|
||||
self._min_mireds: Optional[int] = 154
|
||||
self._max_mireds: Optional[int] = 500
|
||||
self._white_value: Optional[int] = None
|
||||
self._effect_list: Optional[List[str]] = None
|
||||
self._effect: Optional[str] = None
|
||||
self._supported_features: int = 0
|
||||
self._state: bool = False
|
||||
self._on_off_channel = None
|
||||
self._level_channel = None
|
||||
self._color_channel = None
|
||||
self._identify_channel = None
|
||||
|
||||
def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs):
|
||||
"""Initialize the ZHA light."""
|
||||
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||
self._supported_features = 0
|
||||
self._color_temp = None
|
||||
self._hs_color = None
|
||||
self._brightness = None
|
||||
self._off_brightness = None
|
||||
self._effect_list = []
|
||||
self._effect = None
|
||||
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)
|
||||
self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL)
|
||||
self._color_channel = self.cluster_channels.get(CHANNEL_COLOR)
|
||||
self._identify_channel = self.zha_device.channels.identify_ch
|
||||
self._cancel_refresh_handle = None
|
||||
|
||||
if self._level_channel:
|
||||
self._supported_features |= light.SUPPORT_BRIGHTNESS
|
||||
self._supported_features |= light.SUPPORT_TRANSITION
|
||||
self._brightness = 0
|
||||
|
||||
if self._color_channel:
|
||||
color_capabilities = self._color_channel.get_color_capabilities()
|
||||
if color_capabilities & CAPABILITIES_COLOR_TEMP:
|
||||
self._supported_features |= light.SUPPORT_COLOR_TEMP
|
||||
|
||||
if color_capabilities & CAPABILITIES_COLOR_XY:
|
||||
self._supported_features |= light.SUPPORT_COLOR
|
||||
self._hs_color = (0, 0)
|
||||
|
||||
if color_capabilities & CAPABILITIES_COLOR_LOOP:
|
||||
self._supported_features |= light.SUPPORT_EFFECT
|
||||
self._effect_list.append(light.EFFECT_COLORLOOP)
|
||||
|
||||
if self._identify_channel:
|
||||
self._supported_features |= light.SUPPORT_FLASH
|
||||
@property
|
||||
def device_state_attributes(self) -> Dict[str, Any]:
|
||||
"""Return state attributes."""
|
||||
attributes = {"off_brightness": self._off_brightness}
|
||||
return attributes
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
|
@ -118,12 +142,6 @@ class Light(ZhaEntity, light.Light):
|
|||
"""Return the brightness of this light."""
|
||||
return self._brightness
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
"""Return state attributes."""
|
||||
attributes = {"off_brightness": self._off_brightness}
|
||||
return attributes
|
||||
|
||||
def set_level(self, value):
|
||||
"""Set the brightness of this light between 0..254.
|
||||
|
||||
|
@ -160,49 +178,6 @@ class Light(ZhaEntity, light.Light):
|
|||
"""Flag supported features."""
|
||||
return self._supported_features
|
||||
|
||||
@callback
|
||||
def async_set_state(self, attr_id, attr_name, value):
|
||||
"""Set the state."""
|
||||
self._state = bool(value)
|
||||
if value:
|
||||
self._off_brightness = None
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
|
||||
)
|
||||
if self._level_channel:
|
||||
await self.async_accept_signal(
|
||||
self._level_channel, SIGNAL_SET_LEVEL, self.set_level
|
||||
)
|
||||
refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL])
|
||||
self._cancel_refresh_handle = async_track_time_interval(
|
||||
self.hass, self._refresh, timedelta(seconds=refresh_interval)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Disconnect entity object when removed."""
|
||||
self._cancel_refresh_handle()
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state):
|
||||
"""Restore previous state."""
|
||||
self._state = last_state.state == STATE_ON
|
||||
if "brightness" in last_state.attributes:
|
||||
self._brightness = last_state.attributes["brightness"]
|
||||
if "off_brightness" in last_state.attributes:
|
||||
self._off_brightness = last_state.attributes["off_brightness"]
|
||||
if "color_temp" in last_state.attributes:
|
||||
self._color_temp = last_state.attributes["color_temp"]
|
||||
if "hs_color" in last_state.attributes:
|
||||
self._hs_color = last_state.attributes["hs_color"]
|
||||
if "effect" in last_state.attributes:
|
||||
self._effect = last_state.attributes["effect"]
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn the entity on."""
|
||||
transition = kwargs.get(light.ATTR_TRANSITION)
|
||||
|
@ -331,6 +306,86 @@ class Light(ZhaEntity, light.Light):
|
|||
|
||||
self.async_write_ha_state()
|
||||
|
||||
|
||||
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL})
|
||||
class Light(ZhaEntity, BaseLight):
|
||||
"""Representation of a ZHA or ZLL light."""
|
||||
|
||||
_REFRESH_INTERVAL = (45, 75)
|
||||
|
||||
def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs):
|
||||
"""Initialize the ZHA light."""
|
||||
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)
|
||||
self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL)
|
||||
self._color_channel = self.cluster_channels.get(CHANNEL_COLOR)
|
||||
self._identify_channel = self.zha_device.channels.identify_ch
|
||||
self._cancel_refresh_handle = None
|
||||
|
||||
if self._level_channel:
|
||||
self._supported_features |= light.SUPPORT_BRIGHTNESS
|
||||
self._supported_features |= light.SUPPORT_TRANSITION
|
||||
self._brightness = 0
|
||||
|
||||
if self._color_channel:
|
||||
color_capabilities = self._color_channel.get_color_capabilities()
|
||||
if color_capabilities & CAPABILITIES_COLOR_TEMP:
|
||||
self._supported_features |= light.SUPPORT_COLOR_TEMP
|
||||
|
||||
if color_capabilities & CAPABILITIES_COLOR_XY:
|
||||
self._supported_features |= light.SUPPORT_COLOR
|
||||
self._hs_color = (0, 0)
|
||||
|
||||
if color_capabilities & CAPABILITIES_COLOR_LOOP:
|
||||
self._supported_features |= light.SUPPORT_EFFECT
|
||||
self._effect_list.append(light.EFFECT_COLORLOOP)
|
||||
|
||||
if self._identify_channel:
|
||||
self._supported_features |= light.SUPPORT_FLASH
|
||||
|
||||
@callback
|
||||
def async_set_state(self, attr_id, attr_name, value):
|
||||
"""Set the state."""
|
||||
self._state = bool(value)
|
||||
if value:
|
||||
self._off_brightness = None
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
|
||||
)
|
||||
if self._level_channel:
|
||||
await self.async_accept_signal(
|
||||
self._level_channel, SIGNAL_SET_LEVEL, self.set_level
|
||||
)
|
||||
refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL])
|
||||
self._cancel_refresh_handle = async_track_time_interval(
|
||||
self.hass, self._refresh, timedelta(seconds=refresh_interval)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Disconnect entity object when removed."""
|
||||
self._cancel_refresh_handle()
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state):
|
||||
"""Restore previous state."""
|
||||
self._state = last_state.state == STATE_ON
|
||||
if "brightness" in last_state.attributes:
|
||||
self._brightness = last_state.attributes["brightness"]
|
||||
if "off_brightness" in last_state.attributes:
|
||||
self._off_brightness = last_state.attributes["off_brightness"]
|
||||
if "color_temp" in last_state.attributes:
|
||||
self._color_temp = last_state.attributes["color_temp"]
|
||||
if "hs_color" in last_state.attributes:
|
||||
self._hs_color = last_state.attributes["hs_color"]
|
||||
if "effect" in last_state.attributes:
|
||||
self._effect = last_state.attributes["effect"]
|
||||
|
||||
async def async_update(self):
|
||||
"""Attempt to retrieve on off state from the light."""
|
||||
await super().async_update()
|
||||
|
@ -410,3 +465,99 @@ class HueLight(Light):
|
|||
"""Representation of a HUE light which does not report attributes."""
|
||||
|
||||
_REFRESH_INTERVAL = (3, 5)
|
||||
|
||||
|
||||
@GROUP_MATCH()
|
||||
class LightGroup(BaseLight):
|
||||
"""Representation of a light group."""
|
||||
|
||||
def __init__(
|
||||
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
|
||||
) -> None:
|
||||
"""Initialize a light group."""
|
||||
super().__init__(unique_id, zha_device, **kwargs)
|
||||
self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
|
||||
self._group_id: int = group_id
|
||||
self._entity_ids: List[str] = entity_ids
|
||||
group = self.zha_device.gateway.get_group(self._group_id)
|
||||
self._on_off_channel = group.endpoint[OnOff.cluster_id]
|
||||
self._level_channel = group.endpoint[LevelControl.cluster_id]
|
||||
self._color_channel = group.endpoint[Color.cluster_id]
|
||||
self._identify_channel = group.endpoint[Identify.cluster_id]
|
||||
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register callbacks."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
None,
|
||||
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
|
||||
self.async_remove,
|
||||
signal_override=True,
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_state_changed_listener(
|
||||
entity_id: str, old_state: State, new_state: State
|
||||
):
|
||||
"""Handle child updates."""
|
||||
self.async_schedule_update_ha_state(True)
|
||||
|
||||
self._async_unsub_state_changed = async_track_state_change(
|
||||
self.hass, self._entity_ids, async_state_changed_listener
|
||||
)
|
||||
await self.async_update()
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Handle removal from Home Assistant."""
|
||||
await super().async_will_remove_from_hass()
|
||||
if self._async_unsub_state_changed is not None:
|
||||
self._async_unsub_state_changed()
|
||||
self._async_unsub_state_changed = None
|
||||
|
||||
async def async_update(self) -> None:
|
||||
"""Query all members and determine the light group state."""
|
||||
all_states = [self.hass.states.get(x) for x in self._entity_ids]
|
||||
states: List[State] = list(filter(None, all_states))
|
||||
on_states = [state for state in states if state.state == STATE_ON]
|
||||
|
||||
self._is_on = len(on_states) > 0
|
||||
self._available = any(state.state != STATE_UNAVAILABLE for state in states)
|
||||
|
||||
self._brightness = helpers.reduce_attribute(on_states, ATTR_BRIGHTNESS)
|
||||
|
||||
self._hs_color = helpers.reduce_attribute(
|
||||
on_states, ATTR_HS_COLOR, reduce=helpers.mean_tuple
|
||||
)
|
||||
|
||||
self._white_value = helpers.reduce_attribute(on_states, ATTR_WHITE_VALUE)
|
||||
|
||||
self._color_temp = helpers.reduce_attribute(on_states, ATTR_COLOR_TEMP)
|
||||
self._min_mireds = helpers.reduce_attribute(
|
||||
states, ATTR_MIN_MIREDS, default=154, reduce=min
|
||||
)
|
||||
self._max_mireds = helpers.reduce_attribute(
|
||||
states, ATTR_MAX_MIREDS, default=500, reduce=max
|
||||
)
|
||||
|
||||
self._effect_list = None
|
||||
all_effect_lists = list(helpers.find_state_attributes(states, ATTR_EFFECT_LIST))
|
||||
if all_effect_lists:
|
||||
# Merge all effects from all effect_lists with a union merge.
|
||||
self._effect_list = list(set().union(*all_effect_lists))
|
||||
|
||||
self._effect = None
|
||||
all_effects = list(helpers.find_state_attributes(on_states, ATTR_EFFECT))
|
||||
if all_effects:
|
||||
# Report the most common effect.
|
||||
effects_count = Counter(itertools.chain(all_effects))
|
||||
self._effect = effects_count.most_common(1)[0][0]
|
||||
|
||||
self._supported_features = 0
|
||||
for support in helpers.find_state_attributes(states, ATTR_SUPPORTED_FEATURES):
|
||||
# Merge supported features by emulating support for every feature
|
||||
# we find.
|
||||
self._supported_features |= support
|
||||
# Bitwise-and the supported features with the GroupedLight's features
|
||||
# so that we don't break in the future when a new feature is added.
|
||||
self._supported_features &= SUPPORT_GROUP_LIGHT
|
||||
|
|
|
@ -3,6 +3,8 @@ import time
|
|||
from unittest.mock import Mock
|
||||
|
||||
from asynctest import CoroutineMock
|
||||
from zigpy.device import Device as zigpy_dev
|
||||
from zigpy.endpoint import Endpoint as zigpy_ep
|
||||
import zigpy.profiles.zha
|
||||
import zigpy.types
|
||||
import zigpy.zcl
|
||||
|
@ -24,6 +26,7 @@ class FakeEndpoint:
|
|||
self.in_clusters = {}
|
||||
self.out_clusters = {}
|
||||
self._cluster_attr = {}
|
||||
self.member_of = {}
|
||||
self.status = 1
|
||||
self.manufacturer = manufacturer
|
||||
self.model = model
|
||||
|
@ -45,6 +48,19 @@ class FakeEndpoint:
|
|||
patch_cluster(cluster)
|
||||
self.out_clusters[cluster_id] = cluster
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
"""Fake being Zigpy endpoint."""
|
||||
return zigpy_ep
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique id for the endpoint."""
|
||||
return self.device.ieee, self.endpoint_id
|
||||
|
||||
|
||||
FakeEndpoint.add_to_group = zigpy_ep.add_to_group
|
||||
|
||||
|
||||
def patch_cluster(cluster):
|
||||
"""Patch a cluster for testing."""
|
||||
|
@ -56,17 +72,19 @@ def patch_cluster(cluster):
|
|||
cluster.read_attributes_raw = Mock()
|
||||
cluster.unbind = CoroutineMock(return_value=[0])
|
||||
cluster.write_attributes = CoroutineMock(return_value=[0])
|
||||
if cluster.cluster_id == 4:
|
||||
cluster.add = CoroutineMock(return_value=[0])
|
||||
|
||||
|
||||
class FakeDevice:
|
||||
"""Fake device for mocking zigpy."""
|
||||
|
||||
def __init__(self, app, ieee, manufacturer, model, node_desc=None):
|
||||
def __init__(self, app, ieee, manufacturer, model, node_desc=None, nwk=0xB79C):
|
||||
"""Init fake device."""
|
||||
self._application = app
|
||||
self.application = app
|
||||
self.ieee = zigpy.types.EUI64.convert(ieee)
|
||||
self.nwk = 0xB79C
|
||||
self.nwk = nwk
|
||||
self.zdo = Mock()
|
||||
self.endpoints = {0: self.zdo}
|
||||
self.lqi = 255
|
||||
|
@ -78,13 +96,15 @@ class FakeDevice:
|
|||
self.manufacturer = manufacturer
|
||||
self.model = model
|
||||
self.node_desc = zigpy.zdo.types.NodeDescriptor()
|
||||
self.add_to_group = CoroutineMock()
|
||||
self.remove_from_group = CoroutineMock()
|
||||
if node_desc is None:
|
||||
node_desc = b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00"
|
||||
self.node_desc = zigpy.zdo.types.NodeDescriptor.deserialize(node_desc)[0]
|
||||
|
||||
|
||||
FakeDevice.add_to_group = zigpy_dev.add_to_group
|
||||
|
||||
|
||||
def get_zha_gateway(hass):
|
||||
"""Return ZHA gateway from hass.data."""
|
||||
try:
|
||||
|
@ -137,6 +157,19 @@ async def find_entity_id(domain, zha_device, hass):
|
|||
return None
|
||||
|
||||
|
||||
def async_find_group_entity_id(hass, domain, group):
|
||||
"""Find the group entity id under test."""
|
||||
entity_id = (
|
||||
f"{domain}.{group.name.lower().replace(' ','_')}_group_0x{group.group_id:04x}"
|
||||
)
|
||||
|
||||
entity_ids = hass.states.async_entity_ids(domain)
|
||||
|
||||
if entity_id in entity_ids:
|
||||
return entity_id
|
||||
return None
|
||||
|
||||
|
||||
async def async_enable_traffic(hass, zha_devices):
|
||||
"""Allow traffic to flow through the gateway and the zha device."""
|
||||
for zha_device in zha_devices:
|
||||
|
|
|
@ -110,10 +110,11 @@ def zigpy_device_mock(zigpy_app_controller):
|
|||
manufacturer="FakeManufacturer",
|
||||
model="FakeModel",
|
||||
node_descriptor=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00",
|
||||
nwk=0xB79C,
|
||||
):
|
||||
"""Make a fake device using the specified cluster classes."""
|
||||
device = FakeDevice(
|
||||
zigpy_app_controller, ieee, manufacturer, model, node_descriptor
|
||||
zigpy_app_controller, ieee, manufacturer, model, node_descriptor, nwk=nwk
|
||||
)
|
||||
for epid, ep in endpoints.items():
|
||||
endpoint = FakeEndpoint(manufacturer, model, epid)
|
||||
|
|
|
@ -1,8 +1,18 @@
|
|||
"""Test ZHA Gateway."""
|
||||
import pytest
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import logging
|
||||
|
||||
from .common import async_enable_traffic, get_zha_gateway
|
||||
import pytest
|
||||
import zigpy.profiles.zha as zha
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import zigpy.zcl.clusters.lighting as lighting
|
||||
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
|
||||
from .common import async_enable_traffic, async_find_group_entity_id, get_zha_gateway
|
||||
|
||||
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
|
||||
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -15,7 +25,7 @@ def zigpy_dev_basic(zigpy_device_mock):
|
|||
"out_clusters": [],
|
||||
"device_type": 0,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
@ -27,6 +37,74 @@ async def zha_dev_basic(hass, zha_device_restored, zigpy_dev_basic):
|
|||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee="00:15:8d:00:02:32:4f:32",
|
||||
nwk=0x0000,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_light_1(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [
|
||||
general.OnOff.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
lighting.Color.cluster_id,
|
||||
general.Groups.cluster_id,
|
||||
],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_light_2(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [
|
||||
general.OnOff.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
lighting.Color.cluster_id,
|
||||
general.Groups.cluster_id,
|
||||
],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE2,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic):
|
||||
"""Device leaving the network should become unavailable."""
|
||||
|
||||
|
@ -37,3 +115,57 @@ async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic):
|
|||
|
||||
get_zha_gateway(hass).device_left(zigpy_dev_basic)
|
||||
assert zha_dev_basic.available is False
|
||||
|
||||
|
||||
async def test_gateway_group_methods(hass, device_light_1, device_light_2, coordinator):
|
||||
"""Test creating a group with 2 members."""
|
||||
zha_gateway = get_zha_gateway(hass)
|
||||
assert zha_gateway is not None
|
||||
zha_gateway.coordinator_zha_device = coordinator
|
||||
coordinator._zha_gateway = zha_gateway
|
||||
device_light_1._zha_gateway = zha_gateway
|
||||
device_light_2._zha_gateway = zha_gateway
|
||||
member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee]
|
||||
|
||||
# test creating a group with 2 members
|
||||
zha_group = await zha_gateway.async_create_zigpy_group(
|
||||
"Test Group", member_ieee_addresses
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain == LIGHT_DOMAIN
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
|
||||
entity_id = async_find_group_entity_id(hass, LIGHT_DOMAIN, zha_group)
|
||||
assert hass.states.get(entity_id) is not None
|
||||
|
||||
# test get group by name
|
||||
assert zha_group == zha_gateway.async_get_group_by_name(zha_group.name)
|
||||
|
||||
# test removing a group
|
||||
await zha_gateway.async_remove_zigpy_group(zha_group.group_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# we shouldn't have the group anymore
|
||||
assert zha_gateway.async_get_group_by_name(zha_group.name) is None
|
||||
|
||||
# the group entity should be cleaned up
|
||||
assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN)
|
||||
|
||||
# test creating a group with 1 member
|
||||
zha_group = await zha_gateway.async_create_zigpy_group(
|
||||
"Test Group", [device_light_1.ieee]
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain is None
|
||||
assert len(zha_group.members) == 1
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in [device_light_1.ieee]
|
||||
|
||||
# the group entity should not have been cleaned up
|
||||
assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN)
|
||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock, call, sentinel
|
|||
|
||||
from asynctest import CoroutineMock, patch
|
||||
import pytest
|
||||
import zigpy.profiles.zha
|
||||
import zigpy.profiles.zha as zha
|
||||
import zigpy.types
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import zigpy.zcl.clusters.lighting as lighting
|
||||
|
@ -17,8 +17,10 @@ import homeassistant.util.dt as dt_util
|
|||
|
||||
from .common import (
|
||||
async_enable_traffic,
|
||||
async_find_group_entity_id,
|
||||
async_test_rejoin,
|
||||
find_entity_id,
|
||||
get_zha_gateway,
|
||||
send_attributes_report,
|
||||
)
|
||||
|
||||
|
@ -26,6 +28,8 @@ from tests.common import async_fire_time_changed
|
|||
|
||||
ON = 1
|
||||
OFF = 0
|
||||
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
|
||||
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
|
||||
|
||||
LIGHT_ON_OFF = {
|
||||
1: {
|
||||
|
@ -66,6 +70,76 @@ LIGHT_COLOR = {
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee="00:15:8d:00:02:32:4f:32",
|
||||
nwk=0x0000,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_light_1(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [
|
||||
general.OnOff.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
lighting.Color.cluster_id,
|
||||
general.Groups.cluster_id,
|
||||
general.Identify.cluster_id,
|
||||
],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_light_2(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha light platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [
|
||||
general.OnOff.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
lighting.Color.cluster_id,
|
||||
general.Groups.cluster_id,
|
||||
general.Identify.cluster_id,
|
||||
],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE2,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@patch("zigpy.zcl.clusters.general.OnOff.read_attributes", new=MagicMock())
|
||||
async def test_light_refresh(hass, zigpy_device_mock, zha_device_joined_restored):
|
||||
"""Test zha light platform refresh."""
|
||||
|
@ -337,3 +411,96 @@ async def async_test_flash_from_hass(hass, cluster, entity_id, flash):
|
|||
manufacturer=None,
|
||||
tsn=None,
|
||||
)
|
||||
|
||||
|
||||
async def async_test_zha_group_light_entity(
|
||||
hass, device_light_1, device_light_2, coordinator
|
||||
):
|
||||
"""Test the light entity for a ZHA group."""
|
||||
zha_gateway = get_zha_gateway(hass)
|
||||
assert zha_gateway is not None
|
||||
zha_gateway.coordinator_zha_device = coordinator
|
||||
coordinator._zha_gateway = zha_gateway
|
||||
device_light_1._zha_gateway = zha_gateway
|
||||
device_light_2._zha_gateway = zha_gateway
|
||||
member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee]
|
||||
|
||||
# test creating a group with 2 members
|
||||
zha_group = await zha_gateway.async_create_zigpy_group(
|
||||
"Test Group", member_ieee_addresses
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain == DOMAIN
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
|
||||
entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group)
|
||||
assert hass.states.get(entity_id) is not None
|
||||
|
||||
group_cluster_on_off = zha_group.endpoint[general.OnOff.cluster_id]
|
||||
group_cluster_level = zha_group.endpoint[general.LevelControl.cluster_id]
|
||||
group_cluster_identify = zha_group.endpoint[general.Identify.cluster_id]
|
||||
|
||||
dev1_cluster_on_off = device_light_1.endpoints[1].on_off
|
||||
dev2_cluster_on_off = device_light_2.endpoints[1].on_off
|
||||
|
||||
# test that the lights were created and that they are unavailable
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
|
||||
# allow traffic to flow through the gateway and device
|
||||
await async_enable_traffic(hass, zha_group.members)
|
||||
|
||||
# test that the lights were created and are off
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
||||
# test turning the lights on and off from the light
|
||||
await async_test_on_off_from_light(hass, group_cluster_on_off, entity_id)
|
||||
|
||||
# test turning the lights on and off from the HA
|
||||
await async_test_on_off_from_hass(hass, group_cluster_on_off, entity_id)
|
||||
|
||||
# test short flashing the lights from the HA
|
||||
await async_test_flash_from_hass(
|
||||
hass, group_cluster_identify, entity_id, FLASH_SHORT
|
||||
)
|
||||
|
||||
# test turning the lights on and off from the HA
|
||||
await async_test_level_on_off_from_hass(
|
||||
hass, group_cluster_on_off, group_cluster_level, entity_id
|
||||
)
|
||||
|
||||
# test getting a brightness change from the network
|
||||
await async_test_on_from_light(hass, group_cluster_on_off, entity_id)
|
||||
await async_test_dimmer_from_light(
|
||||
hass, group_cluster_level, entity_id, 150, STATE_ON
|
||||
)
|
||||
|
||||
# test long flashing the lights from the HA
|
||||
await async_test_flash_from_hass(
|
||||
hass, group_cluster_identify, entity_id, FLASH_LONG
|
||||
)
|
||||
|
||||
# test some of the group logic to make sure we key off states correctly
|
||||
await dev1_cluster_on_off.on()
|
||||
await dev2_cluster_on_off.on()
|
||||
|
||||
# test that group light is on
|
||||
assert hass.states.get(entity_id).state == STATE_ON
|
||||
|
||||
await dev1_cluster_on_off.off()
|
||||
|
||||
# test that group light is still on
|
||||
assert hass.states.get(entity_id).state == STATE_ON
|
||||
|
||||
await dev2_cluster_on_off.off()
|
||||
|
||||
# test that group light is now off
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
||||
await dev1_cluster_on_off.on()
|
||||
|
||||
# test that group light is now back on
|
||||
assert hass.states.get(entity_id).state == STATE_ON
|
||||
|
|
Loading…
Reference in New Issue