allow wildcards in subscription (#12247)

* allow wildcards in subscription

* remove whitespaces

* make function public

* also implement for mqtt_json

* avoid mqtt-outside topic matching

* add wildcard tests

* add not matching wildcard tests

* fix not-matching tests
pull/12267/head
escoand 2018-02-10 00:22:50 +01:00 committed by Paulus Schoutsen
parent 3333dcc6c2
commit a9412d27aa
4 changed files with 175 additions and 34 deletions

View File

@ -31,17 +31,14 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
devices = config[CONF_DEVICES]
qos = config[CONF_QOS]
dev_id_lookup = {}
@callback
def async_tracker_message_received(topic, payload, qos):
"""Handle received MQTT message."""
hass.async_add_job(
async_see(dev_id=dev_id_lookup[topic], location_name=payload))
for dev_id, topic in devices.items():
dev_id_lookup[topic] = dev_id
@callback
def async_message_received(topic, payload, qos, dev_id=dev_id):
"""Handle received MQTT message."""
hass.async_add_job(
async_see(dev_id=dev_id, location_name=payload))
yield from mqtt.async_subscribe(
hass, topic, async_tracker_message_received, qos)
hass, topic, async_message_received, qos)
return True

View File

@ -41,32 +41,26 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
devices = config[CONF_DEVICES]
qos = config[CONF_QOS]
dev_id_lookup = {}
@callback
def async_tracker_message_received(topic, payload, qos):
"""Handle received MQTT message."""
dev_id = dev_id_lookup[topic]
try:
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
except vol.MultipleInvalid:
_LOGGER.error("Skipping update for following data "
"because of missing or malformatted data: %s",
payload)
return
except ValueError:
_LOGGER.error("Error parsing JSON payload: %s", payload)
return
kwargs = _parse_see_args(dev_id, data)
hass.async_add_job(
async_see(**kwargs))
for dev_id, topic in devices.items():
dev_id_lookup[topic] = dev_id
@callback
def async_message_received(topic, payload, qos, dev_id=dev_id):
"""Handle received MQTT message."""
try:
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
except vol.MultipleInvalid:
_LOGGER.error("Skipping update for following data "
"because of missing or malformatted data: %s",
payload)
return
except ValueError:
_LOGGER.error("Error parsing JSON payload: %s", payload)
return
kwargs = _parse_see_args(dev_id, data)
hass.async_add_job(async_see(**kwargs))
yield from mqtt.async_subscribe(
hass, topic, async_tracker_message_received, qos)
hass, topic, async_message_received, qos)
return True

View File

@ -70,3 +70,79 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)
def test_single_level_wildcard_topic(self):
"""Test single level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/+/paulus'
topic = '/location/room/paulus'
location = 'work'
self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)
def test_multi_level_wildcard_topic(self):
"""Test multi level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/#'
topic = '/location/room/paulus'
location = 'work'
self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)
def test_single_level_wildcard_topic_not_matching(self):
"""Test not matching single level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/+/paulus'
topic = '/location/paulus'
location = 'work'
self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))
def test_multi_level_wildcard_topic_not_matching(self):
"""Test not matching multi level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/#'
topic = '/somewhere/room/paulus'
location = 'work'
self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))

View File

@ -123,3 +123,77 @@ class TestComponentsDeviceTrackerJSONMQTT(unittest.TestCase):
"Skipping update for following data because of missing "
"or malformatted data: {\"longitude\": 2.0}",
test_handle.output[0])
def test_single_level_wildcard_topic(self):
"""Test single level wildcard topic."""
dev_id = 'zanzito'
subscription = 'location/+/zanzito'
topic = 'location/room/zanzito'
location = json.dumps(LOCATION_MESSAGE)
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
state = self.hass.states.get('device_tracker.zanzito')
self.assertEqual(state.attributes.get('latitude'), 2.0)
self.assertEqual(state.attributes.get('longitude'), 1.0)
def test_multi_level_wildcard_topic(self):
"""Test multi level wildcard topic."""
dev_id = 'zanzito'
subscription = 'location/#'
topic = 'location/zanzito'
location = json.dumps(LOCATION_MESSAGE)
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
state = self.hass.states.get('device_tracker.zanzito')
self.assertEqual(state.attributes.get('latitude'), 2.0)
self.assertEqual(state.attributes.get('longitude'), 1.0)
def test_single_level_wildcard_topic_not_matching(self):
"""Test not matching single level wildcard topic."""
dev_id = 'zanzito'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = 'location/+/zanzito'
topic = 'location/zanzito'
location = json.dumps(LOCATION_MESSAGE)
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))
def test_multi_level_wildcard_topic_not_matching(self):
"""Test not matching multi level wildcard topic."""
dev_id = 'zanzito'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = 'location/#'
topic = 'somewhere/zanzito'
location = json.dumps(LOCATION_MESSAGE)
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))