From d80c05b6b61dd0d2d7b329ea7e7f86b50028ceab Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 9 Aug 2016 19:41:45 -0700 Subject: [PATCH] Enforce lower case for services and warn if local unknown service called (#2764) --- homeassistant/core.py | 19 ++++++++++++------- tests/test_core.py | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/homeassistant/core.py b/homeassistant/core.py index 81660d65a9b..63273ce789a 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -582,8 +582,8 @@ class ServiceCall(object): def __init__(self, domain, service, data=None, call_id=None): """Initialize a service call.""" - self.domain = domain - self.service = service + self.domain = domain.lower() + self.service = service.lower() self.data = data or {} self.call_id = call_id @@ -618,7 +618,7 @@ class ServiceRegistry(object): def has_service(self, domain, service): """Test if specified service exists.""" - return service in self._services.get(domain, []) + return service.lower() in self._services.get(domain.lower(), []) # pylint: disable=too-many-arguments def register(self, domain, service, service_func, description=None, @@ -631,6 +631,8 @@ class ServiceRegistry(object): Schema is called to coerce and validate the service data. """ + domain = domain.lower() + service = service.lower() description = description or {} service_obj = Service(service_func, description.get('description'), description.get('fields', {}), schema) @@ -664,8 +666,8 @@ class ServiceRegistry(object): call_id = self._generate_unique_id() event_data = { - ATTR_DOMAIN: domain, - ATTR_SERVICE: service, + ATTR_DOMAIN: domain.lower(), + ATTR_SERVICE: service.lower(), ATTR_SERVICE_DATA: service_data, ATTR_SERVICE_CALL_ID: call_id, } @@ -691,11 +693,14 @@ class ServiceRegistry(object): def _event_to_service_call(self, event): """Callback for SERVICE_CALLED events from the event bus.""" service_data = event.data.get(ATTR_SERVICE_DATA) - domain = event.data.get(ATTR_DOMAIN) - service = event.data.get(ATTR_SERVICE) + domain = event.data.get(ATTR_DOMAIN).lower() + service = event.data.get(ATTR_SERVICE).lower() call_id = event.data.get(ATTR_SERVICE_CALL_ID) if not self.has_service(domain, service): + if event.origin == EventOrigin.local: + _LOGGER.warning('Unable to find service %s/%s', + domain, service) return service_handler = self._services[domain][service] diff --git a/tests/test_core.py b/tests/test_core.py index 78a676708df..aa3cdd2aecc 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -386,7 +386,7 @@ class TestServiceRegistry(unittest.TestCase): return ha.HomeAssistant.add_job(self, *args, **kwargs) self.services = ha.ServiceRegistry(self.bus, add_job) - self.services.register("test_domain", "test_service", lambda x: None) + self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None) def tearDown(self): # pylint: disable=invalid-name """Stop down stuff we started.""" @@ -396,7 +396,7 @@ class TestServiceRegistry(unittest.TestCase): def test_has_service(self): """Test has_service method.""" self.assertTrue( - self.services.has_service("test_domain", "test_service")) + self.services.has_service("tesT_domaiN", "tesT_servicE")) self.assertFalse( self.services.has_service("test_domain", "non_existing")) self.assertFalse( @@ -418,7 +418,7 @@ class TestServiceRegistry(unittest.TestCase): lambda x: calls.append(1)) self.assertTrue( - self.services.call('test_domain', 'register_calls', blocking=True)) + self.services.call('test_domain', 'REGISTER_CALLS', blocking=True)) self.assertEqual(1, len(calls)) def test_call_with_blocking_not_done_in_time(self):