diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 1e47595058d..347166fdb82 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -240,9 +240,17 @@ async def async_start( hass.data[ALREADY_DISCOVERED] = {} hass.data[PENDING_DISCOVERED] = {} - hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe( - hass, f"{discovery_topic}/#", async_discovery_message_received, 0 + discovery_topics = [ + f"{discovery_topic}/+/+/config", + f"{discovery_topic}/+/+/+/config", + ] + hass.data[DISCOVERY_UNSUBSCRIBE] = await asyncio.gather( + *( + mqtt.async_subscribe(hass, topic, async_discovery_message_received, 0) + for topic in discovery_topics + ) ) + hass.data[LAST_DISCOVERY] = time.time() mqtt_integrations = await async_get_mqtt(hass) @@ -289,9 +297,10 @@ async def async_start( async def async_stop(hass: HomeAssistantType) -> bool: """Stop MQTT Discovery.""" - if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]: - hass.data[DISCOVERY_UNSUBSCRIBE]() - hass.data[DISCOVERY_UNSUBSCRIBE] = None + if DISCOVERY_UNSUBSCRIBE in hass.data: + for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]: + unsub() + hass.data[DISCOVERY_UNSUBSCRIBE] = [] if INTEGRATION_UNSUBSCRIBE in hass.data: for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()): unsub() diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 124f40c31fa..c9b0879d490 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -46,10 +46,13 @@ async def test_subscribing_config_topic(hass, mqtt_mock): discovery_topic = "homeassistant" await async_start(hass, discovery_topic, entry) - assert mqtt_mock.async_subscribe.called - call_args = mqtt_mock.async_subscribe.mock_calls[0][1] - assert call_args[0] == discovery_topic + "/#" - assert call_args[2] == 0 + call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1] + assert call_args1[2] == 0 + call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1] + assert call_args2[2] == 0 + topics = [call_args1[0], call_args2[0]] + assert discovery_topic + "/+/+/config" in topics + assert discovery_topic + "/+/+/+/config" in topics async def test_invalid_topic(hass, mqtt_mock):