Enforce lower case for services and warn if local unknown service called (#2764)
parent
180a7ec295
commit
d80c05b6b6
|
@ -582,8 +582,8 @@ class ServiceCall(object):
|
|||
|
||||
def __init__(self, domain, service, data=None, call_id=None):
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain
|
||||
self.service = service
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = data or {}
|
||||
self.call_id = call_id
|
||||
|
||||
|
@ -618,7 +618,7 @@ class ServiceRegistry(object):
|
|||
|
||||
def has_service(self, domain, service):
|
||||
"""Test if specified service exists."""
|
||||
return service in self._services.get(domain, [])
|
||||
return service.lower() in self._services.get(domain.lower(), [])
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def register(self, domain, service, service_func, description=None,
|
||||
|
@ -631,6 +631,8 @@ class ServiceRegistry(object):
|
|||
|
||||
Schema is called to coerce and validate the service data.
|
||||
"""
|
||||
domain = domain.lower()
|
||||
service = service.lower()
|
||||
description = description or {}
|
||||
service_obj = Service(service_func, description.get('description'),
|
||||
description.get('fields', {}), schema)
|
||||
|
@ -664,8 +666,8 @@ class ServiceRegistry(object):
|
|||
call_id = self._generate_unique_id()
|
||||
|
||||
event_data = {
|
||||
ATTR_DOMAIN: domain,
|
||||
ATTR_SERVICE: service,
|
||||
ATTR_DOMAIN: domain.lower(),
|
||||
ATTR_SERVICE: service.lower(),
|
||||
ATTR_SERVICE_DATA: service_data,
|
||||
ATTR_SERVICE_CALL_ID: call_id,
|
||||
}
|
||||
|
@ -691,11 +693,14 @@ class ServiceRegistry(object):
|
|||
def _event_to_service_call(self, event):
|
||||
"""Callback for SERVICE_CALLED events from the event bus."""
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA)
|
||||
domain = event.data.get(ATTR_DOMAIN)
|
||||
service = event.data.get(ATTR_SERVICE)
|
||||
domain = event.data.get(ATTR_DOMAIN).lower()
|
||||
service = event.data.get(ATTR_SERVICE).lower()
|
||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||
|
||||
if not self.has_service(domain, service):
|
||||
if event.origin == EventOrigin.local:
|
||||
_LOGGER.warning('Unable to find service %s/%s',
|
||||
domain, service)
|
||||
return
|
||||
|
||||
service_handler = self._services[domain][service]
|
||||
|
|
|
@ -386,7 +386,7 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
return ha.HomeAssistant.add_job(self, *args, **kwargs)
|
||||
|
||||
self.services = ha.ServiceRegistry(self.bus, add_job)
|
||||
self.services.register("test_domain", "test_service", lambda x: None)
|
||||
self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
|
||||
|
||||
def tearDown(self): # pylint: disable=invalid-name
|
||||
"""Stop down stuff we started."""
|
||||
|
@ -396,7 +396,7 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
def test_has_service(self):
|
||||
"""Test has_service method."""
|
||||
self.assertTrue(
|
||||
self.services.has_service("test_domain", "test_service"))
|
||||
self.services.has_service("tesT_domaiN", "tesT_servicE"))
|
||||
self.assertFalse(
|
||||
self.services.has_service("test_domain", "non_existing"))
|
||||
self.assertFalse(
|
||||
|
@ -418,7 +418,7 @@ class TestServiceRegistry(unittest.TestCase):
|
|||
lambda x: calls.append(1))
|
||||
|
||||
self.assertTrue(
|
||||
self.services.call('test_domain', 'register_calls', blocking=True))
|
||||
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
|
||||
self.assertEqual(1, len(calls))
|
||||
|
||||
def test_call_with_blocking_not_done_in_time(self):
|
||||
|
|
Loading…
Reference in New Issue