From bf8e2bd77ee743624a4f1a15936fa4885857f8f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cezar=20S=C3=A1=20Espinola?= Date: Fri, 17 Nov 2017 16:29:23 -0200 Subject: [PATCH] Make MQTT reconnection logic more resilient and fix race condition (#10133) --- homeassistant/components/mqtt/__init__.py | 34 ++++++++--------------- tests/components/mqtt/test_init.py | 28 ++++++++++++------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 9decc9a14aa..3a6abec0ddf 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -438,7 +438,8 @@ class MQTT(object): self.broker = broker self.port = port self.keepalive = keepalive - self.topics = {} + self.wanted_topics = {} + self.subscribed_topics = {} self.progress = {} self.birth_message = birth_message self._mqttc = None @@ -526,15 +527,14 @@ class MQTT(object): raise HomeAssistantError("topic need to be a string!") with (yield from self._paho_lock): - if topic in self.topics: + if topic in self.subscribed_topics: return - + self.wanted_topics[topic] = qos result, mid = yield from self.hass.async_add_job( self._mqttc.subscribe, topic, qos) _raise_on_error(result) self.progress[mid] = topic - self.topics[topic] = None @asyncio.coroutine def async_unsubscribe(self, topic): @@ -542,6 +542,7 @@ class MQTT(object): This method is a coroutine. """ + self.wanted_topics.pop(topic, None) result, mid = yield from self.hass.async_add_job( self._mqttc.unsubscribe, topic) @@ -562,15 +563,10 @@ class MQTT(object): self._mqttc.disconnect() return - old_topics = self.topics - - self.topics = {key: value for key, value in self.topics.items() - if value is None} - - for topic, qos in old_topics.items(): - # qos is None if we were in process of subscribing - if qos is not None: - self.hass.add_job(self.async_subscribe, topic, qos) + self.progress = {} + self.subscribed_topics = {} + for topic, qos in self.wanted_topics.items(): + self.hass.add_job(self.async_subscribe, topic, qos) if self.birth_message: self.hass.add_job(self.async_publish( @@ -584,7 +580,7 @@ class MQTT(object): topic = self.progress.pop(mid, None) if topic is None: return - self.topics[topic] = granted_qos[0] + self.subscribed_topics[topic] = granted_qos[0] def _mqtt_on_message(self, _mqttc, _userdata, msg): """Message received callback.""" @@ -598,18 +594,12 @@ class MQTT(object): topic = self.progress.pop(mid, None) if topic is None: return - self.topics.pop(topic, None) + self.subscribed_topics.pop(topic, None) def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): """Disconnected callback.""" self.progress = {} - self.topics = {key: value for key, value in self.topics.items() - if value is not None} - - # Remove None values from topic list - for key in list(self.topics): - if self.topics[key] is None: - self.topics.pop(key) + self.subscribed_topics = {} # When disconnected because of calling disconnect() if result_code == 0: diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 3d068224243..55ff0e9ff05 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -388,9 +388,12 @@ class TestMQTTCallbacks(unittest.TestCase): @mock.patch('homeassistant.components.mqtt.time.sleep') def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): """Test the re-connect tries.""" - self.hass.data['mqtt'].topics = { + self.hass.data['mqtt'].subscribed_topics = { 'test/topic': 1, - 'test/progress': None + } + self.hass.data['mqtt'].wanted_topics = { + 'test/progress': 0, + 'test/topic': 2, } self.hass.data['mqtt'].progress = { 1: 'test/progress' @@ -403,7 +406,9 @@ class TestMQTTCallbacks(unittest.TestCase): self.assertEqual([1, 2, 4], [call[1][0] for call in mock_sleep.mock_calls]) - self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics) + self.assertEqual({'test/topic': 2, 'test/progress': 0}, + self.hass.data['mqtt'].wanted_topics) + self.assertEqual({}, self.hass.data['mqtt'].subscribed_topics) self.assertEqual({}, self.hass.data['mqtt'].progress) def test_invalid_mqtt_topics(self): @@ -556,12 +561,15 @@ def test_mqtt_subscribes_topics_on_connect(hass): """Test subscription to topic on connect.""" mqtt_client = yield from mock_mqtt_client(hass) - prev_topics = OrderedDict() - prev_topics['topic/test'] = 1, - prev_topics['home/sensor'] = 2, - prev_topics['still/pending'] = None + subscribed_topics = OrderedDict() + subscribed_topics['topic/test'] = 1 + subscribed_topics['home/sensor'] = 2 - hass.data['mqtt'].topics = prev_topics + wanted_topics = subscribed_topics.copy() + wanted_topics['still/pending'] = 0 + + hass.data['mqtt'].wanted_topics = wanted_topics + hass.data['mqtt'].subscribed_topics = subscribed_topics hass.data['mqtt'].progress = {1: 'still/pending'} # Return values for subscribe calls (rc, mid) @@ -574,7 +582,7 @@ def test_mqtt_subscribes_topics_on_connect(hass): assert not mqtt_client.disconnect.called - expected = [(topic, qos) for topic, qos in prev_topics.items() - if qos is not None] + expected = [(topic, qos) for topic, qos in wanted_topics.items()] assert [call[1][1:] for call in hass.add_job.mock_calls] == expected + assert hass.data['mqtt'].progress == {}