From 96745abf5da6e457000e748fed01b078ba6c1e89 Mon Sep 17 00:00:00 2001 From: Johan Bloemberg Date: Thu, 2 Feb 2017 06:00:05 +0100 Subject: [PATCH] Prevent infinite loop in crossconfigured mqtt event streams (#5624) * Prevent events about MQTT messages received to cause infinite loop when two HA instances are crossconfigured for mqtt_eventstream. * Fix linting * Publish all MQTT received events except incoming from eventstream. Also make it configurable. --- homeassistant/components/mqtt_eventstream.py | 14 +++++++ tests/components/test_mqtt_eventstream.py | 43 ++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/homeassistant/components/mqtt_eventstream.py b/homeassistant/components/mqtt_eventstream.py index 293b644da1f..8632f8aa99d 100644 --- a/homeassistant/components/mqtt_eventstream.py +++ b/homeassistant/components/mqtt_eventstream.py @@ -15,18 +15,23 @@ from homeassistant.const import ( ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_SERVICE_EXECUTED, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) from homeassistant.core import EventOrigin, State +import homeassistant.helpers.config_validation as cv from homeassistant.remote import JSONEncoder +from .mqtt import EVENT_MQTT_MESSAGE_RECEIVED DOMAIN = "mqtt_eventstream" DEPENDENCIES = ['mqtt'] CONF_PUBLISH_TOPIC = 'publish_topic' CONF_SUBSCRIBE_TOPIC = 'subscribe_topic' +CONF_PUBLISH_EVENTSTREAM_RECEIVED = 'publish_eventstream_received' CONFIG_SCHEMA = vol.Schema({ DOMAIN: vol.Schema({ vol.Optional(CONF_PUBLISH_TOPIC): valid_publish_topic, vol.Optional(CONF_SUBSCRIBE_TOPIC): valid_subscribe_topic, + vol.Optional(CONF_PUBLISH_EVENTSTREAM_RECEIVED, default=False): + cv.boolean, }), }, extra=vol.ALLOW_EXTRA) @@ -45,6 +50,15 @@ def setup(hass, config): if event.event_type == EVENT_TIME_CHANGED: return + # MQTT fires a bus event for every incoming message, also messages from + # eventstream. Disable publishing these messages to other HA instances + # and possibly creating an infinite loop if these instances publish + # back to this one. + if all([not conf.get(CONF_PUBLISH_EVENTSTREAM_RECEIVED), + event.event_type == EVENT_MQTT_MESSAGE_RECEIVED, + event.data.get('topic') == sub_topic]): + return + # Filter out the events that were triggered by publishing # to the MQTT topic, or you will end up in an infinite loop. if event.event_type == EVENT_CALL_SERVICE: diff --git a/tests/components/test_mqtt_eventstream.py b/tests/components/test_mqtt_eventstream.py index a60e54df016..3dbe6338e3f 100644 --- a/tests/components/test_mqtt_eventstream.py +++ b/tests/components/test_mqtt_eventstream.py @@ -1,10 +1,12 @@ """The tests for the MQTT eventstream component.""" +from collections import namedtuple import json import unittest from unittest.mock import ANY, patch from homeassistant.bootstrap import setup_component import homeassistant.components.mqtt_eventstream as eventstream +import homeassistant.components.mqtt as mqtt from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.core import State, callback from homeassistant.remote import JSONEncoder @@ -146,3 +148,44 @@ class TestMqttEventStream(unittest.TestCase): self.hass.block_till_done() self.assertEqual(1, len(calls)) + + @patch('homeassistant.components.mqtt.publish') + def test_mqtt_received_event(self, mock_pub): + """Don't filter events from the mqtt component about received message. + + Mqtt component sends an event if a message is received. Also + messages that originate from an incoming eventstream. + Broadcasting these messages result in an infinite loop if two HA + instances are crossconfigured for the same mqtt topics. + + """ + SUB_TOPIC = 'from_slaves' + self.assertTrue( + self.add_eventstream( + pub_topic='bar', + sub_topic=SUB_TOPIC)) + self.hass.block_till_done() + + # Reset the mock because it will have already gotten calls for the + # mqtt_eventstream state change on initialization, etc. + mock_pub.reset_mock() + + # Use MQTT component message handler to simulate firing message + # received event. + MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) + message = MQTTMessage(SUB_TOPIC, 1, 'Hello World!'.encode('utf-8')) + mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message) + + self.hass.block_till_done() + + # 'normal' incoming mqtt messages should be broadcasted + self.assertEqual(mock_pub.call_count, 0) + + MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) + message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) + mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message) + + self.hass.block_till_done() + + # but event from the event stream not + self.assertEqual(mock_pub.call_count, 1)