Improve device class of utility meter (#114368)

pull/114764/head
Erik Montnemery 2024-03-28 13:24:44 +01:00 committed by Franck Nijhof
parent 42580a1113
commit b143390d88
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
2 changed files with 146 additions and 29 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal, DecimalException, InvalidOperation
@ -13,6 +14,7 @@ import voluptuous as vol
from homeassistant.components.sensor import (
ATTR_LAST_RESET,
DEVICE_CLASS_UNITS,
RestoreSensor,
SensorDeviceClass,
SensorExtraStoredData,
@ -21,12 +23,12 @@ from homeassistant.components.sensor import (
from homeassistant.components.sensor.recorder import _suggest_report_issue
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_UNIT_OF_MEASUREMENT,
CONF_NAME,
CONF_UNIQUE_ID,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
UnitOfEnergy,
)
from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers import (
@ -47,6 +49,7 @@ from homeassistant.helpers.template import is_number
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import slugify
import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum
from .const import (
ATTR_CRON_PATTERN,
@ -97,12 +100,6 @@ ATTR_LAST_PERIOD = "last_period"
ATTR_LAST_VALID_STATE = "last_valid_state"
ATTR_TARIFF = "tariff"
DEVICE_CLASS_MAP = {
UnitOfEnergy.WATT_HOUR: SensorDeviceClass.ENERGY,
UnitOfEnergy.KILO_WATT_HOUR: SensorDeviceClass.ENERGY,
}
PRECISION = 3
PAUSED = "paused"
COLLECTING = "collecting"
@ -313,6 +310,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
last_reset: datetime | None
last_valid_state: Decimal | None
status: str
input_device_class: SensorDeviceClass | None
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the utility sensor data."""
@ -324,6 +322,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
str(self.last_valid_state) if self.last_valid_state else None
)
data["status"] = self.status
data["input_device_class"] = str(self.input_device_class)
return data
@ -343,6 +342,9 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
else None
)
status: str = restored["status"]
input_device_class = try_parse_enum(
SensorDeviceClass, restored.get("input_device_class")
)
except KeyError:
# restored is a dict, but does not have all values
return None
@ -357,6 +359,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData):
last_reset,
last_valid_state,
status,
input_device_class,
)
@ -397,6 +400,7 @@ class UtilityMeterSensor(RestoreSensor):
self._last_valid_state = None
self._collecting = None
self._name = name
self._input_device_class = None
self._unit_of_measurement = None
self._period = meter_type
if meter_type is not None:
@ -416,9 +420,10 @@ class UtilityMeterSensor(RestoreSensor):
self._tariff = tariff
self._tariff_entity = tariff_entity
def start(self, unit):
def start(self, attributes: Mapping[str, Any]) -> None:
"""Initialize unit and state upon source initial update."""
self._unit_of_measurement = unit
self._input_device_class = attributes.get(ATTR_DEVICE_CLASS)
self._unit_of_measurement = attributes.get(ATTR_UNIT_OF_MEASUREMENT)
self._state = 0
self.async_write_ha_state()
@ -482,6 +487,7 @@ class UtilityMeterSensor(RestoreSensor):
new_state = event.data["new_state"]
if new_state is None:
return
new_state_attributes: Mapping[str, Any] = new_state.attributes or {}
# First check if the new_state is valid (see discussion in PR #88446)
if (new_state_val := self._validate_state(new_state)) is None:
@ -498,7 +504,7 @@ class UtilityMeterSensor(RestoreSensor):
for sensor in self.hass.data[DATA_UTILITY][self._parent_meter][
DATA_TARIFF_SENSORS
]:
sensor.start(new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT))
sensor.start(new_state_attributes)
if self._unit_of_measurement is None:
_LOGGER.warning(
"Source sensor %s has no unit of measurement. Please %s",
@ -512,7 +518,8 @@ class UtilityMeterSensor(RestoreSensor):
# If net_consumption is off, the adjustment must be non-negative
self._state += adjustment # type: ignore[operator] # self._state will be set to by the start function if it is None, therefore it always has a valid Decimal value at this line
self._unit_of_measurement = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
self._input_device_class = new_state_attributes.get(ATTR_DEVICE_CLASS)
self._unit_of_measurement = new_state_attributes.get(ATTR_UNIT_OF_MEASUREMENT)
self._last_valid_state = new_state_val
self.async_write_ha_state()
@ -600,6 +607,7 @@ class UtilityMeterSensor(RestoreSensor):
if (last_sensor_data := await self.async_get_last_sensor_data()) is not None:
# new introduced in 2022.04
self._state = last_sensor_data.native_value
self._input_device_class = last_sensor_data.input_device_class
self._unit_of_measurement = last_sensor_data.native_unit_of_measurement
self._last_period = last_sensor_data.last_period
self._last_reset = last_sensor_data.last_reset
@ -693,7 +701,11 @@ class UtilityMeterSensor(RestoreSensor):
@property
def device_class(self):
"""Return the device class of the sensor."""
return DEVICE_CLASS_MAP.get(self._unit_of_measurement)
if self._input_device_class is not None:
return self._input_device_class
if self._unit_of_measurement in DEVICE_CLASS_UNITS[SensorDeviceClass.ENERGY]:
return SensorDeviceClass.ENERGY
return None
@property
def state_class(self):
@ -744,6 +756,7 @@ class UtilityMeterSensor(RestoreSensor):
self._last_reset,
self._last_valid_state,
PAUSED if self._collecting is None else COLLECTING,
self._input_device_class,
)
async def async_get_last_sensor_data(self) -> UtilitySensorExtraStoredData | None:

View File

@ -40,6 +40,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
UnitOfEnergy,
UnitOfVolume,
)
from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.helpers import device_registry as dr, entity_registry as er
@ -553,8 +554,66 @@ async def test_entity_name(hass: HomeAssistant, yaml_config, entity_id, name) ->
),
],
)
@pytest.mark.parametrize(
(
"energy_sensor_attributes",
"gas_sensor_attributes",
"energy_meter_attributes",
"gas_meter_attributes",
),
[
(
{ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR},
{ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.ENERGY,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
),
(
{},
{},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: None,
},
{
ATTR_DEVICE_CLASS: None,
ATTR_UNIT_OF_MEASUREMENT: None,
},
),
(
{
ATTR_DEVICE_CLASS: SensorDeviceClass.GAS,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.WATER,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.GAS,
ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR,
},
{
ATTR_DEVICE_CLASS: SensorDeviceClass.WATER,
ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit",
},
),
],
)
async def test_device_class(
hass: HomeAssistant, yaml_config, config_entry_configs
hass: HomeAssistant,
yaml_config,
config_entry_configs,
energy_sensor_attributes,
gas_sensor_attributes,
energy_meter_attributes,
gas_meter_attributes,
) -> None:
"""Test utility device_class."""
if yaml_config:
@ -579,27 +638,23 @@ async def test_device_class(
await hass.async_block_till_done()
hass.states.async_set(
entity_id_energy, 2, {ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR}
)
hass.states.async_set(
entity_id_gas, 2, {ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"}
)
hass.states.async_set(entity_id_energy, 2, energy_sensor_attributes)
hass.states.async_set(entity_id_gas, 2, gas_sensor_attributes)
await hass.async_block_till_done()
state = hass.states.get("sensor.energy_meter")
assert state is not None
assert state.state == "0"
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR
for attr, value in energy_meter_attributes.items():
assert state.attributes.get(attr) == value
state = hass.states.get("sensor.gas_meter")
assert state is not None
assert state.state == "0"
assert state.attributes.get(ATTR_DEVICE_CLASS) is None
assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL_INCREASING
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "some_archaic_unit"
for attr, value in gas_meter_attributes.items():
assert state.attributes.get(attr) == value
@pytest.mark.parametrize(
@ -610,7 +665,13 @@ async def test_device_class(
"utility_meter": {
"energy_bill": {
"source": "sensor.energy",
"tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"],
"tariffs": [
"tariff0",
"tariff1",
"tariff2",
"tariff3",
"tariff4",
],
}
}
},
@ -626,7 +687,13 @@ async def test_device_class(
"offset": 0,
"periodically_resetting": True,
"source": "sensor.energy",
"tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"],
"tariffs": [
"tariff0",
"tariff1",
"tariff2",
"tariff3",
"tariff4",
],
},
),
],
@ -644,7 +711,33 @@ async def test_restore_state(
mock_restore_cache_with_extra_data(
hass,
[
# sensor.energy_bill_tariff1 is restored as expected
# sensor.energy_bill_tariff0 is restored as expected, including device
# class
(
State(
"sensor.energy_bill_tariff0",
"0.1",
attributes={
ATTR_STATUS: PAUSED,
ATTR_LAST_RESET: last_reset_1,
ATTR_UNIT_OF_MEASUREMENT: UnitOfVolume.CUBIC_METERS,
},
),
{
"native_value": {
"__type": "<class 'decimal.Decimal'>",
"decimal_str": "0.2",
},
"native_unit_of_measurement": "gal",
"last_reset": last_reset_2,
"last_period": "1.3",
"last_valid_state": None,
"status": "collecting",
"input_device_class": "water",
},
),
# sensor.energy_bill_tariff1 is restored as expected, except device
# class
(
State(
"sensor.energy_bill_tariff1",
@ -743,12 +836,21 @@ async def test_restore_state(
await hass.async_block_till_done()
# restore from cache
state = hass.states.get("sensor.energy_bill_tariff0")
assert state.state == "0.2"
assert state.attributes.get("status") == COLLECTING
assert state.attributes.get("last_reset") == last_reset_2
assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfVolume.GALLONS
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.WATER
state = hass.states.get("sensor.energy_bill_tariff1")
assert state.state == "1.2"
assert state.attributes.get("status") == PAUSED
assert state.attributes.get("last_reset") == last_reset_2
assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff2")
assert state.state == "2.1"
@ -756,6 +858,7 @@ async def test_restore_state(
assert state.attributes.get("last_reset") == last_reset_1
assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff3")
assert state.state == "3.1"
@ -763,6 +866,7 @@ async def test_restore_state(
assert state.attributes.get("last_reset") == last_reset_1
assert state.attributes.get("last_valid_state") == "None"
assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR
assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY
state = hass.states.get("sensor.energy_bill_tariff4")
assert state.state == STATE_UNKNOWN
@ -770,16 +874,16 @@ async def test_restore_state(
# utility_meter is loaded, now set sensors according to utility_meter:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
state = hass.states.get("select.energy_bill")
assert state.state == "tariff1"
assert state.state == "tariff0"
state = hass.states.get("sensor.energy_bill_tariff1")
state = hass.states.get("sensor.energy_bill_tariff0")
assert state.attributes.get("status") == COLLECTING
for entity_id in (
"sensor.energy_bill_tariff1",
"sensor.energy_bill_tariff2",
"sensor.energy_bill_tariff3",
"sensor.energy_bill_tariff4",