Refactor ZHA sensor initialization ()

* Refactor ZHA sensors to use cached values after restart

* Get attr from cluster, not channel

* Run cached state through formatter method

* Use cached values for div/multiplier for SmartEnergy channel

* Restore batter voltage from cache

* Refactor sensor to use cached values only

* Update tests

* Add battery sensor test
pull/88364/head
Alexei Chetroi 2020-11-18 21:34:12 -05:00 committed by GitHub
parent 70a3489845
commit 54c4e9335f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 138 deletions
homeassistant/components/zha
tests/components/zha

View File

@ -208,7 +208,7 @@ class ZigbeeChannel(LogMixin):
attributes = []
for report_config in self._report_config:
attributes.append(report_config["attr"])
if len(attributes) > 0:
if attributes:
await self.get_attributes(attributes, from_cache=from_cache)
self._status = ChannelStatus.INITIALIZED

View File

@ -17,10 +17,9 @@ from ..const import (
SIGNAL_ATTR_UPDATED,
SIGNAL_MOVE_LEVEL,
SIGNAL_SET_LEVEL,
SIGNAL_STATE_ATTR,
SIGNAL_UPDATE_DEVICE,
)
from .base import ClientChannel, ZigbeeChannel, parse_and_log_command
from .base import ChannelStatus, ClientChannel, ZigbeeChannel, parse_and_log_command
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.Alarms.cluster_id)
@ -72,13 +71,6 @@ class BasicChannel(ZigbeeChannel):
6: "Emergency mains and transfer switch",
}
def __init__(
self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType
) -> None:
"""Initialize BasicChannel."""
super().__init__(cluster, ch_pool)
self._power_source = None
async def async_configure(self):
"""Configure this channel."""
await super().async_configure()
@ -87,16 +79,12 @@ class BasicChannel(ZigbeeChannel):
async def async_initialize(self, from_cache):
"""Initialize channel."""
if not self._ch_pool.skip_configuration or from_cache:
power_source = await self.get_attribute_value(
"power_source", from_cache=from_cache
)
if power_source is not None:
self._power_source = power_source
await self.get_attribute_value("power_source", from_cache=from_cache)
await super().async_initialize(from_cache)
def get_power_source(self):
"""Get the power source."""
return self._power_source
return self.cluster.get("power_source")
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.BinaryInput.cluster_id)
@ -392,38 +380,8 @@ class PowerConfigurationChannel(ZigbeeChannel):
{"attr": "battery_percentage_remaining", "config": REPORT_CONFIG_BATTERY_SAVE},
)
@callback
def attribute_updated(self, attrid, value):
"""Handle attribute updates on this cluster."""
attr = self._report_config[1].get("attr")
if isinstance(attr, str):
attr_id = self.cluster.attridx.get(attr)
else:
attr_id = attr
if attrid == attr_id:
self.async_send_signal(
f"{self.unique_id}_{SIGNAL_ATTR_UPDATED}",
attrid,
self.cluster.attributes.get(attrid, [attrid])[0],
value,
)
return
attr_name = self.cluster.attributes.get(attrid, [attrid])[0]
self.async_send_signal(
f"{self.unique_id}_{SIGNAL_STATE_ATTR}", attr_name, value
)
async def async_initialize(self, from_cache):
"""Initialize channel."""
await self.async_read_state(from_cache)
await super().async_initialize(from_cache)
async def async_update(self):
"""Retrieve latest state."""
await self.async_read_state(True)
async def async_read_state(self, from_cache):
"""Read data from the cluster."""
attributes = [
"battery_size",
"battery_percentage_remaining",
@ -431,6 +389,7 @@ class PowerConfigurationChannel(ZigbeeChannel):
"battery_quantity",
]
await self.get_attributes(attributes, from_cache=from_cache)
self._status = ChannelStatus.INITIALIZED
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PowerProfile.cluster_id)

View File

@ -1,4 +1,6 @@
"""Smart energy channels module for Zigbee Home Automation."""
from typing import Union
import zigpy.zcl.clusters.smartenergy as smartenergy
from homeassistant.const import (
@ -82,44 +84,48 @@ class Metering(ZigbeeChannel):
) -> None:
"""Initialize Metering."""
super().__init__(cluster, ch_pool)
self._divisor = 1
self._multiplier = 1
self._unit_enum = None
self._format_spec = None
async def async_configure(self):
@property
def divisor(self) -> int:
"""Return divisor for the value."""
return self.cluster.get("divisor")
@property
def multiplier(self) -> int:
"""Return multiplier for the value."""
return self.cluster.get("multiplier")
async def async_configure(self) -> None:
"""Configure channel."""
await self.fetch_config(False)
await super().async_configure()
async def async_initialize(self, from_cache):
async def async_initialize(self, from_cache: bool) -> None:
"""Initialize channel."""
await self.fetch_config(True)
await super().async_initialize(from_cache)
@callback
def attribute_updated(self, attrid, value):
def attribute_updated(self, attrid: int, value: int) -> None:
"""Handle attribute update from Metering cluster."""
if None in (self._multiplier, self._divisor, self._format_spec):
if None in (self.multiplier, self.divisor, self._format_spec):
return
super().attribute_updated(attrid, value * self._multiplier / self._divisor)
super().attribute_updated(attrid, value)
@property
def unit_of_measurement(self):
def unit_of_measurement(self) -> str:
"""Return unit of measurement."""
return self.unit_of_measure_map.get(self._unit_enum & 0x7F, "unknown")
uom = self.cluster.get("unit_of_measure", 0x7F)
return self.unit_of_measure_map.get(uom & 0x7F, "unknown")
async def fetch_config(self, from_cache):
async def fetch_config(self, from_cache: bool) -> None:
"""Fetch config from device and updates format specifier."""
results = await self.get_attributes(
["divisor", "multiplier", "unit_of_measure", "demand_formatting"],
from_cache=from_cache,
)
self._divisor = results.get("divisor", self._divisor)
self._multiplier = results.get("multiplier", self._multiplier)
self._unit_enum = results.get("unit_of_measure", 0x7F) # default to unknown
fmting = results.get(
"demand_formatting", 0xF9
) # 1 digit to the right, 15 digits to the left
@ -135,8 +141,9 @@ class Metering(ZigbeeChannel):
else:
self._format_spec = "{:0" + str(width) + "." + str(r_digits) + "f}"
def formatter_function(self, value):
def formatter_function(self, value: int) -> Union[int, float]:
"""Return formatted value for display."""
value = value * self.multiplier / self.divisor
if self.unit_of_measurement == POWER_WATT:
# Zigbee spec power unit is kW, but we show the value in W
value_watt = value * 1000

View File

@ -39,12 +39,13 @@ async def async_add_entities(
Tuple[str, zha_typing.ZhaDeviceType, List[zha_typing.ChannelType]],
]
],
update_before_add: bool = True,
) -> None:
"""Add entities helper."""
if not entities:
return
to_add = [ent_cls(*args) for ent_cls, args in entities]
_async_add_entities(to_add, update_before_add=True)
_async_add_entities(to_add, update_before_add=update_before_add)
entities.clear()

View File

