Make MQTT reconnection logic more resilient and fix race condition (#10133)

pull/10630/head
Cezar Sá Espinola 2017-11-17 16:29:23 -02:00 committed by Paulus Schoutsen
parent 0202e966ea
commit bf8e2bd77e
2 changed files with 30 additions and 32 deletions

View File

@ -438,7 +438,8 @@ class MQTT(object):
self.broker = broker self.broker = broker
self.port = port self.port = port
self.keepalive = keepalive self.keepalive = keepalive
self.topics = {} self.wanted_topics = {}
self.subscribed_topics = {}
self.progress = {} self.progress = {}
self.birth_message = birth_message self.birth_message = birth_message
self._mqttc = None self._mqttc = None
@ -526,15 +527,14 @@ class MQTT(object):
raise HomeAssistantError("topic need to be a string!") raise HomeAssistantError("topic need to be a string!")
with (yield from self._paho_lock): with (yield from self._paho_lock):
if topic in self.topics: if topic in self.subscribed_topics:
return return
self.wanted_topics[topic] = qos
result, mid = yield from self.hass.async_add_job( result, mid = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos) self._mqttc.subscribe, topic, qos)
_raise_on_error(result) _raise_on_error(result)
self.progress[mid] = topic self.progress[mid] = topic
self.topics[topic] = None
@asyncio.coroutine @asyncio.coroutine
def async_unsubscribe(self, topic): def async_unsubscribe(self, topic):
@ -542,6 +542,7 @@ class MQTT(object):
This method is a coroutine. This method is a coroutine.
""" """
self.wanted_topics.pop(topic, None)
result, mid = yield from self.hass.async_add_job( result, mid = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic) self._mqttc.unsubscribe, topic)
@ -562,15 +563,10 @@ class MQTT(object):
self._mqttc.disconnect() self._mqttc.disconnect()
return return
old_topics = self.topics self.progress = {}
self.subscribed_topics = {}
self.topics = {key: value for key, value in self.topics.items() for topic, qos in self.wanted_topics.items():
if value is None} self.hass.add_job(self.async_subscribe, topic, qos)
for topic, qos in old_topics.items():
# qos is None if we were in process of subscribing
if qos is not None:
self.hass.add_job(self.async_subscribe, topic, qos)
if self.birth_message: if self.birth_message:
self.hass.add_job(self.async_publish( self.hass.add_job(self.async_publish(
@ -584,7 +580,7 @@ class MQTT(object):
topic = self.progress.pop(mid, None) topic = self.progress.pop(mid, None)
if topic is None: if topic is None:
return return
self.topics[topic] = granted_qos[0] self.subscribed_topics[topic] = granted_qos[0]
def _mqtt_on_message(self, _mqttc, _userdata, msg): def _mqtt_on_message(self, _mqttc, _userdata, msg):
"""Message received callback.""" """Message received callback."""
@ -598,18 +594,12 @@ class MQTT(object):
topic = self.progress.pop(mid, None) topic = self.progress.pop(mid, None)
if topic is None: if topic is None:
return return
self.topics.pop(topic, None) self.subscribed_topics.pop(topic, None)
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
"""Disconnected callback.""" """Disconnected callback."""
self.progress = {} self.progress = {}
self.topics = {key: value for key, value in self.topics.items() self.subscribed_topics = {}
if value is not None}
# Remove None values from topic list
for key in list(self.topics):
if self.topics[key] is None:
self.topics.pop(key)
# When disconnected because of calling disconnect() # When disconnected because of calling disconnect()
if result_code == 0: if result_code == 0:

View File

@ -388,9 +388,12 @@ class TestMQTTCallbacks(unittest.TestCase):
@mock.patch('homeassistant.components.mqtt.time.sleep') @mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries.""" """Test the re-connect tries."""
self.hass.data['mqtt'].topics = { self.hass.data['mqtt'].subscribed_topics = {
'test/topic': 1, 'test/topic': 1,
'test/progress': None }
self.hass.data['mqtt'].wanted_topics = {
'test/progress': 0,
'test/topic': 2,
} }
self.hass.data['mqtt'].progress = { self.hass.data['mqtt'].progress = {
1: 'test/progress' 1: 'test/progress'
@ -403,7 +406,9 @@ class TestMQTTCallbacks(unittest.TestCase):
self.assertEqual([1, 2, 4], self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls]) [call[1][0] for call in mock_sleep.mock_calls])
self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics) self.assertEqual({'test/topic': 2, 'test/progress': 0},
self.hass.data['mqtt'].wanted_topics)
self.assertEqual({}, self.hass.data['mqtt'].subscribed_topics)
self.assertEqual({}, self.hass.data['mqtt'].progress) self.assertEqual({}, self.hass.data['mqtt'].progress)
def test_invalid_mqtt_topics(self): def test_invalid_mqtt_topics(self):
@ -556,12 +561,15 @@ def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect.""" """Test subscription to topic on connect."""
mqtt_client = yield from mock_mqtt_client(hass) mqtt_client = yield from mock_mqtt_client(hass)
prev_topics = OrderedDict() subscribed_topics = OrderedDict()
prev_topics['topic/test'] = 1, subscribed_topics['topic/test'] = 1
prev_topics['home/sensor'] = 2, subscribed_topics['home/sensor'] = 2
prev_topics['still/pending'] = None
hass.data['mqtt'].topics = prev_topics wanted_topics = subscribed_topics.copy()
wanted_topics['still/pending'] = 0
hass.data['mqtt'].wanted_topics = wanted_topics
hass.data['mqtt'].subscribed_topics = subscribed_topics
hass.data['mqtt'].progress = {1: 'still/pending'} hass.data['mqtt'].progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid) # Return values for subscribe calls (rc, mid)
@ -574,7 +582,7 @@ def test_mqtt_subscribes_topics_on_connect(hass):
assert not mqtt_client.disconnect.called assert not mqtt_client.disconnect.called
expected = [(topic, qos) for topic, qos in prev_topics.items() expected = [(topic, qos) for topic, qos in wanted_topics.items()]
if qos is not None]
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected assert [call[1][1:] for call in hass.add_job.mock_calls] == expected
assert hass.data['mqtt'].progress == {}