Subscribe per component for MQTT discovery (#119974)

* Subscribe per component for MQTT discovery

* Use single assignment

* Handle wildcard subscriptions first

* Split subsRecription handling, update helper

* Fix help_all_subscribe_calls

* Fix import

* Fix test

* Update import order

* Undo move self._last_subscribe

* Recover removed test

* Revert not needed changes to binary_sensor platform tests

* Revert line removal

* Rework interation of discovery topics

* Reduce

* Add comment

* Move comment

* Chain subscriptions
pull/124303/head^2
Jan Bouwhuis 2024-08-20 17:02:48 +02:00 committed by GitHub
parent a1e3e7f24f
commit b74aced6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 26 deletions

View File

@ -111,6 +111,7 @@ UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10
RECONNECT_INTERVAL_SECONDS = 10
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
MAX_SUBSCRIBES_PER_CALL = 500
MAX_UNSUBSCRIBES_PER_CALL = 500
@ -893,14 +894,27 @@ class MQTT:
if not self._pending_subscriptions:
return
subscriptions: dict[str, int] = self._pending_subscriptions
# Split out the wildcard subscriptions, we subscribe to them one by one
pending_subscriptions: dict[str, int] = self._pending_subscriptions
pending_wildcard_subscriptions = {
subscription.topic: pending_subscriptions.pop(subscription.topic)
for subscription in self._wildcard_subscriptions
if subscription.topic in pending_subscriptions
}
self._pending_subscriptions = {}
subscription_list = list(subscriptions.items())
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
for chunk in chunked_or_all(subscription_list, MAX_SUBSCRIBES_PER_CALL):
for chunk in chain(
chunked_or_all(
pending_wildcard_subscriptions.items(), MAX_WILDCARD_SUBSCRIBES_PER_CALL
),
chunked_or_all(pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL),
):
chunk_list = list(chunk)
if not chunk_list:
continue
result, mid = self._mqttc.subscribe(chunk_list)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections import deque
import functools
from itertools import chain
import logging
import re
import time
@ -238,10 +239,6 @@ async def async_start( # noqa: C901
component, node_id, object_id = match.groups()
if component not in SUPPORTED_COMPONENTS:
_LOGGER.warning("Integration %s is not supported", component)
return
if payload:
try:
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
@ -351,9 +348,15 @@ async def async_start( # noqa: C901
0,
job_type=HassJobType.Callback,
)
for topic in (
f"{discovery_topic}/+/+/config",
f"{discovery_topic}/+/+/+/config",
for topic in chain(
(
f"{discovery_topic}/{component}/+/config"
for component in SUPPORTED_COMPONENTS
),
(
f"{discovery_topic}/{component}/+/+/config"
for component in SUPPORTED_COMPONENTS
),
)
]

View File

@ -13,6 +13,7 @@ import pytest
from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import RECONNECT_INTERVAL_SECONDS
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.models import MessageCallbackType, ReceiveMessage
from homeassistant.config_entries import ConfigEntryDisabler, ConfigEntryState
from homeassistant.const import (
@ -1614,8 +1615,9 @@ async def test_subscription_done_when_birth_message_is_sent(
"""Test sending birth message until initial subscription has been completed."""
mqtt_client_mock = setup_with_birth_msg_client_mock
subscribe_calls = help_all_subscribe_calls(mqtt_client_mock)
assert ("homeassistant/+/+/config", 0) in subscribe_calls
assert ("homeassistant/+/+/+/config", 0) in subscribe_calls
for component in SUPPORTED_COMPONENTS:
assert (f"homeassistant/{component}/+/config", 0) in subscribe_calls
assert (f"homeassistant/{component}/+/+/config", 0) in subscribe_calls
mqtt_client_mock.publish.assert_called_with(
"homeassistant/status", "online", 0, False
)

View File

@ -16,7 +16,10 @@ import yaml
from homeassistant import config as module_hass_config
from homeassistant.components import mqtt
from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.const import MQTT_CONNECTION_STATE
from homeassistant.components.mqtt.const import (
MQTT_CONNECTION_STATE,
SUPPORTED_COMPONENTS,
)
from homeassistant.components.mqtt.mixins import MQTT_ATTRIBUTES_BLOCKED
from homeassistant.components.mqtt.models import PublishPayloadType
from homeassistant.config_entries import ConfigEntryState
@ -75,9 +78,12 @@ type _StateDataType = list[tuple[_MqttMessageType, str | None, _AttributesType |
def help_all_subscribe_calls(mqtt_client_mock: MqttMockPahoClient) -> list[Any]:
"""Test of a call."""
all_calls = []
for calls in mqtt_client_mock.subscribe.mock_calls:
for call in calls[1]:
all_calls.extend(call)
for call_l1 in mqtt_client_mock.subscribe.mock_calls:
if isinstance(call_l1[1][0], list):
for call_l2 in call_l1[1]:
all_calls.extend(call_l2)
else:
all_calls.append(call_l1[1])
return all_calls
@ -1178,7 +1184,10 @@ async def help_test_entity_id_update_subscriptions(
state = hass.states.get(f"{domain}.test")
assert state is not None
assert mqtt_mock.async_subscribe.call_count == len(topics) + 2 + DISCOVERY_COUNT
assert (
mqtt_mock.async_subscribe.call_count
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
)
for topic in topics:
mqtt_mock.async_subscribe.assert_any_call(
topic, ANY, ANY, ANY, HassJobType.Callback

View File

@ -15,6 +15,7 @@ from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS,
DEVICE_ABBREVIATIONS,
)
from homeassistant.components.mqtt.const import SUPPORTED_COMPONENTS
from homeassistant.components.mqtt.discovery import (
MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW,
@ -73,13 +74,10 @@ async def test_subscribing_config_topic(
discovery_topic = "homeassistant"
await async_start(hass, discovery_topic, entry)
call_args1 = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args1[2] == 0
call_args2 = mqtt_mock.async_subscribe.mock_calls[1][1]
assert call_args2[2] == 0
topics = [call_args1[0], call_args2[0]]
assert discovery_topic + "/+/+/config" in topics
assert discovery_topic + "/+/+/+/config" in topics
topics = [call[1][0] for call in mqtt_mock.async_subscribe.mock_calls]
for component in SUPPORTED_COMPONENTS:
assert f"{discovery_topic}/{component}/+/config" in topics
assert f"{discovery_topic}/{component}/+/+/config" in topics
@pytest.mark.parametrize(
@ -198,8 +196,6 @@ async def test_only_valid_components(
await hass.async_block_till_done()
assert f"Integration {invalid_component} is not supported" in caplog.text
assert not mock_dispatcher_send.called