From 39229ce098797257a57f7d5bf2f6af93b369480b Mon Sep 17 00:00:00 2001 From: dougiteixeira <31328123+dougiteixeira@users.noreply.github.com> Date: Mon, 26 Jun 2023 13:08:13 -0300 Subject: [PATCH] Add the device of the source entity in the helper entities for Utility Meter (#94734) Co-authored-by: Franck Nijhof --- .../components/utility_meter/select.py | 52 +++++++++++- .../components/utility_meter/sensor.py | 32 +++++++- tests/components/utility_meter/test_sensor.py | 79 ++++++++++++++++++- 3 files changed, 158 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/utility_meter/select.py b/homeassistant/components/utility_meter/select.py index 55845569af0..cf0e6e91ffb 100644 --- a/homeassistant/components/utility_meter/select.py +++ b/homeassistant/components/utility_meter/select.py @@ -7,11 +7,22 @@ from homeassistant.components.select import SelectEntity from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_UNIQUE_ID from homeassistant.core import HomeAssistant +from homeassistant.helpers import ( + device_registry as dr, + entity_registry as er, +) +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from .const import CONF_METER, CONF_TARIFFS, DATA_UTILITY, TARIFF_ICON +from .const import ( + CONF_METER, + CONF_SOURCE_SENSOR, + CONF_TARIFFS, + DATA_UTILITY, + TARIFF_ICON, +) _LOGGER = logging.getLogger(__name__) @@ -26,7 +37,35 @@ async def async_setup_entry( tariffs: list[str] = config_entry.options[CONF_TARIFFS] unique_id = config_entry.entry_id - tariff_select = TariffSelect(name, tariffs, unique_id) + + registry = er.async_get(hass) + source_entity = registry.async_get(config_entry.options[CONF_SOURCE_SENSOR]) + dev_reg = dr.async_get(hass) + # Resolve source entity device + if ( + (source_entity is not None) + and (source_entity.device_id is not None) + and ( + ( + device := dev_reg.async_get( + device_id=source_entity.device_id, + ) + ) + is not None + ) + ): + device_info = DeviceInfo( + identifiers=device.identifiers, + ) + else: + device_info = None + + tariff_select = TariffSelect( + name, + tariffs, + unique_id, + device_info=device_info, + ) async_add_entities([tariff_select]) @@ -63,10 +102,17 @@ async def async_setup_platform( class TariffSelect(SelectEntity, RestoreEntity): """Representation of a Tariff selector.""" - def __init__(self, name, tariffs, unique_id): + def __init__( + self, + name, + tariffs, + unique_id, + device_info: DeviceInfo | None = None, + ) -> None: """Initialize a tariff selector.""" self._attr_name = name self._attr_unique_id = unique_id + self._attr_device_info = device_info self._current_tariff: str | None = None self._tariffs = tariffs self._attr_icon = TARIFF_ICON diff --git a/homeassistant/components/utility_meter/sensor.py b/homeassistant/components/utility_meter/sensor.py index 7ad5afaa503..5f426fc49c5 100644 --- a/homeassistant/components/utility_meter/sensor.py +++ b/homeassistant/components/utility_meter/sensor.py @@ -28,8 +28,13 @@ from homeassistant.const import ( UnitOfEnergy, ) from homeassistant.core import Event, HomeAssistant, State, callback -from homeassistant.helpers import entity_platform, entity_registry as er +from homeassistant.helpers import ( + device_registry as dr, + entity_platform, + entity_registry as er, +) from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( async_track_point_in_time, @@ -120,6 +125,27 @@ async def async_setup_entry( registry, config_entry.options[CONF_SOURCE_SENSOR] ) + source_entity = registry.async_get(source_entity_id) + dev_reg = dr.async_get(hass) + # Resolve source entity device + if ( + (source_entity is not None) + and (source_entity.device_id is not None) + and ( + ( + device := dev_reg.async_get( + device_id=source_entity.device_id, + ) + ) + is not None + ) + ): + device_info = DeviceInfo( + identifiers=device.identifiers, + ) + else: + device_info = None + cron_pattern = None delta_values = config_entry.options[CONF_METER_DELTA_VALUES] meter_offset = timedelta(days=config_entry.options[CONF_METER_OFFSET]) @@ -149,6 +175,7 @@ async def async_setup_entry( tariff_entity=tariff_entity, tariff=None, unique_id=entry_id, + device_info=device_info, ) meters.append(meter_sensor) hass.data[DATA_UTILITY][entry_id][DATA_TARIFF_SENSORS].append(meter_sensor) @@ -168,6 +195,7 @@ async def async_setup_entry( tariff_entity=tariff_entity, tariff=tariff, unique_id=f"{entry_id}_{tariff}", + device_info=device_info, ) meters.append(meter_sensor) hass.data[DATA_UTILITY][entry_id][DATA_TARIFF_SENSORS].append(meter_sensor) @@ -341,9 +369,11 @@ class UtilityMeterSensor(RestoreSensor): tariff, unique_id, suggested_entity_id=None, + device_info=None, ): """Initialize the Utility Meter sensor.""" self._attr_unique_id = unique_id + self._attr_device_info = device_info self.entity_id = suggested_entity_id self._parent_meter = parent_meter self._sensor_source_id = source_entity diff --git a/tests/components/utility_meter/test_sensor.py b/tests/components/utility_meter/test_sensor.py index 65892ae376a..1e26d5e211a 100644 --- a/tests/components/utility_meter/test_sensor.py +++ b/tests/components/utility_meter/test_sensor.py @@ -41,7 +41,7 @@ from homeassistant.const import ( UnitOfEnergy, ) from homeassistant.core import CoreState, HomeAssistant, State -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -1458,3 +1458,80 @@ def test_calculate_adjustment_invalid_new_state( new_state: State = State(entity_id="sensor.test", state="unknown") assert mock_sensor.calculate_adjustment(None, new_state) is None assert "Invalid state unknown" in caplog.text + + +async def test_device_id(hass: HomeAssistant) -> None: + """Test for source entity device for Utility Meter.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + source_config_entry = MockConfigEntry() + source_device_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test")}, + ) + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + utility_meter_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "cycle": "monthly", + "delta_values": False, + "name": "Energy", + "net_consumption": False, + "offset": 0, + "periodically_resetting": True, + "source": "sensor.test_source", + "tariffs": ["peak", "offpeak"], + }, + title="Energy", + ) + + utility_meter_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(utility_meter_config_entry.entry_id) + await hass.async_block_till_done() + + utility_meter_entity = entity_registry.async_get("sensor.energy_peak") + assert utility_meter_entity is not None + assert utility_meter_entity.device_id == source_entity.device_id + + utility_meter_entity = entity_registry.async_get("sensor.energy_offpeak") + assert utility_meter_entity is not None + assert utility_meter_entity.device_id == source_entity.device_id + + utility_meter_no_tariffs_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "cycle": "monthly", + "delta_values": False, + "name": "Energy", + "net_consumption": False, + "offset": 0, + "periodically_resetting": True, + "source": "sensor.test_source", + "tariffs": [], + }, + title="Energy", + ) + + utility_meter_no_tariffs_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup( + utility_meter_no_tariffs_config_entry.entry_id + ) + await hass.async_block_till_done() + + utility_meter_no_tariffs_entity = entity_registry.async_get("sensor.energy") + assert utility_meter_no_tariffs_entity is not None + assert utility_meter_no_tariffs_entity.device_id == source_entity.device_id