Support overriding unit of temperature number entities (#74977)
parent
66e27945ac
commit
9d2c213903
|
@ -14,8 +14,13 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.backports.enum import StrEnum
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_MODE, TEMP_CELSIUS, TEMP_FAHRENHEIT
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.const import (
|
||||
ATTR_MODE,
|
||||
CONF_UNIT_OF_MEASUREMENT,
|
||||
TEMP_CELSIUS,
|
||||
TEMP_FAHRENHEIT,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.helpers.config_validation import ( # noqa: F401
|
||||
PLATFORM_SCHEMA,
|
||||
PLATFORM_SCHEMA_BASE,
|
||||
|
@ -69,6 +74,10 @@ UNIT_CONVERSIONS: dict[str, Callable[[float, str, str], float]] = {
|
|||
NumberDeviceClass.TEMPERATURE: temperature_util.convert,
|
||||
}
|
||||
|
||||
VALID_UNITS: dict[str, tuple[str, ...]] = {
|
||||
NumberDeviceClass.TEMPERATURE: temperature_util.VALID_UNITS,
|
||||
}
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Number entities."""
|
||||
|
@ -193,6 +202,7 @@ class NumberEntity(Entity):
|
|||
_attr_native_value: float
|
||||
_attr_native_unit_of_measurement: str | None
|
||||
_deprecated_number_entity_reported = False
|
||||
_number_option_unit_of_measurement: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""Post initialisation processing."""
|
||||
|
@ -226,6 +236,13 @@ class NumberEntity(Entity):
|
|||
report_issue,
|
||||
)
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the number entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
if not self.registry_entry:
|
||||
return
|
||||
self.async_registry_entry_updated()
|
||||
|
||||
@property
|
||||
def capability_attributes(self) -> dict[str, Any]:
|
||||
"""Return capability attributes."""
|
||||
|
@ -348,6 +365,9 @@ class NumberEntity(Entity):
|
|||
@final
|
||||
def unit_of_measurement(self) -> str | None:
|
||||
"""Return the unit of measurement of the entity, after unit conversion."""
|
||||
if self._number_option_unit_of_measurement:
|
||||
return self._number_option_unit_of_measurement
|
||||
|
||||
if hasattr(self, "_attr_unit_of_measurement"):
|
||||
return self._attr_unit_of_measurement
|
||||
if (
|
||||
|
@ -467,6 +487,22 @@ class NumberEntity(Entity):
|
|||
report_issue,
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_registry_entry_updated(self) -> None:
|
||||
"""Run when the entity registry entry has been updated."""
|
||||
assert self.registry_entry
|
||||
if (
|
||||
(number_options := self.registry_entry.options.get(DOMAIN))
|
||||
and (custom_unit := number_options.get(CONF_UNIT_OF_MEASUREMENT))
|
||||
and (device_class := self.device_class) in UNIT_CONVERSIONS
|
||||
and self.native_unit_of_measurement in VALID_UNITS[device_class]
|
||||
and custom_unit in VALID_UNITS[device_class]
|
||||
):
|
||||
self._number_option_unit_of_measurement = custom_unit
|
||||
return
|
||||
|
||||
self._number_option_unit_of_measurement = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NumberExtraStoredData(ExtraStoredData):
|
||||
|
|
|
@ -22,6 +22,7 @@ from homeassistant.const import (
|
|||
TEMP_FAHRENHEIT,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
|
||||
|
@ -689,3 +690,161 @@ async def test_restore_number_restore_state(
|
|||
assert entity0.native_value == native_value
|
||||
assert type(entity0.native_value) == native_value_type
|
||||
assert entity0.native_unit_of_measurement == uom
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device_class,native_unit,custom_unit,state_unit,native_value,custom_value",
|
||||
[
|
||||
# Not a supported temperature unit
|
||||
(
|
||||
NumberDeviceClass.TEMPERATURE,
|
||||
TEMP_CELSIUS,
|
||||
"my_temperature_unit",
|
||||
TEMP_CELSIUS,
|
||||
1000,
|
||||
1000,
|
||||
),
|
||||
(
|
||||
NumberDeviceClass.TEMPERATURE,
|
||||
TEMP_CELSIUS,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_FAHRENHEIT,
|
||||
37.5,
|
||||
99.5,
|
||||
),
|
||||
(
|
||||
NumberDeviceClass.TEMPERATURE,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_CELSIUS,
|
||||
TEMP_CELSIUS,
|
||||
100,
|
||||
38.0,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_custom_unit(
|
||||
hass,
|
||||
enable_custom_integrations,
|
||||
device_class,
|
||||
native_unit,
|
||||
custom_unit,
|
||||
state_unit,
|
||||
native_value,
|
||||
custom_value,
|
||||
):
|
||||
"""Test custom unit."""
|
||||
entity_registry = er.async_get(hass)
|
||||
|
||||
entry = entity_registry.async_get_or_create("number", "test", "very_unique")
|
||||
entity_registry.async_update_entity_options(
|
||||
entry.entity_id, "number", {"unit_of_measurement": custom_unit}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
platform = getattr(hass.components, "test.number")
|
||||
platform.init(empty=True)
|
||||
platform.ENTITIES.append(
|
||||
platform.MockNumberEntity(
|
||||
name="Test",
|
||||
native_value=native_value,
|
||||
native_unit_of_measurement=native_unit,
|
||||
device_class=device_class,
|
||||
unique_id="very_unique",
|
||||
)
|
||||
)
|
||||
|
||||
entity0 = platform.ENTITIES[0]
|
||||
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(entity0.entity_id)
|
||||
assert float(state.state) == pytest.approx(float(custom_value))
|
||||
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == state_unit
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"native_unit, custom_unit, used_custom_unit, default_unit, native_value, custom_value, default_value",
|
||||
[
|
||||
(
|
||||
TEMP_CELSIUS,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_CELSIUS,
|
||||
37.5,
|
||||
99.5,
|
||||
37.5,
|
||||
),
|
||||
(
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_FAHRENHEIT,
|
||||
TEMP_CELSIUS,
|
||||
100,
|
||||
100,
|
||||
38.0,
|
||||
),
|
||||
# Not a supported temperature unit
|
||||
(TEMP_CELSIUS, "no_unit", TEMP_CELSIUS, TEMP_CELSIUS, 1000, 1000, 1000),
|
||||
],
|
||||
)
|
||||
async def test_custom_unit_change(
|
||||
hass,
|
||||
enable_custom_integrations,
|
||||
native_unit,
|
||||
custom_unit,
|
||||
used_custom_unit,
|
||||
default_unit,
|
||||
native_value,
|
||||
custom_value,
|
||||
default_value,
|
||||
):
|
||||
"""Test custom unit changes are picked up."""
|
||||
entity_registry = er.async_get(hass)
|
||||
platform = getattr(hass.components, "test.number")
|
||||
platform.init(empty=True)
|
||||
platform.ENTITIES.append(
|
||||
platform.MockNumberEntity(
|
||||
name="Test",
|
||||
native_value=native_value,
|
||||
native_unit_of_measurement=native_unit,
|
||||
device_class=NumberDeviceClass.TEMPERATURE,
|
||||
unique_id="very_unique",
|
||||
)
|
||||
)
|
||||
|
||||
entity0 = platform.ENTITIES[0]
|
||||
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Default unit conversion according to unit system
|
||||
state = hass.states.get(entity0.entity_id)
|
||||
assert float(state.state) == pytest.approx(float(default_value))
|
||||
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit
|
||||
|
||||
entity_registry.async_update_entity_options(
|
||||
"number.test", "number", {"unit_of_measurement": custom_unit}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Unit conversion to the custom unit
|
||||
state = hass.states.get(entity0.entity_id)
|
||||
assert float(state.state) == pytest.approx(float(custom_value))
|
||||
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == used_custom_unit
|
||||
|
||||
entity_registry.async_update_entity_options(
|
||||
"number.test", "number", {"unit_of_measurement": native_unit}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Unit conversion to another custom unit
|
||||
state = hass.states.get(entity0.entity_id)
|
||||
assert float(state.state) == pytest.approx(float(native_value))
|
||||
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
|
||||
|
||||
entity_registry.async_update_entity_options("number.test", "number", None)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Default unit conversion according to unit system
|
||||
state = hass.states.get(entity0.entity_id)
|
||||
assert float(state.state) == pytest.approx(float(default_value))
|
||||
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == default_unit
|
||||
|
|
Loading…
Reference in New Issue