@ -1,6 +1,7 @@
"""Sensors on Zigbee Home Automation networks."""
import functools
import numbers
from typing import Any, Callable, Dict, List, Optional, Union
from homeassistant.components.sensor import (
DEVICE_CLASS_BATTERY,
@ -11,18 +12,17 @@ from homeassistant.components.sensor import (
DEVICE_CLASS_TEMPERATURE,
DOMAIN,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_UNIT_OF_MEASUREMENT,
LIGHT_LUX,
PERCENTAGE,
POWER_WATT,
PRESSURE_HPA,
STATE_UNKNOWN,
TEMP_CELSIUS,
)
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.util.temperature import fahrenheit_to_celsius
from homeassistant.helpers.typing import HomeAssistantType, StateType
from .core import discovery
from .core.const import (
@ -38,9 +38,9 @@ from .core.const import (
DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_STATE_ATTR,
)
from .core.registries import SMARTTHINGS_HUMIDITY_CLUSTER, ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity
PARALLEL_UPDATES = 5
@ -65,7 +65,9 @@ CHANNEL_ST_HUMIDITY_CLUSTER = f"channel_0x{SMARTTHINGS_HUMIDITY_CLUSTER:04x}"
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_setup_entry(
hass: HomeAssistantType, config_entry: ConfigEntry, async_add_entities: Callable
) -> None:
"""Set up the Zigbee Home Automation sensor from config entry."""
entities_to_create = hass.data[DATA_ZHA][DOMAIN]
@ -73,7 +75,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass,
SIGNAL_ADD_ENTITIES,
functools.partial(
discovery.async_add_entities, async_add_entities, entities_to_create
discovery.async_add_entities,
async_add_entities,
entities_to_create,
update_before_add=False,
),
)
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
@ -82,29 +87,30 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
class Sensor(ZhaEntity):
"""Base ZHA sensor."""
SENSOR_ATTR = None
_decimals = 1
_device_class = None
_divisor = 1
_multiplier = 1
_unit = None
SENSOR_ATTR: Optional[Union[int, str]] = None
_decimals: int = 1
_device_class: Optional[str] = None
_divisor: int = 1
_multiplier: int = 1
_unit: Optional[str] = None
def __init__(self, unique_id, zha_device, channels, **kwargs):
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: List[ChannelType],
**kwargs,
):
"""Init this sensor."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._channel = channels[0]
self._channel: ChannelType = channels[0]
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Run when about to be added to hass."""
await super().async_added_to_hass()
self._device_state_attributes.update(await self.async_state_attr_provider())
self.async_accept_signal(
self._channel, SIGNAL_ATTR_UPDATED, self.async_set_state
)
self.async_accept_signal(
self._channel, SIGNAL_STATE_ATTR, self.async_update_state_attribute
)
@property
def device_class(self) -> str:
@ -112,37 +118,25 @@ class Sensor(ZhaEntity):
return self._device_class
@property
def unit_of_measurement(self):
def unit_of_measurement(self) -> Optional[str]:
"""Return the unit of measurement of this entity."""
return self._unit
@property
def state(self) -> str:
def state(self) -> StateType:
"""Return the state of the entity."""
if self._state is None:
assert self.SENSOR_ATTR is not None
raw_state = self._channel.cluster.get(self.SENSOR_ATTR)
if raw_state is None:
return None
return self._state
return self.formatter(raw_state)
@callback
def async_set_state(self, attr_id, attr_name, value):
def async_set_state(self, attr_id: int, attr_name: str, value: Any) -> None:
"""Handle state update from channel."""
if self.SENSOR_ATTR is None or self.SENSOR_ATTR != attr_name:
return
if value is not None:
value = self.formatter(value)
self._state = value
self.async_write_ha_state()
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
self._state = last_state.state
async def async_state_attr_provider(self):
"""Initialize device state attributes."""
return {}
def formatter(self, value):
def formatter(self, value: int) -> Union[int, float]:
"""Numeric pass-through formatter."""
if self._decimals > 0:
return round(
@ -167,7 +161,7 @@ class Battery(Sensor):
_unit = PERCENTAGE
@staticmethod
def formatter(value):
def formatter(value: int) -> int:
"""Return the state of the entity."""
# per zcl specs battery percent is reported at 200% ¯\_(ツ)_/¯
if not isinstance(value, numbers.Number) or value == -1:
@ -175,26 +169,21 @@ class Battery(Sensor):
value = round(value / 2)
return value
async def async_state_attr_provider(self):
@property
def device_state_attributes(self) -> Dict[str, Any]:
"""Return device state attrs for battery sensors."""
state_attrs = {}
attributes = ["battery_size", "battery_quantity"]
results = await self._channel.get_attributes(attributes)
battery_size = results.get("battery_size")
battery_size = self._channel.cluster.get("battery_size")
if battery_size is not None:
state_attrs["battery_size"] = BATTERY_SIZES.get(battery_size, "Unknown")
battery_quantity = results.get("battery_quantity")
battery_quantity = self._channel.cluster.get("battery_quantity")
if battery_quantity is not None:
state_attrs["battery_quantity"] = battery_quantity
battery_voltage = self._channel.cluster.get("battery_voltage")
if battery_voltage is not None:
state_attrs["battery_voltage"] = round(battery_voltage / 10, 1)
return state_attrs
@callback
def async_update_state_attribute(self, key, value):
"""Update a single device state attribute."""
if key == "battery_voltage":
self._device_state_attributes[key] = round(value / 10, 1)
self.async_write_ha_state()
@STRICT_MATCH(channel_names=CHANNEL_ELECTRICAL_MEASUREMENT)
class ElectricalMeasurement(Sensor):
@ -202,7 +191,6 @@ class ElectricalMeasurement(Sensor):
SENSOR_ATTR = "active_power"
_device_class = DEVICE_CLASS_POWER
_divisor = 10
_unit = POWER_WATT
@property
@ -210,7 +198,7 @@ class ElectricalMeasurement(Sensor):
"""Return True if HA needs to poll for state changes."""
return True
def formatter(self, value) -> int:
def formatter(self, value: int) -> Union[int, float]:
"""Return 'normalized' value."""
value = value * self._channel.multiplier / self._channel.divisor
if value < 100 and self._channel.divisor > 1:
@ -244,7 +232,7 @@ class Illuminance(Sensor):
_unit = LIGHT_LUX
@staticmethod
def formatter(value):
def formatter(value: int) -> float:
"""Convert illumination data."""
return round(pow(10, ((value - 1) / 10000)), 1)
@ -256,12 +244,12 @@ class SmartEnergyMetering(Sensor):
SENSOR_ATTR = "instantaneous_demand"
_device_class = DEVICE_CLASS_POWER
def formatter(self, value):
def formatter(self, value: int) -> Union[int, float]:
"""Pass through channel formatter."""
return self._channel.formatter_function(value)
@property
def unit_of_measurement(self):
def unit_of_measurement(self) -> str:
"""Return Unit of measurement."""
return self._channel.unit_of_measurement
@ -284,14 +272,3 @@ class Temperature(Sensor):
_device_class = DEVICE_CLASS_TEMPERATURE
_divisor = 100
_unit = TEMP_CELSIUS
@callback
def async_restore_last_state(self, last_state):
"""Restore previous state."""
if last_state.state == STATE_UNKNOWN:
return
if last_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) != TEMP_CELSIUS:
ftemp = float(last_state.state)
self._state = round(fahrenheit_to_celsius(ftemp), 1)
return
self._state = last_state.state

