allow wildcards in subscription (#12247)
* allow wildcards in subscription * remove whitespaces * make function public * also implement for mqtt_json * avoid mqtt-outside topic matching * add wildcard tests * add not matching wildcard tests * fix not-matching testspull/12267/head
parent
3333dcc6c2
commit
a9412d27aa
|
@ -31,17 +31,14 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
|
|||
devices = config[CONF_DEVICES]
|
||||
qos = config[CONF_QOS]
|
||||
|
||||
dev_id_lookup = {}
|
||||
|
||||
@callback
|
||||
def async_tracker_message_received(topic, payload, qos):
|
||||
"""Handle received MQTT message."""
|
||||
hass.async_add_job(
|
||||
async_see(dev_id=dev_id_lookup[topic], location_name=payload))
|
||||
|
||||
for dev_id, topic in devices.items():
|
||||
dev_id_lookup[topic] = dev_id
|
||||
@callback
|
||||
def async_message_received(topic, payload, qos, dev_id=dev_id):
|
||||
"""Handle received MQTT message."""
|
||||
hass.async_add_job(
|
||||
async_see(dev_id=dev_id, location_name=payload))
|
||||
|
||||
yield from mqtt.async_subscribe(
|
||||
hass, topic, async_tracker_message_received, qos)
|
||||
hass, topic, async_message_received, qos)
|
||||
|
||||
return True
|
||||
|
|
|
@ -41,32 +41,26 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
|
|||
devices = config[CONF_DEVICES]
|
||||
qos = config[CONF_QOS]
|
||||
|
||||
dev_id_lookup = {}
|
||||
|
||||
@callback
|
||||
def async_tracker_message_received(topic, payload, qos):
|
||||
"""Handle received MQTT message."""
|
||||
dev_id = dev_id_lookup[topic]
|
||||
|
||||
try:
|
||||
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
|
||||
except vol.MultipleInvalid:
|
||||
_LOGGER.error("Skipping update for following data "
|
||||
"because of missing or malformatted data: %s",
|
||||
payload)
|
||||
return
|
||||
except ValueError:
|
||||
_LOGGER.error("Error parsing JSON payload: %s", payload)
|
||||
return
|
||||
|
||||
kwargs = _parse_see_args(dev_id, data)
|
||||
hass.async_add_job(
|
||||
async_see(**kwargs))
|
||||
|
||||
for dev_id, topic in devices.items():
|
||||
dev_id_lookup[topic] = dev_id
|
||||
@callback
|
||||
def async_message_received(topic, payload, qos, dev_id=dev_id):
|
||||
"""Handle received MQTT message."""
|
||||
try:
|
||||
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
|
||||
except vol.MultipleInvalid:
|
||||
_LOGGER.error("Skipping update for following data "
|
||||
"because of missing or malformatted data: %s",
|
||||
payload)
|
||||
return
|
||||
except ValueError:
|
||||
_LOGGER.error("Error parsing JSON payload: %s", payload)
|
||||
return
|
||||
|
||||
kwargs = _parse_see_args(dev_id, data)
|
||||
hass.async_add_job(async_see(**kwargs))
|
||||
|
||||
yield from mqtt.async_subscribe(
|
||||
hass, topic, async_tracker_message_received, qos)
|
||||
hass, topic, async_message_received, qos)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -70,3 +70,79 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
|
|||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||
|
||||
def test_single_level_wildcard_topic(self):
|
||||
"""Test single level wildcard topic."""
|
||||
dev_id = 'paulus'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = '/location/+/paulus'
|
||||
topic = '/location/room/paulus'
|
||||
location = 'work'
|
||||
|
||||
self.hass.config.components = set(['mqtt', 'zone'])
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||
|
||||
def test_multi_level_wildcard_topic(self):
|
||||
"""Test multi level wildcard topic."""
|
||||
dev_id = 'paulus'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = '/location/#'
|
||||
topic = '/location/room/paulus'
|
||||
location = 'work'
|
||||
|
||||
self.hass.config.components = set(['mqtt', 'zone'])
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(location, self.hass.states.get(entity_id).state)
|
||||
|
||||
def test_single_level_wildcard_topic_not_matching(self):
|
||||
"""Test not matching single level wildcard topic."""
|
||||
dev_id = 'paulus'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = '/location/+/paulus'
|
||||
topic = '/location/paulus'
|
||||
location = 'work'
|
||||
|
||||
self.hass.config.components = set(['mqtt', 'zone'])
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertIsNone(self.hass.states.get(entity_id))
|
||||
|
||||
def test_multi_level_wildcard_topic_not_matching(self):
|
||||
"""Test not matching multi level wildcard topic."""
|
||||
dev_id = 'paulus'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = '/location/#'
|
||||
topic = '/somewhere/room/paulus'
|
||||
location = 'work'
|
||||
|
||||
self.hass.config.components = set(['mqtt', 'zone'])
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertIsNone(self.hass.states.get(entity_id))
|
||||
|
|
|
@ -123,3 +123,77 @@ class TestComponentsDeviceTrackerJSONMQTT(unittest.TestCase):
|
|||
"Skipping update for following data because of missing "
|
||||
"or malformatted data: {\"longitude\": 2.0}",
|
||||
test_handle.output[0])
|
||||
|
||||
def test_single_level_wildcard_topic(self):
|
||||
"""Test single level wildcard topic."""
|
||||
dev_id = 'zanzito'
|
||||
subscription = 'location/+/zanzito'
|
||||
topic = 'location/room/zanzito'
|
||||
location = json.dumps(LOCATION_MESSAGE)
|
||||
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt_json',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get('device_tracker.zanzito')
|
||||
self.assertEqual(state.attributes.get('latitude'), 2.0)
|
||||
self.assertEqual(state.attributes.get('longitude'), 1.0)
|
||||
|
||||
def test_multi_level_wildcard_topic(self):
|
||||
"""Test multi level wildcard topic."""
|
||||
dev_id = 'zanzito'
|
||||
subscription = 'location/#'
|
||||
topic = 'location/zanzito'
|
||||
location = json.dumps(LOCATION_MESSAGE)
|
||||
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt_json',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
state = self.hass.states.get('device_tracker.zanzito')
|
||||
self.assertEqual(state.attributes.get('latitude'), 2.0)
|
||||
self.assertEqual(state.attributes.get('longitude'), 1.0)
|
||||
|
||||
def test_single_level_wildcard_topic_not_matching(self):
|
||||
"""Test not matching single level wildcard topic."""
|
||||
dev_id = 'zanzito'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = 'location/+/zanzito'
|
||||
topic = 'location/zanzito'
|
||||
location = json.dumps(LOCATION_MESSAGE)
|
||||
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt_json',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertIsNone(self.hass.states.get(entity_id))
|
||||
|
||||
def test_multi_level_wildcard_topic_not_matching(self):
|
||||
"""Test not matching multi level wildcard topic."""
|
||||
dev_id = 'zanzito'
|
||||
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
|
||||
subscription = 'location/#'
|
||||
topic = 'somewhere/zanzito'
|
||||
location = json.dumps(LOCATION_MESSAGE)
|
||||
|
||||
assert setup_component(self.hass, device_tracker.DOMAIN, {
|
||||
device_tracker.DOMAIN: {
|
||||
CONF_PLATFORM: 'mqtt_json',
|
||||
'devices': {dev_id: subscription}
|
||||
}
|
||||
})
|
||||
fire_mqtt_message(self.hass, topic, location)
|
||||
self.hass.block_till_done()
|
||||
self.assertIsNone(self.hass.states.get(entity_id))
|
||||
|
|
Loading…
Reference in New Issue