Add group entity support to ZHA (#33196)

* split entity into base and entity

* add initial light group support

* add dispatching of groups to light

* added zha group object

* add group event listener

* add and remove group members

* get group by name

* fix rebase

* fix rebase

* use group_id for unique_id

* get entities from registry

* use group name

* update entity domain

* update zha storage to handle groups

* dispatch group entities

* update light group

* fix group remove and dispatch light group entities

* allow picking the domain for group entities

* beginning - auto determine entity domain

* move methods to helpers so they can be shared

* fix rebase

* remove double init groups... again

* cleanup startup

* use asyncio create task

* group entity discovery

* add logging and fix group name

* add logging and update group after probe if needed

* test add group via gateway

* add method to get group entity ids

* update storage

* test get group by name

* update storage on remove

* test group with single member

* add light group tests

* test some light group logic

* type hints

* fix tests and cleanup

* revert init changes except for create task

* remove group entity domain changing for now

* add missing import

* tricky code saving

* review comments

* clean up class defs

* cleanup

* fix rebase because I cant read

* make pylint happy
pull/33246/head
David F. Mulcahey 2020-03-25 07:23:54 -04:00 committed by GitHub
parent 3ee05ad4bb
commit 2a3c94bad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1084 additions and 257 deletions

View File

@ -130,7 +130,7 @@ async def async_setup_entry(hass, config_entry):
await zha_data[DATA_ZHA_GATEWAY].async_update_device_storage()
hass.bus.async_listen_once(ha_const.EVENT_HOMEASSISTANT_STOP, async_zha_shutdown)
asyncio.create_task(async_load_entities(hass, config_entry))
asyncio.create_task(async_load_entities(hass))
return True
@ -150,11 +150,9 @@ async def async_unload_entry(hass, config_entry):
return True
async def async_load_entities(
hass: HomeAssistantType, config_entry: config_entries.ConfigEntry
) -> None:
async def async_load_entities(hass: HomeAssistantType) -> None:
"""Load entities after integration was setup."""
await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_prepare_entities()
await hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].async_initialize_devices_and_entities()
to_setup = hass.data[DATA_ZHA][DATA_ZHA_PLATFORM_LOADED]
results = await asyncio.gather(*to_setup, return_exceptions=True)
for res in results:

View File

@ -23,6 +23,7 @@ ATTR_COMMAND_TYPE = "command_type"
ATTR_DEVICE_IEEE = "device_ieee"
ATTR_DEVICE_TYPE = "device_type"
ATTR_ENDPOINT_ID = "endpoint_id"
ATTR_ENTITY_DOMAIN = "entity_domain"
ATTR_IEEE = "ieee"
ATTR_LAST_SEEN = "last_seen"
ATTR_LEVEL = "level"
@ -207,6 +208,7 @@ SIGNAL_REMOVE = "remove"
SIGNAL_SET_LEVEL = "set_level"
SIGNAL_STATE_ATTR = "update_state_attribute"
SIGNAL_UPDATE_DEVICE = "{}_zha_update_device"
SIGNAL_REMOVE_GROUP = "remove_group"
UNKNOWN = "unknown"
UNKNOWN_MANUFACTURER = "unk_manufacturer"

View File

@ -373,7 +373,7 @@ class ZHADevice(LogMixin):
self.debug("started configuration")
await self._channels.async_configure()
self.debug("completed configuration")
entry = self.gateway.zha_storage.async_create_or_update(self)
entry = self.gateway.zha_storage.async_create_or_update_device(self)
self.debug("stored in registry: %s", entry)
if self._channels.identify_ch is not None:

View File

@ -1,10 +1,12 @@
"""Device discovery functions for Zigbee Home Automation."""
from collections import Counter
import logging
from typing import Callable, List, Tuple
from homeassistant import const as ha_const
from homeassistant.core import callback
from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.helpers.typing import HomeAssistantType
from . import const as zha_const, registries as zha_regs, typing as zha_typing
@ -157,4 +159,102 @@ class ProbeEndpoint:
self._device_configs.update(overrides)
class GroupProbe:
"""Determine the appropriate component for a group."""
def __init__(self):
"""Initialize instance."""
self._hass = None
def initialize(self, hass: HomeAssistantType) -> None:
"""Initialize the group probe."""
self._hass = hass
@callback
def discover_group_entities(self, group: zha_typing.ZhaGroupType) -> None:
"""Process a group and create any entities that are needed."""
# only create a group entity if there are 2 or more members in a group
if len(group.members) < 2:
_LOGGER.debug(
"Group: %s:0x%04x has less than 2 members - skipping entity discovery",
group.name,
group.group_id,
)
return
if group.entity_domain is None:
_LOGGER.debug(
"Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery",
group.name,
group.group_id,
)
group.entity_domain = GroupProbe.determine_default_entity_domain(
self._hass, group
)
if group.entity_domain is None:
return
_LOGGER.debug(
"Group: %s:0x%04x has an entity domain of: %s after discovery",
group.name,
group.group_id,
group.entity_domain,
)
zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain)
if entity_class is None:
return
self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append(
(
entity_class,
(
group.domain_entity_ids,
f"{group.entity_domain}_group_{group.group_id}",
group.group_id,
zha_gateway.coordinator_zha_device,
),
)
)
@staticmethod
def determine_default_entity_domain(
hass: HomeAssistantType, group: zha_typing.ZhaGroupType
):
"""Determine the default entity domain for this group."""
if len(group.members) < 2:
_LOGGER.debug(
"Group: %s:0x%04x has less than 2 members so cannot default an entity domain",
group.name,
group.group_id,
)
return None
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
all_domain_occurrences = []
for device in group.members:
entities = async_entries_for_device(
zha_gateway.ha_entity_registry, device.device_id
)
all_domain_occurrences.extend(
[
entity.domain
for entity in entities
if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS
]
)
counts = Counter(all_domain_occurrences)
domain = counts.most_common(1)[0][0]
_LOGGER.debug(
"The default entity domain is: %s for group: %s:0x%04x",
domain,
group.name,
group.group_id,
)
return domain
PROBE = ProbeEndpoint()
GROUP_PROBE = GroupProbe()

View File

