Fix websocket api reaching queue (#7590)

* Fix websocket api reaching queue

* Fix outside task message sending

* Fix Py34 tests
pull/7612/head
Paulus Schoutsen 2017-05-15 00:10:06 -07:00 committed by GitHub
parent 6d245c43fc
commit 36d7fe72eb
2 changed files with 79 additions and 38 deletions

View File

@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/developers/websocket_api/
"""
import asyncio
from contextlib import suppress
from functools import partial
import json
import logging
@ -201,19 +202,20 @@ class WebsocketAPIView(HomeAssistantView):
def get(self, request):
"""Handle an incoming websocket connection."""
# pylint: disable=no-self-use
return ActiveConnection(request.app['hass'], request).handle()
return ActiveConnection(request.app['hass']).handle(request)
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
def __init__(self, hass):
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = None
self.event_listeners = {}
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
self._handle_task = None
self._writer_task = None
def debug(self, message1, message2=''):
"""Print a debug message."""
@ -226,42 +228,60 @@ class ActiveConnection:
@asyncio.coroutine
def _writer(self):
"""Write outgoing messages."""
try:
while True:
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, asyncio.CancelledError):
while not self.wsock.closed:
message = yield from self.to_write.get()
if message is None:
break
self.debug("Sending", message)
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
except (RuntimeError, asyncio.CancelledError):
# Socket disconnected or cancelled by connection handler
pass
@callback
def send_message_outside(self, message):
"""Send a message to the client outside of the main task.
Closes connection if the client is not reading the messages.
Async friendly.
"""
try:
self.to_write.put_nowait(message)
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages [2]:",
MAX_PENDING_MSG)
self.cancel()
@callback
def cancel(self):
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
@asyncio.coroutine
def handle(self):
def handle(self, request):
"""Handle the websocket connection."""
wsock = self.wsock = web.WebSocketResponse()
yield from wsock.prepare(self.request)
# Set up to cancel this connection when Home Assistant shuts down
socket_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def cancel_connection(event):
"""Cancel this connection."""
socket_task.cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, cancel_connection)
writer_task = self.hass.async_add_job(self._writer())
final_message = None
yield from wsock.prepare(request)
self.debug("Connected")
# Get a reference to current task so we can cancel our connection
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
@callback
def handle_hass_stop(event):
"""Cancel this connection."""
self.cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
self._writer_task = self.hass.async_add_job(self._writer())
final_message = None
msg = None
authenticated = False
try:
if self.request[KEY_AUTHENTICATED]:
if request[KEY_AUTHENTICATED]:
authenticated = True
else:
@ -269,7 +289,7 @@ class ActiveConnection:
msg = yield from wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if validate_password(self.request, msg['api_password']):
if validate_password(request, msg['api_password']):
authenticated = True
else:
@ -278,13 +298,14 @@ class ActiveConnection:
auth_invalid_message('Invalid password'))
if not authenticated:
yield from process_wrong_login(self.request)
yield from process_wrong_login(request)
return wsock
yield from self.wsock.send_json(auth_ok_message())
msg = yield from wsock.receive_json()
# ---------- AUTH PHASE OVER ----------
msg = yield from wsock.receive_json()
last_id = 0
while msg:
@ -337,14 +358,15 @@ class ActiveConnection:
if value:
msg += ': {}'.format(value)
self.log_error(msg)
self._writer_task.cancel()
except asyncio.CancelledError:
self.debug("Connection cancelled by server")
except asyncio.QueueFull:
self.log_error("Client exceeded max pending messages:",
self.log_error("Client exceeded max pending messages [1]:",
MAX_PENDING_MSG)
writer_task.cancel()
self._writer_task.cancel()
except Exception: # pylint: disable=broad-except
error = "Unexpected error inside websocket API. "
@ -353,19 +375,19 @@ class ActiveConnection:
_LOGGER.exception(error)
finally:
unsub_stop()
for unsub in self.event_listeners.values():
unsub()
try:
if final_message is not None:
self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None)
# Make sure all error messages are written before closing
yield from writer_task
yield from self._writer_task
except asyncio.QueueFull:
pass
unsub_stop()
for unsub in self.event_listeners.values():
unsub()
self._writer_task.cancel()
yield from wsock.close()
self.debug("Closed connection")
@ -385,7 +407,7 @@ class ActiveConnection:
if event.event_type == EVENT_TIME_CHANGED:
return
self.to_write.put_nowait(event_message(msg['id'], event))
self.send_message_outside(event_message(msg['id'], event))
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
msg['event_type'], forward_events)
@ -421,7 +443,7 @@ class ActiveConnection:
"""Call a service and fire complete message."""
yield from self.hass.services.async_call(
msg['domain'], msg['service'], msg['service_data'], True)
self.to_write.put_nowait(result_message(msg['id']))
self.send_message_outside(result_message(msg['id']))
self.hass.async_add_job(call_service_helper(msg))

View File

@ -50,6 +50,13 @@ def no_auth_websocket_client(hass, loop, test_client):
loop.run_until_complete(ws.close())
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
yield
@asyncio.coroutine
def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating."""
@ -304,3 +311,15 @@ def test_ping(websocket_client):
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_PONG
@asyncio.coroutine
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
websocket_client.send_json({
'id': idx + 1,
'type': wapi.TYPE_PING,
})
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close