Avoid redundant calls to async_ha_write_state mqtt update platform (#100819)

Avoid redundant calls to async_ha_write_state
pull/100873/head
Jan Bouwhuis 2023-09-25 18:08:02 +02:00 committed by GitHub
parent cd3d3b76a3
commit 30c7e7fbdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 12 deletions

View File

@ -33,9 +33,14 @@ from .const import (
PAYLOAD_EMPTY_JSON,
)
from .debug_info import log_messages
from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper
from .mixins import (
MQTT_ENTITY_COMMON_SCHEMA,
MqttEntity,
async_setup_entry_helper,
write_state_on_attr_change,
)
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
from .util import valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -171,6 +176,17 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
{
"_attr_installed_version",
"_attr_latest_version",
"_attr_title",
"_attr_release_summary",
"_attr_release_url",
"_entity_picture",
},
)
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
@ -219,39 +235,33 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
if "installed_version" in json_payload:
self._attr_installed_version = json_payload["installed_version"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if "latest_version" in json_payload:
self._attr_latest_version = json_payload["latest_version"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if "title" in json_payload:
self._attr_title = json_payload["title"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if "release_summary" in json_payload:
self._attr_release_summary = json_payload["release_summary"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if "release_url" in json_payload:
self._attr_release_url = json_payload["release_url"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"]
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_latest_version"})
def handle_latest_version_received(msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscription(
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received
@ -279,8 +289,6 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
self._config[CONF_ENCODING],
)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@property
def supported_features(self) -> UpdateEntityFeature:
"""Return the list of supported features."""

View File

@ -16,6 +16,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant
from .test_common import (
help_custom_config,
help_test_availability_when_connection_lost,
help_test_availability_without_topic,
help_test_custom_availability_payload,
@ -33,6 +34,7 @@ from .test_common import (
help_test_reloadable,
help_test_setting_attribute_via_mqtt_json_message,
help_test_setting_attribute_with_template,
help_test_skipped_async_ha_write_state,
help_test_unique_id,
help_test_unload_config_entry_with_platform,
help_test_update_with_json_attrs_bad_json,
@ -47,7 +49,7 @@ DEFAULT_CONFIG = {
update.DOMAIN: {
"name": "test",
"state_topic": "test-topic",
"latest_version_topic": "test-topic",
"latest_version_topic": "latest-version-topic",
"command_topic": "test-topic",
"payload_install": "install",
}
@ -730,3 +732,53 @@ async def test_reloadable(
domain = update.DOMAIN
config = DEFAULT_CONFIG
await help_test_reloadable(hass, mqtt_client_mock, domain, config)
@pytest.mark.parametrize(
"hass_config",
[
help_custom_config(
update.DOMAIN,
DEFAULT_CONFIG,
(
{
"availability_topic": "availability-topic",
"json_attributes_topic": "json-attributes-topic",
},
),
)
],
)
@pytest.mark.parametrize(
("topic", "payload1", "payload2"),
[
("latest-version-topic", "1.1", "1.2"),
("test-topic", "1.1", "1.2"),
("test-topic", '{"installed_version": "1.1"}', '{"installed_version": "1.2"}'),
("test-topic", '{"latest_version": "1.1"}', '{"latest_version": "1.2"}'),
("test-topic", '{"title": "Update"}', '{"title": "Patch"}'),
("test-topic", '{"release_summary": "bla1"}', '{"release_summary": "bla2"}'),
(
"test-topic",
'{"release_url": "https://example.com/update?r=1"}',
'{"release_url": "https://example.com/update?r=2"}',
),
(
"test-topic",
'{"entity_picture": "https://example.com/icon1.png"}',
'{"entity_picture": "https://example.com/icon2.png"}',
),
("availability-topic", "online", "offline"),
("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'),
],
)
async def test_skipped_async_ha_write_state(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
topic: str,
payload1: str,
payload2: str,
) -> None:
"""Test a write state command is only called when there is change."""
await mqtt_mock_entry()
await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2)