core/tests/components/mqtt/test_init.py

619 lines
24 KiB
Python

"""The tests for the MQTT component."""
import asyncio
import unittest
from unittest import mock
import socket
import ssl
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.setup import async_setup_component
import homeassistant.components.mqtt as mqtt
from homeassistant.const import (EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE,
EVENT_HOMEASSISTANT_STOP)
from tests.common import (get_test_home_assistant, mock_coro,
mock_mqtt_component,
threadsafe_coroutine_factory, fire_mqtt_message,
async_fire_mqtt_message)
@asyncio.coroutine
def async_mock_mqtt_client(hass, config=None):
"""Mock the MQTT paho client."""
if config is None:
config = {mqtt.CONF_BROKER: 'mock-broker'}
with mock.patch('paho.mqtt.client.Client') as mock_client:
mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
result = yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: config
})
assert result
return mock_client()
mock_mqtt_client = threadsafe_coroutine_factory(async_mock_mqtt_client)
# pylint: disable=invalid-name
class TestMQTTComponent(unittest.TestCase):
"""Test the MQTT component."""
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_mqtt_component(self.hass)
self.calls = []
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
self.hass.stop()
@callback
def record_calls(self, *args):
"""Helper for recording calls."""
self.calls.append(args)
def test_client_stops_on_home_assistant_start(self):
"""Test if client stops on HA stop."""
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
self.hass.block_till_done()
self.assertTrue(self.hass.data['mqtt'].async_disconnect.called)
def test_publish_calls_service(self):
"""Test the publishing of call to services."""
self.hass.bus.listen_once(EVENT_CALL_SERVICE, self.record_calls)
mqtt.publish(self.hass, 'test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual(
'test-topic',
self.calls[0][0].data['service_data'][mqtt.ATTR_TOPIC])
self.assertEqual(
'test-payload',
self.calls[0][0].data['service_data'][mqtt.ATTR_PAYLOAD])
def test_service_call_without_topic_does_not_publish(self):
"""Test the service call if topic is missing."""
self.hass.bus.fire(EVENT_CALL_SERVICE, {
ATTR_DOMAIN: mqtt.DOMAIN,
ATTR_SERVICE: mqtt.SERVICE_PUBLISH
})
self.hass.block_till_done()
self.assertTrue(not self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_template_payload_renders_template(self):
"""Test the service call with rendered template.
If 'payload_template' is provided and 'payload' is not, then render it.
"""
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}")
self.hass.block_till_done()
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(
self.hass.data['mqtt'].async_publish.call_args[0][1], "2")
def test_service_call_with_payload_doesnt_render_template(self):
"""Test the service call with unrendered template.
If both 'payload' and 'payload_template' are provided then fail.
"""
payload = "not a template"
payload_template = "a template"
self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, {
mqtt.ATTR_TOPIC: "test/topic",
mqtt.ATTR_PAYLOAD: payload,
mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template
}, blocking=True)
self.assertFalse(self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_ascii_qos_retain_flags(self):
"""Test the service call with args that can be misinterpreted.
Empty payload message and ascii formatted qos and retain flags.
"""
self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, {
mqtt.ATTR_TOPIC: "test/topic",
mqtt.ATTR_PAYLOAD: "",
mqtt.ATTR_QOS: '2',
mqtt.ATTR_RETAIN: 'no'
}, blocking=True)
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(
self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
def test_invalid_mqtt_topics(self):
"""Test invalid topics."""
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one')
# pylint: disable=invalid-name
class TestMQTTCallbacks(unittest.TestCase):
"""Test the MQTT callbacks."""
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_mqtt_client(self.hass)
self.calls = []
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
self.hass.stop()
@callback
def record_calls(self, *args):
"""Helper for recording calls."""
self.calls.append(args)
def test_client_starts_on_home_assistant_mqtt_setup(self):
"""Test if client is connected after mqtt init on bootstrap."""
self.assertEqual(self.hass.data['mqtt']._mqttc.connect.call_count, 1)
def test_receiving_non_utf8_message_gets_logged(self):
"""Test receiving a non utf8 encoded message."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
with self.assertLogs(level='WARNING') as test_handle:
fire_mqtt_message(self.hass, 'test-topic', b'\x9a')
self.hass.block_till_done()
self.assertIn(
"WARNING:homeassistant.components.mqtt:Can't decode payload "
"b'\\x9a' on test-topic with encoding utf-8",
test_handle.output[0])
def test_all_subscriptions_run_when_decode_fails(self):
"""Test all other subscriptions still run when decode fails for one."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,
encoding='ascii')
mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic', '°C')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
def test_subscribe_topic(self):
"""Test the subscription of a topic."""
unsub = mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
unsub()
fire_mqtt_message(self.hass, 'test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
def test_subscribe_topic_not_match(self):
"""Test if subscribed topic is not a match."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls)
fire_mqtt_message(self.hass, 'another-test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_level_wildcard(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, 'test-topic/+/on', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic/bier/on', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic/bier/on', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_level_wildcard_no_subtree_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, 'test-topic/+/on', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic/bier', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_subtree_wildcard_subtree_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, 'test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic/bier/on', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic/bier/on', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_subtree_wildcard_root_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, 'test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_subtree_wildcard_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, 'test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'another-test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/test-topic/here-iam', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic/here-iam', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/here-iam/test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, 'hi/another-test-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))
def test_subscribe_topic_sys_root(self):
"""Test the subscription of $ root topics."""
mqtt.subscribe(self.hass, '$test-topic/subtree/on', self.record_calls)
fire_mqtt_message(self.hass, '$test-topic/subtree/on', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('$test-topic/subtree/on', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_sys_root_and_wildcard_topic(self):
"""Test the subscription of $ root and wildcard topics."""
mqtt.subscribe(self.hass, '$test-topic/#', self.record_calls)
fire_mqtt_message(self.hass, '$test-topic/some-topic', 'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('$test-topic/some-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(self):
"""Test the subscription of $ root and wildcard subtree topics."""
mqtt.subscribe(self.hass, '$test-topic/subtree/#', self.record_calls)
fire_mqtt_message(self.hass, '$test-topic/subtree/some-topic',
'test-payload')
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('$test-topic/subtree/some-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])
def test_subscribe_special_characters(self):
"""Test the subscription to topics with special characters."""
topic = '/test-topic/$(.)[^]{-}'
payload = 'p4y.l[]a|> ?'
mqtt.subscribe(self.hass, topic, self.record_calls)
fire_mqtt_message(self.hass, topic, payload)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual(topic, self.calls[0][0])
self.assertEqual(payload, self.calls[0][1])
def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect."""
for result_code in range(1, 6):
self.hass.data['mqtt']._mqttc = mock.MagicMock()
self.hass.data['mqtt']._mqtt_on_connect(
None, {'topics': {}}, 0, result_code)
self.assertTrue(self.hass.data['mqtt']._mqttc.disconnect.called)
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
"""Test the disconnect tries."""
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.assertFalse(self.hass.data['mqtt']._mqttc.reconnect.called)
@mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries."""
self.hass.data['mqtt'].subscriptions = [
mqtt.Subscription('test/progress', None, 0),
mqtt.Subscription('test/progress', None, 1),
mqtt.Subscription('test/topic', None, 2),
]
self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0]
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1)
self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called)
self.assertEqual(
4, len(self.hass.data['mqtt']._mqttc.reconnect.mock_calls))
self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls])
def test_retained_message_on_subscribe_received(self):
"""Test every subscriber receives retained message on subscribe."""
def side_effect(*args):
async_fire_mqtt_message(self.hass, 'test/state', 'online')
return 0, 0
self.hass.data['mqtt']._mqttc.subscribe.side_effect = side_effect
calls_a = mock.MagicMock()
mqtt.subscribe(self.hass, 'test/state', calls_a)
self.hass.block_till_done()
self.assertTrue(calls_a.called)
calls_b = mock.MagicMock()
mqtt.subscribe(self.hass, 'test/state', calls_b)
self.hass.block_till_done()
self.assertTrue(calls_b.called)
def test_not_calling_unsubscribe_with_active_subscribers(self):
"""Test not calling unsubscribe() when other subscribers are active."""
unsub = mqtt.subscribe(self.hass, 'test/state', None)
mqtt.subscribe(self.hass, 'test/state', None)
self.hass.block_till_done()
self.assertTrue(self.hass.data['mqtt']._mqttc.subscribe.called)
unsub()
self.hass.block_till_done()
self.assertFalse(self.hass.data['mqtt']._mqttc.unsubscribe.called)
def test_restore_subscriptions_on_reconnect(self):
"""Test subscriptions are restored on reconnect."""
mqtt.subscribe(self.hass, 'test/state', None)
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.call_count, 1)
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.hass.data['mqtt']._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.call_count, 2)
def test_restore_all_active_subscriptions_on_reconnect(self):
"""Test active subscriptions are restored correctly on reconnect."""
self.hass.data['mqtt']._mqttc.subscribe.side_effect = (
(0, 1), (0, 2), (0, 3), (0, 4)
)
unsub = mqtt.subscribe(self.hass, 'test/state', None, qos=2)
mqtt.subscribe(self.hass, 'test/state', None)
mqtt.subscribe(self.hass, 'test/state', None, qos=1)
self.hass.block_till_done()
expected = [
mock.call('test/state', 2),
mock.call('test/state', 0),
mock.call('test/state', 1)
]
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.mock_calls,
expected)
unsub()
self.hass.block_till_done()
self.assertEqual(self.hass.data['mqtt']._mqttc.unsubscribe.call_count,
0)
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.hass.data['mqtt']._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done()
expected.append(mock.call('test/state', 1))
self.assertEqual(self.hass.data['mqtt']._mqttc.subscribe.mock_calls,
expected)
@asyncio.coroutine
def test_setup_embedded_starts_with_no_config(hass):
"""Test setting up embedded server with no config."""
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
with mock.patch('homeassistant.components.mqtt.server.async_start',
return_value=mock_coro(
return_value=(True, client_config))
) as _start:
yield from async_mock_mqtt_client(hass, {})
assert _start.call_count == 1
@asyncio.coroutine
def test_setup_embedded_with_embedded(hass):
"""Test setting up embedded server with no config."""
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
with mock.patch('homeassistant.components.mqtt.server.async_start',
return_value=mock_coro(
return_value=(True, client_config))
) as _start:
_start.return_value = mock_coro(return_value=(True, client_config))
yield from async_mock_mqtt_client(hass, {'embedded': None})
assert _start.call_count == 1
@asyncio.coroutine
def test_setup_fails_if_no_connect_broker(hass):
"""Test for setup failure if connection to broker is missing."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker'}}
with mock.patch('homeassistant.components.mqtt.MQTT',
side_effect=socket.error()):
result = yield from async_setup_component(hass, mqtt.DOMAIN,
test_broker_cfg)
assert not result
with mock.patch('paho.mqtt.client.Client') as mock_client:
mock_client().connect = lambda *args: 1
result = yield from async_setup_component(hass, mqtt.DOMAIN,
test_broker_cfg)
assert not result
@asyncio.coroutine
def test_setup_uses_certificate_on_certificate_set_to_auto(hass):
"""Test setup uses bundled certs when certificate is set to auto."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker',
'certificate': 'auto'}}
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_MQTT:
yield from async_setup_component(hass, mqtt.DOMAIN, test_broker_cfg)
assert mock_MQTT.called
import requests.certs
expectedCertificate = requests.certs.where()
assert mock_MQTT.mock_calls[0][1][7] == expectedCertificate
@asyncio.coroutine
def test_setup_does_not_use_certificate_on_mqtts_port(hass):
"""Test setup doesn't use bundled certs when certificate is not set."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker',
'port': 8883}}
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_MQTT:
yield from async_setup_component(hass, mqtt.DOMAIN, test_broker_cfg)
assert mock_MQTT.called
assert mock_MQTT.mock_calls[0][1][2] == 8883
import requests.certs
mqttsCertificateBundle = requests.certs.where()
assert mock_MQTT.mock_calls[0][1][7] != mqttsCertificateBundle
@asyncio.coroutine
def test_setup_without_tls_config_uses_tlsv1_under_python36(hass):
"""Test setup defaults to TLSv1 under python3.6."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker'}}
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_MQTT:
yield from async_setup_component(hass, mqtt.DOMAIN, test_broker_cfg)
assert mock_MQTT.called
import sys
if sys.hexversion >= 0x03060000:
expectedTlsVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
else:
expectedTlsVersion = ssl.PROTOCOL_TLSv1
assert mock_MQTT.mock_calls[0][1][14] == expectedTlsVersion
@asyncio.coroutine
def test_setup_with_tls_config_uses_tls_version1_2(hass):
"""Test setup uses specified TLS version."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker',
'tls_version': '1.2'}}
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_MQTT:
yield from async_setup_component(hass, mqtt.DOMAIN, test_broker_cfg)
assert mock_MQTT.called
assert mock_MQTT.mock_calls[0][1][14] == ssl.PROTOCOL_TLSv1_2
@asyncio.coroutine
def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass):
"""Test setup uses TLSv1.0 if explicitly chosen."""
test_broker_cfg = {mqtt.DOMAIN: {mqtt.CONF_BROKER: 'test-broker',
'tls_version': '1.0'}}
with mock.patch('homeassistant.components.mqtt.MQTT') as mock_MQTT:
yield from async_setup_component(hass, mqtt.DOMAIN, test_broker_cfg)
assert mock_MQTT.called
assert mock_MQTT.mock_calls[0][1][14] == ssl.PROTOCOL_TLSv1
@asyncio.coroutine
def test_birth_message(hass):
"""Test sending birth message."""
mqtt_client = yield from async_mock_mqtt_client(hass, {
mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'}
})
calls = []
mqtt_client.publish.side_effect = lambda *args: calls.append(args)
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert calls[-1] == ('birth', 'birth', 0, False)
@asyncio.coroutine
def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect."""
mqtt_client = yield from async_mock_mqtt_client(hass)
hass.data['mqtt'].subscriptions = [
mqtt.Subscription('topic/test', None),
mqtt.Subscription('home/sensor', None, 2),
mqtt.Subscription('still/pending', None),
mqtt.Subscription('still/pending', None, 1),
]
hass.add_job = mock.MagicMock()
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert mqtt_client.disconnect.call_count == 0
expected = {
'topic/test': 0,
'home/sensor': 2,
'still/pending': 1
}
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
assert calls == expected