Reconfigure MQTT binary_sensor component if discovery info is changed (#18169)

* Recreate component if discovery info is changed

* Update component instead of remove+add

* Set name and unique_id in __init__

* Update unit test

* Cleanup

* More cleanup

* Refactor according to review comments

* Change discovery_hash

* Review comments, add tests

* Fix handling of value_template
pull/18585/head
emontnemery 2018-11-19 16:49:04 +01:00 committed by Paulus Schoutsen
parent 01953ab46b
commit de9bac9ee3
8 changed files with 431 additions and 66 deletions

View File

@ -5,7 +5,6 @@ For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/binary_sensor.mqtt/
"""
import logging
from typing import Optional
import voluptuous as vol
@ -19,7 +18,8 @@ from homeassistant.const import (
from homeassistant.components.mqtt import (
ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC, CONF_AVAILABILITY_TOPIC,
CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS,
MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo)
MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
subscription)
from homeassistant.components.mqtt.discovery import MQTT_DISCOVERY_NEW
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -79,21 +79,8 @@ async def _async_setup_entity(hass, config, async_add_entities,
value_template.hass = hass
async_add_entities([MqttBinarySensor(
config.get(CONF_NAME),
config.get(CONF_STATE_TOPIC),
config.get(CONF_AVAILABILITY_TOPIC),
config.get(CONF_DEVICE_CLASS),
config.get(CONF_QOS),
config.get(CONF_FORCE_UPDATE),
config.get(CONF_OFF_DELAY),
config.get(CONF_PAYLOAD_ON),
config.get(CONF_PAYLOAD_OFF),
config.get(CONF_PAYLOAD_AVAILABLE),
config.get(CONF_PAYLOAD_NOT_AVAILABLE),
value_template,
config.get(CONF_UNIQUE_ID),
config.get(CONF_DEVICE),
discovery_hash,
config,
discovery_hash
)])
@ -101,35 +88,71 @@ class MqttBinarySensor(MqttAvailability, MqttDiscoveryUpdate,
MqttEntityDeviceInfo, BinarySensorDevice):
"""Representation a binary sensor that is updated by MQTT."""
def __init__(self, name, state_topic, availability_topic, device_class,
qos, force_update, off_delay, payload_on, payload_off,
payload_available, payload_not_available, value_template,
unique_id: Optional[str], device_config: Optional[ConfigType],
discovery_hash):
def __init__(self, config, discovery_hash):
"""Initialize the MQTT binary sensor."""
MqttAvailability.__init__(self, availability_topic, qos,
payload_available, payload_not_available)
MqttDiscoveryUpdate.__init__(self, discovery_hash)
MqttEntityDeviceInfo.__init__(self, device_config)
self._name = name
self._config = config
self._state = None
self._state_topic = state_topic
self._device_class = device_class
self._payload_on = payload_on
self._payload_off = payload_off
self._qos = qos
self._force_update = force_update
self._off_delay = off_delay
self._template = value_template
self._unique_id = unique_id
self._discovery_hash = discovery_hash
self._sub_state = None
self._delay_listener = None
self._name = None
self._state_topic = None
self._device_class = None
self._payload_on = None
self._payload_off = None
self._qos = None
self._force_update = None
self._off_delay = None
self._template = None
self._unique_id = None
# Load config
self._setup_from_config(config)
availability_topic = config.get(CONF_AVAILABILITY_TOPIC)
payload_available = config.get(CONF_PAYLOAD_AVAILABLE)
payload_not_available = config.get(CONF_PAYLOAD_NOT_AVAILABLE)
device_config = config.get(CONF_DEVICE)
MqttAvailability.__init__(self, availability_topic, self._qos,
payload_available, payload_not_available)
MqttDiscoveryUpdate.__init__(self, discovery_hash,
self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config)
async def async_added_to_hass(self):
"""Subscribe mqtt events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):
"""Handle updated discovery message."""
config = PLATFORM_SCHEMA(discovery_payload)
self._setup_from_config(config)
await self.availability_discovery_update(config)
await self._subscribe_topics()
self.async_schedule_update_ha_state()
def _setup_from_config(self, config):
"""(Re)Setup the entity."""
self._name = config.get(CONF_NAME)
self._state_topic = config.get(CONF_STATE_TOPIC)
self._device_class = config.get(CONF_DEVICE_CLASS)
self._qos = config.get(CONF_QOS)
self._force_update = config.get(CONF_FORCE_UPDATE)
self._off_delay = config.get(CONF_OFF_DELAY)
self._payload_on = config.get(CONF_PAYLOAD_ON)
self._payload_off = config.get(CONF_PAYLOAD_OFF)
value_template = config.get(CONF_VALUE_TEMPLATE)
if value_template is not None and value_template.hass is None:
value_template.hass = self.hass
self._template = value_template
self._unique_id = config.get(CONF_UNIQUE_ID)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
def off_delay_listener(now):
"""Switch device off after a delay."""
@ -163,8 +186,16 @@ class MqttBinarySensor(MqttAvailability, MqttDiscoveryUpdate,
self.async_schedule_update_ha_state()
await mqtt.async_subscribe(
self.hass, self._state_topic, state_message_received, self._qos)
self._sub_state = await subscription.async_subscribe_topics(
self.hass, self._sub_state,
{'state_topic': {'topic': self._state_topic,
'msg_callback': state_message_received,
'qos': self._qos}})
async def async_will_remove_from_hass(self):
"""Unsubscribe when removed."""
await subscription.async_unsubscribe_topics(self.hass, self._sub_state)
await MqttAvailability.async_will_remove_from_hass(self)
@property
def should_poll(self):

View File

@ -832,12 +832,30 @@ class MqttAvailability(Entity):
self._available = availability_topic is None # type: bool
self._payload_available = payload_available
self._payload_not_available = payload_not_available
self._availability_sub_state = None
async def async_added_to_hass(self) -> None:
"""Subscribe MQTT events.
This method must be run in the event loop and returns a coroutine.
"""
await self._availability_subscribe_topics()
async def availability_discovery_update(self, config: dict):
"""Handle updated discovery message."""
self._availability_setup_from_config(config)
await self._availability_subscribe_topics()
def _availability_setup_from_config(self, config):
"""(Re)Setup."""
self._availability_topic = config.get(CONF_AVAILABILITY_TOPIC)
self._payload_available = config.get(CONF_PAYLOAD_AVAILABLE)
self._payload_not_available = config.get(CONF_PAYLOAD_NOT_AVAILABLE)
async def _availability_subscribe_topics(self):
"""(Re)Subscribe to topics."""
from .subscription import async_subscribe_topics
@callback
def availability_message_received(topic: str,
payload: SubscribePayloadType,
@ -850,10 +868,17 @@ class MqttAvailability(Entity):
self.async_schedule_update_ha_state()
if self._availability_topic is not None:
await async_subscribe(
self.hass, self._availability_topic,
availability_message_received, self._availability_qos)
self._availability_sub_state = await async_subscribe_topics(
self.hass, self._availability_sub_state,
{'availability_topic': {
'topic': self._availability_topic,
'msg_callback': availability_message_received,
'qos': self._availability_qos}})
async def async_will_remove_from_hass(self):
"""Unsubscribe when removed."""
from .subscription import async_unsubscribe_topics
await async_unsubscribe_topics(self.hass, self._availability_sub_state)
@property
def available(self) -> bool:
@ -864,9 +889,10 @@ class MqttAvailability(Entity):
class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message."""
def __init__(self, discovery_hash) -> None:
def __init__(self, discovery_hash, discovery_update=None) -> None:
"""Initialize the discovery update mixin."""
self._discovery_hash = discovery_hash
self._discovery_update = discovery_update
self._remove_signal = None
async def async_added_to_hass(self) -> None:
@ -886,6 +912,10 @@ class MqttDiscoveryUpdate(Entity):
self.hass.async_create_task(self.async_remove())
del self.hass.data[ALREADY_DISCOVERED][self._discovery_hash]
self._remove_signal()
elif self._discovery_update:
# Non-empty payload: Notify component
_LOGGER.info("Updating component: %s", self.entity_id)
self.hass.async_create_task(self._discovery_update(payload))
if self._discovery_hash:
self._remove_signal = async_dispatcher_connect(

View File

@ -208,15 +208,32 @@ async def async_start(hass: HomeAssistantType, discovery_topic, hass_config,
if value[-1] == TOPIC_BASE and key.endswith('_topic'):
payload[key] = "{}{}".format(value[:-1], base)
# If present, the node_id will be included in the discovered object id
discovery_id = '_'.join((node_id, object_id)) if node_id else object_id
# If present, unique_id is used as the discovered object id. Otherwise,
# if present, the node_id will be included in the discovered object id
discovery_id = payload.get(
'unique_id', ' '.join(
(node_id, object_id)) if node_id else object_id)
discovery_hash = (component, discovery_id)
if payload:
platform = payload.get(CONF_PLATFORM, 'mqtt')
if platform not in ALLOWED_PLATFORMS.get(component, []):
_LOGGER.warning("Platform %s (component %s) is not allowed",
platform, component)
return
payload[CONF_PLATFORM] = platform
if CONF_STATE_TOPIC not in payload:
payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format(
discovery_topic, component,
'%s/' % node_id if node_id else '', object_id)
payload[ATTR_DISCOVERY_HASH] = discovery_hash
if ALREADY_DISCOVERED not in hass.data:
hass.data[ALREADY_DISCOVERED] = {}
discovery_hash = (component, discovery_id)
if discovery_hash in hass.data[ALREADY_DISCOVERED]:
# Dispatch update
_LOGGER.info(
"Component has already been discovered: %s %s, sending update",
component, discovery_id)
@ -224,22 +241,8 @@ async def async_start(hass: HomeAssistantType, discovery_topic, hass_config,
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), payload)
elif payload:
# Add component
platform = payload.get(CONF_PLATFORM, 'mqtt')
if platform not in ALLOWED_PLATFORMS.get(component, []):
_LOGGER.warning("Platform %s (component %s) is not allowed",
platform, component)
return
payload[CONF_PLATFORM] = platform
if CONF_STATE_TOPIC not in payload:
payload[CONF_STATE_TOPIC] = '{}/{}/{}{}/state'.format(
discovery_topic, component,
'%s/' % node_id if node_id else '', object_id)
hass.data[ALREADY_DISCOVERED][discovery_hash] = None
payload[ATTR_DISCOVERY_HASH] = discovery_hash
_LOGGER.info("Found new component: %s %s", component, discovery_id)
hass.data[ALREADY_DISCOVERED][discovery_hash] = None
if platform not in CONFIG_ENTRY_PLATFORMS.get(component, []):
await async_load_platform(

View File

@ -0,0 +1,54 @@
"""
Helper to handle a set of topics to subscribe to.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/
"""
import logging
from homeassistant.components import mqtt
from homeassistant.components.mqtt import DEFAULT_QOS
from homeassistant.loader import bind_hass
from homeassistant.helpers.typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
@bind_hass
async def async_subscribe_topics(hass: HomeAssistantType, sub_state: dict,
topics: dict):
"""(Re)Subscribe to a set of MQTT topics.
State is kept in sub_state.
"""
cur_state = sub_state if sub_state is not None else {}
sub_state = {}
for key in topics:
topic = topics[key].get('topic', None)
msg_callback = topics[key].get('msg_callback', None)
qos = topics[key].get('qos', DEFAULT_QOS)
encoding = topics[key].get('encoding', 'utf-8')
topic = (topic, msg_callback, qos, encoding)
(cur_topic, unsub) = cur_state.pop(
key, ((None, None, None, None), None))
if topic != cur_topic and topic[0] is not None:
if unsub is not None:
unsub()
unsub = await mqtt.async_subscribe(
hass, topic[0], topic[1], topic[2], topic[3])
sub_state[key] = (topic, unsub)
for key, (topic, unsub) in list(cur_state.items()):
if unsub is not None:
unsub()
return sub_state
@bind_hass
async def async_unsubscribe_topics(hass: HomeAssistantType, sub_state: dict):
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
await async_subscribe_topics(hass, sub_state, {})
return sub_state

View File

@ -295,6 +295,7 @@ def async_mock_mqtt_component(hass, config=None):
with patch('paho.mqtt.client.Client') as mock_client:
mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
result = yield from async_setup_component(hass, mqtt.DOMAIN, {

View File

@ -284,7 +284,8 @@ async def test_discovery_removal_binary_sensor(hass, mqtt_mock, caplog):
await async_start(hass, 'homeassistant', {}, entry)
data = (
'{ "name": "Beer",'
' "status_topic": "test_topic" }'
' "state_topic": "test_topic",'
' "availability_topic": "availability_topic" }'
)
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
data)
@ -300,6 +301,71 @@ async def test_discovery_removal_binary_sensor(hass, mqtt_mock, caplog):
assert state is None
async def test_discovery_update_binary_sensor(hass, mqtt_mock, caplog):
"""Test removal of discovered binary_sensor."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, 'homeassistant', {}, entry)
data1 = (
'{ "name": "Beer",'
' "state_topic": "test_topic",'
' "availability_topic": "availability_topic1" }'
)
data2 = (
'{ "name": "Milk",'
' "state_topic": "test_topic2",'
' "availability_topic": "availability_topic2" }'
)
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
data1)
await hass.async_block_till_done()
state = hass.states.get('binary_sensor.beer')
assert state is not None
assert state.name == 'Beer'
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
data2)
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get('binary_sensor.beer')
assert state is not None
assert state.name == 'Milk'
state = hass.states.get('binary_sensor.milk')
assert state is None
async def test_discovery_unique_id(hass, mqtt_mock, caplog):
"""Test unique id option only creates one sensor per unique_id."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, 'homeassistant', {}, entry)
data1 = (
'{ "name": "Beer",'
' "state_topic": "test_topic",'
' "unique_id": "TOTALLY_UNIQUE" }'
)
data2 = (
'{ "name": "Milk",'
' "state_topic": "test_topic",'
' "unique_id": "TOTALLY_DIFFERENT" }'
)
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
data1)
await hass.async_block_till_done()
state = hass.states.get('binary_sensor.beer')
assert state is not None
assert state.name == 'Beer'
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
data2)
await hass.async_block_till_done()
await hass.async_block_till_done()
state = hass.states.get('binary_sensor.beer')
assert state is not None
assert state.name == 'Beer'
state = hass.states.get('binary_sensor.milk')
assert state is not None
assert state.name == 'Milk'
async def test_entity_device_info_with_identifier(hass, mqtt_mock):
"""Test MQTT binary sensor device registry integration."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)

