diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 18ebb004209..d603b6637b0 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -25,7 +25,7 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_USERNAME, CONF_VALUE_TEMPLATE, EVENT_HOMEASSISTANT_STOP) from homeassistant.core import Event, ServiceCall, callback -from homeassistant.exceptions import HomeAssistantError +from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers.entity import Entity from homeassistant.helpers.typing import ( @@ -35,6 +35,7 @@ from homeassistant.setup import async_prepare_setup_platform from homeassistant.util.async_ import ( run_callback_threadsafe, run_coroutine_threadsafe) from homeassistant.util.logging import catch_log_exception +from homeassistant.components import websocket_api # Loading the config flow file will register the flow from . import config_flow # noqa pylint: disable=unused-import @@ -391,6 +392,8 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: # This needs a better solution. hass.data[DATA_MQTT_HASS_CONFIG] = config + websocket_api.async_register_command(hass, websocket_subscribe) + if conf is None: # If we have a config entry, setup is done by that config entry. # If there is no config entry, this should fail. @@ -602,6 +605,7 @@ class MQTT: self.keepalive = keepalive self.subscriptions = [] # type: List[Subscription] self.birth_message = birth_message + self.connected = False self._mqttc = None # type: mqtt.Client self._paho_lock = asyncio.Lock(loop=hass.loop) @@ -703,7 +707,10 @@ class MQTT: if any(other.topic == topic for other in self.subscriptions): # Other subscriptions on topic remaining - don't unsubscribe. return - self.hass.async_create_task(self._async_unsubscribe(topic)) + + # Only unsubscribe if currently connected. + if self.connected: + self.hass.async_create_task(self._async_unsubscribe(topic)) return async_remove @@ -743,6 +750,8 @@ class MQTT: self._mqttc.disconnect() return + self.connected = True + # Group subscriptions to only re-subscribe once for each topic. keyfunc = attrgetter('topic') for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), @@ -782,6 +791,8 @@ class MQTT: def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" + self.connected = False + # When disconnected because of calling disconnect() if result_code == 0: return @@ -791,6 +802,7 @@ class MQTT: while True: try: if self._mqttc.reconnect() == 0: + self.connected = True _LOGGER.info("Successfully reconnected to the MQTT server") break except socket.error: @@ -1040,3 +1052,27 @@ class MqttEntityDeviceInfo(Entity): info['via_hub'] = (DOMAIN, self._device_config[CONF_VIA_HUB]) return info + + +@websocket_api.async_response +@websocket_api.websocket_command({ + vol.Required('type'): 'mqtt/subscribe', + vol.Required('topic'): valid_subscribe_topic, +}) +async def websocket_subscribe(hass, connection, msg): + """Subscribe to a MQTT topic.""" + if not connection.user.is_admin: + raise Unauthorized + + async def forward_messages(topic: str, payload: str, qos: int): + """Forward events to websocket.""" + connection.send_message(websocket_api.event_message(msg['id'], { + 'topic': topic, + 'payload': payload, + 'qos': qos, + })) + + connection.subscriptions[msg['id']] = await async_subscribe( + hass, msg['topic'], forward_messages) + + connection.send_message(websocket_api.result_message(msg['id'])) diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index 3734f46abb7..6c4935b9d95 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -14,6 +14,7 @@ ActiveConnection = connection.ActiveConnection BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA error_message = messages.error_message result_message = messages.result_message +event_message = messages.event_message async_response = decorators.async_response require_admin = decorators.require_admin ws_require_user = decorators.ws_require_user diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index b367e3392ed..b64fac0ed51 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -24,15 +24,6 @@ def async_register_commands(hass): async_reg(handle_ping) -def event_message(iden, event): - """Return an event message.""" - return { - 'id': iden, - 'type': 'event', - 'event': event.as_dict(), - } - - def pong_message(iden): """Return a pong message.""" return { @@ -59,9 +50,11 @@ def handle_subscribe_events(hass, connection, msg): if event.event_type == EVENT_TIME_CHANGED: return - connection.send_message(event_message(msg['id'], event)) + connection.send_message(messages.event_message( + msg['id'], event.as_dict() + )) - connection.event_listeners[msg['id']] = hass.bus.async_listen( + connection.subscriptions[msg['id']] = hass.bus.async_listen( msg['event_type'], forward_events) connection.send_message(messages.result_message(msg['id'])) @@ -79,8 +72,8 @@ def handle_unsubscribe_events(hass, connection, msg): """ subscription = msg['subscription'] - if subscription in connection.event_listeners: - connection.event_listeners.pop(subscription)() + if subscription in connection.subscriptions: + connection.subscriptions.pop(subscription)() connection.send_message(messages.result_message(msg['id'])) else: connection.send_message(messages.error_message( diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 041aad3969e..d65ba4c54d8 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -21,7 +21,7 @@ class ActiveConnection: else: self.refresh_token_id = None - self.event_listeners = {} + self.subscriptions = {} self.last_id = 0 def context(self, msg): @@ -82,7 +82,7 @@ class ActiveConnection: @callback def async_close(self): """Close down connection.""" - for unsub in self.event_listeners.values(): + for unsub in self.subscriptions.values(): unsub() @callback diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index d616b6ad670..c0f899d279e 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -40,3 +40,12 @@ def error_message(iden, code, message): 'message': message, }, } + + +def event_message(iden, event): + """Return an event message.""" + return { + 'id': iden, + 'type': 'event', + 'event': event, + } diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 94506efa909..81941173d68 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -767,3 +767,37 @@ async def test_message_callback_exception_gets_logged(hass, caplog): assert \ "Exception in bad_handler when handling msg on 'test-topic':" \ " 'test'" in caplog.text + + +async def test_mqtt_ws_subscription(hass, hass_ws_client): + """Test MQTT websocket subscription.""" + await async_mock_mqtt_component(hass) + + client = await hass_ws_client(hass) + await client.send_json({ + 'id': 5, + 'type': 'mqtt/subscribe', + 'topic': 'test-topic', + }) + response = await client.receive_json() + assert response['success'] + + async_fire_mqtt_message(hass, 'test-topic', 'test1') + async_fire_mqtt_message(hass, 'test-topic', 'test2') + + response = await client.receive_json() + assert response['event']['topic'] == 'test-topic' + assert response['event']['payload'] == 'test1' + + response = await client.receive_json() + assert response['event']['topic'] == 'test-topic' + assert response['event']['payload'] == 'test2' + + # Unsubscribe + await client.send_json({ + 'id': 8, + 'type': 'unsubscribe_events', + 'subscription': 5, + }) + response = await client.receive_json() + assert response['success']