diff --git a/homeassistant/components/mqtt/event.py b/homeassistant/components/mqtt/event.py index 6f8be33f21a..6fe39b5e899 100644 --- a/homeassistant/components/mqtt/event.py +++ b/homeassistant/components/mqtt/event.py @@ -32,14 +32,18 @@ from .const import ( PAYLOAD_NONE, ) from .debug_info import log_messages -from .mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper +from .mixins import ( + MQTT_ENTITY_COMMON_SCHEMA, + MqttEntity, + async_setup_entry_helper, + write_state_on_attr_change, +) from .models import ( MqttValueTemplate, PayloadSentinel, ReceiveMessage, ReceivePayloadType, ) -from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -133,6 +137,7 @@ class MqttEvent(MqttEntity, EventEntity): @callback @log_messages(self.hass, self.entity_id) + @write_state_on_attr_change(self, {"state"}) def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" event_attributes: dict[str, Any] = {} @@ -195,7 +200,6 @@ class MqttEvent(MqttEntity, EventEntity): payload, ) return - get_mqtt_data(self.hass).state_write_requests.write_state_request(self) topics["state_topic"] = { "topic": self._config[CONF_STATE_TOPIC], diff --git a/tests/components/mqtt/test_event.py b/tests/components/mqtt/test_event.py index abcd6e8f3ee..401caac8007 100644 --- a/tests/components/mqtt/test_event.py +++ b/tests/components/mqtt/test_event.py @@ -13,6 +13,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr from .test_common import ( + help_custom_config, help_test_availability_when_connection_lost, help_test_availability_without_topic, help_test_custom_availability_payload, @@ -42,6 +43,7 @@ from .test_common import ( help_test_setting_attribute_via_mqtt_json_message, help_test_setting_attribute_with_template, help_test_setting_blocked_attribute_via_mqtt_json_message, + help_test_skipped_async_ha_write_state, help_test_unique_id, help_test_unload_config_entry_with_platform, help_test_update_with_json_attrs_bad_json, @@ -668,3 +670,68 @@ async def test_entity_name( await help_test_entity_name( hass, mqtt_mock_entry, domain, config, expected_friendly_name, device_class ) + + +@pytest.mark.parametrize( + "hass_config", + [ + help_custom_config( + event.DOMAIN, + DEFAULT_CONFIG, + ( + { + "availability_topic": "availability-topic", + "json_attributes_topic": "json-attributes-topic", + }, + ), + ) + ], +) +@pytest.mark.parametrize( + ("topic", "payload1", "payload2"), + [ + ("availability-topic", "online", "offline"), + ("json-attributes-topic", '{"attr1": "val1"}', '{"attr1": "val2"}'), + ], +) +async def test_skipped_async_ha_write_state( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, + topic: str, + payload1: str, + payload2: str, +) -> None: + """Test a write state command is only called when there is change.""" + await mqtt_mock_entry() + await help_test_skipped_async_ha_write_state(hass, topic, payload1, payload2) + + +@pytest.mark.parametrize("hass_config", [DEFAULT_CONFIG]) +async def test_skipped_async_ha_write_state2( + hass: HomeAssistant, + mqtt_mock_entry: MqttMockHAClientGenerator, +) -> None: + """Test a write state command is only called when there is a valid event.""" + await mqtt_mock_entry() + topic = "test-topic" + payload1 = '{"event_type": "press"}' + payload2 = '{"event_type": "unknown"}' + with patch( + "homeassistant.components.mqtt.mixins.MqttEntity.async_write_ha_state" + ) as mock_async_ha_write_state: + assert len(mock_async_ha_write_state.mock_calls) == 0 + async_fire_mqtt_message(hass, topic, payload1) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 1 + + async_fire_mqtt_message(hass, topic, payload1) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 2 + + async_fire_mqtt_message(hass, topic, payload2) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 2 + + async_fire_mqtt_message(hass, topic, payload2) + await hass.async_block_till_done() + assert len(mock_async_ha_write_state.mock_calls) == 2