Add progress support to MQTT update platform (#129468)
* Add progress support to MQTT update platform and add validation on state updates * Clean up cast to type class * Add support for display_precision attributepull/125950/head^2
parent
1773f2aadc
commit
9fbd484dfe
|
@ -46,6 +46,7 @@ ABBREVIATIONS = {
|
|||
"dir_cmd_tpl": "direction_command_template",
|
||||
"dir_stat_t": "direction_state_topic",
|
||||
"dir_val_tpl": "direction_value_template",
|
||||
"dsp_prc": "display_precision",
|
||||
"dock_cmd_t": "dock_command_topic",
|
||||
"dock_cmd_tpl": "dock_command_template",
|
||||
"e": "encoding",
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -34,6 +34,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
DEFAULT_NAME = "MQTT Update"
|
||||
|
||||
CONF_DISPLAY_PRECISION = "display_precision"
|
||||
CONF_LATEST_VERSION_TEMPLATE = "latest_version_template"
|
||||
CONF_LATEST_VERSION_TOPIC = "latest_version_topic"
|
||||
CONF_PAYLOAD_INSTALL = "payload_install"
|
||||
|
@ -46,6 +47,7 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend(
|
|||
{
|
||||
vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic,
|
||||
vol.Optional(CONF_DEVICE_CLASS): vol.Any(DEVICE_CLASSES_SCHEMA, None),
|
||||
vol.Optional(CONF_DISPLAY_PRECISION, default=0): cv.positive_int,
|
||||
vol.Optional(CONF_LATEST_VERSION_TEMPLATE): cv.template,
|
||||
vol.Optional(CONF_LATEST_VERSION_TOPIC): valid_subscribe_topic,
|
||||
vol.Optional(CONF_NAME): vol.Any(cv.string, None),
|
||||
|
@ -61,15 +63,18 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend(
|
|||
DISCOVERY_SCHEMA = vol.All(PLATFORM_SCHEMA_MODERN.extend({}, extra=vol.REMOVE_EXTRA))
|
||||
|
||||
|
||||
class _MqttUpdatePayloadType(TypedDict, total=False):
|
||||
"""Presentation of supported JSON payload to process state updates."""
|
||||
|
||||
installed_version: str
|
||||
latest_version: str
|
||||
title: str
|
||||
release_summary: str
|
||||
release_url: str
|
||||
entity_picture: str
|
||||
MQTT_JSON_UPDATE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional("installed_version"): cv.string,
|
||||
vol.Optional("latest_version"): cv.string,
|
||||
vol.Optional("title"): cv.string,
|
||||
vol.Optional("release_summary"): cv.string,
|
||||
vol.Optional("release_url"): cv.url,
|
||||
vol.Optional("entity_picture"): cv.url,
|
||||
vol.Optional("in_progress"): cv.boolean,
|
||||
vol.Optional("update_percentage"): vol.Any(vol.Range(min=0, max=100), None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
|
@ -111,6 +116,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
def _setup_from_config(self, config: ConfigType) -> None:
|
||||
"""(Re)Setup the entity."""
|
||||
self._attr_device_class = self._config.get(CONF_DEVICE_CLASS)
|
||||
self._attr_display_precision = self._config[CONF_DISPLAY_PRECISION]
|
||||
self._attr_release_summary = self._config.get(CONF_RELEASE_SUMMARY)
|
||||
self._attr_release_url = self._config.get(CONF_RELEASE_URL)
|
||||
self._attr_title = self._config.get(CONF_TITLE)
|
||||
|
@ -138,7 +144,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
)
|
||||
return
|
||||
|
||||
json_payload: _MqttUpdatePayloadType = {}
|
||||
json_payload: dict[str, Any] = {}
|
||||
try:
|
||||
rendered_json_payload = json_loads(payload)
|
||||
if isinstance(rendered_json_payload, dict):
|
||||
|
@ -150,7 +156,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
rendered_json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
|
||||
json_payload = MQTT_JSON_UPDATE_SCHEMA(rendered_json_payload)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
|
@ -161,14 +167,27 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
msg.topic,
|
||||
)
|
||||
json_payload = {"installed_version": str(payload)}
|
||||
except vol.MultipleInvalid as exc:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"Schema violation after processing payload '%s'"
|
||||
" on topic '%s' for entity '%s': %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
self.entity_id,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"No valid (JSON) payload detected after processing payload '%s'"
|
||||
" on topic %s"
|
||||
" on topic '%s' for entity '%s'"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
self.entity_id,
|
||||
)
|
||||
json_payload["installed_version"] = str(payload)
|
||||
|
||||
|
@ -190,6 +209,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
if "entity_picture" in json_payload:
|
||||
self._attr_entity_picture = json_payload["entity_picture"]
|
||||
|
||||
if "update_percentage" in json_payload:
|
||||
self._attr_update_percentage = json_payload["update_percentage"]
|
||||
self._attr_in_progress = self._attr_update_percentage is not None
|
||||
|
||||
if "in_progress" in json_payload:
|
||||
self._attr_in_progress = json_payload["in_progress"]
|
||||
|
||||
@callback
|
||||
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving latest version via MQTT."""
|
||||
|
@ -206,11 +232,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
self._handle_state_message_received,
|
||||
{
|
||||
"_attr_entity_picture",
|
||||
"_attr_in_progress",
|
||||
"_attr_installed_version",
|
||||
"_attr_latest_version",
|
||||
"_attr_title",
|
||||
"_attr_release_summary",
|
||||
"_attr_release_url",
|
||||
"_attr_update_percentage",
|
||||
},
|
||||
)
|
||||
self.add_subscription(
|
||||
|
@ -233,7 +261,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||
@property
|
||||
def supported_features(self) -> UpdateEntityFeature:
|
||||
"""Return the list of supported features."""
|
||||
support = UpdateEntityFeature(0)
|
||||
support = UpdateEntityFeature(UpdateEntityFeature.PROGRESS)
|
||||
|
||||
if self._config.get(CONF_COMMAND_TOPIC) is not None:
|
||||
support |= UpdateEntityFeature.INSTALL
|
||||
|
|
|
@ -314,6 +314,60 @@ async def test_empty_json_state_message(
|
|||
}
|
||||
],
|
||||
)
|
||||
async def test_invalid_json_state_message(
|
||||
hass: HomeAssistant,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test an empty JSON payload."""
|
||||
state_topic = "test/state-topic"
|
||||
await mqtt_mock_entry()
|
||||
|
||||
async_fire_mqtt_message(
|
||||
hass,
|
||||
state_topic,
|
||||
'{"installed_version":"1.9.0","latest_version":"1.9.0",'
|
||||
'"title":"Test Update 1 Title","release_url":"https://example.com/release1",'
|
||||
'"release_summary":"Test release summary 1",'
|
||||
'"entity_picture": "https://example.com/icon1.png"}',
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.state == STATE_OFF
|
||||
assert state.attributes.get("installed_version") == "1.9.0"
|
||||
assert state.attributes.get("latest_version") == "1.9.0"
|
||||
assert state.attributes.get("release_summary") == "Test release summary 1"
|
||||
assert state.attributes.get("release_url") == "https://example.com/release1"
|
||||
assert state.attributes.get("title") == "Test Update 1 Title"
|
||||
assert state.attributes.get("entity_picture") == "https://example.com/icon1.png"
|
||||
|
||||
# Test update schema validation with invalid value in JSON update
|
||||
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":101}')
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert (
|
||||
"Schema violation after processing payload '{\"update_percentage\":101}' on "
|
||||
"topic 'test/state-topic' for entity 'update.test_update': value must be at "
|
||||
"most 100 for dictionary value @ data['update_percentage']" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hass_config",
|
||||
[
|
||||
{
|
||||
mqtt.DOMAIN: {
|
||||
update.DOMAIN: {
|
||||
"state_topic": "test/state-topic",
|
||||
"name": "Test Update",
|
||||
"display_precision": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
async def test_json_state_message(
|
||||
hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator
|
||||
) -> None:
|
||||
|
@ -355,6 +409,45 @@ async def test_json_state_message(
|
|||
assert state.attributes.get("installed_version") == "1.9.0"
|
||||
assert state.attributes.get("latest_version") == "2.0.0"
|
||||
assert state.attributes.get("entity_picture") == "https://example.com/icon2.png"
|
||||
assert state.attributes.get("in_progress") is False
|
||||
assert state.attributes.get("update_percentage") is None
|
||||
|
||||
# Test in_progress status
|
||||
async_fire_mqtt_message(hass, state_topic, '{"in_progress":true}')
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.state == STATE_ON
|
||||
assert state.attributes.get("installed_version") == "1.9.0"
|
||||
assert state.attributes.get("latest_version") == "2.0.0"
|
||||
assert state.attributes.get("entity_picture") == "https://example.com/icon2.png"
|
||||
assert state.attributes.get("in_progress") is True
|
||||
assert state.attributes.get("update_percentage") is None
|
||||
|
||||
async_fire_mqtt_message(hass, state_topic, '{"in_progress":false}')
|
||||
await hass.async_block_till_done()
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.attributes.get("in_progress") is False
|
||||
|
||||
# Test update_percentage status
|
||||
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":51.75}')
|
||||
await hass.async_block_till_done()
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.attributes.get("in_progress") is True
|
||||
assert state.attributes.get("update_percentage") == 51.75
|
||||
assert state.attributes.get("display_precision") == 1
|
||||
|
||||
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":100}')
|
||||
await hass.async_block_till_done()
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.attributes.get("in_progress") is True
|
||||
assert state.attributes.get("update_percentage") == 100
|
||||
|
||||
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":null}')
|
||||
await hass.async_block_till_done()
|
||||
state = hass.states.get("update.test_update")
|
||||
assert state.attributes.get("in_progress") is False
|
||||
assert state.attributes.get("update_percentage") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -725,6 +818,10 @@ async def test_reloadable(
|
|||
'{"entity_picture": "https://example.com/icon1.png"}',
|
||||
'{"entity_picture": "https://example.com/icon2.png"}',
|
||||
),
|
||||
("test-topic", '{"in_progress": true}', '{"in_progress": false}'),
|
||||
("test-topic", '{"update_percentage": 0}', '{"update_percentage": 50}'),
|
||||
("test-topic", '{"update_percentage": 50}', '{"update_percentage": 100}'),
|
||||
("test-topic", '{"update_percentage": 100}', '{"update_percentage": null}'),
|
||||
("availability-topic", "online", "offline"),
|
||||
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue