Add WS subscription command for MQTT (#21696)
* Add WS subscription command for MQTT * Add test * Add check for connected * Rename event_listeners to subscriptionspull/21922/head
parent
fc85b3fc5f
commit
429bbc05dc
|
@ -25,7 +25,7 @@ from homeassistant.const import (
|
|||
CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE,
|
||||
EVENT_HOMEASSISTANT_STOP)
|
||||
from homeassistant.core import Event, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
from homeassistant.helpers import config_validation as cv, template
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.typing import (
|
||||
|
@ -35,6 +35,7 @@ from homeassistant.setup import async_prepare_setup_platform
|
|||
from homeassistant.util.async_ import (
|
||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||
from homeassistant.util.logging import catch_log_exception
|
||||
from homeassistant.components import websocket_api
|
||||
|
||||
# Loading the config flow file will register the flow
|
||||
from . import config_flow # noqa pylint: disable=unused-import
|
||||
|
@ -391,6 +392,8 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
|
|||
# This needs a better solution.
|
||||
hass.data[DATA_MQTT_HASS_CONFIG] = config
|
||||
|
||||
websocket_api.async_register_command(hass, websocket_subscribe)
|
||||
|
||||
if conf is None:
|
||||
# If we have a config entry, setup is done by that config entry.
|
||||
# If there is no config entry, this should fail.
|
||||
|
@ -602,6 +605,7 @@ class MQTT:
|
|||
self.keepalive = keepalive
|
||||
self.subscriptions = [] # type: List[Subscription]
|
||||
self.birth_message = birth_message
|
||||
self.connected = False
|
||||
self._mqttc = None # type: mqtt.Client
|
||||
self._paho_lock = asyncio.Lock(loop=hass.loop)
|
||||
|
||||
|
@ -703,7 +707,10 @@ class MQTT:
|
|||
if any(other.topic == topic for other in self.subscriptions):
|
||||
# Other subscriptions on topic remaining - don't unsubscribe.
|
||||
return
|
||||
self.hass.async_create_task(self._async_unsubscribe(topic))
|
||||
|
||||
# Only unsubscribe if currently connected.
|
||||
if self.connected:
|
||||
self.hass.async_create_task(self._async_unsubscribe(topic))
|
||||
|
||||
return async_remove
|
||||
|
||||
|
@ -743,6 +750,8 @@ class MQTT:
|
|||
self._mqttc.disconnect()
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
|
||||
# Group subscriptions to only re-subscribe once for each topic.
|
||||
keyfunc = attrgetter('topic')
|
||||
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
|
||||
|
@ -782,6 +791,8 @@ class MQTT:
|
|||
|
||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||
"""Disconnected callback."""
|
||||
self.connected = False
|
||||
|
||||
# When disconnected because of calling disconnect()
|
||||
if result_code == 0:
|
||||
return
|
||||
|
@ -791,6 +802,7 @@ class MQTT:
|
|||
while True:
|
||||
try:
|
||||
if self._mqttc.reconnect() == 0:
|
||||
self.connected = True
|
||||
_LOGGER.info("Successfully reconnected to the MQTT server")
|
||||
break
|
||||
except socket.error:
|
||||
|
@ -1040,3 +1052,27 @@ class MqttEntityDeviceInfo(Entity):
|
|||
info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB])
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required('type'): 'mqtt/subscribe',
|
||||
vol.Required('topic'): valid_subscribe_topic,
|
||||
})
|
||||
async def websocket_subscribe(hass, connection, msg):
|
||||
"""Subscribe to a MQTT topic."""
|
||||
if not connection.user.is_admin:
|
||||
raise Unauthorized
|
||||
|
||||
async def forward_messages(topic: str, payload: str, qos: int):
|
||||
"""Forward events to websocket."""
|
||||
connection.send_message(websocket_api.event_message(msg['id'], {
|
||||
'topic': topic,
|
||||
'payload': payload,
|
||||
'qos': qos,
|
||||
}))
|
||||
|
||||
connection.subscriptions[msg['id']] = await async_subscribe(
|
||||
hass, msg['topic'], forward_messages)
|
||||
|
||||
connection.send_message(websocket_api.result_message(msg['id']))
|
||||
|
|
|
@ -14,6 +14,7 @@ ActiveConnection = connection.ActiveConnection
|
|||
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
||||
error_message = messages.error_message
|
||||
result_message = messages.result_message
|
||||
event_message = messages.event_message
|
||||
async_response = decorators.async_response
|
||||
require_admin = decorators.require_admin
|
||||
ws_require_user = decorators.ws_require_user
|
||||
|
|
|
@ -24,15 +24,6 @@ def async_register_commands(hass):
|
|||
async_reg(handle_ping)
|
||||
|
||||
|
||||
def event_message(iden, event):
|
||||
"""Return an event message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': 'event',
|
||||
'event': event.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
def pong_message(iden):
|
||||
"""Return a pong message."""
|
||||
return {
|
||||
|
@ -59,9 +50,11 @@ def handle_subscribe_events(hass, connection, msg):
|
|||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
connection.send_message(event_message(msg['id'], event))
|
||||
connection.send_message(messages.event_message(
|
||||
msg['id'], event.as_dict()
|
||||
))
|
||||
|
||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
||||
connection.subscriptions[msg['id']] = hass.bus.async_listen(
|
||||
msg['event_type'], forward_events)
|
||||
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
|
@ -79,8 +72,8 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||
"""
|
||||
subscription = msg['subscription']
|
||||
|
||||
if subscription in connection.event_listeners:
|
||||
connection.event_listeners.pop(subscription)()
|
||||
if subscription in connection.subscriptions:
|
||||
connection.subscriptions.pop(subscription)()
|
||||
connection.send_message(messages.result_message(msg['id']))
|
||||
else:
|
||||
connection.send_message(messages.error_message(
|
||||
|
|
|
@ -21,7 +21,7 @@ class ActiveConnection:
|
|||
else:
|
||||
self.refresh_token_id = None
|
||||
|
||||
self.event_listeners = {}
|
||||
self.subscriptions = {}
|
||||
self.last_id = 0
|
||||
|
||||
def context(self, msg):
|
||||
|
@ -82,7 +82,7 @@ class ActiveConnection:
|
|||
@callback
|
||||
def async_close(self):
|
||||
"""Close down connection."""
|
||||
for unsub in self.event_listeners.values():
|
||||
for unsub in self.subscriptions.values():
|
||||
unsub()
|
||||
|
||||
@callback
|
||||
|
|
|
@ -40,3 +40,12 @@ def error_message(iden, code, message):
|
|||
'message': message,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def event_message(iden, event):
|
||||
"""Return an event message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': 'event',
|
||||
'event': event,
|
||||
}
|
||||
|
|
|
@ -767,3 +767,37 @@ async def test_message_callback_exception_gets_logged(hass, caplog):
|
|||
assert \
|
||||
"Exception in bad_handler when handling msg on 'test-topic':" \
|
||||
" 'test'" in caplog.text
|
||||
|
||||
|
||||
async def test_mqtt_ws_subscription(hass, hass_ws_client):
|
||||
"""Test MQTT websocket subscription."""
|
||||
await async_mock_mqtt_component(hass)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': 'mqtt/subscribe',
|
||||
'topic': 'test-topic',
|
||||
})
|
||||
response = await client.receive_json()
|
||||
assert response['success']
|
||||
|
||||
async_fire_mqtt_message(hass, 'test-topic', 'test1')
|
||||
async_fire_mqtt_message(hass, 'test-topic', 'test2')
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response['event']['topic'] == 'test-topic'
|
||||
assert response['event']['payload'] == 'test1'
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response['event']['topic'] == 'test-topic'
|
||||
assert response['event']['payload'] == 'test2'
|
||||
|
||||
# Unsubscribe
|
||||
await client.send_json({
|
||||
'id': 8,
|
||||
'type': 'unsubscribe_events',
|
||||
'subscription': 5,
|
||||
})
|
||||
response = await client.receive_json()
|
||||
assert response['success']
|
||||
|
|
Loading…
Reference in New Issue