Add progress support to MQTT update platform (#129468)

* Add progress support to MQTT update platform and add validation on state updates

* Clean up cast to type class

* Add support for display_precision attribute
pull/125950/head^2
Jan Bouwhuis 2024-10-30 17:22:55 +01:00 committed by GitHub
parent 1773f2aadc
commit 9fbd484dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 14 deletions

View File

@ -46,6 +46,7 @@ ABBREVIATIONS = {
"dir_cmd_tpl": "direction_command_template",
"dir_stat_t": "direction_state_topic",
"dir_val_tpl": "direction_value_template",
"dsp_prc": "display_precision",
"dock_cmd_t": "dock_command_topic",
"dock_cmd_tpl": "dock_command_template",
"e": "encoding",

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import Any, TypedDict, cast
from typing import Any
import voluptuous as vol
@ -34,6 +34,7 @@ _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = "MQTT Update"
CONF_DISPLAY_PRECISION = "display_precision"
CONF_LATEST_VERSION_TEMPLATE = "latest_version_template"
CONF_LATEST_VERSION_TOPIC = "latest_version_topic"
CONF_PAYLOAD_INSTALL = "payload_install"
@ -46,6 +47,7 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend(
{
vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic,
vol.Optional(CONF_DEVICE_CLASS): vol.Any(DEVICE_CLASSES_SCHEMA, None),
vol.Optional(CONF_DISPLAY_PRECISION, default=0): cv.positive_int,
vol.Optional(CONF_LATEST_VERSION_TEMPLATE): cv.template,
vol.Optional(CONF_LATEST_VERSION_TOPIC): valid_subscribe_topic,
vol.Optional(CONF_NAME): vol.Any(cv.string, None),
@ -61,15 +63,18 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend(
DISCOVERY_SCHEMA = vol.All(PLATFORM_SCHEMA_MODERN.extend({}, extra=vol.REMOVE_EXTRA))
class _MqttUpdatePayloadType(TypedDict, total=False):
"""Presentation of supported JSON payload to process state updates."""
installed_version: str
latest_version: str
title: str
release_summary: str
release_url: str
entity_picture: str
MQTT_JSON_UPDATE_SCHEMA = vol.Schema(
{
vol.Optional("installed_version"): cv.string,
vol.Optional("latest_version"): cv.string,
vol.Optional("title"): cv.string,
vol.Optional("release_summary"): cv.string,
vol.Optional("release_url"): cv.url,
vol.Optional("entity_picture"): cv.url,
vol.Optional("in_progress"): cv.boolean,
vol.Optional("update_percentage"): vol.Any(vol.Range(min=0, max=100), None),
}
)
async def async_setup_entry(
@ -111,6 +116,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity."""
self._attr_device_class = self._config.get(CONF_DEVICE_CLASS)
self._attr_display_precision = self._config[CONF_DISPLAY_PRECISION]
self._attr_release_summary = self._config.get(CONF_RELEASE_SUMMARY)
self._attr_release_url = self._config.get(CONF_RELEASE_URL)
self._attr_title = self._config.get(CONF_TITLE)
@ -138,7 +144,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
)
return
json_payload: _MqttUpdatePayloadType = {}
json_payload: dict[str, Any] = {}
try:
rendered_json_payload = json_loads(payload)
if isinstance(rendered_json_payload, dict):
@ -150,7 +156,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
rendered_json_payload,
msg.topic,
)
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
json_payload = MQTT_JSON_UPDATE_SCHEMA(rendered_json_payload)
else:
_LOGGER.debug(
(
@ -161,14 +167,27 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
msg.topic,
)
json_payload = {"installed_version": str(payload)}
except vol.MultipleInvalid as exc:
_LOGGER.warning(
(
"Schema violation after processing payload '%s'"
" on topic '%s' for entity '%s': %s"
),
payload,
msg.topic,
self.entity_id,
exc,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.debug(
(
"No valid (JSON) payload detected after processing payload '%s'"
" on topic %s"
" on topic '%s' for entity '%s'"
),
payload,
msg.topic,
self.entity_id,
)
json_payload["installed_version"] = str(payload)
@ -190,6 +209,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
if "entity_picture" in json_payload:
self._attr_entity_picture = json_payload["entity_picture"]
if "update_percentage" in json_payload:
self._attr_update_percentage = json_payload["update_percentage"]
self._attr_in_progress = self._attr_update_percentage is not None
if "in_progress" in json_payload:
self._attr_in_progress = json_payload["in_progress"]
@callback
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
@ -206,11 +232,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
self._handle_state_message_received,
{
"_attr_entity_picture",
"_attr_in_progress",
"_attr_installed_version",
"_attr_latest_version",
"_attr_title",
"_attr_release_summary",
"_attr_release_url",
"_attr_update_percentage",
},
)
self.add_subscription(
@ -233,7 +261,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
@property
def supported_features(self) -> UpdateEntityFeature:
"""Return the list of supported features."""
support = UpdateEntityFeature(0)
support = UpdateEntityFeature(UpdateEntityFeature.PROGRESS)
if self._config.get(CONF_COMMAND_TOPIC) is not None:
support |= UpdateEntityFeature.INSTALL

View File

@ -314,6 +314,60 @@ async def test_empty_json_state_message(
}
],
)
async def test_invalid_json_state_message(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test an empty JSON payload."""
state_topic = "test/state-topic"
await mqtt_mock_entry()
async_fire_mqtt_message(
hass,
state_topic,
'{"installed_version":"1.9.0","latest_version":"1.9.0",'
'"title":"Test Update 1 Title","release_url":"https://example.com/release1",'
'"release_summary":"Test release summary 1",'
'"entity_picture": "https://example.com/icon1.png"}',
)
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.state == STATE_OFF
assert state.attributes.get("installed_version") == "1.9.0"
assert state.attributes.get("latest_version") == "1.9.0"
assert state.attributes.get("release_summary") == "Test release summary 1"
assert state.attributes.get("release_url") == "https://example.com/release1"
assert state.attributes.get("title") == "Test Update 1 Title"
assert state.attributes.get("entity_picture") == "https://example.com/icon1.png"
# Test update schema validation with invalid value in JSON update
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":101}')
await hass.async_block_till_done()
assert (
"Schema violation after processing payload '{\"update_percentage\":101}' on "
"topic 'test/state-topic' for entity 'update.test_update': value must be at "
"most 100 for dictionary value @ data['update_percentage']" in caplog.text
)
@pytest.mark.parametrize(
"hass_config",
[
{
mqtt.DOMAIN: {
update.DOMAIN: {
"state_topic": "test/state-topic",
"name": "Test Update",
"display_precision": 1,
}
}
}
],
)
async def test_json_state_message(
hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator
) -> None:
@ -355,6 +409,45 @@ async def test_json_state_message(
assert state.attributes.get("installed_version") == "1.9.0"
assert state.attributes.get("latest_version") == "2.0.0"
assert state.attributes.get("entity_picture") == "https://example.com/icon2.png"
assert state.attributes.get("in_progress") is False
assert state.attributes.get("update_percentage") is None
# Test in_progress status
async_fire_mqtt_message(hass, state_topic, '{"in_progress":true}')
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.state == STATE_ON
assert state.attributes.get("installed_version") == "1.9.0"
assert state.attributes.get("latest_version") == "2.0.0"
assert state.attributes.get("entity_picture") == "https://example.com/icon2.png"
assert state.attributes.get("in_progress") is True
assert state.attributes.get("update_percentage") is None
async_fire_mqtt_message(hass, state_topic, '{"in_progress":false}')
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.attributes.get("in_progress") is False
# Test update_percentage status
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":51.75}')
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.attributes.get("in_progress") is True
assert state.attributes.get("update_percentage") == 51.75
assert state.attributes.get("display_precision") == 1
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":100}')
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.attributes.get("in_progress") is True
assert state.attributes.get("update_percentage") == 100
async_fire_mqtt_message(hass, state_topic, '{"update_percentage":null}')
await hass.async_block_till_done()
state = hass.states.get("update.test_update")
assert state.attributes.get("in_progress") is False
assert state.attributes.get("update_percentage") is None
@pytest.mark.parametrize(
@ -725,6 +818,10 @@ async def test_reloadable(
'{"entity_picture": "https://example.com/icon1.png"}',
'{"entity_picture": "https://example.com/icon2.png"}',
),
("test-topic", '{"in_progress": true}', '{"in_progress": false}'),
("test-topic", '{"update_percentage": 0}', '{"update_percentage": 50}'),
("test-topic", '{"update_percentage": 50}', '{"update_percentage": 100}'),
("test-topic", '{"update_percentage": 100}', '{"update_percentage": null}'),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],