Support overriding unit of temperature number entities (#74977)

pull/74992/head
Erik Montnemery 2022-07-11 14:49:36 +02:00 committed by GitHub
parent 66e27945ac
commit 9d2c213903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 197 additions and 2 deletions

View File

@ -14,8 +14,13 @@ import voluptuous as vol
from homeassistant.backports.enum import StrEnum
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_MODE, TEMP_CELSIUS, TEMP_FAHRENHEIT
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.const import (
ATTR_MODE,
CONF_UNIT_OF_MEASUREMENT,
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.config_validation import ( # noqa: F401
PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
@ -69,6 +74,10 @@ UNIT_CONVERSIONS: dict[str, Callable[[float, str, str], float]] = {
NumberDeviceClass.TEMPERATURE: temperature_util.convert,
}
VALID_UNITS: dict[str, tuple[str, ...]] = {
NumberDeviceClass.TEMPERATURE: temperature_util.VALID_UNITS,
}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Number entities."""
@ -193,6 +202,7 @@ class NumberEntity(Entity):
_attr_native_value: float
_attr_native_unit_of_measurement: str | None
_deprecated_number_entity_reported = False
_number_option_unit_of_measurement: str | None = None
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Post initialisation processing."""
@ -226,6 +236,13 @@ class NumberEntity(Entity):
report_issue,
)
async def async_internal_added_to_hass(self) -> None:
"""Call when the number entity is added to hass."""
await super().async_internal_added_to_hass()
if not self.registry_entry:
return
self.async_registry_entry_updated()
@property
def capability_attributes(self) -> dict[str, Any]:
"""Return capability attributes."""
@ -348,6 +365,9 @@ class NumberEntity(Entity):
@final
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of the entity, after unit conversion."""
if self._number_option_unit_of_measurement:
return self._number_option_unit_of_measurement
if hasattr(self, "_attr_unit_of_measurement"):
return self._attr_unit_of_measurement
if (
@ -467,6 +487,22 @@ class NumberEntity(Entity):
report_issue,
)
@callback
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated."""
assert self.registry_entry
if (
(number_options := self.registry_entry.options.get(DOMAIN))
and (custom_unit := number_options.get(CONF_UNIT_OF_MEASUREMENT))
and (device_class := self.device_class) in UNIT_CONVERSIONS
and self.native_unit_of_measurement in VALID_UNITS[device_class]
and custom_unit in VALID_UNITS[device_class]
):
self._number_option_unit_of_measurement = custom_unit
return
self._number_option_unit_of_measurement = None
@dataclasses.dataclass
class NumberExtraStoredData(ExtraStoredData):

View File

@ -22,6 +22,7 @@ from homeassistant.const import (
TEMP_FAHRENHEIT,
)
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
@ -689,3 +690,161 @@ async def test_restore_number_restore_state(
assert entity0.native_value == native_value
assert type(entity0.native_value) == native_value_type
assert entity0.native_unit_of_measurement == uom
@pytest.mark.parametrize(
"device_class,native_unit,custom_unit,state_unit,native_value,custom_value",
[
# Not a supported temperature unit
(
NumberDeviceClass.TEMPERATURE,
TEMP_CELSIUS,
"my_temperature_unit",
TEMP_CELSIUS,
1000,
1000,
),
(
NumberDeviceClass.TEMPERATURE,
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
TEMP_FAHRENHEIT,
37.5,
99.5,
),
(
NumberDeviceClass.TEMPERATURE,
TEMP_FAHRENHEIT,
TEMP_CELSIUS,
TEMP_CELSIUS,
100,
38.0,
),
],
)
async def test_custom_unit(
hass,
enable_custom_integrations,
device_class,
native_unit,
custom_unit,
state_unit,
native_value,
custom_value,
):
"""Test custom unit."""
entity_registry = er.async_get(hass)
entry = entity_registry.async_get_or_create("number", "test", "very_unique")
entity_registry.async_update_entity_options(
entry.entity_id, "number", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
platform = getattr(hass.components, "test.number")
platform.init(empty=True)
platform.ENTITIES.append(
platform.MockNumberEntity(
name="Test",
native_value=native_value,
native_unit_of_measurement=native_unit,
device_class=device_class,
unique_id="very_unique",
)
)
entity0 = platform.ENTITIES[0]
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit
@pytest.mark.parametrize(
"native_unit, custom_unit, used_custom_unit, default_unit, native_value, custom_value, default_value",
[
(
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
TEMP_FAHRENHEIT,
TEMP_CELSIUS,
37.5,
99.5,
37.5,
),
(
TEMP_FAHRENHEIT,
TEMP_FAHRENHEIT,
TEMP_FAHRENHEIT,
TEMP_CELSIUS,
100,
100,
38.0,
),
# Not a supported temperature unit
(TEMP_CELSIUS, "no_unit", TEMP_CELSIUS, TEMP_CELSIUS, 1000, 1000, 1000),
],
)
async def test_custom_unit_change(
hass,
enable_custom_integrations,
native_unit,
custom_unit,
used_custom_unit,
default_unit,
native_value,
custom_value,
default_value,
):
"""Test custom unit changes are picked up."""
entity_registry = er.async_get(hass)
platform = getattr(hass.components, "test.number")
platform.init(empty=True)
platform.ENTITIES.append(
platform.MockNumberEntity(
name="Test",
native_value=native_value,
native_unit_of_measurement=native_unit,
device_class=NumberDeviceClass.TEMPERATURE,
unique_id="very_unique",
)
)
entity0 = platform.ENTITIES[0]
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
await hass.async_block_till_done()
# Default unit conversion according to unit system
state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(default_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit
entity_registry.async_update_entity_options(
"number.test", "number", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
# Unit conversion to the custom unit
state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == used_custom_unit
entity_registry.async_update_entity_options(
"number.test", "number", {"unit_of_measurement": native_unit}
)
await hass.async_block_till_done()
# Unit conversion to another custom unit
state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
entity_registry.async_update_entity_options("number.test", "number", None)
await hass.async_block_till_done()
# Default unit conversion according to unit system
state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(default_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit