Load quirks in ZHA unit tests (#91779)

pull/91721/head
puddly 2023-04-21 02:24:39 -04:00 committed by GitHub
parent f9416e1c34
commit 72414a5864
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 208 additions and 885 deletions

View File

@ -130,17 +130,19 @@ class ClusterHandler(LogMixin):
unique_id = endpoint.unique_id.replace("-", ":") unique_id = endpoint.unique_id.replace("-", ":")
self._unique_id = f"{unique_id}:0x{cluster.cluster_id:04x}" self._unique_id = f"{unique_id}:0x{cluster.cluster_id:04x}"
if not hasattr(self, "_value_attribute") and self.REPORT_CONFIG: if not hasattr(self, "_value_attribute") and self.REPORT_CONFIG:
attr_def: ZCLAttributeDef | None = self.cluster.attributes_by_name.get( attr_def: ZCLAttributeDef = self.cluster.attributes_by_name[
self.REPORT_CONFIG[0]["attr"] self.REPORT_CONFIG[0]["attr"]
) ]
if attr_def is not None: self.value_attribute = attr_def.id
self.value_attribute = attr_def.id
else:
self.value_attribute = None
self._status = ClusterHandlerStatus.CREATED self._status = ClusterHandlerStatus.CREATED
self._cluster.add_listener(self) self._cluster.add_listener(self)
self.data_cache: dict[str, Enum] = {} self.data_cache: dict[str, Enum] = {}
@classmethod
def matches(cls, cluster: zigpy.zcl.Cluster, endpoint: Endpoint) -> bool:
"""Filter the cluster match for specific devices."""
return True
@property @property
def id(self) -> str: def id(self) -> str:
"""Return cluster handler id unique for this device only.""" """Return cluster handler id unique for this device only."""
@ -203,7 +205,10 @@ class ClusterHandler(LogMixin):
) )
except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex: except (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError) as ex:
self.debug( self.debug(
"Failed to bind '%s' cluster: %s", self.cluster.ep_attribute, str(ex) "Failed to bind '%s' cluster: %s",
self.cluster.ep_attribute,
str(ex),
exc_info=ex,
) )
async_dispatcher_send( async_dispatcher_send(
self._endpoint.device.hass, self._endpoint.device.hass,

View File

@ -534,7 +534,7 @@ class PollControl(ClusterHandler):
@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register( @registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(
general.PowerConfiguration.cluster_id general.PowerConfiguration.cluster_id
) )
class PowerConfigurationCLusterHandler(ClusterHandler): class PowerConfigurationClusterHandler(ClusterHandler):
"""Cluster handler for the zigbee power configuration cluster.""" """Cluster handler for the zigbee power configuration cluster."""
REPORT_CONFIG = ( REPORT_CONFIG = (

View File

@ -188,6 +188,15 @@ class SmartThingsAcceleration(ClusterHandler):
AttrReportConfig(attr="z_axis", config=REPORT_CONFIG_ASAP), AttrReportConfig(attr="z_axis", config=REPORT_CONFIG_ASAP),
) )
@classmethod
def matches(cls, cluster: zigpy.zcl.Cluster, endpoint: Endpoint) -> bool:
"""Filter the cluster match for specific devices."""
return cluster.endpoint.device.manufacturer in (
"CentraLite",
"Samjin",
"SmartThings",
)
@callback @callback
def attribute_updated(self, attrid, value): def attribute_updated(self, attrid, value):
"""Handle attribute updates on this cluster.""" """Handle attribute updates on this cluster."""

View File

@ -118,6 +118,11 @@ class Endpoint:
cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get( cluster_handler_class = registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.get(
cluster_id, ClusterHandler cluster_id, ClusterHandler
) )
# Allow cluster handler to filter out bad matches
if not cluster_handler_class.matches(cluster, self):
cluster_handler_class = ClusterHandler
_LOGGER.info( _LOGGER.info(
"Creating cluster handler for cluster id: %s class: %s", "Creating cluster handler for cluster id: %s class: %s",
cluster_id, cluster_id,

View File

@ -309,19 +309,19 @@ class LogMixin:
def debug(self, msg, *args, **kwargs): def debug(self, msg, *args, **kwargs):
"""Debug level log.""" """Debug level log."""
return self.log(logging.DEBUG, msg, *args) return self.log(logging.DEBUG, msg, *args, **kwargs)
def info(self, msg, *args, **kwargs): def info(self, msg, *args, **kwargs):
"""Info level log.""" """Info level log."""
return self.log(logging.INFO, msg, *args) return self.log(logging.INFO, msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs): def warning(self, msg, *args, **kwargs):
"""Warning method log.""" """Warning method log."""
return self.log(logging.WARNING, msg, *args) return self.log(logging.WARNING, msg, *args, **kwargs)
def error(self, msg, *args, **kwargs): def error(self, msg, *args, **kwargs):
"""Error level log.""" """Error level log."""
return self.log(logging.ERROR, msg, *args) return self.log(logging.ERROR, msg, *args, **kwargs)
def retryable_req( def retryable_req(

View File

@ -2,7 +2,7 @@
from collections.abc import Callable from collections.abc import Callable
import itertools import itertools
import time import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import zigpy import zigpy
@ -13,7 +13,7 @@ from zigpy.const import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
import zigpy.device import zigpy.device
import zigpy.group import zigpy.group
import zigpy.profiles import zigpy.profiles
from zigpy.state import State import zigpy.quirks
import zigpy.types import zigpy.types
import zigpy.zdo.types as zdo_t import zigpy.zdo.types as zdo_t
@ -44,31 +44,80 @@ def globally_load_quirks():
zhaquirks.setup() zhaquirks.setup()
class _FakeApp(ControllerApplication):
async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor):
pass
async def connect(self):
pass
async def disconnect(self):
pass
async def force_remove(self, dev: zigpy.device.Device):
pass
async def load_network_info(self, *, load_devices: bool = False):
pass
async def permit_ncp(self, time_s: int = 60):
pass
async def permit_with_key(
self, node: zigpy.types.EUI64, code: bytes, time_s: int = 60
):
pass
async def reset_network_info(self):
pass
async def send_packet(self, packet: zigpy.types.ZigbeePacket):
pass
async def start_network(self):
pass
async def write_network_info(self):
pass
async def request(
self,
device: zigpy.device.Device,
profile: zigpy.types.uint16_t,
cluster: zigpy.types.uint16_t,
src_ep: zigpy.types.uint8_t,
dst_ep: zigpy.types.uint8_t,
sequence: zigpy.types.uint8_t,
data: bytes,
*,
expect_reply: bool = True,
use_ieee: bool = False,
extended_timeout: bool = False,
):
pass
@pytest.fixture @pytest.fixture
def zigpy_app_controller(): def zigpy_app_controller():
"""Zigpy ApplicationController fixture.""" """Zigpy ApplicationController fixture."""
app = MagicMock(spec_set=ControllerApplication) app = _FakeApp(
app.startup = AsyncMock() {
app.shutdown = AsyncMock() zigpy.config.CONF_DATABASE: None,
groups = zigpy.group.Groups(app) zigpy.config.CONF_DEVICE: {zigpy.config.CONF_DEVICE_PATH: "/dev/null"},
groups.add_group(FIXTURE_GRP_ID, FIXTURE_GRP_NAME, suppress_event=True) }
app.configure_mock(groups=groups) )
type(app).ieee = PropertyMock()
app.ieee.return_value = zigpy.types.EUI64.convert("00:15:8d:00:02:32:4f:32")
type(app).nwk = PropertyMock(return_value=zigpy.types.NWK(0x0000))
type(app).devices = PropertyMock(return_value={})
type(app).backups = zigpy.backups.BackupManager(app)
type(app).topology = zigpy.topology.Topology(app)
state = State() app.groups.add_group(FIXTURE_GRP_ID, FIXTURE_GRP_NAME, suppress_event=True)
state.node_info.ieee = app.ieee.return_value
state.network_info.extended_pan_id = app.ieee.return_value
state.network_info.pan_id = 0x1234
state.network_info.channel = 15
state.network_info.network_key.key = zigpy.types.KeyData(range(16))
type(app).state = PropertyMock(return_value=state)
return app app.state.node_info.nwk = 0x0000
app.state.node_info.ieee = zigpy.types.EUI64.convert("00:15:8d:00:02:32:4f:32")
app.state.network_info.pan_id = 0x1234
app.state.network_info.extended_pan_id = app.state.node_info.ieee
app.state.network_info.channel = 15
app.state.network_info.network_key.key = zigpy.types.KeyData(range(16))
with patch("zigpy.device.Device.request"):
yield app
@pytest.fixture(name="config_entry") @pytest.fixture(name="config_entry")
@ -164,7 +213,7 @@ def zigpy_device_mock(zigpy_app_controller):
endpoint = device.add_endpoint(epid) endpoint = device.add_endpoint(epid)
endpoint.device_type = ep[SIG_EP_TYPE] endpoint.device_type = ep[SIG_EP_TYPE]
endpoint.profile_id = ep.get(SIG_EP_PROFILE, 0x0104) endpoint.profile_id = ep.get(SIG_EP_PROFILE, 0x0104)
endpoint.request = AsyncMock(return_value=[0]) endpoint.request = AsyncMock()
for cluster_id in ep.get(SIG_EP_INPUT, []): for cluster_id in ep.get(SIG_EP_INPUT, []):
endpoint.add_input_cluster(cluster_id) endpoint.add_input_cluster(cluster_id)
@ -176,6 +225,9 @@ def zigpy_device_mock(zigpy_app_controller):
if quirk: if quirk:
device = quirk(zigpy_app_controller, device.ieee, device.nwk, device) device = quirk(zigpy_app_controller, device.ieee, device.nwk, device)
else:
# Allow zigpy to apply quirks if we don't pass one explicitly
device = zigpy.quirks.get_device(device)
if patch_cluster: if patch_cluster:
for endpoint in (ep for epid, ep in device.endpoints.items() if epid): for endpoint in (ep for epid, ep in device.endpoints.items() if epid):

View File

@ -44,12 +44,17 @@ async def test_async_get_network_settings_inactive(
with patch( with patch(
"bellows.zigbee.application.ControllerApplication.__new__", "bellows.zigbee.application.ControllerApplication.__new__",
return_value=zigpy_app_controller, return_value=zigpy_app_controller,
): ), patch.object(
zigpy_app_controller, "_load_db", wraps=zigpy_app_controller._load_db
) as mock_load_db, patch.object(
zigpy_app_controller,
"start_network",
wraps=zigpy_app_controller.start_network,
) as mock_start_network:
settings = await api.async_get_network_settings(hass) settings = await api.async_get_network_settings(hass)
assert len(zigpy_app_controller._load_db.mock_calls) == 1 assert len(mock_load_db.mock_calls) == 1
assert len(zigpy_app_controller.start_network.mock_calls) == 0 assert len(mock_start_network.mock_calls) == 0
assert settings.network_info.channel == 20 assert settings.network_info.channel == 20

View File

@ -111,23 +111,30 @@ async def test_devices(
Endpoint.async_new_entity = orig_new_entity Endpoint.async_new_entity = orig_new_entity
if cluster_identify: if cluster_identify:
called = int(zha_device_joined_restored.name == "zha_device_joined") # We only identify on join
assert cluster_identify.request.call_count == called should_identify = (
assert cluster_identify.request.await_count == called zha_device_joined_restored.name == "zha_device_joined"
if called: and not zigpy_device.skip_configuration
assert cluster_identify.request.call_args == mock.call( )
False,
cluster_identify.commands_by_name["trigger_effect"].id, if should_identify:
cluster_identify.commands_by_name["trigger_effect"].schema, assert cluster_identify.request.mock_calls == [
effect_id=zigpy.zcl.clusters.general.Identify.EffectIdentifier.Okay, mock.call(
effect_variant=( False,
zigpy.zcl.clusters.general.Identify.EffectVariant.Default cluster_identify.commands_by_name["trigger_effect"].id,
), cluster_identify.commands_by_name["trigger_effect"].schema,
expect_reply=True, effect_id=zigpy.zcl.clusters.general.Identify.EffectIdentifier.Okay,
manufacturer=None, effect_variant=(
tries=1, zigpy.zcl.clusters.general.Identify.EffectVariant.Default
tsn=None, ),
) expect_reply=True,
manufacturer=None,
tries=1,
tsn=None,
)
]
else:
assert cluster_identify.request.mock_calls == []
event_cluster_handlers = { event_cluster_handlers = {
ch.id ch.id

File diff suppressed because it is too large Load Diff