Add WS subscription command for MQTT (#21696)

* Add WS subscription command for MQTT

* Add test

* Add check for connected

* Rename event_listeners to subscriptions
pull/21922/head
Paulus Schoutsen 2019-03-10 20:07:09 -07:00 committed by GitHub
parent fc85b3fc5f
commit 429bbc05dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 17 deletions

View File

@ -25,7 +25,7 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE,
EVENT_HOMEASSISTANT_STOP)
from homeassistant.core import Event, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import (
@ -35,6 +35,7 @@ from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.async_ import (
run_callback_threadsafe, run_coroutine_threadsafe)
from homeassistant.util.logging import catch_log_exception
from homeassistant.components import websocket_api
# Loading the config flow file will register the flow
from . import config_flow # noqa pylint: disable=unused-import
@ -391,6 +392,8 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
# This needs a better solution.
hass.data[DATA_MQTT_HASS_CONFIG] = config
websocket_api.async_register_command(hass, websocket_subscribe)
if conf is None:
# If we have a config entry, setup is done by that config entry.
# If there is no config entry, this should fail.
@ -602,6 +605,7 @@ class MQTT:
self.keepalive = keepalive
self.subscriptions = [] # type: List[Subscription]
self.birth_message = birth_message
self.connected = False
self._mqttc = None # type: mqtt.Client
self._paho_lock = asyncio.Lock(loop=hass.loop)
@ -703,7 +707,10 @@ class MQTT:
if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
self.hass.async_create_task(self._async_unsubscribe(topic))
# Only unsubscribe if currently connected.
if self.connected:
self.hass.async_create_task(self._async_unsubscribe(topic))
return async_remove
@ -743,6 +750,8 @@ class MQTT:
self._mqttc.disconnect()
return
self.connected = True
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter('topic')
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
@ -782,6 +791,8 @@ class MQTT:
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
self.connected = False
# When disconnected because of calling disconnect()
if result_code == 0:
return
@ -791,6 +802,7 @@ class MQTT:
while True:
try:
if self._mqttc.reconnect() == 0:
self.connected = True
_LOGGER.info("Successfully reconnected to the MQTT server")
break
except socket.error:
@ -1040,3 +1052,27 @@ class MqttEntityDeviceInfo(Entity):
info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB])
return info
@websocket_api.async_response
@websocket_api.websocket_command({
vol.Required('type'): 'mqtt/subscribe',
vol.Required('topic'): valid_subscribe_topic,
})
async def websocket_subscribe(hass, connection, msg):
"""Subscribe to a MQTT topic."""
if not connection.user.is_admin:
raise Unauthorized
async def forward_messages(topic: str, payload: str, qos: int):
"""Forward events to websocket."""
connection.send_message(websocket_api.event_message(msg['id'], {
'topic': topic,
'payload': payload,
'qos': qos,
}))
connection.subscriptions[msg['id']] = await async_subscribe(
hass, msg['topic'], forward_messages)
connection.send_message(websocket_api.result_message(msg['id']))

View File

@ -14,6 +14,7 @@ ActiveConnection = connection.ActiveConnection
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
error_message = messages.error_message
result_message = messages.result_message
event_message = messages.event_message
async_response = decorators.async_response
require_admin = decorators.require_admin
ws_require_user = decorators.ws_require_user

View File

@ -24,15 +24,6 @@ def async_register_commands(hass):
async_reg(handle_ping)
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': 'event',
'event': event.as_dict(),
}
def pong_message(iden):
"""Return a pong message."""
return {
@ -59,9 +50,11 @@ def handle_subscribe_events(hass, connection, msg):
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message(event_message(msg['id'], event))
connection.send_message(messages.event_message(
msg['id'], event.as_dict()
))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
connection.subscriptions[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.send_message(messages.result_message(msg['id']))
@ -79,8 +72,8 @@ def handle_unsubscribe_events(hass, connection, msg):
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
if subscription in connection.subscriptions:
connection.subscriptions.pop(subscription)()
connection.send_message(messages.result_message(msg['id']))
else:
connection.send_message(messages.error_message(

View File

@ -21,7 +21,7 @@ class ActiveConnection:
else:
self.refresh_token_id = None
self.event_listeners = {}
self.subscriptions = {}
self.last_id = 0
def context(self, msg):
@ -82,7 +82,7 @@ class ActiveConnection:
@callback
def async_close(self):
"""Close down connection."""
for unsub in self.event_listeners.values():
for unsub in self.subscriptions.values():
unsub()
@callback

View File

@ -40,3 +40,12 @@ def error_message(iden, code, message):
'message': message,
},
}
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': 'event',
'event': event,
}

View File

@ -767,3 +767,37 @@ async def test_message_callback_exception_gets_logged(hass, caplog):
assert \
"Exception in bad_handler when handling msg on 'test-topic':" \
" 'test'" in caplog.text
async def test_mqtt_ws_subscription(hass, hass_ws_client):
"""Test MQTT websocket subscription."""
await async_mock_mqtt_component(hass)
client = await hass_ws_client(hass)
await client.send_json({
'id': 5,
'type': 'mqtt/subscribe',
'topic': 'test-topic',
})
response = await client.receive_json()
assert response['success']
async_fire_mqtt_message(hass, 'test-topic', 'test1')
async_fire_mqtt_message(hass, 'test-topic', 'test2')
response = await client.receive_json()
assert response['event']['topic'] == 'test-topic'
assert response['event']['payload'] == 'test1'
response = await client.receive_json()
assert response['event']['topic'] == 'test-topic'
assert response['event']['payload'] == 'test2'
# Unsubscribe
await client.send_json({
'id': 8,
'type': 'unsubscribe_events',
'subscription': 5,
})
response = await client.receive_json()
assert response['success']