Merge branch 'dev'

* dev:
  Bug fixes related to entity_ids being lowercase
  Have statemachine.track_change work on new states
pull/34/head
Paulus Schoutsen 2015-02-08 22:28:28 -08:00
commit b643ef628b
6 changed files with 118 additions and 57 deletions

View File

@ -614,13 +614,19 @@ class StateMachine(object):
@ft.wraps(action) @ft.wraps(action)
def state_listener(event): def state_listener(event):
""" The listener that listens for specific state changes. """ """ The listener that listens for specific state changes. """
if event.data['entity_id'] in entity_ids and \ if event.data['entity_id'] not in entity_ids:
'old_state' in event.data and \ return
_matcher(event.data['old_state'].state, from_state) and \
_matcher(event.data['new_state'].state, to_state): if 'old_state' in event.data:
old_state = event.data['old_state'].state
else:
old_state = None
if _matcher(old_state, from_state) and \
_matcher(event.data['new_state'].state, to_state):
action(event.data['entity_id'], action(event.data['entity_id'],
event.data['old_state'], event.data.get('old_state'),
event.data['new_state']) event.data['new_state'])
self._bus.listen(EVENT_STATE_CHANGED, state_listener) self._bus.listen(EVENT_STATE_CHANGED, state_listener)

View File

@ -51,6 +51,11 @@ def expand_entity_ids(hass, entity_ids):
found_ids = [] found_ids = []
for entity_id in entity_ids: for entity_id in entity_ids:
if not isinstance(entity_id, str):
continue
entity_id = entity_id.lower()
try: try:
# If entity_id points at a group, expand it # If entity_id points at a group, expand it
domain, _ = util.split_entity_id(entity_id) domain, _ = util.split_entity_id(entity_id)
@ -74,10 +79,14 @@ def expand_entity_ids(hass, entity_ids):
def get_entity_ids(hass, entity_id, domain_filter=None): def get_entity_ids(hass, entity_id, domain_filter=None):
""" Get the entity ids that make up this group. """ """ Get the entity ids that make up this group. """
entity_id = entity_id.lower()
try: try:
entity_ids = hass.states.get(entity_id).attributes[ATTR_ENTITY_ID] entity_ids = hass.states.get(entity_id).attributes[ATTR_ENTITY_ID]
if domain_filter: if domain_filter:
domain_filter = domain_filter.lower()
return [ent_id for ent_id in entity_ids return [ent_id for ent_id in entity_ids
if ent_id.startswith(domain_filter)] if ent_id.startswith(domain_filter)]
else: else:
@ -131,7 +140,7 @@ class Group(object):
def update_tracked_entity_ids(self, entity_ids): def update_tracked_entity_ids(self, entity_ids):
""" Update the tracked entity IDs. """ """ Update the tracked entity IDs. """
self.stop() self.stop()
self.tracking = tuple(entity_ids) self.tracking = tuple(ent_id.lower() for ent_id in entity_ids)
self.group_on, self.group_off = None, None self.group_on, self.group_off = None, None
self.force_update() self.force_update()

View File

@ -29,24 +29,18 @@ def extract_entity_ids(hass, service):
Helper method to extract a list of entity ids from a service call. Helper method to extract a list of entity ids from a service call.
Will convert group entity ids to the entity ids it represents. Will convert group entity ids to the entity ids it represents.
""" """
entity_ids = [] if not (service.data and ATTR_ENTITY_ID in service.data):
return []
if service.data and ATTR_ENTITY_ID in service.data: group = get_component('group')
group = get_component('group')
# Entity ID attr can be a list or a string # Entity ID attr can be a list or a string
service_ent_id = service.data[ATTR_ENTITY_ID] service_ent_id = service.data[ATTR_ENTITY_ID]
if isinstance(service_ent_id, list):
ent_ids = service_ent_id
else:
ent_ids = [service_ent_id]
entity_ids.extend( if isinstance(service_ent_id, str):
ent_id for ent_id return group.expand_entity_ids(hass, [service_ent_id.lower()])
in group.expand_entity_ids(hass, ent_ids)
if ent_id not in entity_ids)
return entity_ids return [ent_id for ent_id in group.expand_entity_ids(hass, service_ent_id)]
# pylint: disable=too-few-public-methods, attribute-defined-outside-init # pylint: disable=too-few-public-methods, attribute-defined-outside-init

View File

