Enforce lower case for services and warn if local unknown service called (#2764)

pull/2778/head
Paulus Schoutsen 2016-08-09 19:41:45 -07:00 committed by GitHub
parent 180a7ec295
commit d80c05b6b6
2 changed files with 15 additions and 10 deletions

View File

@ -582,8 +582,8 @@ class ServiceCall(object):
def __init__(self, domain, service, data=None, call_id=None):
"""Initialize a service call."""
self.domain = domain
self.service = service
self.domain = domain.lower()
self.service = service.lower()
self.data = data or {}
self.call_id = call_id
@ -618,7 +618,7 @@ class ServiceRegistry(object):
def has_service(self, domain, service):
"""Test if specified service exists."""
return service in self._services.get(domain, [])
return service.lower() in self._services.get(domain.lower(), [])
# pylint: disable=too-many-arguments
def register(self, domain, service, service_func, description=None,
@ -631,6 +631,8 @@ class ServiceRegistry(object):
Schema is called to coerce and validate the service data.
"""
domain = domain.lower()
service = service.lower()
description = description or {}
service_obj = Service(service_func, description.get('description'),
description.get('fields', {}), schema)
@ -664,8 +666,8 @@ class ServiceRegistry(object):
call_id = self._generate_unique_id()
event_data = {
ATTR_DOMAIN: domain,
ATTR_SERVICE: service,
ATTR_DOMAIN: domain.lower(),
ATTR_SERVICE: service.lower(),
ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id,
}
@ -691,11 +693,14 @@ class ServiceRegistry(object):
def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus."""
service_data = event.data.get(ATTR_SERVICE_DATA)
domain = event.data.get(ATTR_DOMAIN)
service = event.data.get(ATTR_SERVICE)
domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE).lower()
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service):
if event.origin == EventOrigin.local:
_LOGGER.warning('Unable to find service %s/%s',
domain, service)
return
service_handler = self._services[domain][service]

View File

@ -386,7 +386,7 @@ class TestServiceRegistry(unittest.TestCase):
return ha.HomeAssistant.add_job(self, *args, **kwargs)
self.services = ha.ServiceRegistry(self.bus, add_job)
self.services.register("test_domain", "test_service", lambda x: None)
self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
def tearDown(self): # pylint: disable=invalid-name
"""Stop down stuff we started."""
@ -396,7 +396,7 @@ class TestServiceRegistry(unittest.TestCase):
def test_has_service(self):
"""Test has_service method."""
self.assertTrue(
self.services.has_service("test_domain", "test_service"))
self.services.has_service("tesT_domaiN", "tesT_servicE"))
self.assertFalse(
self.services.has_service("test_domain", "non_existing"))
self.assertFalse(
@ -418,7 +418,7 @@ class TestServiceRegistry(unittest.TestCase):
lambda x: calls.append(1))
self.assertTrue(
self.services.call('test_domain', 'register_calls', blocking=True))
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.assertEqual(1, len(calls))
def test_call_with_blocking_not_done_in_time(self):