View File

@ -93,18 +93,59 @@ async def async_test_electrical_measurement(hass, cluster, entity_id):
assert_state(hass, entity_id, "9.9", POWER_WATT)
async def async_test_powerconfiguration(hass, cluster, entity_id):
"""Test powerconfiguration/battery sensor."""
await send_attributes_report(hass, cluster, {33: 98})
assert_state(hass, entity_id, "49", "%")
assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.9
assert hass.states.get(entity_id).attributes["battery_quantity"] == 3
assert hass.states.get(entity_id).attributes["battery_size"] == "AAA"
await send_attributes_report(hass, cluster, {32: 20})
assert hass.states.get(entity_id).attributes["battery_voltage"] == 2.0
@pytest.mark.parametrize(
"cluster_id, test_func, report_count",
"cluster_id, test_func, report_count, read_plug",
(
(measurement.RelativeHumidity.cluster_id, async_test_humidity, 1),
(measurement.TemperatureMeasurement.cluster_id, async_test_temperature, 1),
(measurement.PressureMeasurement.cluster_id, async_test_pressure, 1),
(measurement.IlluminanceMeasurement.cluster_id, async_test_illuminance, 1),
(smartenergy.Metering.cluster_id, async_test_metering, 1),
(measurement.RelativeHumidity.cluster_id, async_test_humidity, 1, None),
(
measurement.TemperatureMeasurement.cluster_id,
async_test_temperature,
1,
None,
),
(measurement.PressureMeasurement.cluster_id, async_test_pressure, 1, None),
(
measurement.IlluminanceMeasurement.cluster_id,
async_test_illuminance,
1,
None,
),
(
smartenergy.Metering.cluster_id,
async_test_metering,
1,
{
"demand_formatting": 0xF9,
"divisor": 1,
"multiplier": 1,
},
),
(
homeautomation.ElectricalMeasurement.cluster_id,
async_test_electrical_measurement,
1,
None,
),
(
general.PowerConfiguration.cluster_id,
async_test_powerconfiguration,
2,
{
"battery_size": 4, # AAA
"battery_voltage": 29,
"battery_quantity": 3,
},
),
),
)
@ -115,6 +156,7 @@ async def test_sensor(
cluster_id,
test_func,
report_count,
read_plug,
):
"""Test zha sensor platform."""
@ -128,6 +170,10 @@ async def test_sensor(
}
)
cluster = zigpy_device.endpoints[1].in_clusters[cluster_id]
if cluster_id == smartenergy.Metering.cluster_id:
# this one is mains powered
zigpy_device.node_desc.mac_capability_flags |= 0b_0000_0100
cluster.PLUGGED_ATTR_READS = read_plug
zha_device = await zha_device_joined_restored(zigpy_device)
entity_id = await find_entity_id(DOMAIN, zha_device, hass)