Fix calculation of attributes in group sensor (#128601)

* Fix calculation of attributes in group sensor

* Fixes

* Fixes

* Make module level function
pull/129048/head
G Johansson 2024-10-23 20:51:18 +02:00 committed by GitHub
parent 80984c94a1
commit 6ee6a8a74f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 296 additions and 68 deletions

View File

@ -36,14 +36,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
HomeAssistant,
State,
callback,
)
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity import (
@ -52,7 +45,6 @@ from homeassistant.helpers.entity import (
get_unit_of_measurement,
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.issue_registry import (
IssueSeverity,
async_create_issue,
@ -180,6 +172,17 @@ def async_create_preview_sensor(
)
def _has_numeric_state(hass: HomeAssistant, entity_id: str) -> bool:
"""Test if state is numeric."""
if not (state := hass.states.get(entity_id)):
return False
try:
float(state.state)
except ValueError:
return False
return True
def calc_min(
sensor_values: list[tuple[str, float, State]],
) -> tuple[dict[str, str | None], float | None]:
@ -332,12 +335,11 @@ class SensorGroup(GroupEntity, SensorEntity):
self.hass = hass
self._entity_ids = entity_ids
self._sensor_type = sensor_type
self._state_class = state_class
self._device_class = device_class
self._native_unit_of_measurement = unit_of_measurement
self._configured_state_class = state_class
self._configured_device_class = device_class
self._configured_unit_of_measurement = unit_of_measurement
self._valid_units: set[str | None] = set()
self._can_convert: bool = False
self.calculate_attributes_later: CALLBACK_TYPE | None = None
self._attr_name = name
if name == DEFAULT_NAME:
self._attr_name = f"{DEFAULT_NAME} {sensor_type}".capitalize()
@ -352,39 +354,25 @@ class SensorGroup(GroupEntity, SensorEntity):
self._state_incorrect: set[str] = set()
self._extra_state_attribute: dict[str, Any] = {}
async def async_added_to_hass(self) -> None:
"""When added to hass."""
for entity_id in self._entity_ids:
if self.hass.states.get(entity_id) is None:
self.calculate_attributes_later = async_track_state_change_event(
self.hass, self._entity_ids, self.calculate_state_attributes
)
break
if not self.calculate_attributes_later:
await self.calculate_state_attributes()
await super().async_added_to_hass()
async def calculate_state_attributes(
self, event: Event[EventStateChangedData] | None = None
) -> None:
def calculate_state_attributes(self, valid_state_entities: list[str]) -> None:
"""Calculate state attributes."""
for entity_id in self._entity_ids:
if self.hass.states.get(entity_id) is None:
return
if self.calculate_attributes_later:
self.calculate_attributes_later()
self.calculate_attributes_later = None
self._attr_state_class = self._calculate_state_class(self._state_class)
self._attr_device_class = self._calculate_device_class(self._device_class)
self._attr_state_class = self._calculate_state_class(
self._configured_state_class, valid_state_entities
)
self._attr_device_class = self._calculate_device_class(
self._configured_device_class, valid_state_entities
)
self._attr_native_unit_of_measurement = self._calculate_unit_of_measurement(
self._native_unit_of_measurement
self._configured_unit_of_measurement, valid_state_entities
)
self._valid_units = self._get_valid_units()
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the sensor group state."""
self.calculate_state_attributes(self._get_valid_entities())
states: list[StateType] = []
valid_units = self._valid_units
valid_states: list[bool] = []
sensor_values: list[tuple[str, float, State]] = []
for entity_id in self._entity_ids:
@ -392,20 +380,18 @@ class SensorGroup(GroupEntity, SensorEntity):
states.append(state.state)
try:
numeric_state = float(state.state)
if (
self._valid_units
and (uom := state.attributes["unit_of_measurement"])
in self._valid_units
and self._can_convert is True
):
uom = state.attributes.get("unit_of_measurement")
# Convert the state to the native unit of measurement when we have valid units
# and a correct device class
if valid_units and uom in valid_units and self._can_convert is True:
numeric_state = UNIT_CONVERTERS[self.device_class].convert(
numeric_state, uom, self.native_unit_of_measurement
)
if (
self._valid_units
and (uom := state.attributes["unit_of_measurement"])
not in self._valid_units
):
# If we have valid units and the entity's unit does not match
# we raise which skips the state and log a warning once
if valid_units and uom not in valid_units:
raise HomeAssistantError("Not a valid unit") # noqa: TRY301
sensor_values.append((entity_id, numeric_state, state))
@ -480,7 +466,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None
def _calculate_state_class(
self, state_class: SensorStateClass | None
self,
state_class: SensorStateClass | None,
valid_state_entities: list[str],
) -> SensorStateClass | None:
"""Calculate state class.
@ -491,8 +479,18 @@ class SensorGroup(GroupEntity, SensorEntity):
"""
if state_class:
return state_class
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return state class if all states are valid when not ignoring non numeric
return None
state_classes: list[SensorStateClass] = []
for entity_id in self._entity_ids:
for entity_id in valid_state_entities:
try:
_state_class = get_capability(self.hass, entity_id, "state_class")
except HomeAssistantError:
@ -523,7 +521,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None
def _calculate_device_class(
self, device_class: SensorDeviceClass | None
self,
device_class: SensorDeviceClass | None,
valid_state_entities: list[str],
) -> SensorDeviceClass | None:
"""Calculate device class.
@ -534,8 +534,18 @@ class SensorGroup(GroupEntity, SensorEntity):
"""
if device_class:
return device_class
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return device class if all states are valid when not ignoring non numeric
return None
device_classes: list[SensorDeviceClass] = []
for entity_id in self._entity_ids:
for entity_id in valid_state_entities:
try:
_device_class = get_device_class(self.hass, entity_id)
except HomeAssistantError:
@ -568,7 +578,9 @@ class SensorGroup(GroupEntity, SensorEntity):
return None
def _calculate_unit_of_measurement(
self, unit_of_measurement: str | None
self,
unit_of_measurement: str | None,
valid_state_entities: list[str],
) -> str | None:
"""Calculate the unit of measurement.
@ -579,8 +591,17 @@ class SensorGroup(GroupEntity, SensorEntity):
if unit_of_measurement:
return unit_of_measurement
if not valid_state_entities:
return None
if not self._ignore_non_numeric and len(valid_state_entities) < len(
self._entity_ids
):
# Only return device class if all states are valid when not ignoring non numeric
return None
unit_of_measurements: list[str] = []
for entity_id in self._entity_ids:
for entity_id in valid_state_entities:
try:
_unit_of_measurement = get_unit_of_measurement(self.hass, entity_id)
except HomeAssistantError:
@ -665,19 +686,31 @@ class SensorGroup(GroupEntity, SensorEntity):
If device class is set and compatible unit of measurements.
If device class is not set, use one unit of measurement.
Only calculate valid units if there are no valid units set.
"""
if (
device_class := self.device_class
) in UNIT_CONVERTERS and self.native_unit_of_measurement:
if (valid_units := self._valid_units) and not self._ignore_non_numeric:
# If we have valid units already and not using ignore_non_numeric
# we should not recalculate.
return valid_units
native_uom = self.native_unit_of_measurement
if (device_class := self.device_class) in UNIT_CONVERTERS and native_uom:
self._can_convert = True
return UNIT_CONVERTERS[device_class].VALID_UNITS
if (
device_class
and (device_class) in DEVICE_CLASS_UNITS
and self.native_unit_of_measurement
):
if device_class and (device_class) in DEVICE_CLASS_UNITS and native_uom:
valid_uoms: set = DEVICE_CLASS_UNITS[device_class]
return valid_uoms
if device_class is None and self.native_unit_of_measurement:
return {self.native_unit_of_measurement}
if device_class is None and native_uom:
return {native_uom}
return set()
def _get_valid_entities(
self,
) -> list[str]:
"""Return list of valid entities."""
return [
entity_id
for entity_id in self._entity_ids
if _has_numeric_state(self.hass, entity_id)
]

View File

@ -32,6 +32,7 @@ from homeassistant.const import (
SERVICE_RELOAD,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
UnitOfTemperature,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry as ir
@ -496,7 +497,7 @@ async def test_sensor_with_uoms_but_no_device_class(
state = hass.states.get("sensor.test_sum")
assert state.attributes.get("device_class") is None
assert state.attributes.get("state_class") is None
assert state.attributes.get("unit_of_measurement") == "W"
assert state.attributes.get("unit_of_measurement") is None
assert state.state == STATE_UNKNOWN
assert (
@ -650,10 +651,10 @@ async def test_sensor_calculated_result_fails_on_uom(hass: HomeAssistant) -> Non
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNKNOWN
assert state.state == STATE_UNAVAILABLE
assert state.attributes.get("device_class") == "energy"
assert state.attributes.get("state_class") == "total"
assert state.attributes.get("unit_of_measurement") == "kWh"
assert state.attributes.get("unit_of_measurement") is None
async def test_sensor_calculated_properties_not_convertible_device_class(
@ -730,7 +731,7 @@ async def test_sensor_calculated_properties_not_convertible_device_class(
assert state.state == STATE_UNKNOWN
assert state.attributes.get("device_class") == "humidity"
assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") == "%"
assert state.attributes.get("unit_of_measurement") is None
assert (
"Unable to use state. Only entities with correct unit of measurement is"
@ -812,3 +813,197 @@ async def test_sensors_attributes_added_when_entity_info_available(
assert state.attributes.get(ATTR_ICON) is None
assert state.attributes.get(ATTR_STATE_CLASS) == SensorStateClass.TOTAL
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "L"
async def test_sensor_state_class_no_uom_not_available(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test when input sensors drops unit of measurement."""
# If we have a valid unit of measurement from all input sensors
# the group sensor will go unknown in the case any input sensor
# drops the unit of measurement and log a warning.
config = {
SENSOR_DOMAIN: {
"platform": GROUP_DOMAIN,
"name": "test_sum",
"type": "sum",
"entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"],
"unique_id": "very_unique_id_sum_sensor",
}
}
entity_ids = config["sensor"]["entities"]
input_attributes = {
"state_class": SensorStateClass.MEASUREMENT,
"unit_of_measurement": PERCENTAGE,
}
hass.states.async_set(entity_ids[0], VALUES[0], input_attributes)
hass.states.async_set(entity_ids[1], VALUES[1], input_attributes)
hass.states.async_set(entity_ids[2], VALUES[2], input_attributes)
await hass.async_block_till_done()
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == str(sum(VALUES))
assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") == "%"
assert (
"Unable to use state. Only entities with correct unit of measurement is"
" supported"
) not in caplog.text
# sensor.test_3 drops the unit of measurement
hass.states.async_set(
entity_ids[2],
VALUES[2],
{
"state_class": SensorStateClass.MEASUREMENT,
},
)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNKNOWN
assert state.attributes.get("state_class") == "measurement"
assert state.attributes.get("unit_of_measurement") is None
assert (
"Unable to use state. Only entities with correct unit of measurement is"
" supported, entity sensor.test_3, value 15.3 with"
" device class None and unit of measurement None excluded from calculation"
" in sensor.test_sum"
) in caplog.text
async def test_sensor_different_attributes_ignore_non_numeric(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the sensor handles calculating attributes when using ignore_non_numeric."""
config = {
SENSOR_DOMAIN: {
"platform": GROUP_DOMAIN,
"name": "test_sum",
"type": "sum",
"ignore_non_numeric": True,
"entities": ["sensor.test_1", "sensor.test_2", "sensor.test_3"],
"unique_id": "very_unique_id_sum_sensor",
}
}
entity_ids = config["sensor"]["entities"]
assert await async_setup_component(hass, "sensor", config)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == STATE_UNAVAILABLE
assert state.attributes.get("state_class") is None
assert state.attributes.get("device_class") is None
assert state.attributes.get("unit_of_measurement") is None
test_cases = [
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(VALUES[0])),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[1],
"value": VALUES[1],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum([VALUES[0], VALUES[1]]))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[2],
"value": VALUES[2],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.TEMPERATURE,
"unit_of_measurement": UnitOfTemperature.CELSIUS,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": None,
},
{
"entity": entity_ids[2],
"value": VALUES[2],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
# One sensor does not have a device class
"expected_device_class": None,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
"device_class": SensorDeviceClass.HUMIDITY,
"unit_of_measurement": PERCENTAGE,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
# First sensor now has a device class
"expected_device_class": SensorDeviceClass.HUMIDITY,
"expected_unit_of_measurement": PERCENTAGE,
},
{
"entity": entity_ids[0],
"value": VALUES[0],
"attributes": {
"state_class": SensorStateClass.MEASUREMENT,
},
"expected_state": str(float(sum(VALUES))),
"expected_state_class": SensorStateClass.MEASUREMENT,
"expected_device_class": None,
"expected_unit_of_measurement": None,
},
]
for test_case in test_cases:
hass.states.async_set(
test_case["entity"],
test_case["value"],
test_case["attributes"],
)
await hass.async_block_till_done()
state = hass.states.get("sensor.test_sum")
assert state.state == test_case["expected_state"]
assert state.attributes.get("state_class") == test_case["expected_state_class"]
assert (
state.attributes.get("device_class") == test_case["expected_device_class"]
)
assert (
state.attributes.get("unit_of_measurement")
== test_case["expected_unit_of_measurement"]
)