diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index 215585f465a..65e24d5d780 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -46,6 +46,7 @@ ABBREVIATIONS = { "dir_cmd_tpl": "direction_command_template", "dir_stat_t": "direction_state_topic", "dir_val_tpl": "direction_value_template", + "dsp_prc": "display_precision", "dock_cmd_t": "dock_command_topic", "dock_cmd_tpl": "dock_command_template", "e": "encoding", diff --git a/homeassistant/components/mqtt/update.py b/homeassistant/components/mqtt/update.py index 42aeea1f715..8878ff63127 100644 --- a/homeassistant/components/mqtt/update.py +++ b/homeassistant/components/mqtt/update.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, TypedDict, cast +from typing import Any import voluptuous as vol @@ -34,6 +34,7 @@ _LOGGER = logging.getLogger(__name__) DEFAULT_NAME = "MQTT Update" +CONF_DISPLAY_PRECISION = "display_precision" CONF_LATEST_VERSION_TEMPLATE = "latest_version_template" CONF_LATEST_VERSION_TOPIC = "latest_version_topic" CONF_PAYLOAD_INSTALL = "payload_install" @@ -46,6 +47,7 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend( { vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_DEVICE_CLASS): vol.Any(DEVICE_CLASSES_SCHEMA, None), + vol.Optional(CONF_DISPLAY_PRECISION, default=0): cv.positive_int, vol.Optional(CONF_LATEST_VERSION_TEMPLATE): cv.template, vol.Optional(CONF_LATEST_VERSION_TOPIC): valid_subscribe_topic, vol.Optional(CONF_NAME): vol.Any(cv.string, None), @@ -61,15 +63,18 @@ PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend( DISCOVERY_SCHEMA = vol.All(PLATFORM_SCHEMA_MODERN.extend({}, extra=vol.REMOVE_EXTRA)) -class _MqttUpdatePayloadType(TypedDict, total=False): - """Presentation of supported JSON payload to process state updates.""" - - installed_version: str - latest_version: str - title: str - release_summary: str - release_url: str - entity_picture: str +MQTT_JSON_UPDATE_SCHEMA = vol.Schema( + { + vol.Optional("installed_version"): cv.string, + vol.Optional("latest_version"): cv.string, + vol.Optional("title"): cv.string, + vol.Optional("release_summary"): cv.string, + vol.Optional("release_url"): cv.url, + vol.Optional("entity_picture"): cv.url, + vol.Optional("in_progress"): cv.boolean, + vol.Optional("update_percentage"): vol.Any(vol.Range(min=0, max=100), None), + } +) async def async_setup_entry( @@ -111,6 +116,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._attr_device_class = self._config.get(CONF_DEVICE_CLASS) + self._attr_display_precision = self._config[CONF_DISPLAY_PRECISION] self._attr_release_summary = self._config.get(CONF_RELEASE_SUMMARY) self._attr_release_url = self._config.get(CONF_RELEASE_URL) self._attr_title = self._config.get(CONF_TITLE) @@ -138,7 +144,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): ) return - json_payload: _MqttUpdatePayloadType = {} + json_payload: dict[str, Any] = {} try: rendered_json_payload = json_loads(payload) if isinstance(rendered_json_payload, dict): @@ -150,7 +156,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): rendered_json_payload, msg.topic, ) - json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload) + json_payload = MQTT_JSON_UPDATE_SCHEMA(rendered_json_payload) else: _LOGGER.debug( ( @@ -161,14 +167,27 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): msg.topic, ) json_payload = {"installed_version": str(payload)} + except vol.MultipleInvalid as exc: + _LOGGER.warning( + ( + "Schema violation after processing payload '%s'" + " on topic '%s' for entity '%s': %s" + ), + payload, + msg.topic, + self.entity_id, + exc, + ) + return except JSON_DECODE_EXCEPTIONS: _LOGGER.debug( ( "No valid (JSON) payload detected after processing payload '%s'" - " on topic %s" + " on topic '%s' for entity '%s'" ), payload, msg.topic, + self.entity_id, ) json_payload["installed_version"] = str(payload) @@ -190,6 +209,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): if "entity_picture" in json_payload: self._attr_entity_picture = json_payload["entity_picture"] + if "update_percentage" in json_payload: + self._attr_update_percentage = json_payload["update_percentage"] + self._attr_in_progress = self._attr_update_percentage is not None + + if "in_progress" in json_payload: + self._attr_in_progress = json_payload["in_progress"] + @callback def _handle_latest_version_received(self, msg: ReceiveMessage) -> None: """Handle receiving latest version via MQTT.""" @@ -206,11 +232,13 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): self._handle_state_message_received, { "_attr_entity_picture", + "_attr_in_progress", "_attr_installed_version", "_attr_latest_version", "_attr_title", "_attr_release_summary", "_attr_release_url", + "_attr_update_percentage", }, ) self.add_subscription( @@ -233,7 +261,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): @property def supported_features(self) -> UpdateEntityFeature: """Return the list of supported features.""" - support = UpdateEntityFeature(0) + support = UpdateEntityFeature(UpdateEntityFeature.PROGRESS) if self._config.get(CONF_COMMAND_TOPIC) is not None: support |= UpdateEntityFeature.INSTALL diff --git a/tests/components/mqtt/test_update.py b/tests/components/mqtt/test_update.py index 2bf592f85fb..4ca10cbe8b2 100644 --- a/tests/components/mqtt/test_update.py +++ b/tests/components/mqtt/test_update.py @@ -314,6 +314,60 @@ async def test_empty_json_state_message( } ], ) +async def test_invalid_json_state_message( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test an empty JSON payload.""" + state_topic = "test/state-topic" + await mqtt_mock_entry() + + async_fire_mqtt_message( + hass, + state_topic, + '{"installed_version":"1.9.0","latest_version":"1.9.0",' + '"title":"Test Update 1 Title","release_url":"https://example.com/release1",' + '"release_summary":"Test release summary 1",' + '"entity_picture": "https://example.com/icon1.png"}', + ) + + await hass.async_block_till_done() + + state = hass.states.get("update.test_update") + assert state.state == STATE_OFF + assert state.attributes.get("installed_version") == "1.9.0" + assert state.attributes.get("latest_version") == "1.9.0" + assert state.attributes.get("release_summary") == "Test release summary 1" + assert state.attributes.get("release_url") == "https://example.com/release1" + assert state.attributes.get("title") == "Test Update 1 Title" + assert state.attributes.get("entity_picture") == "https://example.com/icon1.png" + + # Test update schema validation with invalid value in JSON update + async_fire_mqtt_message(hass, state_topic, '{"update_percentage":101}') + + await hass.async_block_till_done() + assert ( + "Schema violation after processing payload '{\"update_percentage\":101}' on " + "topic 'test/state-topic' for entity 'update.test_update': value must be at " + "most 100 for dictionary value @ data['update_percentage']" in caplog.text + ) + + +@pytest.mark.parametrize( + "hass_config", + [ + { + mqtt.DOMAIN: { + update.DOMAIN: { + "state_topic": "test/state-topic", + "name": "Test Update", + "display_precision": 1, + } + } + } + ], +) async def test_json_state_message( hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator ) -> None: @@ -355,6 +409,45 @@ async def test_json_state_message( assert state.attributes.get("installed_version") == "1.9.0" assert state.attributes.get("latest_version") == "2.0.0" assert state.attributes.get("entity_picture") == "https://example.com/icon2.png" + assert state.attributes.get("in_progress") is False + assert state.attributes.get("update_percentage") is None + + # Test in_progress status + async_fire_mqtt_message(hass, state_topic, '{"in_progress":true}') + await hass.async_block_till_done() + + state = hass.states.get("update.test_update") + assert state.state == STATE_ON + assert state.attributes.get("installed_version") == "1.9.0" + assert state.attributes.get("latest_version") == "2.0.0" + assert state.attributes.get("entity_picture") == "https://example.com/icon2.png" + assert state.attributes.get("in_progress") is True + assert state.attributes.get("update_percentage") is None + + async_fire_mqtt_message(hass, state_topic, '{"in_progress":false}') + await hass.async_block_till_done() + state = hass.states.get("update.test_update") + assert state.attributes.get("in_progress") is False + + # Test update_percentage status + async_fire_mqtt_message(hass, state_topic, '{"update_percentage":51.75}') + await hass.async_block_till_done() + state = hass.states.get("update.test_update") + assert state.attributes.get("in_progress") is True + assert state.attributes.get("update_percentage") == 51.75 + assert state.attributes.get("display_precision") == 1 + + async_fire_mqtt_message(hass, state_topic, '{"update_percentage":100}') + await hass.async_block_till_done() + state = hass.states.get("update.test_update") + assert state.attributes.get("in_progress") is True + assert state.attributes.get("update_percentage") == 100 + + async_fire_mqtt_message(hass, state_topic, '{"update_percentage":null}') + await hass.async_block_till_done() + state = hass.states.get("update.test_update") + assert state.attributes.get("in_progress") is False + assert state.attributes.get("update_percentage") is None @pytest.mark.parametrize( @@ -725,6 +818,10 @@ async def test_reloadable( '{"entity_picture": "https://example.com/icon1.png"}', '{"entity_picture": "https://example.com/icon2.png"}', ), + ("test-topic", '{"in_progress": true}', '{"in_progress": false}'), + ("test-topic", '{"update_percentage": 0}', '{"update_percentage": 50}'), + ("test-topic", '{"update_percentage": 50}', '{"update_percentage": 100}'), + ("test-topic", '{"update_percentage": 100}', '{"update_percentage": null}'), ("availability-topic", "online", "offline"), ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), ],