@ -6,6 +6,7 @@ import itertools
import logging
import os
import traceback
from typing import List, Optional
from serial import SerialException
import zigpy.device as zigpy_dev
@ -52,6 +53,7 @@ from .const import (
DOMAIN,
SIGNAL_ADD_ENTITIES,
SIGNAL_REMOVE,
SIGNAL_REMOVE_GROUP,
UNKNOWN_MANUFACTURER,
UNKNOWN_MODEL,
ZHA_GW_MSG,
@ -75,6 +77,7 @@ from .group import ZHAGroup
from .patches import apply_application_controller_patch
from .registries import RADIO_TYPES
from .store import async_get_registry
from .typing import ZhaDeviceType, ZhaGroupType, ZigpyEndpointType, ZigpyGroupType
_LOGGER = logging.getLogger(__name__)
@ -93,6 +96,7 @@ class ZHAGateway:
self._config = config
self._devices = {}
self._groups = {}
self.coordinator_zha_device = None
self._device_registry = collections.defaultdict(list)
self.zha_storage = None
self.ha_device_registry = None
@ -110,6 +114,7 @@ class ZHAGateway:
async def async_initialize(self):
"""Initialize controller and connect radio."""
discovery.PROBE.initialize(self._hass)
discovery.GROUP_PROBE.initialize(self._hass)
self.zha_storage = await async_get_registry(self._hass)
self.ha_device_registry = await get_dev_reg(self._hass)
@ -156,17 +161,29 @@ class ZHAGateway:
self._hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID] = str(
self.application_controller.ieee
)
await self.async_load_devices()
self._initialize_groups()
self.async_load_devices()
self.async_load_groups()
async def async_load_devices(self) -> None:
@callback
def async_load_devices(self) -> None:
"""Restore ZHA devices from zigpy application state."""
zigpy_devices = self.application_controller.devices.values()
for zigpy_device in zigpy_devices:
self._async_get_or_create_device(zigpy_device, restored=True)
zha_device = self._async_get_or_create_device(zigpy_device, restored=True)
if zha_device.nwk == 0x0000:
self.coordinator_zha_device = zha_device
async def async_prepare_entities(self) -> None:
"""Prepare entities by initializing device channels."""
@callback
def async_load_groups(self) -> None:
"""Initialize ZHA groups."""
for group_id in self.application_controller.groups:
group = self.application_controller.groups[group_id]
zha_group = self._async_get_or_create_group(group)
# we can do this here because the entities are in the entity registry tied to the devices
discovery.GROUP_PROBE.discover_group_entities(zha_group)
async def async_initialize_devices_and_entities(self) -> None:
"""Initialize devices and load entities."""
semaphore = asyncio.Semaphore(2)
async def _throttle(zha_device: zha_typing.ZhaDeviceType, cached: bool):
@ -231,35 +248,44 @@ class ZHAGateway:
"""Handle device leaving the network."""
self.async_update_device(device, False)
def group_member_removed(self, zigpy_group, endpoint):
def group_member_removed(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
) -> None:
"""Handle zigpy group member removed event."""
# need to handle endpoint correctly on groups
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_member_removed - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED)
def group_member_added(self, zigpy_group, endpoint):
def group_member_added(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
) -> None:
"""Handle zigpy group member added event."""
# need to handle endpoint correctly on groups
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_member_added - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED)
def group_added(self, zigpy_group):
def group_added(self, zigpy_group: ZigpyGroupType) -> None:
"""Handle zigpy group added event."""
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_added")
# need to dispatch for entity creation here
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
def group_removed(self, zigpy_group):
def group_removed(self, zigpy_group: ZigpyGroupType) -> None:
"""Handle zigpy group added event."""
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
zha_group = self._groups.pop(zigpy_group.group_id, None)
zha_group.info("group_removed")
async_dispatcher_send(
self._hass, f"{SIGNAL_REMOVE_GROUP}_{zigpy_group.group_id}"
)
def _send_group_gateway_message(self, zigpy_group, gateway_message_type):
"""Send the gareway event for a zigpy group event."""
def _send_group_gateway_message(
self, zigpy_group: ZigpyGroupType, gateway_message_type: str
) -> None:
"""Send the gateway event for a zigpy group event."""
zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is not None:
async_dispatcher_send(
@ -306,12 +332,12 @@ class ZHAGateway:
"""Return ZHADevice for given ieee."""
return self._devices.get(ieee)
def get_group(self, group_id):
def get_group(self, group_id: str) -> Optional[ZhaGroupType]:
"""Return Group for given group id."""
return self.groups.get(group_id)
@callback
def async_get_group_by_name(self, group_name):
def async_get_group_by_name(self, group_name: str) -> Optional[ZhaGroupType]:
"""Get ZHA group by name."""
for group in self.groups.values():
if group.name == group_name:
@ -390,12 +416,6 @@ class ZHAGateway:
logging.getLogger(logger_name).removeHandler(self._log_relay_handler)
self.debug_enabled = False
def _initialize_groups(self):
"""Initialize ZHA groups."""
for group_id in self.application_controller.groups:
group = self.application_controller.groups[group_id]
self._async_get_or_create_group(group)
@callback
def _async_get_or_create_device(
self, zigpy_device: zha_typing.ZigpyDeviceType, restored: bool = False
@ -414,17 +434,19 @@ class ZHAGateway:
model=zha_device.model,
)
zha_device.set_device_id(device_registry_device.id)
entry = self.zha_storage.async_get_or_create(zha_device)
entry = self.zha_storage.async_get_or_create_device(zha_device)
zha_device.async_update_last_seen(entry.last_seen)
return zha_device
@callback
def _async_get_or_create_group(self, zigpy_group):
def _async_get_or_create_group(self, zigpy_group: ZigpyGroupType) -> ZhaGroupType:
"""Get or create a ZHA group."""
zha_group = self._groups.get(zigpy_group.group_id)
if zha_group is None:
zha_group = ZHAGroup(self._hass, self, zigpy_group)
self._groups[zigpy_group.group_id] = zha_group
group_entry = self.zha_storage.async_get_or_create_group(zha_group)
zha_group.entity_domain = group_entry.entity_domain
return zha_group
@callback
@ -446,7 +468,9 @@ class ZHAGateway:
async def async_update_device_storage(self):
"""Update the devices in the store."""
for device in self.devices.values():
self.zha_storage.async_update(device)
self.zha_storage.async_update_device(device)
for group in self.groups.values():
self.zha_storage.async_update_group(group)
await self.zha_storage.async_save()
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType):
@ -494,25 +518,6 @@ class ZHAGateway:
zha_device.update_available(True)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
# only public for testing
async def async_device_restored(self, device: zha_typing.ZigpyDeviceType):
"""Add an existing device to the ZHA zigbee network when ZHA first starts."""
zha_device = self._async_get_or_create_device(device, restored=True)
if zha_device.is_mains_powered:
# the device isn't a battery powered device so we should be able
# to update it now
_LOGGER.debug(
"attempting to request fresh state for device - %s:%s %s with power source %s",
zha_device.nwk,
zha_device.ieee,
zha_device.name,
zha_device.power_source,
)
await zha_device.async_initialize(from_cache=False)
else:
await zha_device.async_initialize(from_cache=True)
async def _async_device_rejoined(self, zha_device):
_LOGGER.debug(
"skipping discovery for previously discovered device - %s:%s",
@ -524,7 +529,9 @@ class ZHAGateway:
# will cause async_init to fire so don't explicitly call it
zha_device.update_available(True)
async def async_create_zigpy_group(self, name, members):
async def async_create_zigpy_group(
self, name: str, members: List[ZhaDeviceType]
) -> ZhaGroupType:
"""Create a new Zigpy Zigbee group."""
# we start with one to fill any gaps from a user removing existing groups
group_id = 1
@ -537,24 +544,40 @@ class ZHAGateway:
if members is not None:
tasks = []
for ieee in members:
_LOGGER.debug(
"Adding member with IEEE: %s to group: %s:0x%04x",
ieee,
name,
group_id,
)
tasks.append(self.devices[ieee].async_add_to_group(group_id))
await asyncio.gather(*tasks)
return self.groups.get(group_id)
zha_group = self.groups.get(group_id)
_LOGGER.debug(
"Probing group: %s:0x%04x for entity discovery",
zha_group.name,
zha_group.group_id,
)
discovery.GROUP_PROBE.discover_group_entities(zha_group)
if zha_group.entity_domain is not None:
self.zha_storage.async_update_group(zha_group)
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
return zha_group
async def async_remove_zigpy_group(self, group_id):
async def async_remove_zigpy_group(self, group_id: int) -> None:
"""Remove a Zigbee group from Zigpy."""
group = self.groups.get(group_id)
if not group:
_LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id)
return
if group and group.members:
tasks = []
for member in group.members:
tasks.append(member.async_remove_from_group(group_id))
if tasks:
await asyncio.gather(*tasks)
else:
# we have members but none are tracked by ZHA for whatever reason
self.application_controller.groups.pop(group_id)
else:
self.application_controller.groups.pop(group_id)
self.application_controller.groups.pop(group_id)
self.zha_storage.async_delete_group(group)
async def shutdown(self):
"""Stop ZHA Controller Application."""

