From a9412d27aab92a694127949f4a7b1385322d08a1 Mon Sep 17 00:00:00 2001 From: escoand Date: Sat, 10 Feb 2018 00:22:50 +0100 Subject: [PATCH] allow wildcards in subscription (#12247) * allow wildcards in subscription * remove whitespaces * make function public * also implement for mqtt_json * avoid mqtt-outside topic matching * add wildcard tests * add not matching wildcard tests * fix not-matching tests --- .../components/device_tracker/mqtt.py | 17 ++--- .../components/device_tracker/mqtt_json.py | 42 +++++----- tests/components/device_tracker/test_mqtt.py | 76 +++++++++++++++++++ .../device_tracker/test_mqtt_json.py | 74 ++++++++++++++++++ 4 files changed, 175 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/device_tracker/mqtt.py b/homeassistant/components/device_tracker/mqtt.py index aab5b43acea..2e2d9b10d98 100644 --- a/homeassistant/components/device_tracker/mqtt.py +++ b/homeassistant/components/device_tracker/mqtt.py @@ -31,17 +31,14 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None): devices = config[CONF_DEVICES] qos = config[CONF_QOS] - dev_id_lookup = {} - - @callback - def async_tracker_message_received(topic, payload, qos): - """Handle received MQTT message.""" - hass.async_add_job( - async_see(dev_id=dev_id_lookup[topic], location_name=payload)) - for dev_id, topic in devices.items(): - dev_id_lookup[topic] = dev_id + @callback + def async_message_received(topic, payload, qos, dev_id=dev_id): + """Handle received MQTT message.""" + hass.async_add_job( + async_see(dev_id=dev_id, location_name=payload)) + yield from mqtt.async_subscribe( - hass, topic, async_tracker_message_received, qos) + hass, topic, async_message_received, qos) return True diff --git a/homeassistant/components/device_tracker/mqtt_json.py b/homeassistant/components/device_tracker/mqtt_json.py index 0ef4f1835b6..7bcad60236a 100644 --- a/homeassistant/components/device_tracker/mqtt_json.py +++ b/homeassistant/components/device_tracker/mqtt_json.py @@ -41,32 +41,26 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None): devices = config[CONF_DEVICES] qos = config[CONF_QOS] - dev_id_lookup = {} - - @callback - def async_tracker_message_received(topic, payload, qos): - """Handle received MQTT message.""" - dev_id = dev_id_lookup[topic] - - try: - data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload)) - except vol.MultipleInvalid: - _LOGGER.error("Skipping update for following data " - "because of missing or malformatted data: %s", - payload) - return - except ValueError: - _LOGGER.error("Error parsing JSON payload: %s", payload) - return - - kwargs = _parse_see_args(dev_id, data) - hass.async_add_job( - async_see(**kwargs)) - for dev_id, topic in devices.items(): - dev_id_lookup[topic] = dev_id + @callback + def async_message_received(topic, payload, qos, dev_id=dev_id): + """Handle received MQTT message.""" + try: + data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload)) + except vol.MultipleInvalid: + _LOGGER.error("Skipping update for following data " + "because of missing or malformatted data: %s", + payload) + return + except ValueError: + _LOGGER.error("Error parsing JSON payload: %s", payload) + return + + kwargs = _parse_see_args(dev_id, data) + hass.async_add_job(async_see(**kwargs)) + yield from mqtt.async_subscribe( - hass, topic, async_tracker_message_received, qos) + hass, topic, async_message_received, qos) return True diff --git a/tests/components/device_tracker/test_mqtt.py b/tests/components/device_tracker/test_mqtt.py index 4905ab4d029..78750e91f83 100644 --- a/tests/components/device_tracker/test_mqtt.py +++ b/tests/components/device_tracker/test_mqtt.py @@ -70,3 +70,79 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase): fire_mqtt_message(self.hass, topic, location) self.hass.block_till_done() self.assertEqual(location, self.hass.states.get(entity_id).state) + + def test_single_level_wildcard_topic(self): + """Test single level wildcard topic.""" + dev_id = 'paulus' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = '/location/+/paulus' + topic = '/location/room/paulus' + location = 'work' + + self.hass.config.components = set(['mqtt', 'zone']) + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertEqual(location, self.hass.states.get(entity_id).state) + + def test_multi_level_wildcard_topic(self): + """Test multi level wildcard topic.""" + dev_id = 'paulus' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = '/location/#' + topic = '/location/room/paulus' + location = 'work' + + self.hass.config.components = set(['mqtt', 'zone']) + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertEqual(location, self.hass.states.get(entity_id).state) + + def test_single_level_wildcard_topic_not_matching(self): + """Test not matching single level wildcard topic.""" + dev_id = 'paulus' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = '/location/+/paulus' + topic = '/location/paulus' + location = 'work' + + self.hass.config.components = set(['mqtt', 'zone']) + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertIsNone(self.hass.states.get(entity_id)) + + def test_multi_level_wildcard_topic_not_matching(self): + """Test not matching multi level wildcard topic.""" + dev_id = 'paulus' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = '/location/#' + topic = '/somewhere/room/paulus' + location = 'work' + + self.hass.config.components = set(['mqtt', 'zone']) + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertIsNone(self.hass.states.get(entity_id)) diff --git a/tests/components/device_tracker/test_mqtt_json.py b/tests/components/device_tracker/test_mqtt_json.py index 1755f424d29..43f4fc3bbf3 100644 --- a/tests/components/device_tracker/test_mqtt_json.py +++ b/tests/components/device_tracker/test_mqtt_json.py @@ -123,3 +123,77 @@ class TestComponentsDeviceTrackerJSONMQTT(unittest.TestCase): "Skipping update for following data because of missing " "or malformatted data: {\"longitude\": 2.0}", test_handle.output[0]) + + def test_single_level_wildcard_topic(self): + """Test single level wildcard topic.""" + dev_id = 'zanzito' + subscription = 'location/+/zanzito' + topic = 'location/room/zanzito' + location = json.dumps(LOCATION_MESSAGE) + + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt_json', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + state = self.hass.states.get('device_tracker.zanzito') + self.assertEqual(state.attributes.get('latitude'), 2.0) + self.assertEqual(state.attributes.get('longitude'), 1.0) + + def test_multi_level_wildcard_topic(self): + """Test multi level wildcard topic.""" + dev_id = 'zanzito' + subscription = 'location/#' + topic = 'location/zanzito' + location = json.dumps(LOCATION_MESSAGE) + + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt_json', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + state = self.hass.states.get('device_tracker.zanzito') + self.assertEqual(state.attributes.get('latitude'), 2.0) + self.assertEqual(state.attributes.get('longitude'), 1.0) + + def test_single_level_wildcard_topic_not_matching(self): + """Test not matching single level wildcard topic.""" + dev_id = 'zanzito' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = 'location/+/zanzito' + topic = 'location/zanzito' + location = json.dumps(LOCATION_MESSAGE) + + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt_json', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertIsNone(self.hass.states.get(entity_id)) + + def test_multi_level_wildcard_topic_not_matching(self): + """Test not matching multi level wildcard topic.""" + dev_id = 'zanzito' + entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id) + subscription = 'location/#' + topic = 'somewhere/zanzito' + location = json.dumps(LOCATION_MESSAGE) + + assert setup_component(self.hass, device_tracker.DOMAIN, { + device_tracker.DOMAIN: { + CONF_PLATFORM: 'mqtt_json', + 'devices': {dev_id: subscription} + } + }) + fire_mqtt_message(self.hass, topic, location) + self.hass.block_till_done() + self.assertIsNone(self.hass.states.get(entity_id))