View File

@ -185,7 +185,7 @@ def test_discovery_incl_nodeid(hass, mqtt_mock, caplog):
assert state is not None
assert state.name == 'Beer'
assert ('binary_sensor', 'my_node_id_bla') in hass.data[ALREADY_DISCOVERED]
assert ('binary_sensor', 'my_node_id bla') in hass.data[ALREADY_DISCOVERED]
@asyncio.coroutine

View File

@ -0,0 +1,180 @@
"""The tests for the MQTT subscription component."""
from homeassistant.core import callback
from homeassistant.components.mqtt.subscription import (
async_subscribe_topics, async_unsubscribe_topics)
from tests.common import async_fire_mqtt_message, async_mock_mqtt_component
async def test_subscribe_topics(hass, mqtt_mock, caplog):
"""Test subscription to topics."""
calls1 = []
@callback
def record_calls1(*args):
"""Record calls."""
calls1.append(args)
calls2 = []
@callback
def record_calls2(*args):
"""Record calls."""
calls2.append(args)
sub_state = None
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': record_calls1},
'test_topic2': {'topic': 'test-topic2',
'msg_callback': record_calls2}})
async_fire_mqtt_message(hass, 'test-topic1', 'test-payload1')
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 'test-topic1' == calls1[0][0]
assert 'test-payload1' == calls1[0][1]
assert 0 == len(calls2)
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload2')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 1 == len(calls2)
assert 'test-topic2' == calls2[0][0]
assert 'test-payload2' == calls2[0][1]
await async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, 'test-topic1', 'test-payload')
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload')
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 1 == len(calls2)
async def test_modify_topics(hass, mqtt_mock, caplog):
"""Test modification of topics."""
calls1 = []
@callback
def record_calls1(*args):
"""Record calls."""
calls1.append(args)
calls2 = []
@callback
def record_calls2(*args):
"""Record calls."""
calls2.append(args)
sub_state = None
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': record_calls1},
'test_topic2': {'topic': 'test-topic2',
'msg_callback': record_calls2}})
async_fire_mqtt_message(hass, 'test-topic1', 'test-payload')
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 0 == len(calls2)
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 1 == len(calls2)
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1_1',
'msg_callback': record_calls1}})
async_fire_mqtt_message(hass, 'test-topic1', 'test-payload')
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 1 == len(calls2)
async_fire_mqtt_message(hass, 'test-topic1_1', 'test-payload')
await hass.async_block_till_done()
await hass.async_block_till_done()
assert 2 == len(calls1)
assert 'test-topic1_1' == calls1[1][0]
assert 'test-payload' == calls1[1][1]
assert 1 == len(calls2)
await async_unsubscribe_topics(hass, sub_state)
async_fire_mqtt_message(hass, 'test-topic1_1', 'test-payload')
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload')
await hass.async_block_till_done()
assert 2 == len(calls1)
assert 1 == len(calls2)
async def test_qos_encoding_default(hass, mqtt_mock, caplog):
"""Test default qos and encoding."""
mock_mqtt = await async_mock_mqtt_component(hass)
@callback
def msg_callback(*args):
"""Do nothing."""
pass
sub_state = None
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': msg_callback}})
mock_mqtt.async_subscribe.assert_called_once_with(
'test-topic1', msg_callback, 0, 'utf-8')
async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
"""Test custom qos and encoding."""
mock_mqtt = await async_mock_mqtt_component(hass)
@callback
def msg_callback(*args):
"""Do nothing."""
pass
sub_state = None
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': msg_callback,
'qos': 1,
'encoding': 'utf-16'}})
mock_mqtt.async_subscribe.assert_called_once_with(
'test-topic1', msg_callback, 1, 'utf-16')
async def test_no_change(hass, mqtt_mock, caplog):
"""Test subscription to topics without change."""
mock_mqtt = await async_mock_mqtt_component(hass)
@callback
def msg_callback(*args):
"""Do nothing."""
pass
sub_state = None
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': msg_callback}})
call_count = mock_mqtt.async_subscribe.call_count
sub_state = await async_subscribe_topics(
hass, sub_state,
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': msg_callback}})
assert call_count == mock_mqtt.async_subscribe.call_count