diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index fcfd7f404e9..6566a20814b 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/developers/websocket_api/ """ import asyncio +from contextlib import suppress from functools import partial import json import logging @@ -201,19 +202,20 @@ class WebsocketAPIView(HomeAssistantView): def get(self, request): """Handle an incoming websocket connection.""" # pylint: disable=no-self-use - return ActiveConnection(request.app['hass'], request).handle() + return ActiveConnection(request.app['hass']).handle(request) class ActiveConnection: """Handle an active websocket client connection.""" - def __init__(self, hass, request): + def __init__(self, hass): """Initialize an active connection.""" self.hass = hass - self.request = request self.wsock = None self.event_listeners = {} self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop) + self._handle_task = None + self._writer_task = None def debug(self, message1, message2=''): """Print a debug message.""" @@ -226,42 +228,60 @@ class ActiveConnection: @asyncio.coroutine def _writer(self): """Write outgoing messages.""" - try: - while True: + # Exceptions if Socket disconnected or cancelled by connection handler + with suppress(RuntimeError, asyncio.CancelledError): + while not self.wsock.closed: message = yield from self.to_write.get() if message is None: break self.debug("Sending", message) yield from self.wsock.send_json(message, dumps=JSON_DUMP) - except (RuntimeError, asyncio.CancelledError): - # Socket disconnected or cancelled by connection handler - pass + + @callback + def send_message_outside(self, message): + """Send a message to the client outside of the main task. + + Closes connection if the client is not reading the messages. + + Async friendly. + """ + try: + self.to_write.put_nowait(message) + except asyncio.QueueFull: + self.log_error("Client exceeded max pending messages [2]:", + MAX_PENDING_MSG) + self.cancel() + + @callback + def cancel(self): + """Cancel the connection.""" + self._handle_task.cancel() + self._writer_task.cancel() @asyncio.coroutine - def handle(self): + def handle(self, request): """Handle the websocket connection.""" wsock = self.wsock = web.WebSocketResponse() - yield from wsock.prepare(self.request) - - # Set up to cancel this connection when Home Assistant shuts down - socket_task = asyncio.Task.current_task(loop=self.hass.loop) - - @callback - def cancel_connection(event): - """Cancel this connection.""" - socket_task.cancel() - - unsub_stop = self.hass.bus.async_listen( - EVENT_HOMEASSISTANT_STOP, cancel_connection) - writer_task = self.hass.async_add_job(self._writer()) - final_message = None + yield from wsock.prepare(request) self.debug("Connected") + # Get a reference to current task so we can cancel our connection + self._handle_task = asyncio.Task.current_task(loop=self.hass.loop) + + @callback + def handle_hass_stop(event): + """Cancel this connection.""" + self.cancel() + + unsub_stop = self.hass.bus.async_listen( + EVENT_HOMEASSISTANT_STOP, handle_hass_stop) + self._writer_task = self.hass.async_add_job(self._writer()) + final_message = None msg = None authenticated = False try: - if self.request[KEY_AUTHENTICATED]: + if request[KEY_AUTHENTICATED]: authenticated = True else: @@ -269,7 +289,7 @@ class ActiveConnection: msg = yield from wsock.receive_json() msg = AUTH_MESSAGE_SCHEMA(msg) - if validate_password(self.request, msg['api_password']): + if validate_password(request, msg['api_password']): authenticated = True else: @@ -278,13 +298,14 @@ class ActiveConnection: auth_invalid_message('Invalid password')) if not authenticated: - yield from process_wrong_login(self.request) + yield from process_wrong_login(request) return wsock yield from self.wsock.send_json(auth_ok_message()) - msg = yield from wsock.receive_json() + # ---------- AUTH PHASE OVER ---------- + msg = yield from wsock.receive_json() last_id = 0 while msg: @@ -337,14 +358,15 @@ class ActiveConnection: if value: msg += ': {}'.format(value) self.log_error(msg) + self._writer_task.cancel() except asyncio.CancelledError: self.debug("Connection cancelled by server") except asyncio.QueueFull: - self.log_error("Client exceeded max pending messages:", + self.log_error("Client exceeded max pending messages [1]:", MAX_PENDING_MSG) - writer_task.cancel() + self._writer_task.cancel() except Exception: # pylint: disable=broad-except error = "Unexpected error inside websocket API. " @@ -353,19 +375,19 @@ class ActiveConnection: _LOGGER.exception(error) finally: + unsub_stop() + + for unsub in self.event_listeners.values(): + unsub() + try: if final_message is not None: self.to_write.put_nowait(final_message) self.to_write.put_nowait(None) # Make sure all error messages are written before closing - yield from writer_task + yield from self._writer_task except asyncio.QueueFull: - pass - - unsub_stop() - - for unsub in self.event_listeners.values(): - unsub() + self._writer_task.cancel() yield from wsock.close() self.debug("Closed connection") @@ -385,7 +407,7 @@ class ActiveConnection: if event.event_type == EVENT_TIME_CHANGED: return - self.to_write.put_nowait(event_message(msg['id'], event)) + self.send_message_outside(event_message(msg['id'], event)) self.event_listeners[msg['id']] = self.hass.bus.async_listen( msg['event_type'], forward_events) @@ -421,7 +443,7 @@ class ActiveConnection: """Call a service and fire complete message.""" yield from self.hass.services.async_call( msg['domain'], msg['service'], msg['service_data'], True) - self.to_write.put_nowait(result_message(msg['id'])) + self.send_message_outside(result_message(msg['id'])) self.hass.async_add_job(call_service_helper(msg)) diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py index 658a5e0be53..9ca429f6f52 100644 --- a/tests/components/test_websocket_api.py +++ b/tests/components/test_websocket_api.py @@ -50,6 +50,13 @@ def no_auth_websocket_client(hass, loop, test_client): loop.run_until_complete(ws.close()) +@pytest.fixture +def mock_low_queue(): + """Mock a low queue.""" + with patch.object(wapi, 'MAX_PENDING_MSG', 5): + yield + + @asyncio.coroutine def test_auth_via_msg(no_auth_websocket_client): """Test authenticating.""" @@ -304,3 +311,15 @@ def test_ping(websocket_client): msg = yield from websocket_client.receive_json() assert msg['id'] == 5 assert msg['type'] == wapi.TYPE_PONG + + +@asyncio.coroutine +def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): + """Test get_panels command.""" + for idx in range(10): + websocket_client.send_json({ + 'id': idx + 1, + 'type': wapi.TYPE_PING, + }) + msg = yield from websocket_client.receive() + assert msg.type == WSMsgType.close