View File

@ -1,10 +1,16 @@
"""Group for Zigbee Home Automation."""
import asyncio
import logging
from typing import Any, Dict, List, Optional
from zigpy.types.named import EUI64
from homeassistant.core import callback
from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.helpers.typing import HomeAssistantType
from .helpers import LogMixin
from .typing import ZhaDeviceType, ZhaGatewayType, ZigpyEndpointType, ZigpyGroupType
_LOGGER = logging.getLogger(__name__)
@ -12,29 +18,45 @@ _LOGGER = logging.getLogger(__name__)
class ZHAGroup(LogMixin):
"""ZHA Zigbee group object."""
def __init__(self, hass, zha_gateway, zigpy_group):
def __init__(
self,
hass: HomeAssistantType,
zha_gateway: ZhaGatewayType,
zigpy_group: ZigpyGroupType,
):
"""Initialize the group."""
self.hass = hass
self._zigpy_group = zigpy_group
self._zha_gateway = zha_gateway
self.hass: HomeAssistantType = hass
self._zigpy_group: ZigpyGroupType = zigpy_group
self._zha_gateway: ZhaGatewayType = zha_gateway
self._entity_domain: str = None
@property
def name(self):
def name(self) -> str:
"""Return group name."""
return self._zigpy_group.name
@property
def group_id(self):
def group_id(self) -> int:
"""Return group name."""
return self._zigpy_group.group_id
@property
def endpoint(self):
def endpoint(self) -> ZigpyEndpointType:
"""Return the endpoint for this group."""
return self._zigpy_group.endpoint
@property
def members(self):
def entity_domain(self) -> Optional[str]:
"""Return the domain that will be used for the entity representing this group."""
return self._entity_domain
@entity_domain.setter
def entity_domain(self, domain: Optional[str]) -> None:
"""Set the domain that will be used for the entity representing this group."""
self._entity_domain = domain
@property
def members(self) -> List[ZhaDeviceType]:
"""Return the ZHA devices that are members of this group."""
return [
self._zha_gateway.devices.get(member_ieee[0])
@ -42,7 +64,7 @@ class ZHAGroup(LogMixin):
if member_ieee[0] in self._zha_gateway.devices
]
async def async_add_members(self, member_ieee_addresses):
async def async_add_members(self, member_ieee_addresses: List[EUI64]) -> None:
"""Add members to this group."""
if len(member_ieee_addresses) > 1:
tasks = []
@ -56,7 +78,7 @@ class ZHAGroup(LogMixin):
member_ieee_addresses[0]
].async_add_to_group(self.group_id)
async def async_remove_members(self, member_ieee_addresses):
async def async_remove_members(self, member_ieee_addresses: List[EUI64]) -> None:
"""Remove members from this group."""
if len(member_ieee_addresses) > 1:
tasks = []
@ -72,18 +94,50 @@ class ZHAGroup(LogMixin):
member_ieee_addresses[0]
].async_remove_from_group(self.group_id)
@property
def member_entity_ids(self) -> List[str]:
"""Return the ZHA entity ids for all entities for the members of this group."""
all_entity_ids: List[str] = []
for device in self.members:
entities = async_entries_for_device(
self._zha_gateway.ha_entity_registry, device.device_id
)
for entity in entities:
all_entity_ids.append(entity.entity_id)
return all_entity_ids
@property
def domain_entity_ids(self) -> List[str]:
"""Return entity ids from the entity domain for this group."""
if self.entity_domain is None:
return
domain_entity_ids: List[str] = []
for device in self.members:
entities = async_entries_for_device(
self._zha_gateway.ha_entity_registry, device.device_id
)
domain_entity_ids.extend(
[
entity.entity_id
for entity in entities
if entity.domain == self.entity_domain
]
)
return domain_entity_ids
@callback
def async_get_info(self):
def async_get_info(self) -> Dict[str, Any]:
"""Get ZHA group info."""
group_info = {}
group_info: Dict[str, Any] = {}
group_info["group_id"] = self.group_id
group_info["entity_domain"] = self.entity_domain
group_info["name"] = self.name
group_info["members"] = [
zha_device.async_get_info() for zha_device in self.members
]
return group_info
def log(self, level, msg, *args):
def log(self, level: int, msg: str, *args):
"""Log a message."""
msg = f"[%s](%s): {msg}"
args = (self.name, self.group_id) + args

View File

@ -1,10 +1,11 @@
"""Helpers for Zigbee Home Automation."""
import collections
import logging
from typing import Any, Callable, Iterator, List, Optional
import zigpy.types
from homeassistant.core import callback
from homeassistant.core import State, callback
from .const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT, DATA_ZHA, DATA_ZHA_GATEWAY
from .registries import BINDABLE_CLUSTERS
@ -85,6 +86,45 @@ async def async_get_zha_device(hass, device_id):
return zha_gateway.devices[ieee]
def find_state_attributes(states: List[State], key: str) -> Iterator[Any]:
"""Find attributes with matching key from states."""
for state in states:
value = state.attributes.get(key)
if value is not None:
yield value
def mean_int(*args):
"""Return the mean of the supplied values."""
return int(sum(args) / len(args))
def mean_tuple(*args):
"""Return the mean values along the columns of the supplied values."""
return tuple(sum(l) / len(l) for l in zip(*args))
def reduce_attribute(
states: List[State],
key: str,
default: Optional[Any] = None,
reduce: Callable[..., Any] = mean_int,
) -> Any:
"""Find the first attribute matching key from states.
If none are found, return default.
"""
attrs = list(find_state_attributes(states, key))
if not attrs:
return default
if len(attrs) == 1:
return attrs[0]
return reduce(*attrs)
class LogMixin:
"""Log helper."""

View File

