Log exceptions thrown by MQTT message callbacks (#19977)

* Log exceptions thrown by MQTT message callbacks

* Fix tests

* Correct method for skipping wrapper in traceback

* Lint

* Simplify traceback print

* Add test

* Move wrapper to common helper function

* Typing

* Lint
pull/20172/head
emontnemery 2019-01-16 22:50:21 +01:00 committed by Paulus Schoutsen
parent 1d86905d5b
commit 368682647d
4 changed files with 66 additions and 4 deletions

View File

@ -34,6 +34,7 @@ from homeassistant.loader import bind_hass
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util.async_ import (
run_callback_threadsafe, run_coroutine_threadsafe)
from homeassistant.util.logging import catch_log_exception
# Loading the config flow file will register the flow
from . import config_flow # noqa pylint: disable=unused-import
@ -311,7 +312,11 @@ async def async_subscribe(hass: HomeAssistantType, topic: str,
Call the return value to unsubscribe.
"""
async_remove = await hass.data[DATA_MQTT].async_subscribe(
topic, msg_callback, qos, encoding)
topic, catch_log_exception(
msg_callback, lambda topic, msg, qos:
"Exception in {} when handling msg on '{}': '{}'".format(
msg_callback.__name__, topic, msg)),
qos, encoding)
return async_remove

View File

@ -1,9 +1,12 @@
"""Logging utilities."""
import asyncio
from asyncio.events import AbstractEventLoop
from functools import wraps
import inspect
import logging
import threading
from typing import Optional
import traceback
from typing import Any, Callable, Optional
from .async_ import run_coroutine_threadsafe
@ -121,3 +124,38 @@ class AsyncHandler:
def name(self, name: str) -> None:
"""Wrap property get_name to handler."""
self.handler.set_name(name) # type: ignore
def catch_log_exception(
func: Callable[..., Any],
format_err: Callable[..., Any],
*args: Any) -> Callable[[], None]:
"""Decorate an callback to catch and log exceptions."""
def log_exception(*args: Any) -> None:
module_name = inspect.getmodule(inspect.trace()[1][0]).__name__
# Do not print the wrapper in the traceback
frames = len(inspect.trace()) - 1
exc_msg = traceback.format_exc(-frames)
friendly_msg = format_err(*args)
logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg)
wrapper_func = None
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
await func(*args)
except Exception: # pylint: disable=broad-except
log_exception(*args)
wrapper_func = async_wrapper
else:
@wraps(func)
def wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(*args)
wrapper_func = wrapper
return wrapper_func

View File

@ -297,6 +297,23 @@ class TestMQTTCallbacks(unittest.TestCase):
"b'\\x9a' on test-topic with encoding utf-8" in \
test_handle.output[0]
def test_message_callback_exception_gets_logged(self):
"""Test exception raised by message handler."""
@callback
def bad_handler(*args):
"""Record calls."""
raise Exception('This is a bad message callback')
mqtt.subscribe(self.hass, 'test-topic', bad_handler)
with self.assertLogs(level='WARNING') as test_handle:
fire_mqtt_message(self.hass, 'test-topic', 'test')
self.hass.block_till_done()
assert \
"Exception in bad_handler when handling msg on 'test-topic':" \
" 'test'" in \
test_handle.output[0]
def test_all_subscriptions_run_when_decode_fails(self):
"""Test all other subscriptions still run when decode fails for one."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,

View File

@ -1,4 +1,6 @@
"""The tests for the MQTT subscription component."""
from unittest import mock
from homeassistant.core import callback
from homeassistant.components.mqtt.subscription import (
async_subscribe_topics, async_unsubscribe_topics)
@ -135,7 +137,7 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog):
{'test_topic1': {'topic': 'test-topic1',
'msg_callback': msg_callback}})
mock_mqtt.async_subscribe.assert_called_once_with(
'test-topic1', msg_callback, 0, 'utf-8')
'test-topic1', mock.ANY, 0, 'utf-8')
async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
@ -155,7 +157,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
'qos': 1,
'encoding': 'utf-16'}})
mock_mqtt.async_subscribe.assert_called_once_with(
'test-topic1', msg_callback, 1, 'utf-16')
'test-topic1', mock.ANY, 1, 'utf-16')
async def test_no_change(hass, mqtt_mock, caplog):