Make MQTT reconnection logic more resilient and fix race condition (#10133)
parent
0202e966ea
commit
bf8e2bd77e
|
@ -438,7 +438,8 @@ class MQTT(object):
|
||||||
self.broker = broker
|
self.broker = broker
|
||||||
self.port = port
|
self.port = port
|
||||||
self.keepalive = keepalive
|
self.keepalive = keepalive
|
||||||
self.topics = {}
|
self.wanted_topics = {}
|
||||||
|
self.subscribed_topics = {}
|
||||||
self.progress = {}
|
self.progress = {}
|
||||||
self.birth_message = birth_message
|
self.birth_message = birth_message
|
||||||
self._mqttc = None
|
self._mqttc = None
|
||||||
|
@ -526,15 +527,14 @@ class MQTT(object):
|
||||||
raise HomeAssistantError("topic need to be a string!")
|
raise HomeAssistantError("topic need to be a string!")
|
||||||
|
|
||||||
with (yield from self._paho_lock):
|
with (yield from self._paho_lock):
|
||||||
if topic in self.topics:
|
if topic in self.subscribed_topics:
|
||||||
return
|
return
|
||||||
|
self.wanted_topics[topic] = qos
|
||||||
result, mid = yield from self.hass.async_add_job(
|
result, mid = yield from self.hass.async_add_job(
|
||||||
self._mqttc.subscribe, topic, qos)
|
self._mqttc.subscribe, topic, qos)
|
||||||
|
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
self.progress[mid] = topic
|
self.progress[mid] = topic
|
||||||
self.topics[topic] = None
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_unsubscribe(self, topic):
|
def async_unsubscribe(self, topic):
|
||||||
|
@ -542,6 +542,7 @@ class MQTT(object):
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
|
self.wanted_topics.pop(topic, None)
|
||||||
result, mid = yield from self.hass.async_add_job(
|
result, mid = yield from self.hass.async_add_job(
|
||||||
self._mqttc.unsubscribe, topic)
|
self._mqttc.unsubscribe, topic)
|
||||||
|
|
||||||
|
@ -562,15 +563,10 @@ class MQTT(object):
|
||||||
self._mqttc.disconnect()
|
self._mqttc.disconnect()
|
||||||
return
|
return
|
||||||
|
|
||||||
old_topics = self.topics
|
self.progress = {}
|
||||||
|
self.subscribed_topics = {}
|
||||||
self.topics = {key: value for key, value in self.topics.items()
|
for topic, qos in self.wanted_topics.items():
|
||||||
if value is None}
|
self.hass.add_job(self.async_subscribe, topic, qos)
|
||||||
|
|
||||||
for topic, qos in old_topics.items():
|
|
||||||
# qos is None if we were in process of subscribing
|
|
||||||
if qos is not None:
|
|
||||||
self.hass.add_job(self.async_subscribe, topic, qos)
|
|
||||||
|
|
||||||
if self.birth_message:
|
if self.birth_message:
|
||||||
self.hass.add_job(self.async_publish(
|
self.hass.add_job(self.async_publish(
|
||||||
|
@ -584,7 +580,7 @@ class MQTT(object):
|
||||||
topic = self.progress.pop(mid, None)
|
topic = self.progress.pop(mid, None)
|
||||||
if topic is None:
|
if topic is None:
|
||||||
return
|
return
|
||||||
self.topics[topic] = granted_qos[0]
|
self.subscribed_topics[topic] = granted_qos[0]
|
||||||
|
|
||||||
def _mqtt_on_message(self, _mqttc, _userdata, msg):
|
def _mqtt_on_message(self, _mqttc, _userdata, msg):
|
||||||
"""Message received callback."""
|
"""Message received callback."""
|
||||||
|
@ -598,18 +594,12 @@ class MQTT(object):
|
||||||
topic = self.progress.pop(mid, None)
|
topic = self.progress.pop(mid, None)
|
||||||
if topic is None:
|
if topic is None:
|
||||||
return
|
return
|
||||||
self.topics.pop(topic, None)
|
self.subscribed_topics.pop(topic, None)
|
||||||
|
|
||||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
self.progress = {}
|
self.progress = {}
|
||||||
self.topics = {key: value for key, value in self.topics.items()
|
self.subscribed_topics = {}
|
||||||
if value is not None}
|
|
||||||
|
|
||||||
# Remove None values from topic list
|
|
||||||
for key in list(self.topics):
|
|
||||||
if self.topics[key] is None:
|
|
||||||
self.topics.pop(key)
|
|
||||||
|
|
||||||
# When disconnected because of calling disconnect()
|
# When disconnected because of calling disconnect()
|
||||||
if result_code == 0:
|
if result_code == 0:
|
||||||
|
|
|
@ -388,9 +388,12 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
@mock.patch('homeassistant.components.mqtt.time.sleep')
|
@mock.patch('homeassistant.components.mqtt.time.sleep')
|
||||||
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
|
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
|
||||||
"""Test the re-connect tries."""
|
"""Test the re-connect tries."""
|
||||||
self.hass.data['mqtt'].topics = {
|
self.hass.data['mqtt'].subscribed_topics = {
|
||||||
'test/topic': 1,
|
'test/topic': 1,
|
||||||
'test/progress': None
|
}
|
||||||
|
self.hass.data['mqtt'].wanted_topics = {
|
||||||
|
'test/progress': 0,
|
||||||
|
'test/topic': 2,
|
||||||
}
|
}
|
||||||
self.hass.data['mqtt'].progress = {
|
self.hass.data['mqtt'].progress = {
|
||||||
1: 'test/progress'
|
1: 'test/progress'
|
||||||
|
@ -403,7 +406,9 @@ class TestMQTTCallbacks(unittest.TestCase):
|
||||||
self.assertEqual([1, 2, 4],
|
self.assertEqual([1, 2, 4],
|
||||||
[call[1][0] for call in mock_sleep.mock_calls])
|
[call[1][0] for call in mock_sleep.mock_calls])
|
||||||
|
|
||||||
self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics)
|
self.assertEqual({'test/topic': 2, 'test/progress': 0},
|
||||||
|
self.hass.data['mqtt'].wanted_topics)
|
||||||
|
self.assertEqual({}, self.hass.data['mqtt'].subscribed_topics)
|
||||||
self.assertEqual({}, self.hass.data['mqtt'].progress)
|
self.assertEqual({}, self.hass.data['mqtt'].progress)
|
||||||
|
|
||||||
def test_invalid_mqtt_topics(self):
|
def test_invalid_mqtt_topics(self):
|
||||||
|
@ -556,12 +561,15 @@ def test_mqtt_subscribes_topics_on_connect(hass):
|
||||||
"""Test subscription to topic on connect."""
|
"""Test subscription to topic on connect."""
|
||||||
mqtt_client = yield from mock_mqtt_client(hass)
|
mqtt_client = yield from mock_mqtt_client(hass)
|
||||||
|
|
||||||
prev_topics = OrderedDict()
|
subscribed_topics = OrderedDict()
|
||||||
prev_topics['topic/test'] = 1,
|
subscribed_topics['topic/test'] = 1
|
||||||
prev_topics['home/sensor'] = 2,
|
subscribed_topics['home/sensor'] = 2
|
||||||
prev_topics['still/pending'] = None
|
|
||||||
|
|
||||||
hass.data['mqtt'].topics = prev_topics
|
wanted_topics = subscribed_topics.copy()
|
||||||
|
wanted_topics['still/pending'] = 0
|
||||||
|
|
||||||
|
hass.data['mqtt'].wanted_topics = wanted_topics
|
||||||
|
hass.data['mqtt'].subscribed_topics = subscribed_topics
|
||||||
hass.data['mqtt'].progress = {1: 'still/pending'}
|
hass.data['mqtt'].progress = {1: 'still/pending'}
|
||||||
|
|
||||||
# Return values for subscribe calls (rc, mid)
|
# Return values for subscribe calls (rc, mid)
|
||||||
|
@ -574,7 +582,7 @@ def test_mqtt_subscribes_topics_on_connect(hass):
|
||||||
|
|
||||||
assert not mqtt_client.disconnect.called
|
assert not mqtt_client.disconnect.called
|
||||||
|
|
||||||
expected = [(topic, qos) for topic, qos in prev_topics.items()
|
expected = [(topic, qos) for topic, qos in wanted_topics.items()]
|
||||||
if qos is not None]
|
|
||||||
|
|
||||||
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected
|
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected
|
||||||
|
assert hass.data['mqtt'].progress == {}
|
||||||
|
|
Loading…
Reference in New Issue