Subscribe only to valid MQTT discovery topics (#45456)

pull/45466/head
Erik Montnemery 2021-01-23 14:51:25 +01:00 committed by GitHub
parent a0b906005d
commit f86beed7b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 9 deletions

View File

@ -240,9 +240,17 @@ async def async_start(
hass.data[ALREADY_DISCOVERED] = {}
hass.data[PENDING_DISCOVERED] = {}
hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_discovery_message_received, 0
discovery_topics = [
f"{discovery_topic}/+/+/config",
f"{discovery_topic}/+/+/+/config",
]
hass.data[DISCOVERY_UNSUBSCRIBE] = await asyncio.gather(
*(
mqtt.async_subscribe(hass, topic, async_discovery_message_received, 0)
for topic in discovery_topics
)
)
hass.data[LAST_DISCOVERY] = time.time()
mqtt_integrations = await async_get_mqtt(hass)
@ -289,9 +297,10 @@ async def async_start(
async def async_stop(hass: HomeAssistantType) -> bool:
"""Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
hass.data[DISCOVERY_UNSUBSCRIBE]()
hass.data[DISCOVERY_UNSUBSCRIBE] = None
if DISCOVERY_UNSUBSCRIBE in hass.data:
for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]:
unsub()
hass.data[DISCOVERY_UNSUBSCRIBE] = []
if INTEGRATION_UNSUBSCRIBE in hass.data:
for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()):
unsub()

View File

@ -46,10 +46,13 @@ async def test_subscribing_config_topic(hass, mqtt_mock):
discovery_topic = "homeassistant"
await async_start(hass, discovery_topic, entry)
assert mqtt_mock.async_subscribe.called
call_args = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args[0] == discovery_topic + "/#"
assert call_args[2] == 0
call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args1[2] == 0
call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1]
assert call_args2[2] == 0
topics = [call_args1[0], call_args2[0]]
assert discovery_topic + "/+/+/config" in topics
assert discovery_topic + "/+/+/+/config" in topics
async def test_invalid_topic(hass, mqtt_mock):