Optionally update sensor units when unit system is changed (#83851)

pull/87330/head
Erik Montnemery 2023-02-03 16:30:50 +01:00 committed by GitHub
parent 4b27af6a8f
commit 4d4fb2477d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 267 additions and 12 deletions

View File

@ -6,6 +6,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.sensor import async_update_suggested_units
from homeassistant.config import async_check_ha_config_file
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
@ -40,17 +41,18 @@ class CheckConfigView(HomeAssistantView):
@websocket_api.websocket_command(
{
"type": "config/core/update",
vol.Optional("latitude"): cv.latitude,
vol.Optional("longitude"): cv.longitude,
vol.Optional("country"): cv.country,
vol.Optional("currency"): cv.currency,
vol.Optional("elevation"): int,
vol.Optional("unit_system"): unit_system.validate_unit_system,
vol.Optional("location_name"): str,
vol.Optional("time_zone"): cv.time_zone,
vol.Optional("external_url"): vol.Any(cv.url_no_path, None),
vol.Optional("internal_url"): vol.Any(cv.url_no_path, None),
vol.Optional("currency"): cv.currency,
vol.Optional("country"): cv.country,
vol.Optional("language"): cv.language,
vol.Optional("latitude"): cv.latitude,
vol.Optional("location_name"): str,
vol.Optional("longitude"): cv.longitude,
vol.Optional("time_zone"): cv.time_zone,
vol.Optional("update_units"): bool,
vol.Optional("unit_system"): unit_system.validate_unit_system,
}
)
@websocket_api.async_response
@ -64,8 +66,12 @@ async def websocket_update_config(
data.pop("id")
data.pop("type")
update_units = data.pop("update_units", False)
try:
await hass.config.async_update(**data)
if update_units:
async_update_suggested_units(hass)
connection.send_result(msg["id"])
except ValueError as err:
connection.send_error(msg["id"], "invalid_info", str(err))

View File

@ -730,6 +730,17 @@ class SensorEntity(Entity):
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated."""
self._sensor_option_precision = self._custom_precision_or_none()
assert self.registry_entry
if (
sensor_options := self.registry_entry.options.get(f"{DOMAIN}.private")
) and "refresh_initial_entity_options" in sensor_options:
registry = er.async_get(self.hass)
initial_options = self.get_initial_entity_options() or {}
registry.async_update_entity_options(
self.entity_id,
f"{DOMAIN}.private",
initial_options.get(f"{DOMAIN}.private"),
)
self._sensor_option_unit_of_measurement = self._custom_unit_or_undef(
DOMAIN, CONF_UNIT_OF_MEASUREMENT
)
@ -808,3 +819,21 @@ class RestoreSensor(SensorEntity, RestoreEntity):
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
return None
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())
@callback
def async_update_suggested_units(hass: HomeAssistant) -> None:
"""Update the suggested_unit_of_measurement according to the unit system."""
registry = er.async_get(hass)
for entry in registry.entities.values():
if entry.domain != DOMAIN:
continue
sensor_private_options = dict(entry.options.get(f"{DOMAIN}.private", {}))
sensor_private_options["refresh_initial_entity_options"] = True
registry.async_update_entity_options(
entry.entity_id,
f"{DOMAIN}.private",
sensor_private_options,
)

View File

@ -859,11 +859,18 @@ class EntityRegistry:
@callback
def async_update_entity_options(
self, entity_id: str, domain: str, options: dict[str, Any]
self, entity_id: str, domain: str, options: Mapping[str, Any] | None
) -> RegistryEntry:
"""Update entity options."""
"""Update entity options for a domain.
If the domain options are set to None, they will be removed.
"""
old = self.entities[entity_id]
new_options: EntityOptionsType = {**old.options, domain: options}
new_options = {
key: value for key, value in old.options.items() if key != domain
}
if options is not None:
new_options[domain] = options
return self._async_update_entity(entity_id, options=new_options)
async def async_load(self) -> None:

View File

@ -7,7 +7,11 @@ import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import config
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.const import CONF_UNIT_SYSTEM, CONF_UNIT_SYSTEM_IMPERIAL
from homeassistant.const import (
CONF_UNIT_SYSTEM,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
)
from homeassistant.util import dt as dt_util, location
from homeassistant.util.unit_system import US_CUSTOMARY_SYSTEM
@ -64,7 +68,9 @@ async def test_websocket_core_update(hass, client):
assert hass.config.country != "SE"
assert hass.config.language != "sv"
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz:
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch(
"homeassistant.components.config.core.async_update_suggested_units"
) as mock_update_sensor_units:
await client.send_json(
{
"id": 5,
@ -85,6 +91,8 @@ async def test_websocket_core_update(hass, client):
msg = await client.receive_json()
mock_update_sensor_units.assert_not_called()
assert msg["id"] == 5
assert msg["type"] == TYPE_RESULT
assert msg["success"]
@ -100,6 +108,22 @@ async def test_websocket_core_update(hass, client):
assert len(mock_set_tz.mock_calls) == 1
assert mock_set_tz.mock_calls[0][1][0] == dt_util.get_time_zone("America/New_York")
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch(
"homeassistant.components.config.core.async_update_suggested_units"
) as mock_update_sensor_units:
await client.send_json(
{
"id": 6,
"type": "config/core/update",
CONF_UNIT_SYSTEM: CONF_UNIT_SYSTEM_METRIC,
"update_units": True,
}
)
msg = await client.receive_json()
mock_update_sensor_units.assert_called_once()
async def test_websocket_core_update_not_admin(hass, hass_ws_client, hass_admin_user):
"""Test core config fails for non admin."""

View File

@ -12,6 +12,7 @@ from homeassistant.components.sensor import (
DEVICE_CLASS_UNITS,
SensorDeviceClass,
SensorStateClass,
async_update_suggested_units,
)
from homeassistant.const import (
ATTR_UNIT_OF_MEASUREMENT,
@ -1685,3 +1686,191 @@ async def test_numeric_state_expected_helper(
assert state is not None
assert entity0._numeric_state_expected == is_numeric
@pytest.mark.parametrize(
"unit_system_1, unit_system_2, native_unit, automatic_unit_1, automatic_unit_2, suggested_unit, custom_unit, native_value, automatic_state_1, automatic_state_2, suggested_state, custom_state, device_class",
[
# Distance
(
US_CUSTOMARY_SYSTEM,
METRIC_SYSTEM,
UnitOfLength.KILOMETERS,
UnitOfLength.MILES,
UnitOfLength.KILOMETERS,
UnitOfLength.METERS,
UnitOfLength.YARDS,
1000,
"621",
"1000",
"1000000",
"1093613",
SensorDeviceClass.DISTANCE,
),
],
)
async def test_unit_conversion_update(
hass,
enable_custom_integrations,
unit_system_1,
unit_system_2,
native_unit,
automatic_unit_1,
automatic_unit_2,
suggested_unit,
custom_unit,
native_value,
automatic_state_1,
automatic_state_2,
suggested_state,
custom_state,
device_class,
):
"""Test suggested unit can be updated."""
hass.config.units = unit_system_1
entity_registry = er.async_get(hass)
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockSensor(
name="Test 0",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
unique_id="very_unique",
)
entity0 = platform.ENTITIES["0"]
platform.ENTITIES["1"] = platform.MockSensor(
name="Test 1",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
unique_id="very_unique_1",
)
entity1 = platform.ENTITIES["1"]
platform.ENTITIES["2"] = platform.MockSensor(
name="Test 2",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
suggested_unit_of_measurement=suggested_unit,
unique_id="very_unique_2",
)
entity2 = platform.ENTITIES["2"]
platform.ENTITIES["3"] = platform.MockSensor(
name="Test 3",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
suggested_unit_of_measurement=suggested_unit,
unique_id="very_unique_3",
)
entity3 = platform.ENTITIES["3"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
# Registered entity -> Follow automatic unit conversion
state = hass.states.get(entity0.entity_id)
assert state.state == automatic_state_1
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
# Assert the automatic unit conversion is stored in the registry
entry = entity_registry.async_get(entity0.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": automatic_unit_1}
}
state = hass.states.get(entity1.entity_id)
assert state.state == automatic_state_1
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
# Assert the automatic unit conversion is stored in the registry
entry = entity_registry.async_get(entity1.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": automatic_unit_1}
}
# Registered entity with suggested unit
state = hass.states.get(entity2.entity_id)
assert state.state == suggested_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
# Assert the suggested unit is stored in the registry
entry = entity_registry.async_get(entity2.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": suggested_unit}
}
state = hass.states.get(entity3.entity_id)
assert state.state == suggested_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
# Assert the suggested unit is stored in the registry
entry = entity_registry.async_get(entity3.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": suggested_unit}
}
# Set a custom unit, this should have priority over the automatic unit conversion
entity_registry.async_update_entity_options(
entity0.entity_id, "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
entity_registry.async_update_entity_options(
entity2.entity_id, "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity2.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
# Change unit system, states and units should be unchanged
hass.config.units = unit_system_2
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
state = hass.states.get(entity1.entity_id)
assert state.state == automatic_state_1
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
state = hass.states.get(entity2.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
state = hass.states.get(entity3.entity_id)
assert state.state == suggested_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
# Update suggested unit
async_update_suggested_units(hass)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
state = hass.states.get(entity1.entity_id)
assert state.state == automatic_state_2
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_2
state = hass.states.get(entity2.entity_id)
assert state.state == custom_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
state = hass.states.get(entity3.entity_id)
assert state.state == suggested_state
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit