diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index dbdb7f7eb32..fcaa05f7921 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -34,6 +34,7 @@ from homeassistant.loader import bind_hass from homeassistant.setup import async_prepare_setup_platform from homeassistant.util.async_ import ( run_callback_threadsafe, run_coroutine_threadsafe) +from homeassistant.util.logging import catch_log_exception # Loading the config flow file will register the flow from . import config_flow # noqa pylint: disable=unused-import @@ -311,7 +312,11 @@ async def async_subscribe(hass: HomeAssistantType, topic: str, Call the return value to unsubscribe. """ async_remove = await hass.data[DATA_MQTT].async_subscribe( - topic, msg_callback, qos, encoding) + topic, catch_log_exception( + msg_callback, lambda topic, msg, qos: + "Exception in {} when handling msg on '{}': '{}'".format( + msg_callback.__name__, topic, msg)), + qos, encoding) return async_remove diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index f2bf15d8a03..ae32566c73c 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -1,9 +1,12 @@ """Logging utilities.""" import asyncio from asyncio.events import AbstractEventLoop +from functools import wraps +import inspect import logging import threading -from typing import Optional +import traceback +from typing import Any, Callable, Optional from .async_ import run_coroutine_threadsafe @@ -121,3 +124,38 @@ class AsyncHandler: def name(self, name: str) -> None: """Wrap property get_name to handler.""" self.handler.set_name(name) # type: ignore + + +def catch_log_exception( + func: Callable[..., Any], + format_err: Callable[..., Any], + *args: Any) -> Callable[[], None]: + """Decorate an callback to catch and log exceptions.""" + def log_exception(*args: Any) -> None: + module_name = inspect.getmodule(inspect.trace()[1][0]).__name__ + # Do not print the wrapper in the traceback + frames = len(inspect.trace()) - 1 + exc_msg = traceback.format_exc(-frames) + friendly_msg = format_err(*args) + logging.getLogger(module_name).error('%s\n%s', friendly_msg, exc_msg) + + wrapper_func = None + if asyncio.iscoroutinefunction(func): + @wraps(func) + async def async_wrapper(*args: Any) -> None: + """Catch and log exception.""" + try: + await func(*args) + except Exception: # pylint: disable=broad-except + log_exception(*args) + wrapper_func = async_wrapper + else: + @wraps(func) + def wrapper(*args: Any) -> None: + """Catch and log exception.""" + try: + func(*args) + except Exception: # pylint: disable=broad-except + log_exception(*args) + wrapper_func = wrapper + return wrapper_func diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 6652eddd20b..540cfe0369d 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -297,6 +297,23 @@ class TestMQTTCallbacks(unittest.TestCase): "b'\\x9a' on test-topic with encoding utf-8" in \ test_handle.output[0] + def test_message_callback_exception_gets_logged(self): + """Test exception raised by message handler.""" + @callback + def bad_handler(*args): + """Record calls.""" + raise Exception('This is a bad message callback') + mqtt.subscribe(self.hass, 'test-topic', bad_handler) + + with self.assertLogs(level='WARNING') as test_handle: + fire_mqtt_message(self.hass, 'test-topic', 'test') + + self.hass.block_till_done() + assert \ + "Exception in bad_handler when handling msg on 'test-topic':" \ + " 'test'" in \ + test_handle.output[0] + def test_all_subscriptions_run_when_decode_fails(self): """Test all other subscriptions still run when decode fails for one.""" mqtt.subscribe(self.hass, 'test-topic', self.record_calls, diff --git a/tests/components/mqtt/test_subscription.py b/tests/components/mqtt/test_subscription.py index 102b71d7b53..69386e2bad4 100644 --- a/tests/components/mqtt/test_subscription.py +++ b/tests/components/mqtt/test_subscription.py @@ -1,4 +1,6 @@ """The tests for the MQTT subscription component.""" +from unittest import mock + from homeassistant.core import callback from homeassistant.components.mqtt.subscription import ( async_subscribe_topics, async_unsubscribe_topics) @@ -135,7 +137,7 @@ async def test_qos_encoding_default(hass, mqtt_mock, caplog): {'test_topic1': {'topic': 'test-topic1', 'msg_callback': msg_callback}}) mock_mqtt.async_subscribe.assert_called_once_with( - 'test-topic1', msg_callback, 0, 'utf-8') + 'test-topic1', mock.ANY, 0, 'utf-8') async def test_qos_encoding_custom(hass, mqtt_mock, caplog): @@ -155,7 +157,7 @@ async def test_qos_encoding_custom(hass, mqtt_mock, caplog): 'qos': 1, 'encoding': 'utf-16'}}) mock_mqtt.async_subscribe.assert_called_once_with( - 'test-topic1', msg_callback, 1, 'utf-16') + 'test-topic1', mock.ANY, 1, 'utf-16') async def test_no_change(hass, mqtt_mock, caplog):