From b143390d8802b0ab69136785ebd3b5d44db78211 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 28 Mar 2024 13:24:44 +0100 Subject: [PATCH] Improve device class of utility meter (#114368) --- .../components/utility_meter/sensor.py | 37 +++-- tests/components/utility_meter/test_sensor.py | 138 +++++++++++++++--- 2 files changed, 146 insertions(+), 29 deletions(-) diff --git a/homeassistant/components/utility_meter/sensor.py b/homeassistant/components/utility_meter/sensor.py index 4e9be403cf7..26582df1b44 100644 --- a/homeassistant/components/utility_meter/sensor.py +++ b/homeassistant/components/utility_meter/sensor.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime, timedelta from decimal import Decimal, DecimalException, InvalidOperation @@ -13,6 +14,7 @@ import voluptuous as vol from homeassistant.components.sensor import ( ATTR_LAST_RESET, + DEVICE_CLASS_UNITS, RestoreSensor, SensorDeviceClass, SensorExtraStoredData, @@ -21,12 +23,12 @@ from homeassistant.components.sensor import ( from homeassistant.components.sensor.recorder import _suggest_report_issue from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, CONF_NAME, CONF_UNIQUE_ID, STATE_UNAVAILABLE, STATE_UNKNOWN, - UnitOfEnergy, ) from homeassistant.core import Event, HomeAssistant, State, callback from homeassistant.helpers import ( @@ -47,6 +49,7 @@ from homeassistant.helpers.template import is_number from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util import slugify import homeassistant.util.dt as dt_util +from homeassistant.util.enum import try_parse_enum from .const import ( ATTR_CRON_PATTERN, @@ -97,12 +100,6 @@ ATTR_LAST_PERIOD = "last_period" ATTR_LAST_VALID_STATE = "last_valid_state" ATTR_TARIFF = "tariff" -DEVICE_CLASS_MAP = { - UnitOfEnergy.WATT_HOUR: SensorDeviceClass.ENERGY, - UnitOfEnergy.KILO_WATT_HOUR: SensorDeviceClass.ENERGY, -} - - PRECISION = 3 PAUSED = "paused" COLLECTING = "collecting" @@ -313,6 +310,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData): last_reset: datetime | None last_valid_state: Decimal | None status: str + input_device_class: SensorDeviceClass | None def as_dict(self) -> dict[str, Any]: """Return a dict representation of the utility sensor data.""" @@ -324,6 +322,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData): str(self.last_valid_state) if self.last_valid_state else None ) data["status"] = self.status + data["input_device_class"] = str(self.input_device_class) return data @@ -343,6 +342,9 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData): else None ) status: str = restored["status"] + input_device_class = try_parse_enum( + SensorDeviceClass, restored.get("input_device_class") + ) except KeyError: # restored is a dict, but does not have all values return None @@ -357,6 +359,7 @@ class UtilitySensorExtraStoredData(SensorExtraStoredData): last_reset, last_valid_state, status, + input_device_class, ) @@ -397,6 +400,7 @@ class UtilityMeterSensor(RestoreSensor): self._last_valid_state = None self._collecting = None self._name = name + self._input_device_class = None self._unit_of_measurement = None self._period = meter_type if meter_type is not None: @@ -416,9 +420,10 @@ class UtilityMeterSensor(RestoreSensor): self._tariff = tariff self._tariff_entity = tariff_entity - def start(self, unit): + def start(self, attributes: Mapping[str, Any]) -> None: """Initialize unit and state upon source initial update.""" - self._unit_of_measurement = unit + self._input_device_class = attributes.get(ATTR_DEVICE_CLASS) + self._unit_of_measurement = attributes.get(ATTR_UNIT_OF_MEASUREMENT) self._state = 0 self.async_write_ha_state() @@ -482,6 +487,7 @@ class UtilityMeterSensor(RestoreSensor): new_state = event.data["new_state"] if new_state is None: return + new_state_attributes: Mapping[str, Any] = new_state.attributes or {} # First check if the new_state is valid (see discussion in PR #88446) if (new_state_val := self._validate_state(new_state)) is None: @@ -498,7 +504,7 @@ class UtilityMeterSensor(RestoreSensor): for sensor in self.hass.data[DATA_UTILITY][self._parent_meter][ DATA_TARIFF_SENSORS ]: - sensor.start(new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)) + sensor.start(new_state_attributes) if self._unit_of_measurement is None: _LOGGER.warning( "Source sensor %s has no unit of measurement. Please %s", @@ -512,7 +518,8 @@ class UtilityMeterSensor(RestoreSensor): # If net_consumption is off, the adjustment must be non-negative self._state += adjustment # type: ignore[operator] # self._state will be set to by the start function if it is None, therefore it always has a valid Decimal value at this line - self._unit_of_measurement = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + self._input_device_class = new_state_attributes.get(ATTR_DEVICE_CLASS) + self._unit_of_measurement = new_state_attributes.get(ATTR_UNIT_OF_MEASUREMENT) self._last_valid_state = new_state_val self.async_write_ha_state() @@ -600,6 +607,7 @@ class UtilityMeterSensor(RestoreSensor): if (last_sensor_data := await self.async_get_last_sensor_data()) is not None: # new introduced in 2022.04 self._state = last_sensor_data.native_value + self._input_device_class = last_sensor_data.input_device_class self._unit_of_measurement = last_sensor_data.native_unit_of_measurement self._last_period = last_sensor_data.last_period self._last_reset = last_sensor_data.last_reset @@ -693,7 +701,11 @@ class UtilityMeterSensor(RestoreSensor): @property def device_class(self): """Return the device class of the sensor.""" - return DEVICE_CLASS_MAP.get(self._unit_of_measurement) + if self._input_device_class is not None: + return self._input_device_class + if self._unit_of_measurement in DEVICE_CLASS_UNITS[SensorDeviceClass.ENERGY]: + return SensorDeviceClass.ENERGY + return None @property def state_class(self): @@ -744,6 +756,7 @@ class UtilityMeterSensor(RestoreSensor): self._last_reset, self._last_valid_state, PAUSED if self._collecting is None else COLLECTING, + self._input_device_class, ) async def async_get_last_sensor_data(self) -> UtilitySensorExtraStoredData | None: diff --git a/tests/components/utility_meter/test_sensor.py b/tests/components/utility_meter/test_sensor.py index 13b367b1fb7..99a63809329 100644 --- a/tests/components/utility_meter/test_sensor.py +++ b/tests/components/utility_meter/test_sensor.py @@ -40,6 +40,7 @@ from homeassistant.const import ( STATE_UNAVAILABLE, STATE_UNKNOWN, UnitOfEnergy, + UnitOfVolume, ) from homeassistant.core import CoreState, HomeAssistant, State from homeassistant.helpers import device_registry as dr, entity_registry as er @@ -553,8 +554,66 @@ async def test_entity_name(hass: HomeAssistant, yaml_config, entity_id, name) -> ), ], ) +@pytest.mark.parametrize( + ( + "energy_sensor_attributes", + "gas_sensor_attributes", + "energy_meter_attributes", + "gas_meter_attributes", + ), + [ + ( + {ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR}, + {ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"}, + { + ATTR_DEVICE_CLASS: SensorDeviceClass.ENERGY, + ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR, + }, + { + ATTR_DEVICE_CLASS: None, + ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit", + }, + ), + ( + {}, + {}, + { + ATTR_DEVICE_CLASS: None, + ATTR_UNIT_OF_MEASUREMENT: None, + }, + { + ATTR_DEVICE_CLASS: None, + ATTR_UNIT_OF_MEASUREMENT: None, + }, + ), + ( + { + ATTR_DEVICE_CLASS: SensorDeviceClass.GAS, + ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR, + }, + { + ATTR_DEVICE_CLASS: SensorDeviceClass.WATER, + ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit", + }, + { + ATTR_DEVICE_CLASS: SensorDeviceClass.GAS, + ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR, + }, + { + ATTR_DEVICE_CLASS: SensorDeviceClass.WATER, + ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit", + }, + ), + ], +) async def test_device_class( - hass: HomeAssistant, yaml_config, config_entry_configs + hass: HomeAssistant, + yaml_config, + config_entry_configs, + energy_sensor_attributes, + gas_sensor_attributes, + energy_meter_attributes, + gas_meter_attributes, ) -> None: """Test utility device_class.""" if yaml_config: @@ -579,27 +638,23 @@ async def test_device_class( await hass.async_block_till_done() - hass.states.async_set( - entity_id_energy, 2, {ATTR_UNIT_OF_MEASUREMENT: UnitOfEnergy.KILO_WATT_HOUR} - ) - hass.states.async_set( - entity_id_gas, 2, {ATTR_UNIT_OF_MEASUREMENT: "some_archaic_unit"} - ) + hass.states.async_set(entity_id_energy, 2, energy_sensor_attributes) + hass.states.async_set(entity_id_gas, 2, gas_sensor_attributes) await hass.async_block_till_done() state = hass.states.get("sensor.energy_meter") assert state is not None assert state.state == "0" - assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR + for attr, value in energy_meter_attributes.items(): + assert state.attributes.get(attr) == value state = hass.states.get("sensor.gas_meter") assert state is not None assert state.state == "0" - assert state.attributes.get(ATTR_DEVICE_CLASS) is None assert state.attributes.get(ATTR_STATE_CLASS) is SensorStateClass.TOTAL_INCREASING - assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == "some_archaic_unit" + for attr, value in gas_meter_attributes.items(): + assert state.attributes.get(attr) == value @pytest.mark.parametrize( @@ -610,7 +665,13 @@ async def test_device_class( "utility_meter": { "energy_bill": { "source": "sensor.energy", - "tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"], + "tariffs": [ + "tariff0", + "tariff1", + "tariff2", + "tariff3", + "tariff4", + ], } } }, @@ -626,7 +687,13 @@ async def test_device_class( "offset": 0, "periodically_resetting": True, "source": "sensor.energy", - "tariffs": ["tariff1", "tariff2", "tariff3", "tariff4"], + "tariffs": [ + "tariff0", + "tariff1", + "tariff2", + "tariff3", + "tariff4", + ], }, ), ], @@ -644,7 +711,33 @@ async def test_restore_state( mock_restore_cache_with_extra_data( hass, [ - # sensor.energy_bill_tariff1 is restored as expected + # sensor.energy_bill_tariff0 is restored as expected, including device + # class + ( + State( + "sensor.energy_bill_tariff0", + "0.1", + attributes={ + ATTR_STATUS: PAUSED, + ATTR_LAST_RESET: last_reset_1, + ATTR_UNIT_OF_MEASUREMENT: UnitOfVolume.CUBIC_METERS, + }, + ), + { + "native_value": { + "__type": "", + "decimal_str": "0.2", + }, + "native_unit_of_measurement": "gal", + "last_reset": last_reset_2, + "last_period": "1.3", + "last_valid_state": None, + "status": "collecting", + "input_device_class": "water", + }, + ), + # sensor.energy_bill_tariff1 is restored as expected, except device + # class ( State( "sensor.energy_bill_tariff1", @@ -743,12 +836,21 @@ async def test_restore_state( await hass.async_block_till_done() # restore from cache + state = hass.states.get("sensor.energy_bill_tariff0") + assert state.state == "0.2" + assert state.attributes.get("status") == COLLECTING + assert state.attributes.get("last_reset") == last_reset_2 + assert state.attributes.get("last_valid_state") == "None" + assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfVolume.GALLONS + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.WATER + state = hass.states.get("sensor.energy_bill_tariff1") assert state.state == "1.2" assert state.attributes.get("status") == PAUSED assert state.attributes.get("last_reset") == last_reset_2 assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.KILO_WATT_HOUR + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY state = hass.states.get("sensor.energy_bill_tariff2") assert state.state == "2.1" @@ -756,6 +858,7 @@ async def test_restore_state( assert state.attributes.get("last_reset") == last_reset_1 assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY state = hass.states.get("sensor.energy_bill_tariff3") assert state.state == "3.1" @@ -763,6 +866,7 @@ async def test_restore_state( assert state.attributes.get("last_reset") == last_reset_1 assert state.attributes.get("last_valid_state") == "None" assert state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) == UnitOfEnergy.MEGA_WATT_HOUR + assert state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.ENERGY state = hass.states.get("sensor.energy_bill_tariff4") assert state.state == STATE_UNKNOWN @@ -770,16 +874,16 @@ async def test_restore_state( # utility_meter is loaded, now set sensors according to utility_meter: hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) - await hass.async_block_till_done() state = hass.states.get("select.energy_bill") - assert state.state == "tariff1" + assert state.state == "tariff0" - state = hass.states.get("sensor.energy_bill_tariff1") + state = hass.states.get("sensor.energy_bill_tariff0") assert state.attributes.get("status") == COLLECTING for entity_id in ( + "sensor.energy_bill_tariff1", "sensor.energy_bill_tariff2", "sensor.energy_bill_tariff3", "sensor.energy_bill_tariff4",