@ -78,8 +78,13 @@ class MockToggleDevice(ToggleDevice):
self._state = STATE_OFF self._state = STATE_OFF
def last_call(self, method=None): def last_call(self, method=None):
if method is None: if not self.calls:
return None
elif method is None:
return self.calls[-1] return self.calls[-1]
else: else:
return next(call for call in reversed(self.calls) try:
if call[0] == method) return next(call for call in reversed(self.calls)
if call[0] == method)
except StopIteration:
return None

View File

@ -27,14 +27,10 @@ class TestComponentsGroup(unittest.TestCase):
self.hass.states.set('light.Bowl', STATE_ON) self.hass.states.set('light.Bowl', STATE_ON)
self.hass.states.set('light.Ceiling', STATE_OFF) self.hass.states.set('light.Ceiling', STATE_OFF)
self.hass.states.set('switch.AC', STATE_OFF) test_group = group.Group(
group.setup_group(self.hass, 'init_group', self.hass, 'init_group', ['light.Bowl', 'light.Ceiling'], False)
['light.Bowl', 'light.Ceiling'], False)
group.setup_group(self.hass, 'mixed_group',
['light.Bowl', 'switch.AC'], False)
self.group_name = group.ENTITY_ID_FORMAT.format('init_group') self.group_entity_id = test_group.entity_id
self.mixed_group_name = group.ENTITY_ID_FORMAT.format('mixed_group')
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
""" Stop down stuff we started. """ """ Stop down stuff we started. """
@ -80,72 +76,122 @@ class TestComponentsGroup(unittest.TestCase):
""" Test if the group keeps track of states. """ """ Test if the group keeps track of states. """
# Test if group setup in our init mode is ok # Test if group setup in our init mode is ok
self.assertIn(self.group_name, self.hass.states.entity_ids()) self.assertIn(self.group_entity_id, self.hass.states.entity_ids())
group_state = self.hass.states.get(self.group_name) group_state = self.hass.states.get(self.group_entity_id)
self.assertEqual(STATE_ON, group_state.state) self.assertEqual(STATE_ON, group_state.state)
self.assertTrue(group_state.attributes[group.ATTR_AUTO]) self.assertTrue(group_state.attributes[group.ATTR_AUTO])
# Turn the Bowl off and see if group turns off def test_group_turns_off_if_all_off(self):
"""
Test if the group turns off if the last device that was on turns off.
"""
self.hass.states.set('light.Bowl', STATE_OFF) self.hass.states.set('light.Bowl', STATE_OFF)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
group_state = self.hass.states.get(self.group_name) group_state = self.hass.states.get(self.group_entity_id)
self.assertEqual(STATE_OFF, group_state.state) self.assertEqual(STATE_OFF, group_state.state)
# Turn the Ceiling on and see if group turns on def test_group_turns_on_if_all_are_off_and_one_turns_on(self):
self.hass.states.set('light.Ceiling', STATE_ON) """
Test if group turns on if all devices were turned off and one turns on.
"""
# Make sure all are off.
self.hass.states.set('light.Bowl', STATE_OFF)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
group_state = self.hass.states.get(self.group_name) # Turn one on
self.hass.states.set('light.Ceiling', STATE_ON)
self.hass.pool.block_till_done()
group_state = self.hass.states.get(self.group_entity_id)
self.assertEqual(STATE_ON, group_state.state) self.assertEqual(STATE_ON, group_state.state)
def test_is_on(self): def test_is_on(self):
""" Test is_on method. """ """ Test is_on method. """
self.assertTrue(group.is_on(self.hass, self.group_name)) self.assertTrue(group.is_on(self.hass, self.group_entity_id))
self.hass.states.set('light.Bowl', STATE_OFF) self.hass.states.set('light.Bowl', STATE_OFF)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertFalse(group.is_on(self.hass, self.group_name)) self.assertFalse(group.is_on(self.hass, self.group_entity_id))
# Try on non existing state # Try on non existing state
self.assertFalse(group.is_on(self.hass, 'non.existing')) self.assertFalse(group.is_on(self.hass, 'non.existing'))
def test_expand_entity_ids(self): def test_expand_entity_ids(self):
""" Test expand_entity_ids method. """ """ Test expand_entity_ids method. """
self.assertEqual(sorted(['light.Ceiling', 'light.Bowl']), self.assertEqual(sorted(['light.ceiling', 'light.bowl']),
sorted(group.expand_entity_ids( sorted(group.expand_entity_ids(
self.hass, [self.group_name]))) self.hass, [self.group_entity_id])))
# Make sure that no duplicates are returned def test_expand_entity_ids_does_not_return_duplicates(self):
""" Test that expand_entity_ids does not return duplicates. """
self.assertEqual( self.assertEqual(
sorted(['light.Ceiling', 'light.Bowl']), ['light.bowl', 'light.ceiling'],
sorted(group.expand_entity_ids( sorted(group.expand_entity_ids(
self.hass, [self.group_name, 'light.Ceiling']))) self.hass, [self.group_entity_id, 'light.Ceiling'])))
# Test that non strings are ignored self.assertEqual(
['light.bowl', 'light.ceiling'],
sorted(group.expand_entity_ids(
self.hass, ['light.bowl', self.group_entity_id])))
def test_expand_entity_ids_ignores_non_strings(self):
""" Test that non string elements in lists are ignored. """
self.assertEqual([], group.expand_entity_ids(self.hass, [5, True])) self.assertEqual([], group.expand_entity_ids(self.hass, [5, True]))
def test_get_entity_ids(self): def test_get_entity_ids(self):
""" Test get_entity_ids method. """ """ Test get_entity_ids method. """
# Get entity IDs from our group
self.assertEqual( self.assertEqual(
sorted(['light.Ceiling', 'light.Bowl']), ['light.bowl', 'light.ceiling'],
sorted(group.get_entity_ids(self.hass, self.group_name))) sorted(group.get_entity_ids(self.hass, self.group_entity_id)))
def test_get_entity_ids_with_domain_filter(self):
""" Test if get_entity_ids works with a domain_filter. """
self.hass.states.set('switch.AC', STATE_OFF)
mixed_group = group.Group(
self.hass, 'mixed_group', ['light.Bowl', 'switch.AC'], False)
# Test domain_filter
self.assertEqual( self.assertEqual(
['switch.AC'], ['switch.ac'],
group.get_entity_ids( group.get_entity_ids(
self.hass, self.mixed_group_name, domain_filter="switch")) self.hass, mixed_group.entity_id, domain_filter="switch"))
# Test with non existing group name def test_get_entity_ids_with_non_existing_group_name(self):
""" Tests get_entity_ids with a non existing group. """
self.assertEqual([], group.get_entity_ids(self.hass, 'non_existing')) self.assertEqual([], group.get_entity_ids(self.hass, 'non_existing'))
# Test with non-group state def test_get_entity_ids_with_non_group_state(self):
""" Tests get_entity_ids with a non group state. """
self.assertEqual([], group.get_entity_ids(self.hass, 'switch.AC')) self.assertEqual([], group.get_entity_ids(self.hass, 'switch.AC'))
def test_group_being_init_before_first_tracked_state_is_set_to_on(self):
""" Test if the group turns on if no states existed and now a state it is
tracking is being added as ON. """
test_group = group.Group(
self.hass, 'test group', ['light.not_there_1'])
self.hass.states.set('light.not_there_1', STATE_ON)
self.hass.pool.block_till_done()
group_state = self.hass.states.get(test_group.entity_id)
self.assertEqual(STATE_ON, group_state.state)
def test_group_being_init_before_first_tracked_state_is_set_to_off(self):
""" Test if the group turns off if no states existed and now a state it is
tracking is being added as OFF. """
test_group = group.Group(
self.hass, 'test group', ['light.not_there_1'])
self.hass.states.set('light.not_there_1', STATE_OFF)
self.hass.pool.block_till_done()
group_state = self.hass.states.get(test_group.entity_id)
self.assertEqual(STATE_OFF, group_state.state)
def test_setup(self): def test_setup(self):
""" Test setup method. """ """ Test setup method. """
self.assertTrue( self.assertTrue(
@ -153,7 +199,8 @@ class TestComponentsGroup(unittest.TestCase):
self.hass, self.hass,
{ {
group.DOMAIN: { group.DOMAIN: {
'second_group': '{},light.Bowl'.format(self.group_name) 'second_group': ','.join((self.group_entity_id,
'light.Bowl'))
} }
})) }))

View File

@ -39,11 +39,11 @@ class TestComponentsCore(unittest.TestCase):
call = ha.ServiceCall('light', 'turn_on', call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'light.Bowl'}) {ATTR_ENTITY_ID: 'light.Bowl'})
self.assertEqual(['light.Bowl'], self.assertEqual(['light.bowl'],
extract_entity_ids(self.hass, call)) extract_entity_ids(self.hass, call))
call = ha.ServiceCall('light', 'turn_on', call = ha.ServiceCall('light', 'turn_on',
{ATTR_ENTITY_ID: 'group.test'}) {ATTR_ENTITY_ID: 'group.test'})
self.assertEqual(['light.Ceiling', 'light.Kitchen'], self.assertEqual(['light.ceiling', 'light.kitchen'],
extract_entity_ids(self.hass, call)) extract_entity_ids(self.hass, call))