Fix websocket api reaching queue (#7590)
* Fix websocket api reaching queue * Fix outside task message sending * Fix Py34 testspull/7612/head
parent
6d245c43fc
commit
36d7fe72eb
|
@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/developers/websocket_api/
|
https://home-assistant.io/developers/websocket_api/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -201,19 +202,20 @@ class WebsocketAPIView(HomeAssistantView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
"""Handle an incoming websocket connection."""
|
"""Handle an incoming websocket connection."""
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
return ActiveConnection(request.app['hass'], request).handle()
|
return ActiveConnection(request.app['hass']).handle(request)
|
||||||
|
|
||||||
|
|
||||||
class ActiveConnection:
|
class ActiveConnection:
|
||||||
"""Handle an active websocket client connection."""
|
"""Handle an active websocket client connection."""
|
||||||
|
|
||||||
def __init__(self, hass, request):
|
def __init__(self, hass):
|
||||||
"""Initialize an active connection."""
|
"""Initialize an active connection."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.request = request
|
|
||||||
self.wsock = None
|
self.wsock = None
|
||||||
self.event_listeners = {}
|
self.event_listeners = {}
|
||||||
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
|
||||||
|
self._handle_task = None
|
||||||
|
self._writer_task = None
|
||||||
|
|
||||||
def debug(self, message1, message2=''):
|
def debug(self, message1, message2=''):
|
||||||
"""Print a debug message."""
|
"""Print a debug message."""
|
||||||
|
@ -226,42 +228,60 @@ class ActiveConnection:
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _writer(self):
|
def _writer(self):
|
||||||
"""Write outgoing messages."""
|
"""Write outgoing messages."""
|
||||||
try:
|
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||||
while True:
|
with suppress(RuntimeError, asyncio.CancelledError):
|
||||||
|
while not self.wsock.closed:
|
||||||
message = yield from self.to_write.get()
|
message = yield from self.to_write.get()
|
||||||
if message is None:
|
if message is None:
|
||||||
break
|
break
|
||||||
self.debug("Sending", message)
|
self.debug("Sending", message)
|
||||||
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
|
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||||
except (RuntimeError, asyncio.CancelledError):
|
|
||||||
# Socket disconnected or cancelled by connection handler
|
@callback
|
||||||
pass
|
def send_message_outside(self, message):
|
||||||
|
"""Send a message to the client outside of the main task.
|
||||||
|
|
||||||
|
Closes connection if the client is not reading the messages.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.to_write.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
self.log_error("Client exceeded max pending messages [2]:",
|
||||||
|
MAX_PENDING_MSG)
|
||||||
|
self.cancel()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def cancel(self):
|
||||||
|
"""Cancel the connection."""
|
||||||
|
self._handle_task.cancel()
|
||||||
|
self._writer_task.cancel()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def handle(self):
|
def handle(self, request):
|
||||||
"""Handle the websocket connection."""
|
"""Handle the websocket connection."""
|
||||||
wsock = self.wsock = web.WebSocketResponse()
|
wsock = self.wsock = web.WebSocketResponse()
|
||||||
yield from wsock.prepare(self.request)
|
yield from wsock.prepare(request)
|
||||||
|
|
||||||
# Set up to cancel this connection when Home Assistant shuts down
|
|
||||||
socket_task = asyncio.Task.current_task(loop=self.hass.loop)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def cancel_connection(event):
|
|
||||||
"""Cancel this connection."""
|
|
||||||
socket_task.cancel()
|
|
||||||
|
|
||||||
unsub_stop = self.hass.bus.async_listen(
|
|
||||||
EVENT_HOMEASSISTANT_STOP, cancel_connection)
|
|
||||||
writer_task = self.hass.async_add_job(self._writer())
|
|
||||||
final_message = None
|
|
||||||
self.debug("Connected")
|
self.debug("Connected")
|
||||||
|
|
||||||
|
# Get a reference to current task so we can cancel our connection
|
||||||
|
self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def handle_hass_stop(event):
|
||||||
|
"""Cancel this connection."""
|
||||||
|
self.cancel()
|
||||||
|
|
||||||
|
unsub_stop = self.hass.bus.async_listen(
|
||||||
|
EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
|
||||||
|
self._writer_task = self.hass.async_add_job(self._writer())
|
||||||
|
final_message = None
|
||||||
msg = None
|
msg = None
|
||||||
authenticated = False
|
authenticated = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.request[KEY_AUTHENTICATED]:
|
if request[KEY_AUTHENTICATED]:
|
||||||
authenticated = True
|
authenticated = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -269,7 +289,7 @@ class ActiveConnection:
|
||||||
msg = yield from wsock.receive_json()
|
msg = yield from wsock.receive_json()
|
||||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||||
|
|
||||||
if validate_password(self.request, msg['api_password']):
|
if validate_password(request, msg['api_password']):
|
||||||
authenticated = True
|
authenticated = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -278,13 +298,14 @@ class ActiveConnection:
|
||||||
auth_invalid_message('Invalid password'))
|
auth_invalid_message('Invalid password'))
|
||||||
|
|
||||||
if not authenticated:
|
if not authenticated:
|
||||||
yield from process_wrong_login(self.request)
|
yield from process_wrong_login(request)
|
||||||
return wsock
|
return wsock
|
||||||
|
|
||||||
yield from self.wsock.send_json(auth_ok_message())
|
yield from self.wsock.send_json(auth_ok_message())
|
||||||
|
|
||||||
msg = yield from wsock.receive_json()
|
# ---------- AUTH PHASE OVER ----------
|
||||||
|
|
||||||
|
msg = yield from wsock.receive_json()
|
||||||
last_id = 0
|
last_id = 0
|
||||||
|
|
||||||
while msg:
|
while msg:
|
||||||
|
@ -337,14 +358,15 @@ class ActiveConnection:
|
||||||
if value:
|
if value:
|
||||||
msg += ': {}'.format(value)
|
msg += ': {}'.format(value)
|
||||||
self.log_error(msg)
|
self.log_error(msg)
|
||||||
|
self._writer_task.cancel()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.debug("Connection cancelled by server")
|
self.debug("Connection cancelled by server")
|
||||||
|
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
self.log_error("Client exceeded max pending messages:",
|
self.log_error("Client exceeded max pending messages [1]:",
|
||||||
MAX_PENDING_MSG)
|
MAX_PENDING_MSG)
|
||||||
writer_task.cancel()
|
self._writer_task.cancel()
|
||||||
|
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
error = "Unexpected error inside websocket API. "
|
error = "Unexpected error inside websocket API. "
|
||||||
|
@ -353,19 +375,19 @@ class ActiveConnection:
|
||||||
_LOGGER.exception(error)
|
_LOGGER.exception(error)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
unsub_stop()
|
||||||
|
|
||||||
|
for unsub in self.event_listeners.values():
|
||||||
|
unsub()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if final_message is not None:
|
if final_message is not None:
|
||||||
self.to_write.put_nowait(final_message)
|
self.to_write.put_nowait(final_message)
|
||||||
self.to_write.put_nowait(None)
|
self.to_write.put_nowait(None)
|
||||||
# Make sure all error messages are written before closing
|
# Make sure all error messages are written before closing
|
||||||
yield from writer_task
|
yield from self._writer_task
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
pass
|
self._writer_task.cancel()
|
||||||
|
|
||||||
unsub_stop()
|
|
||||||
|
|
||||||
for unsub in self.event_listeners.values():
|
|
||||||
unsub()
|
|
||||||
|
|
||||||
yield from wsock.close()
|
yield from wsock.close()
|
||||||
self.debug("Closed connection")
|
self.debug("Closed connection")
|
||||||
|
@ -385,7 +407,7 @@ class ActiveConnection:
|
||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.to_write.put_nowait(event_message(msg['id'], event))
|
self.send_message_outside(event_message(msg['id'], event))
|
||||||
|
|
||||||
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
|
self.event_listeners[msg['id']] = self.hass.bus.async_listen(
|
||||||
msg['event_type'], forward_events)
|
msg['event_type'], forward_events)
|
||||||
|
@ -421,7 +443,7 @@ class ActiveConnection:
|
||||||
"""Call a service and fire complete message."""
|
"""Call a service and fire complete message."""
|
||||||
yield from self.hass.services.async_call(
|
yield from self.hass.services.async_call(
|
||||||
msg['domain'], msg['service'], msg['service_data'], True)
|
msg['domain'], msg['service'], msg['service_data'], True)
|
||||||
self.to_write.put_nowait(result_message(msg['id']))
|
self.send_message_outside(result_message(msg['id']))
|
||||||
|
|
||||||
self.hass.async_add_job(call_service_helper(msg))
|
self.hass.async_add_job(call_service_helper(msg))
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,13 @@ def no_auth_websocket_client(hass, loop, test_client):
|
||||||
loop.run_until_complete(ws.close())
|
loop.run_until_complete(ws.close())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_low_queue():
|
||||||
|
"""Mock a low queue."""
|
||||||
|
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def test_auth_via_msg(no_auth_websocket_client):
|
def test_auth_via_msg(no_auth_websocket_client):
|
||||||
"""Test authenticating."""
|
"""Test authenticating."""
|
||||||
|
@ -304,3 +311,15 @@ def test_ping(websocket_client):
|
||||||
msg = yield from websocket_client.receive_json()
|
msg = yield from websocket_client.receive_json()
|
||||||
assert msg['id'] == 5
|
assert msg['id'] == 5
|
||||||
assert msg['type'] == wapi.TYPE_PONG
|
assert msg['type'] == wapi.TYPE_PONG
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||||
|
"""Test get_panels command."""
|
||||||
|
for idx in range(10):
|
||||||
|
websocket_client.send_json({
|
||||||
|
'id': idx + 1,
|
||||||
|
'type': wapi.TYPE_PING,
|
||||||
|
})
|
||||||
|
msg = yield from websocket_client.receive()
|
||||||
|
assert msg.type == WSMsgType.close
|
||||||
|
|
Loading…
Reference in New Issue