Log exceptions thrown by MQTT message callbacks (#19977)
* Log exceptions thrown by MQTT message callbacks * Fix tests * Correct method for skipping wrapper in traceback * Lint * Simplify traceback print * Add test * Move wrapper to common helper function * Typing * Lintpull/20172/head
parent
1d86905d5b
commit
368682647d
|
@ -34,6 +34,7 @@ from homeassistant.loader import bind_hass
|
|||
from homeassistant.setup import async_prepare_setup_platform
|
||||
from homeassistant.util.async_ import (
|
||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||
from homeassistant.util.logging import catch_log_exception
|
||||
|
||||
# Loading the config flow file will register the flow
|
||||
from . import config_flow # noqa pylint: disable=unused-import
|
||||
|
@ -311,7 +312,11 @@ async def async_subscribe(hass: HomeAssistantType, topic: str,
|
|||
Call the return value to unsubscribe.
|
||||
"""
|
||||
async_remove = await hass.data[DATA_MQTT].async_subscribe(
|
||||
topic, msg_callback, qos, encoding)
|
||||
topic, catch_log_exception(
|
||||
msg_callback, lambda topic, msg, qos:
|
||||
"Exception in {} when handling msg on '{}': '{}'".format(
|
||||
msg_callback.__name__, topic, msg)),
|
||||
qos, encoding)
|
||||
return async_remove
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
"""Logging utilities."""
|
||||
import asyncio
|
||||
from asyncio.events import AbstractEventLoop
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .async_ import run_coroutine_threadsafe
|
||||
|
||||
|
@ -121,3 +124,38 @@ class AsyncHandler:
|
|||
def name(self, name: str) -> None:
|
||||
"""Wrap property get_name to handler."""
|
||||
self.handler.set_name(name) # type: ignore
|
||||
|
||||
|
||||
def catch_log_exception(
|
||||
func: Callable[..., Any],
|
||||
format_err: Callable[..., Any],
|
||||
*args: Any) -> Callable[[], None]:
|
||||
"""Decorate an callback to catch and log exceptions."""
|
||||
def log_exception(*args: Any) -> None:
|
||||
module_name = inspect.getmodule(inspect.trace()[1][0]).__name__
|
||||
# Do not print the wrapper in the traceback
|
||||
frames = len(inspect.trace()) - 1
|
||||
exc_msg = traceback.format_exc(-frames)
|
||||
friendly_msg = format_err(*args)
|
||||
logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg)
|
||||
|
||||
wrapper_func = None
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any) -> None:
|
||||
"""Catch and log exception."""
|
||||
try:
|
||||
await func(*args)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log_exception(*args)
|
||||
wrapper_func = async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any) -> None:
|
||||
"""Catch and log exception."""
|
||||
try:
|
||||
func(*args)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
log_exception(*args)
|
||||
wrapper_func = wrapper
|
||||
return wrapper_func
|
||||
|
|
|
@ -297,6 +297,23 @@ class TestMQTTCallbacks(unittest.TestCase):
|
|||
"b'\\x9a' on test-topic with encoding utf-8" in \
|
||||
test_handle.output[0]
|
||||
|
||||
def test_message_callback_exception_gets_logged(self):
|
||||
"""Test exception raised by message handler."""
|
||||
@callback
|
||||
def bad_handler(*args):
|
||||
"""Record calls."""
|
||||
raise Exception('This is a bad message callback')
|
||||
mqtt.subscribe(self.hass, 'test-topic', bad_handler)
|
||||
|
||||
with self.assertLogs(level='WARNING') as test_handle:
|
||||
fire_mqtt_message(self.hass, 'test-topic', 'test')
|
||||
|
||||
self.hass.block_till_done()
|
||||
assert \
|
||||
"Exception in bad_handler when handling msg on 'test-topic':" \
|
||||
" 'test'" in \
|
||||
test_handle.output[0]
|
||||
|
||||
def test_all_subscriptions_run_when_decode_fails(self):
|
||||
"""Test all other subscriptions still run when decode fails for one."""
|
||||
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""The tests for the MQTT subscription component."""
|
||||
from unittest import mock
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components.mqtt.subscription import (
|
||||
async_subscribe_topics, async_unsubscribe_topics)
|
||||
|
@ -135,7 +137,7 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog):
|
|||
{'test_topic1': {'topic': 'test-topic1',
|
||||
'msg_callback': msg_callback}})
|
||||
mock_mqtt.async_subscribe.assert_called_once_with(
|
||||
'test-topic1', msg_callback, 0, 'utf-8')
|
||||
'test-topic1', mock.ANY, 0, 'utf-8')
|
||||
|
||||
|
||||
async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
|
||||
|
@ -155,7 +157,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog):
|
|||
'qos': 1,
|
||||
'encoding': 'utf-16'}})
|
||||
mock_mqtt.async_subscribe.assert_called_once_with(
|
||||
'test-topic1', msg_callback, 1, 'utf-16')
|
||||
'test-topic1', mock.ANY, 1, 'utf-16')
|
||||
|
||||
|
||||
async def test_no_change(hass, mqtt_mock, caplog):
|
||||
|
|
Loading…
Reference in New Issue