Fix websocket api reaching queue (#7590)
* Fix websocket api reaching queue * Fix outside task message sending * Fix Py34 testspull/7612/head
parent
6d245c43fc
commit
36d7fe72eb
|
@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at
|
|||
https://home-assistant.io/developers/websocket_api/
|
||||
"""
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
|
@ -201,19 +202,20 @@ class WebsocketAPIView(HomeAssistantView):
|
|||
def get(self, request):
|
||||
"""Handle an incoming websocket connection."""
|
||||
# pylint: disable=no-self-use
|
||||
return ActiveConnection(request.app['hass'], request).handle()
|
||||
return ActiveConnection(request.app['hass']).handle(request)
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, hass, request):
|
||||
def __init__(self, hass):
|
||||
"""Initialize an active connection."""
|
||||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock = None
|
||||
self.event_listeners = {}
|
||||
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
|
||||
def debug(self, message1, message2=''):
|
||||
"""Print a debug message."""
|
||||
|
@ -226,42 +228,60 @@ class ActiveConnection:
|
|||
@asyncio.coroutine
|
||||
def _writer(self):
|
||||
"""Write outgoing messages."""
|
||||
try:
|
||||
while True:
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, asyncio.CancelledError):
|
||||
while not self.wsock.closed:
|
||||
message = yield from self.to_write.get()
|
||||
if message is None:
|
||||
break
|
||||
self.debug("Sending", message)
|
||||
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||
except (RuntimeError, asyncio.CancelledError):
|
||||
# Socket disconnected or cancelled by connection handler
|
||||
pass
|
||||
|
||||
@callback
|
||||
def send_message_outside(self, message):
|
||||
"""Send a message to the client outside of the main task.
|
||||
|
||||
Closes connection if the client is not reading the messages.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
try:
|
||||
self.to_write.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
self.log_error("Client exceeded max pending messages [2]:",
|
||||
MAX_PENDING_MSG)
|
||||
self.cancel()
|
||||
|
||||
@callback
|
||||
def cancel(self):
|
||||
"""Cancel the connection."""
|
||||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle(self):
|
||||
def handle(self, request):
|
||||
"""Handle the websocket connection."""
|
||||
wsock = self.wsock = web.WebSocketResponse()
|
||||
yield from wsock.prepare(self.request)
|
||||
|
||||
# Set up to cancel this connection when Home Assistant shuts down
|
||||
socket_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||
|
||||
@callback
|
||||
def cancel_connection(event):
|
||||
"""Cancel this connection."""
|
||||
socket_task.cancel()
|
||||
|
||||
unsub_stop = self.hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, cancel_connection)
|
||||
writer_task = self.hass.async_add_job(self._writer())
|
||||
final_message = None
|
||||
yield from wsock.prepare(request)
|
||||
self.debug("Connected")
|
||||
|
||||
# Get a reference to current task so we can cancel our connection
|
||||
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event):
|
||||
"""Cancel this connection."""
|
||||
self.cancel()
|
||||
|
||||
unsub_stop = self.hass.bus.async_listen(
|
||||
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||
self._writer_task = self.hass.async_add_job(self._writer())
|
||||
final_message = None
|
||||
msg = None
|
||||
authenticated = False
|
||||
|
||||
try:
|
||||
if self.request[KEY_AUTHENTICATED]:
|
||||
if request[KEY_AUTHENTICATED]:
|
||||
authenticated = True
|
||||
|
||||
else:
|
||||
|
@ -269,7 +289,7 @@ class ActiveConnection:
|
|||
msg = yield from wsock.receive_json()
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
|
||||
if validate_password(self.request, msg['api_password']):
|
||||
if validate_password(request, msg['api_password']):
|
||||
authenticated = True
|
||||
|
||||
else:
|
||||
|
@ -278,13 +298,14 @@ class ActiveConnection:
|
|||
auth_invalid_message('Invalid password'))
|
||||
|
||||
if not authenticated:
|
||||
yield from process_wrong_login(self.request)
|
||||
yield from process_wrong_login(request)
|
||||
return wsock
|
||||
|
||||
yield from self.wsock.send_json(auth_ok_message())
|
||||
|
||||
msg = yield from wsock.receive_json()
|
||||
# ---------- AUTH PHASE OVER ----------
|
||||
|
||||
msg = yield from wsock.receive_json()
|
||||
last_id = 0
|
||||
|
||||
while msg:
|
||||
|
@ -337,14 +358,15 @@ class ActiveConnection:
|
|||
if value:
|
||||
msg += ': {}'.format(value)
|
||||
self.log_error(msg)
|
||||
self._writer_task.cancel()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.debug("Connection cancelled by server")
|
||||
|
||||
except asyncio.QueueFull:
|
||||
self.log_error("Client exceeded max pending messages:",
|
||||
self.log_error("Client exceeded max pending messages [1]:",
|
||||
MAX_PENDING_MSG)
|
||||
writer_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
except Exception: # pylint: disable=broad-except
|
||||
error = "Unexpected error inside websocket API. "
|
||||
|
@ -353,19 +375,19 @@ class ActiveConnection:
|
|||
_LOGGER.exception(error)
|
||||
|
||||
finally:
|
||||
unsub_stop()
|
||||
|
||||
for unsub in self.event_listeners.values():
|
||||
unsub()
|
||||
|
||||
try:
|
||||
if final_message is not None:
|
||||
self.to_write.put_nowait(final_message)
|
||||
self.to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
yield from writer_task
|
||||
yield from self._writer_task
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
unsub_stop()
|
||||
|
||||
for unsub in self.event_listeners.values():
|
||||
unsub()
|
||||
self._writer_task.cancel()
|
||||
|
||||
yield from wsock.close()
|
||||
self.debug("Closed connection")
|
||||
|
@ -385,7 +407,7 @@ class ActiveConnection:
|
|||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
self.to_write.put_nowait(event_message(msg['id'], event))
|
||||
self.send_message_outside(event_message(msg['id'], event))
|
||||
|
||||
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
|
||||
msg['event_type'], forward_events)
|
||||
|
@ -421,7 +443,7 @@ class ActiveConnection:
|
|||
"""Call a service and fire complete message."""
|
||||
yield from self.hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg['service_data'], True)
|
||||
self.to_write.put_nowait(result_message(msg['id']))
|
||||
self.send_message_outside(result_message(msg['id']))
|
||||
|
||||
self.hass.async_add_job(call_service_helper(msg))
|
||||
|
||||
|
|
|
@ -50,6 +50,13 @@ def no_auth_websocket_client(hass, loop, test_client):
|
|||
loop.run_until_complete(ws.close())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
||||
yield
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_auth_via_msg(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
|
@ -304,3 +311,15 @@ def test_ping(websocket_client):
|
|||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_PONG
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||
"""Test get_panels command."""
|
||||
for idx in range(10):
|
||||
websocket_client.send_json({
|
||||
'id': idx + 1,
|
||||
'type': wapi.TYPE_PING,
|
||||
})
|
||||
msg = yield from websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
|
Loading…
Reference in New Issue