diff --git a/mycroft/skills/fallback_skill.py b/mycroft/skills/fallback_skill.py index cef2cd7b9c..fed14bce57 100644 --- a/mycroft/skills/fallback_skill.py +++ b/mycroft/skills/fallback_skill.py @@ -48,6 +48,7 @@ class FallbackSkill(MycroftSkill): utterance will not be see by any other Fallback handlers. """ fallback_handlers = {} + wrapper_map = [] # Map containing (handler, wrapper) tuples def __init__(self, name=None, bus=None, use_settings=True): super().__init__(name, bus, use_settings) @@ -98,18 +99,25 @@ class FallbackSkill(MycroftSkill): return handler @classmethod - def _register_fallback(cls, handler, priority): + def _register_fallback(cls, handler, wrapper, priority): """Register a function to be called as a general info fallback Fallback should receive message and return a boolean (True if succeeded or False if failed) Lower priority gets run first 0 for high priority 100 for low priority + + Arguments: + handler (callable): original handler, used as a reference when + removing + wrapper (callable): wrapped version of handler + priority (int): fallback priority """ while priority in cls.fallback_handlers: priority += 1 - cls.fallback_handlers[priority] = handler + cls.fallback_handlers[priority] = wrapper + cls.wrapper_map.append((handler, wrapper)) def register_fallback(self, handler, priority): """Register a fallback with the list of fallback handlers and with the @@ -122,8 +130,28 @@ class FallbackSkill(MycroftSkill): return True return False - self.instance_fallback_handlers.append(wrapper) - self._register_fallback(wrapper, priority) + self.instance_fallback_handlers.append(handler) + self._register_fallback(handler, wrapper, priority) + + @classmethod + def _remove_registered_handler(cls, wrapper_to_del): + """Remove a registered wrapper. + + Arguments: + wrapper_to_del (callable): wrapped handler to be removed + + Returns: + (bool) True if one or more handlers were removed, otherwise False. + """ + found_handler = False + for priority, handler in list(cls.fallback_handlers.items()): + if handler == wrapper_to_del: + found_handler = True + del cls.fallback_handlers[priority] + + if not found_handler: + LOG.warning('No fallback matching {}'.format(wrapper_to_del)) + return found_handler @classmethod def remove_fallback(cls, handler_to_del): @@ -131,15 +159,27 @@ class FallbackSkill(MycroftSkill): Arguments: handler_to_del: reference to handler + Returns: + (bool) True if at least one handler was removed, otherwise False """ - for priority, handler in cls.fallback_handlers.items(): - if handler == handler_to_del: - del cls.fallback_handlers[priority] - return - LOG.warning('Could not remove fallback!') + # Find wrapper from handler or wrapper + wrapper_to_del = None + for h, w in cls.wrapper_map: + if handler_to_del in (h, w): + wrapper_to_del = w + break + + if wrapper_to_del: + cls.wrapper_map.remove((h, w)) + remove_ok = cls._remove_registered_handler(wrapper_to_del) + else: + LOG.warning('Could not find matching fallback handler') + remove_ok = False + return remove_ok def remove_instance_handlers(self): """Remove all fallback handlers registered by the fallback skill.""" + self.log.info('Removing all handlers...') while len(self.instance_fallback_handlers): handler = self.instance_fallback_handlers.pop() self.remove_fallback(handler) diff --git a/test/unittests/skills/test_fallback_skill.py b/test/unittests/skills/test_fallback_skill.py new file mode 100644 index 0000000000..d13f136261 --- /dev/null +++ b/test/unittests/skills/test_fallback_skill.py @@ -0,0 +1,53 @@ +from unittest import TestCase, mock + +from mycroft.skills import FallbackSkill + + +def setup_fallback(fb_class): + fb_skill = fb_class() + fb_skill.bind(mock.Mock(name='bus')) + fb_skill.initialize() + return fb_skill + + +class TestFallbackSkill(TestCase): + def test_life_cycle(self): + """Test startup and shutdown of a fallback skill. + + Ensure that an added handler is removed as part of default shutdown. + """ + self.assertEqual(len(FallbackSkill.fallback_handlers), 0) + fb_skill = setup_fallback(SimpleFallback) + self.assertEqual(len(FallbackSkill.fallback_handlers), 1) + self.assertEqual(FallbackSkill.wrapper_map[0][0], + fb_skill.fallback_handler) + self.assertEqual(len(FallbackSkill.wrapper_map), 1) + + fb_skill.default_shutdown() + self.assertEqual(len(FallbackSkill.fallback_handlers), 0) + self.assertEqual(len(FallbackSkill.wrapper_map), 0) + + def test_manual_removal(self): + """Test that the call to remove_fallback() removes the handler""" + self.assertEqual(len(FallbackSkill.fallback_handlers), 0) + + # Create skill adding a single handler + fb_skill = setup_fallback(SimpleFallback) + self.assertEqual(len(FallbackSkill.fallback_handlers), 1) + + self.assertTrue(fb_skill.remove_fallback(fb_skill.fallback_handler)) + # Both internal trackers of handlers should be cleared now + self.assertEqual(len(FallbackSkill.fallback_handlers), 0) + self.assertEqual(len(FallbackSkill.wrapper_map), 0) + + # Removing after it's already been removed should fail + self.assertFalse(fb_skill.remove_fallback(fb_skill.fallback_handler)) + + +class SimpleFallback(FallbackSkill): + """Simple fallback skill used for test.""" + def initialize(self): + self.register_fallback(self.fallback_handler, 42) + + def fallback_handler(self): + pass