@ -32,6 +32,8 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType
from .decorators import CALLABLE_T, DictRegistry, SetRegistry
from .typing import ChannelType
GROUP_ENTITY_DOMAINS = [LIGHT]
SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02
SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000
SMARTTHINGS_HUMIDITY_CLUSTER = 0xFC45
@ -275,6 +277,9 @@ RegistryDictType = Dict[
] # pylint: disable=invalid-name
GroupRegistryDictType = Dict[str, CALLABLE_T] # pylint: disable=invalid-name
class ZHAEntityRegistry:
"""Channel to ZHA Entity mapping."""
@ -282,6 +287,7 @@ class ZHAEntityRegistry:
"""Initialize Registry instance."""
self._strict_registry: RegistryDictType = collections.defaultdict(dict)
self._loose_registry: RegistryDictType = collections.defaultdict(dict)
self._group_registry: GroupRegistryDictType = {}
def get_entity(
self,
@ -300,6 +306,10 @@ class ZHAEntityRegistry:
return default, []
def get_group_entity(self, component: str) -> CALLABLE_T:
"""Match a ZHA group to a ZHA Entity class."""
return self._group_registry.get(component)
def strict_match(
self,
component: str,
@ -350,5 +360,15 @@ class ZHAEntityRegistry:
return decorator
def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a group match rule."""
def decorator(zha_ent: CALLABLE_T) -> CALLABLE_T:
"""Register a group match rule."""
self._group_registry[component] = zha_ent
return zha_ent
return decorator
ZHA_ENTITIES = ZHAEntityRegistry()

View File

@ -10,6 +10,8 @@ from homeassistant.core import callback
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
from .typing import ZhaDeviceType, ZhaGroupType
_LOGGER = logging.getLogger(__name__)
DATA_REGISTRY = "zha_storage"
@ -28,52 +30,96 @@ class ZhaDeviceEntry:
last_seen = attr.ib(type=float, default=None)
class ZhaDeviceStorage:
@attr.s(slots=True, frozen=True)
class ZhaGroupEntry:
"""Zha Group storage Entry."""
name = attr.ib(type=str, default=None)
group_id = attr.ib(type=int, default=None)
entity_domain = attr.ib(type=float, default=None)
class ZhaStorage:
"""Class to hold a registry of zha devices."""
def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the zha device storage."""
self.hass = hass
self.hass: HomeAssistantType = hass
self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
self.groups: MutableMapping[str, ZhaGroupEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_create(self, device) -> ZhaDeviceEntry:
def async_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Create a new ZhaDeviceEntry."""
device_entry = ZhaDeviceEntry(
device_entry: ZhaDeviceEntry = ZhaDeviceEntry(
name=device.name, ieee=str(device.ieee), last_seen=device.last_seen
)
self.devices[device_entry.ieee] = device_entry
return self.async_update(device)
return self.async_update_device(device)
@callback
def async_get_or_create(self, device) -> ZhaDeviceEntry:
def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create a new ZhaGroupEntry."""
group_entry: ZhaGroupEntry = ZhaGroupEntry(
name=group.name,
group_id=str(group.group_id),
entity_domain=group.entity_domain,
)
self.groups[str(group.group_id)] = group_entry
return self.async_update_group(group)
@callback
def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Create a new ZhaDeviceEntry."""
ieee_str = str(device.ieee)
ieee_str: str = str(device.ieee)
if ieee_str in self.devices:
return self.devices[ieee_str]
return self.async_create(device)
return self.async_create_device(device)
@callback
def async_create_or_update(self, device) -> ZhaDeviceEntry:
def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create a new ZhaGroupEntry."""
group_id: str = str(group.group_id)
if group_id in self.groups:
return self.groups[group_id]
return self.async_create_group(group)
@callback
def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Create or update a ZhaDeviceEntry."""
if str(device.ieee) in self.devices:
return self.async_update(device)
return self.async_create(device)
return self.async_update_device(device)
return self.async_create_device(device)
@callback
def async_delete(self, device) -> None:
def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Create or update a ZhaGroupEntry."""
if str(group.group_id) in self.groups:
return self.async_update_group(group)
return self.async_create_group(group)
@callback
def async_delete_device(self, device: ZhaDeviceType) -> None:
"""Delete ZhaDeviceEntry."""
ieee_str = str(device.ieee)
ieee_str: str = str(device.ieee)
if ieee_str in self.devices:
del self.devices[ieee_str]
self.async_schedule_save()
@callback
def async_update(self, device) -> ZhaDeviceEntry:
def async_delete_group(self, group: ZhaGroupType) -> None:
"""Delete ZhaGroupEntry."""
group_id: str = str(group.group_id)
if group_id in self.groups:
del self.groups[group_id]
self.async_schedule_save()
@callback
def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
"""Update name of ZhaDeviceEntry."""
ieee_str = str(device.ieee)
ieee_str: str = str(device.ieee)
old = self.devices[ieee_str]
changes = {}
@ -83,11 +129,25 @@ class ZhaDeviceStorage:
self.async_schedule_save()
return new
@callback
def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
"""Update name of ZhaGroupEntry."""
group_id: str = str(group.group_id)
old = self.groups[group_id]
changes = {}
changes["entity_domain"] = group.entity_domain
new = self.groups[group_id] = attr.evolve(old, **changes)
self.async_schedule_save()
return new
async def async_load(self) -> None:
"""Load the registry of zha device entries."""
data = await self._store.async_load()
devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict()
groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict()
if data is not None:
for device in data["devices"]:
@ -97,7 +157,18 @@ class ZhaDeviceStorage:
last_seen=device["last_seen"] if "last_seen" in device else None,
)
if "groups" in data:
for group in data["groups"]:
groups[group["group_id"]] = ZhaGroupEntry(
name=group["name"],
group_id=group["group_id"],
entity_domain=group["entity_domain"]
if "entity_domain" in group
else None,
)
self.devices = devices
self.groups = groups
@callback
def async_schedule_save(self) -> None:
@ -118,21 +189,29 @@ class ZhaDeviceStorage:
for entry in self.devices.values()
]
data["groups"] = [
{
"name": entry.name,
"group_id": entry.group_id,
"entity_domain": entry.entity_domain,
}
for entry in self.groups.values()
]
return data
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> ZhaDeviceStorage:
async def async_get_registry(hass: HomeAssistantType) -> ZhaStorage:
"""Return zha device storage instance."""
task = hass.data.get(DATA_REGISTRY)
if task is None:
async def _load_reg() -> ZhaDeviceStorage:
registry = ZhaDeviceStorage(hass)
async def _load_reg() -> ZhaStorage:
registry = ZhaStorage(hass)
await registry.async_load()
return registry
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return cast(ZhaDeviceStorage, await task)
return cast(ZhaStorage, await task)

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar
import zigpy.device
import zigpy.endpoint
import zigpy.group
import zigpy.zcl
import zigpy.zdo
@ -17,9 +18,11 @@ ZDOChannelType = "ZDOChannel"
ZhaDeviceType = "ZHADevice"
ZhaEntityType = "ZHAEntity"
ZhaGatewayType = "ZHAGateway"
ZhaGroupType = "ZHAGroupType"
ZigpyClusterType = zigpy.zcl.Cluster
ZigpyDeviceType = zigpy.device.Device
ZigpyEndpointType = zigpy.endpoint.Endpoint
ZigpyGroupType = zigpy.group.Group
ZigpyZdoType = zigpy.zdo.ZDO
if TYPE_CHECKING:
@ -38,3 +41,4 @@ if TYPE_CHECKING:
ZhaDeviceType = homeassistant.components.zha.core.device.ZHADevice
ZhaEntityType = homeassistant.components.zha.entity.ZhaEntity
ZhaGatewayType = homeassistant.components.zha.core.gateway.ZHAGateway
ZhaGroupType = homeassistant.components.zha.core.group.ZHAGroup

View File

@ -3,6 +3,7 @@
import asyncio
import logging
import time
from typing import Any, Awaitable, Dict, List
from homeassistant.core import callback
from homeassistant.helpers import entity
@ -20,6 +21,7 @@ from .core.const import (
SIGNAL_REMOVE,
)
from .core.helpers import LogMixin
from .core.typing import CALLABLE_T, ChannelsType, ChannelType, ZhaDeviceType
_LOGGER = logging.getLogger(__name__)
@ -27,30 +29,24 @@ ENTITY_SUFFIX = "entity_suffix"
RESTART_GRACE_PERIOD = 7200 # 2 hours
class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
class BaseZhaEntity(RestoreEntity, LogMixin, entity.Entity):
"""A base class for ZHA entities."""
def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs):
def __init__(self, unique_id: str, zha_device: ZhaDeviceType, **kwargs):
"""Init ZHA entity."""
self._force_update = False
self._should_poll = False
self._unique_id = unique_id
ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]])
ch_names = [ch.cluster.ep_attribute for ch in channels]
ch_names = ", ".join(sorted(ch_names))
self._name = f"{zha_device.name} {ieeetail} {ch_names}"
self._state = None
self._device_state_attributes = {}
self._zha_device = zha_device
self.cluster_channels = {}
self._available = False
self._unsubs = []
self.remove_future = None
for channel in channels:
self.cluster_channels[channel.name] = channel
self._name: str = ""
self._force_update: bool = False
self._should_poll: bool = False
self._unique_id: str = unique_id
self._state: Any = None
self._device_state_attributes: Dict[str, Any] = {}
self._zha_device: ZhaDeviceType = zha_device
self._available: bool = False
self._unsubs: List[CALLABLE_T] = []
self.remove_future: Awaitable[None] = None
@property
def name(self):
def name(self) -> str:
"""Return Entity's default name."""
return self._name
@ -60,12 +56,12 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
return self._unique_id
@property
def zha_device(self):
def zha_device(self) -> ZhaDeviceType:
"""Return the zha device this entity is attached to."""
return self._zha_device
@property
def device_state_attributes(self):
def device_state_attributes(self) -> Dict[str, Any]:
"""Return device specific state attributes."""
return self._device_state_attributes
@ -80,7 +76,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
return self._should_poll
@property
def device_info(self):
def device_info(self) -> Dict[str, Any]:
"""Return a device description for device registry."""
zha_device_info = self._zha_device.device_info
ieee = zha_device_info["ieee"]
@ -94,31 +90,94 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
}
@property
def available(self):
def available(self) -> bool:
"""Return entity availability."""
return self._available
@callback
def async_set_available(self, available):
def async_set_available(self, available: bool) -> None:
"""Set entity availability."""
self._available = available
self.async_write_ha_state()
@callback
def async_update_state_attribute(self, key, value):
def async_update_state_attribute(self, key: str, value: Any) -> None:
"""Update a single device state attribute."""
self._device_state_attributes.update({key: value})
self.async_write_ha_state()
@callback
def async_set_state(self, attr_id, attr_name, value):
def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None:
"""Set the entity state."""
pass
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Run when about to be added to hass."""
await super().async_added_to_hass()
self.remove_future = asyncio.Future()
await self.async_accept_signal(
None,
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
self.async_remove,
signal_override=True,
)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect entity object when removed."""
for unsub in self._unsubs[:]:
unsub()
self._unsubs.remove(unsub)
self.zha_device.gateway.remove_entity_reference(self)
self.remove_future.set_result(True)
@callback
def async_restore_last_state(self, last_state) -> None:
"""Restore previous state."""
pass
async def async_accept_signal(
self, channel: ChannelType, signal: str, func: CALLABLE_T, signal_override=False
):
"""Accept a signal from a channel."""
unsub = None
if signal_override:
unsub = async_dispatcher_connect(self.hass, signal, func)
else:
unsub = async_dispatcher_connect(
self.hass, f"{channel.unique_id}_{signal}", func
)
self._unsubs.append(unsub)
def log(self, level: int, msg: str, *args):
"""Log a message."""
msg = f"%s: {msg}"
args = (self.entity_id,) + args
_LOGGER.log(level, msg, *args)
class ZhaEntity(BaseZhaEntity):
"""A base class for non group ZHA entities."""
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: ChannelsType,
**kwargs,
):
"""Init ZHA entity."""
super().__init__(unique_id, zha_device, **kwargs)
ieeetail = "".join([f"{o:02x}" for o in zha_device.ieee[:4]])
ch_names = [ch.cluster.ep_attribute for ch in channels]
ch_names = ", ".join(sorted(ch_names))
self._name: str = f"{zha_device.name} {ieeetail} {ch_names}"
self.cluster_channels: Dict[str, ChannelType] = {}
for channel in channels:
self.cluster_channels[channel.name] = channel
async def async_added_to_hass(self) -> None:
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_check_recently_seen()
await self.async_accept_signal(
None,
@ -126,12 +185,6 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
self.async_set_available,
signal_override=True,
)
await self.async_accept_signal(
None,
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
self.async_remove,
signal_override=True,
)
self._zha_device.gateway.register_entity_reference(
self._zha_device.ieee,
self.entity_id,
@ -141,7 +194,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
self.remove_future,
)
async def async_check_recently_seen(self):
async def async_check_recently_seen(self) -> None:
"""Check if the device was seen within the last 2 hours."""
last_state = await self.async_get_last_state()
if (
@ -155,38 +208,8 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
self.async_restore_last_state(last_state)
self._zha_device.set_available(True)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect entity object when removed."""
for unsub in self._unsubs[:]:
unsub()
self._unsubs.remove(unsub)
self.zha_device.gateway.remove_entity_reference(self)
self.remove_future.set_result(True)
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
pass
async def async_update(self):
async def async_update(self) -> None:
"""Retrieve latest state."""
for channel in self.cluster_channels.values():
if hasattr(channel, "async_update"):
await channel.async_update()
async def async_accept_signal(self, channel, signal, func, signal_override=False):
"""Accept a signal from a channel."""
unsub = None
if signal_override:
unsub = async_dispatcher_connect(self.hass, signal, func)
else:
unsub = async_dispatcher_connect(
self.hass, f"{channel.unique_id}_{signal}", func
)
self._unsubs.append(unsub)
def log(self, level, msg, *args):
"""Log a message."""
msg = f"%s: {msg}"
args = (self.entity_id,) + args
_LOGGER.log(level, msg, *args)

View File

@ -1,19 +1,44 @@
"""Lights on Zigbee Home Automation networks."""
from collections import Counter
from datetime import timedelta
import functools
import itertools
import logging
import random
from typing import Any, Dict, List, Optional, Tuple
from zigpy.zcl.clusters.general import Identify, LevelControl, OnOff
from zigpy.zcl.clusters.lighting import Color
from zigpy.zcl.foundation import Status
from homeassistant.components import light
from homeassistant.const import STATE_ON
from homeassistant.core import callback
from homeassistant.components.light import (
ATTR_BRIGHTNESS,
ATTR_COLOR_TEMP,
ATTR_EFFECT,
ATTR_EFFECT_LIST,
ATTR_HS_COLOR,
ATTR_MAX_MIREDS,
ATTR_MIN_MIREDS,
ATTR_WHITE_VALUE,
SUPPORT_BRIGHTNESS,
SUPPORT_COLOR,
SUPPORT_COLOR_TEMP,
SUPPORT_EFFECT,
SUPPORT_FLASH,
SUPPORT_TRANSITION,
SUPPORT_WHITE_VALUE,
)
from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.event import (
async_track_state_change,
async_track_time_interval,
)
import homeassistant.util.color as color_util
from .core import discovery
from .core import discovery, helpers
from .core.const import (
CHANNEL_COLOR,
CHANNEL_LEVEL,
@ -25,11 +50,12 @@ from .core.const import (
EFFECT_DEFAULT_VARIANT,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
SIGNAL_SET_LEVEL,
)
from .core.registries import ZHA_ENTITIES
from .core.typing import ZhaDeviceType
from .entity import ZhaEntity
from .entity import BaseZhaEntity, ZhaEntity
_LOGGER = logging.getLogger(__name__)
@ -46,8 +72,19 @@ FLASH_EFFECTS = {light.FLASH_SHORT: EFFECT_BLINK, light.FLASH_LONG: EFFECT_BREAT
UNSUPPORTED_ATTRIBUTE = 0x86
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, light.DOMAIN)
GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, light.DOMAIN)
PARALLEL_UPDATES = 0
SUPPORT_GROUP_LIGHT = (
SUPPORT_BRIGHTNESS
| SUPPORT_COLOR_TEMP
| SUPPORT_EFFECT
| SUPPORT_FLASH
| SUPPORT_COLOR
| SUPPORT_TRANSITION
| SUPPORT_WHITE_VALUE
)
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up the Zigbee Home Automation light from config entry."""
@ -63,48 +100,35 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL})
class Light(ZhaEntity, light.Light):
"""Representation of a ZHA or ZLL light."""
class BaseLight(BaseZhaEntity, light.Light):
"""Operations common to all light entities."""
_REFRESH_INTERVAL = (45, 75)
def __init__(self, *args, **kwargs):
"""Initialize the light."""
super().__init__(*args, **kwargs)
self._is_on: bool = False
self._available: bool = False
self._brightness: Optional[int] = None
self._off_brightness: Optional[int] = None
self._hs_color: Optional[Tuple[float, float]] = None
self._color_temp: Optional[int] = None
self._min_mireds: Optional[int] = 154
self._max_mireds: Optional[int] = 500
self._white_value: Optional[int] = None
self._effect_list: Optional[List[str]] = None
self._effect: Optional[str] = None
self._supported_features: int = 0
self._state: bool = False
self._on_off_channel = None
self._level_channel = None
self._color_channel = None
self._identify_channel = None
def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs):
"""Initialize the ZHA light."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._supported_features = 0
self._color_temp = None
self._hs_color = None
self._brightness = None
self._off_brightness = None
self._effect_list = []
self._effect = None
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)
self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL)
self._color_channel = self.cluster_channels.get(CHANNEL_COLOR)
self._identify_channel = self.zha_device.channels.identify_ch
self._cancel_refresh_handle = None
if self._level_channel:
self._supported_features |= light.SUPPORT_BRIGHTNESS
self._supported_features |= light.SUPPORT_TRANSITION
self._brightness = 0
if self._color_channel:
color_capabilities = self._color_channel.get_color_capabilities()
if color_capabilities & CAPABILITIES_COLOR_TEMP:
self._supported_features |= light.SUPPORT_COLOR_TEMP
if color_capabilities & CAPABILITIES_COLOR_XY:
self._supported_features |= light.SUPPORT_COLOR
self._hs_color = (0, 0)
if color_capabilities & CAPABILITIES_COLOR_LOOP:
self._supported_features |= light.SUPPORT_EFFECT
self._effect_list.append(light.EFFECT_COLORLOOP)
if self._identify_channel:
self._supported_features |= light.SUPPORT_FLASH
@property
def device_state_attributes(self) -> Dict[str, Any]:
"""Return state attributes."""
attributes = {"off_brightness": self._off_brightness}
return attributes
@property
def is_on(self) -> bool:
@ -118,12 +142,6 @@ class Light(ZhaEntity, light.Light):
"""Return the brightness of this light."""
return self._brightness
@property
def device_state_attributes(self):
"""Return state attributes."""
attributes = {"off_brightness": self._off_brightness}
return attributes
def set_level(self, value):
"""Set the brightness of this light between 0..254.
@ -160,49 +178,6 @@ class Light(ZhaEntity, light.Light):
"""Flag supported features."""
return self._supported_features
@callback
def async_set_state(self, attr_id, attr_name, value):
"""Set the state."""
self._state = bool(value)
if value:
self._off_brightness = None
self.async_write_ha_state()
async def async_added_to_hass(self):
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_accept_signal(
self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)
if self._level_channel:
await self.async_accept_signal(
self._level_channel, SIGNAL_SET_LEVEL, self.set_level
)
refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL])
self._cancel_refresh_handle = async_track_time_interval(
self.hass, self._refresh, timedelta(seconds=refresh_interval)
)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect entity object when removed."""
self._cancel_refresh_handle()
await super().async_will_remove_from_hass()
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = last_state.state == STATE_ON
if "brightness" in last_state.attributes:
self._brightness = last_state.attributes["brightness"]
if "off_brightness" in last_state.attributes:
self._off_brightness = last_state.attributes["off_brightness"]
if "color_temp" in last_state.attributes:
self._color_temp = last_state.attributes["color_temp"]
if "hs_color" in last_state.attributes:
self._hs_color = last_state.attributes["hs_color"]
if "effect" in last_state.attributes:
self._effect = last_state.attributes["effect"]
async def async_turn_on(self, **kwargs):
"""Turn the entity on."""
transition = kwargs.get(light.ATTR_TRANSITION)
@ -331,6 +306,86 @@ class Light(ZhaEntity, light.Light):
self.async_write_ha_state()
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL})
class Light(ZhaEntity, BaseLight):
"""Representation of a ZHA or ZLL light."""
_REFRESH_INTERVAL = (45, 75)
def __init__(self, unique_id, zha_device: ZhaDeviceType, channels, **kwargs):
"""Initialize the ZHA light."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._on_off_channel = self.cluster_channels.get(CHANNEL_ON_OFF)
self._level_channel = self.cluster_channels.get(CHANNEL_LEVEL)
self._color_channel = self.cluster_channels.get(CHANNEL_COLOR)
self._identify_channel = self.zha_device.channels.identify_ch
self._cancel_refresh_handle = None
if self._level_channel:
self._supported_features |= light.SUPPORT_BRIGHTNESS
self._supported_features |= light.SUPPORT_TRANSITION
self._brightness = 0
if self._color_channel:
color_capabilities = self._color_channel.get_color_capabilities()
if color_capabilities & CAPABILITIES_COLOR_TEMP:
self._supported_features |= light.SUPPORT_COLOR_TEMP
if color_capabilities & CAPABILITIES_COLOR_XY:
self._supported_features |= light.SUPPORT_COLOR
self._hs_color = (0, 0)
if color_capabilities & CAPABILITIES_COLOR_LOOP:
self._supported_features |= light.SUPPORT_EFFECT
self._effect_list.append(light.EFFECT_COLORLOOP)
if self._identify_channel:
self._supported_features |= light.SUPPORT_FLASH
@callback
def async_set_state(self, attr_id, attr_name, value):
"""Set the state."""
self._state = bool(value)
if value:
self._off_brightness = None
self.async_write_ha_state()
async def async_added_to_hass(self):
"""Run when about to be added to hass."""
await super().async_added_to_hass()
await self.async_accept_signal(
self._on_off_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)
if self._level_channel:
await self.async_accept_signal(
self._level_channel, SIGNAL_SET_LEVEL, self.set_level
)
refresh_interval = random.randint(*[x * 60 for x in self._REFRESH_INTERVAL])
self._cancel_refresh_handle = async_track_time_interval(
self.hass, self._refresh, timedelta(seconds=refresh_interval)
)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect entity object when removed."""
self._cancel_refresh_handle()
await super().async_will_remove_from_hass()
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = last_state.state == STATE_ON
if "brightness" in last_state.attributes:
self._brightness = last_state.attributes["brightness"]
if "off_brightness" in last_state.attributes:
self._off_brightness = last_state.attributes["off_brightness"]
if "color_temp" in last_state.attributes:
self._color_temp = last_state.attributes["color_temp"]
if "hs_color" in last_state.attributes:
self._hs_color = last_state.attributes["hs_color"]
if "effect" in last_state.attributes:
self._effect = last_state.attributes["effect"]
async def async_update(self):
"""Attempt to retrieve on off state from the light."""
await super().async_update()
@ -410,3 +465,99 @@ class HueLight(Light):
"""Representation of a HUE light which does not report attributes."""
_REFRESH_INTERVAL = (3, 5)
@GROUP_MATCH()
class LightGroup(BaseLight):
"""Representation of a light group."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a light group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
self._entity_ids: List[str] = entity_ids
group = self.zha_device.gateway.get_group(self._group_id)
self._on_off_channel = group.endpoint[OnOff.cluster_id]
self._level_channel = group.endpoint[LevelControl.cluster_id]
self._color_channel = group.endpoint[Color.cluster_id]
self._identify_channel = group.endpoint[Identify.cluster_id]
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self) -> None:
"""Query all members and determine the light group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]
states: List[State] = list(filter(None, all_states))
on_states = [state for state in states if state.state == STATE_ON]
self._is_on = len(on_states) > 0
self._available = any(state.state != STATE_UNAVAILABLE for state in states)
self._brightness = helpers.reduce_attribute(on_states, ATTR_BRIGHTNESS)
self._hs_color = helpers.reduce_attribute(
on_states, ATTR_HS_COLOR, reduce=helpers.mean_tuple
)
self._white_value = helpers.reduce_attribute(on_states, ATTR_WHITE_VALUE)
self._color_temp = helpers.reduce_attribute(on_states, ATTR_COLOR_TEMP)
self._min_mireds = helpers.reduce_attribute(
states, ATTR_MIN_MIREDS, default=154, reduce=min
)
self._max_mireds = helpers.reduce_attribute(
states, ATTR_MAX_MIREDS, default=500, reduce=max
)
self._effect_list = None
all_effect_lists = list(helpers.find_state_attributes(states, ATTR_EFFECT_LIST))
if all_effect_lists:
# Merge all effects from all effect_lists with a union merge.
self._effect_list = list(set().union(*all_effect_lists))
self._effect = None
all_effects = list(helpers.find_state_attributes(on_states, ATTR_EFFECT))
if all_effects:
# Report the most common effect.
effects_count = Counter(itertools.chain(all_effects))
self._effect = effects_count.most_common(1)[0][0]
self._supported_features = 0
for support in helpers.find_state_attributes(states, ATTR_SUPPORTED_FEATURES):
# Merge supported features by emulating support for every feature
# we find.
self._supported_features |= support
# Bitwise-and the supported features with the GroupedLight's features
# so that we don't break in the future when a new feature is added.
self._supported_features &= SUPPORT_GROUP_LIGHT

View File

@ -3,6 +3,8 @@ import time
from unittest.mock import Mock
from asynctest import CoroutineMock
from zigpy.device import Device as zigpy_dev
from zigpy.endpoint import Endpoint as zigpy_ep
import zigpy.profiles.zha
import zigpy.types
import zigpy.zcl
@ -24,6 +26,7 @@ class FakeEndpoint:
self.in_clusters = {}
self.out_clusters = {}
self._cluster_attr = {}
self.member_of = {}
self.status = 1
self.manufacturer = manufacturer
self.model = model
@ -45,6 +48,19 @@ class FakeEndpoint:
patch_cluster(cluster)
self.out_clusters[cluster_id] = cluster
@property
def __class__(self):
"""Fake being Zigpy endpoint."""
return zigpy_ep
@property
def unique_id(self):
"""Return the unique id for the endpoint."""
return self.device.ieee, self.endpoint_id
FakeEndpoint.add_to_group = zigpy_ep.add_to_group
def patch_cluster(cluster):
"""Patch a cluster for testing."""
@ -56,17 +72,19 @@ def patch_cluster(cluster):
cluster.read_attributes_raw = Mock()
cluster.unbind = CoroutineMock(return_value=[0])
cluster.write_attributes = CoroutineMock(return_value=[0])
if cluster.cluster_id == 4:
cluster.add = CoroutineMock(return_value=[0])
class FakeDevice:
"""Fake device for mocking zigpy."""
def __init__(self, app, ieee, manufacturer, model, node_desc=None):
def __init__(self, app, ieee, manufacturer, model, node_desc=None, nwk=0xB79C):
"""Init fake device."""
self._application = app
self.application = app
self.ieee = zigpy.types.EUI64.convert(ieee)
self.nwk = 0xB79C
self.nwk = nwk
self.zdo = Mock()
self.endpoints = {0: self.zdo}
self.lqi = 255
@ -78,13 +96,15 @@ class FakeDevice:
self.manufacturer = manufacturer
self.model = model
self.node_desc = zigpy.zdo.types.NodeDescriptor()
self.add_to_group = CoroutineMock()
self.remove_from_group = CoroutineMock()
if node_desc is None:
node_desc = b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00"
self.node_desc = zigpy.zdo.types.NodeDescriptor.deserialize(node_desc)[0]
FakeDevice.add_to_group = zigpy_dev.add_to_group
def get_zha_gateway(hass):
"""Return ZHA gateway from hass.data."""
try:
@ -137,6 +157,19 @@ async def find_entity_id(domain, zha_device, hass):
return None
def async_find_group_entity_id(hass, domain, group):
"""Find the group entity id under test."""
entity_id = (
f"{domain}.{group.name.lower().replace(' ','_')}_group_0x{group.group_id:04x}"
)
entity_ids = hass.states.async_entity_ids(domain)
if entity_id in entity_ids:
return entity_id
return None
async def async_enable_traffic(hass, zha_devices):
"""Allow traffic to flow through the gateway and the zha device."""
for zha_device in zha_devices:

View File

@ -110,10 +110,11 @@ def zigpy_device_mock(zigpy_app_controller):
manufacturer="FakeManufacturer",
model="FakeModel",
node_descriptor=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00",
nwk=0xB79C,
):
"""Make a fake device using the specified cluster classes."""
device = FakeDevice(
zigpy_app_controller, ieee, manufacturer, model, node_descriptor
zigpy_app_controller, ieee, manufacturer, model, node_descriptor, nwk=nwk
)
for epid, ep in endpoints.items():
endpoint = FakeEndpoint(manufacturer, model, epid)

View File

@ -1,8 +1,18 @@
"""Test ZHA Gateway."""
import pytest
import zigpy.zcl.clusters.general as general
import logging
from .common import async_enable_traffic, get_zha_gateway
import pytest
import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.lighting as lighting
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from .common import async_enable_traffic, async_find_group_entity_id, get_zha_gateway
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
_LOGGER = logging.getLogger(__name__)
@pytest.fixture
@ -15,7 +25,7 @@ def zigpy_dev_basic(zigpy_device_mock):
"out_clusters": [],
"device_type": 0,
}
},
}
)
@ -27,6 +37,74 @@ async def zha_dev_basic(hass, zha_device_restored, zigpy_dev_basic):
return zha_device
@pytest.fixture
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee="00:15:8d:00:02:32:4f:32",
nwk=0x0000,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_light_1(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_light_2(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE2,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic):
"""Device leaving the network should become unavailable."""
@ -37,3 +115,57 @@ async def test_device_left(hass, zigpy_dev_basic, zha_dev_basic):
get_zha_gateway(hass).device_left(zigpy_dev_basic)
assert zha_dev_basic.available is False
async def test_gateway_group_methods(hass, device_light_1, device_light_2, coordinator):
"""Test creating a group with 2 members."""
zha_gateway = get_zha_gateway(hass)
assert zha_gateway is not None
zha_gateway.coordinator_zha_device = coordinator
coordinator._zha_gateway = zha_gateway
device_light_1._zha_gateway = zha_gateway
device_light_2._zha_gateway = zha_gateway
member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee]
# test creating a group with 2 members
zha_group = await zha_gateway.async_create_zigpy_group(
"Test Group", member_ieee_addresses
)
await hass.async_block_till_done()
assert zha_group is not None
assert zha_group.entity_domain == LIGHT_DOMAIN
assert len(zha_group.members) == 2
for member in zha_group.members:
assert member.ieee in member_ieee_addresses
entity_id = async_find_group_entity_id(hass, LIGHT_DOMAIN, zha_group)
assert hass.states.get(entity_id) is not None
# test get group by name
assert zha_group == zha_gateway.async_get_group_by_name(zha_group.name)
# test removing a group
await zha_gateway.async_remove_zigpy_group(zha_group.group_id)
await hass.async_block_till_done()
# we shouldn't have the group anymore
assert zha_gateway.async_get_group_by_name(zha_group.name) is None
# the group entity should be cleaned up
assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN)
# test creating a group with 1 member
zha_group = await zha_gateway.async_create_zigpy_group(
"Test Group", [device_light_1.ieee]
)
await hass.async_block_till_done()
assert zha_group is not None
assert zha_group.entity_domain is None
assert len(zha_group.members) == 1
for member in zha_group.members:
assert member.ieee in [device_light_1.ieee]
# the group entity should not have been cleaned up
assert entity_id not in hass.states.async_entity_ids(LIGHT_DOMAIN)

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock, call, sentinel
from asynctest import CoroutineMock, patch
import pytest
import zigpy.profiles.zha
import zigpy.profiles.zha as zha
import zigpy.types
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.lighting as lighting
@ -17,8 +17,10 @@ import homeassistant.util.dt as dt_util
from .common import (
async_enable_traffic,
async_find_group_entity_id,
async_test_rejoin,
find_entity_id,
get_zha_gateway,
send_attributes_report,
)
@ -26,6 +28,8 @@ from tests.common import async_fire_time_changed
ON = 1
OFF = 0
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
LIGHT_ON_OFF = {
1: {
@ -66,6 +70,76 @@ LIGHT_COLOR = {
}
@pytest.fixture
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee="00:15:8d:00:02:32:4f:32",
nwk=0x0000,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_light_1(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
general.Identify.cluster_id,
],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@pytest.fixture
async def device_light_2(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
general.Identify.cluster_id,
],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE2,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@patch("zigpy.zcl.clusters.general.OnOff.read_attributes", new=MagicMock())
async def test_light_refresh(hass, zigpy_device_mock, zha_device_joined_restored):
"""Test zha light platform refresh."""
@ -337,3 +411,96 @@ async def async_test_flash_from_hass(hass, cluster, entity_id, flash):
manufacturer=None,
tsn=None,
)
async def async_test_zha_group_light_entity(
hass, device_light_1, device_light_2, coordinator
):
"""Test the light entity for a ZHA group."""
zha_gateway = get_zha_gateway(hass)
assert zha_gateway is not None
zha_gateway.coordinator_zha_device = coordinator
coordinator._zha_gateway = zha_gateway
device_light_1._zha_gateway = zha_gateway
device_light_2._zha_gateway = zha_gateway
member_ieee_addresses = [device_light_1.ieee, device_light_2.ieee]
# test creating a group with 2 members
zha_group = await zha_gateway.async_create_zigpy_group(
"Test Group", member_ieee_addresses
)
await hass.async_block_till_done()
assert zha_group is not None
assert zha_group.entity_domain == DOMAIN
assert len(zha_group.members) == 2
for member in zha_group.members:
assert member.ieee in member_ieee_addresses
entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group)
assert hass.states.get(entity_id) is not None
group_cluster_on_off = zha_group.endpoint[general.OnOff.cluster_id]
group_cluster_level = zha_group.endpoint[general.LevelControl.cluster_id]
group_cluster_identify = zha_group.endpoint[general.Identify.cluster_id]
dev1_cluster_on_off = device_light_1.endpoints[1].on_off
dev2_cluster_on_off = device_light_2.endpoints[1].on_off
# test that the lights were created and that they are unavailable
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, zha_group.members)
# test that the lights were created and are off
assert hass.states.get(entity_id).state == STATE_OFF
# test turning the lights on and off from the light
await async_test_on_off_from_light(hass, group_cluster_on_off, entity_id)
# test turning the lights on and off from the HA
await async_test_on_off_from_hass(hass, group_cluster_on_off, entity_id)
# test short flashing the lights from the HA
await async_test_flash_from_hass(
hass, group_cluster_identify, entity_id, FLASH_SHORT
)
# test turning the lights on and off from the HA
await async_test_level_on_off_from_hass(
hass, group_cluster_on_off, group_cluster_level, entity_id
)
# test getting a brightness change from the network
await async_test_on_from_light(hass, group_cluster_on_off, entity_id)
await async_test_dimmer_from_light(
hass, group_cluster_level, entity_id, 150, STATE_ON
)
# test long flashing the lights from the HA
await async_test_flash_from_hass(
hass, group_cluster_identify, entity_id, FLASH_LONG
)
# test some of the group logic to make sure we key off states correctly
await dev1_cluster_on_off.on()
await dev2_cluster_on_off.on()
# test that group light is on
assert hass.states.get(entity_id).state == STATE_ON
await dev1_cluster_on_off.off()
# test that group light is still on
assert hass.states.get(entity_id).state == STATE_ON
await dev2_cluster_on_off.off()
# test that group light is now off
assert hass.states.get(entity_id).state == STATE_OFF
await dev1_cluster_on_off.on()
# test that group light is now back on
assert hass.states.get(entity_id).state == STATE_ON