Cleaned up device_tracker and added tests

pull/12/head
Paulus Schoutsen 2014-12-02 21:53:00 -08:00
parent 12c734fa48
commit eef4817804
5 changed files with 429 additions and 171 deletions

View File

@ -0,0 +1,41 @@
"""
custom_components.device_tracker.test
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Provides a mock device scanner.
"""
def get_scanner(hass, config):
""" Returns a mock scanner. """
return SCANNER
class MockScanner(object):
""" Mock device scanner. """
def __init__(self):
""" Initialize the MockScanner. """
self.devices_home = []
def come_home(self, device):
""" Make a device come home. """
self.devices_home.append(device)
def leave_home(self, device):
""" Make a device leave the house. """
self.devices_home.remove(device)
def scan_devices(self):
""" Returns a list of fake devices. """
return list(self.devices_home)
def get_device_name(self, device):
"""
Returns a name for a mock device.
Returns None for dev1 for testing.
"""
return None if device == 'dev1' else device.upper()
SCANNER = MockScanner()

View File

@ -0,0 +1,190 @@
"""
ha_test.test_component_group
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tests the group compoments.
"""
# pylint: disable=protected-access,too-many-public-methods
import unittest
from datetime import datetime, timedelta
import logging
import os
import homeassistant as ha
import homeassistant.loader as loader
from homeassistant.components import (
STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_PICTURE)
import homeassistant.components.device_tracker as device_tracker
from helper import get_test_home_assistant
def setUpModule(): # pylint: disable=invalid-name
""" Setup to ignore group errors. """
logging.disable(logging.CRITICAL)
class TestComponentsDeviceTracker(unittest.TestCase):
""" Tests homeassistant.components.device_tracker module. """
def setUp(self): # pylint: disable=invalid-name
""" Init needed objects. """
self.hass = get_test_home_assistant()
loader.prepare(self.hass)
self.known_dev_path = self.hass.get_config_path(
device_tracker.KNOWN_DEVICES_FILE)
def tearDown(self): # pylint: disable=invalid-name
""" Stop down stuff we started. """
self.hass.stop()
if os.path.isfile(self.known_dev_path):
os.remove(self.known_dev_path)
def test_is_on(self):
""" Test is_on method. """
entity_id = device_tracker.ENTITY_ID_FORMAT.format('test')
self.hass.states.set(entity_id, STATE_HOME)
self.assertTrue(device_tracker.is_on(self.hass, entity_id))
self.hass.states.set(entity_id, STATE_NOT_HOME)
self.assertFalse(device_tracker.is_on(self.hass, entity_id))
def test_setup(self):
""" Test setup method. """
# Bogus config
self.assertFalse(device_tracker.setup(self.hass, {}))
self.assertFalse(
device_tracker.setup(self.hass, {device_tracker.DOMAIN: {}}))
# Test with non-existing component
self.assertFalse(device_tracker.setup(
self.hass, {device_tracker.DOMAIN: {ha.CONF_TYPE: 'nonexisting'}}
))
# Test with a bad known device file around
with open(self.known_dev_path, 'w') as fil:
fil.write("bad data\nbad data\n")
self.assertFalse(device_tracker.setup(self.hass, {
device_tracker.DOMAIN: {ha.CONF_TYPE: 'test'}
}))
def test_device_tracker(self):
""" Test the device tracker class. """
scanner = loader.get_component(
'device_tracker.test').get_scanner(None, None)
scanner.come_home('dev1')
scanner.come_home('dev2')
self.assertTrue(device_tracker.setup(self.hass, {
device_tracker.DOMAIN: {ha.CONF_TYPE: 'test'}
}))
# Ensure a new known devices file has been created.
# Since the device_tracker uses a set internally we cannot
# know what the order of the devices in the known devices file is.
# To ensure all the three expected lines are there, we sort the file
with open(self.known_dev_path) as fil:
self.assertEqual(
['dev1,unknown_device,0,\n', 'dev2,DEV2,0,\n',
'device,name,track,picture\n'],
sorted(fil))
# Write one where we track dev1, dev2
with open(self.known_dev_path, 'w') as fil:
fil.write('device,name,track,picture\n')
fil.write('dev1,Device 1,1,http://example.com/dev1.jpg\n')
fil.write('dev2,Device 2,1,http://example.com/dev2.jpg\n')
scanner.leave_home('dev1')
scanner.come_home('dev3')
self.hass.services.call(
device_tracker.DOMAIN,
device_tracker.SERVICE_DEVICE_TRACKER_RELOAD)
self.hass._pool.block_till_done()
dev1 = device_tracker.ENTITY_ID_FORMAT.format('Device_1')
dev2 = device_tracker.ENTITY_ID_FORMAT.format('Device_2')
dev3 = device_tracker.ENTITY_ID_FORMAT.format('DEV3')
now = datetime.now()
nowNext = now + timedelta(seconds=ha.TIMER_INTERVAL)
nowAlmostMinGone = (now + device_tracker.TIME_DEVICE_NOT_FOUND -
timedelta(seconds=1))
nowMinGone = nowAlmostMinGone + timedelta(seconds=2)
# Test initial is correct
self.assertTrue(device_tracker.is_on(self.hass))
self.assertFalse(device_tracker.is_on(self.hass, dev1))
self.assertTrue(device_tracker.is_on(self.hass, dev2))
self.assertIsNone(self.hass.states.get(dev3))
self.assertEqual(
'http://example.com/dev1.jpg',
self.hass.states.get(dev1).attributes.get(ATTR_ENTITY_PICTURE))
self.assertEqual(
'http://example.com/dev2.jpg',
self.hass.states.get(dev2).attributes.get(ATTR_ENTITY_PICTURE))
# Test if dev3 got added to known dev file
with open(self.known_dev_path) as fil:
self.assertEqual('dev3,DEV3,0,\n', list(fil)[-1])
# Change dev3 to track
with open(self.known_dev_path, 'w') as fil:
fil.write("device,name,track,picture\n")
fil.write('dev1,Device 1,1,http://example.com/picture.jpg\n')
fil.write('dev2,Device 2,1,http://example.com/picture.jpg\n')
fil.write('dev3,DEV3,1,\n')
# reload dev file
scanner.come_home('dev1')
scanner.leave_home('dev2')
self.hass.services.call(
device_tracker.DOMAIN,
device_tracker.SERVICE_DEVICE_TRACKER_RELOAD)
self.hass._pool.block_till_done()
# Test what happens if a device comes home and another leaves
self.assertTrue(device_tracker.is_on(self.hass))
self.assertTrue(device_tracker.is_on(self.hass, dev1))
# Dev2 will still be home because of the error margin on time
self.assertTrue(device_tracker.is_on(self.hass, dev2))
# dev3 should be tracked now after we reload the known devices
self.assertTrue(device_tracker.is_on(self.hass, dev3))
self.assertIsNone(
self.hass.states.get(dev3).attributes.get(ATTR_ENTITY_PICTURE))
# Test if device leaves what happens, test the time span
self.hass.bus.fire(
ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: nowAlmostMinGone})
self.hass._pool.block_till_done()
self.assertTrue(device_tracker.is_on(self.hass))
self.assertTrue(device_tracker.is_on(self.hass, dev1))
# Dev2 will still be home because of the error time
self.assertTrue(device_tracker.is_on(self.hass, dev2))
self.assertTrue(device_tracker.is_on(self.hass, dev3))
# Now test if gone for longer then error margin
self.hass.bus.fire(ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: nowMinGone})
self.hass._pool.block_till_done()
self.assertTrue(device_tracker.is_on(self.hass))
self.assertTrue(device_tracker.is_on(self.hass, dev1))
self.assertFalse(device_tracker.is_on(self.hass, dev2))
self.assertTrue(device_tracker.is_on(self.hass, dev3))

View File

@ -557,6 +557,9 @@ class StateMachine(object):
Track specific state changes.
entity_ids, from_state and to_state can be string or list.
Use list to match multiple.
Returns the listener that listens on the bus for EVENT_STATE_CHANGED.
Pass the return value into hass.bus.remove_listener to remove it.
"""
from_state = _process_match_param(from_state)
to_state = _process_match_param(to_state)
@ -579,6 +582,8 @@ class StateMachine(object):
self._bus.listen(EVENT_STATE_CHANGED, state_listener)
return state_listener
# pylint: disable=too-few-public-methods
class ServiceCall(object):

View File

@ -1,6 +1,6 @@
"""
homeassistant.components.tracker
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Provides functionality to keep track of devices.
"""
@ -13,9 +13,9 @@ from datetime import datetime, timedelta
import homeassistant as ha
from homeassistant.loader import get_component
import homeassistant.util as util
import homeassistant.components as components
from homeassistant.components import group
from homeassistant.components import (
group, STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME)
DOMAIN = "device_tracker"
DEPENDENCIES = []
@ -30,7 +30,7 @@ ENTITY_ID_FORMAT = DOMAIN + '.{}'
# After how much time do we consider a device not home if
# it does not show up on scans
TIME_SPAN_FOR_ERROR_IN_SCANNING = timedelta(minutes=3)
TIME_DEVICE_NOT_FOUND = timedelta(minutes=3)
# Filename to save known devices to
KNOWN_DEVICES_FILE = "known_devices.csv"
@ -43,7 +43,7 @@ def is_on(hass, entity_id=None):
""" Returns if any or specified device is home. """
entity = entity_id or ENTITY_ID_ALL_DEVICES
return hass.states.is_state(entity, components.STATE_HOME)
return hass.states.is_state(entity, STATE_HOME)
def setup(hass, config):
@ -70,223 +70,231 @@ def setup(hass, config):
return False
DeviceTracker(hass, device_scanner)
tracker = DeviceTracker(hass, device_scanner)
return True
# We only succeeded if we got to parse the known devices file
return not tracker.invalid_known_devices_file
# pylint: disable=too-many-instance-attributes
class DeviceTracker(object):
""" Class that tracks which devices are home and which are not. """
def __init__(self, hass, device_scanner):
self.states = hass.states
self.hass = hass
self.device_scanner = device_scanner
self.error_scanning = TIME_SPAN_FOR_ERROR_IN_SCANNING
self.lock = threading.Lock()
self.path_known_devices_file = hass.get_config_path(KNOWN_DEVICES_FILE)
# Dictionary to keep track of known devices and devices we track
self.known_devices = {}
self.tracked = {}
self.untracked_devices = set()
# Did we encounter an invalid known devices file
self.invalid_known_devices_file = False
self._read_known_devices_file()
if self.invalid_known_devices_file:
return
# Wrap it in a func instead of lambda so it can be identified in
# the bus by its __name__ attribute.
def update_device_state(time): # pylint: disable=unused-argument
def update_device_state(now):
""" Triggers update of the device states. """
self.update_devices()
self.update_devices(now)
# pylint: disable=unused-argument
def reload_known_devices_service(service):
""" Reload known devices file. """
group.remove_group(self.hass, GROUP_NAME_ALL_DEVICES)
self._read_known_devices_file()
self.update_devices(datetime.now())
if self.tracked:
group.setup_group(
self.hass, GROUP_NAME_ALL_DEVICES,
self.device_entity_ids, False)
hass.track_time_change(update_device_state)
hass.services.register(DOMAIN,
SERVICE_DEVICE_TRACKER_RELOAD,
lambda service: self._read_known_devices_file())
reload_known_devices_service)
self.update_devices()
group.setup_group(
hass, GROUP_NAME_ALL_DEVICES, self.device_entity_ids, False)
reload_known_devices_service(None)
@property
def device_entity_ids(self):
""" Returns a set containing all device entity ids
that are being tracked. """
return set([self.known_devices[device]['entity_id'] for device
in self.known_devices
if self.known_devices[device]['track']])
return set(device['entity_id'] for device in self.tracked.values())
def update_devices(self, found_devices=None):
def _update_state(self, now, device, is_home):
""" Update the state of a device. """
dev_info = self.tracked[device]
if is_home:
# Update last seen if at home
dev_info['last_seen'] = now
else:
# State remains at home if it has been seen in the last
# TIME_DEVICE_NOT_FOUND
is_home = now - dev_info['last_seen'] < TIME_DEVICE_NOT_FOUND
state = STATE_HOME if is_home else STATE_NOT_HOME
self.hass.states.set(
dev_info['entity_id'], state,
dev_info['state_attr'])
def update_devices(self, now):
""" Update device states based on the found devices. """
self.lock.acquire()
found_devices = found_devices or self.device_scanner.scan_devices()
found_devices = set(self.device_scanner.scan_devices())
now = datetime.now()
for device in self.tracked:
is_home = device in found_devices
known_dev = self.known_devices
self._update_state(now, device, is_home)
temp_tracking_devices = [device for device in known_dev
if known_dev[device]['track']]
if is_home:
found_devices.remove(device)
for device in found_devices:
# Are we tracking this device?
if device in temp_tracking_devices:
temp_tracking_devices.remove(device)
# Did we find any devices that we didn't know about yet?
new_devices = found_devices - self.untracked_devices
known_dev[device]['last_seen'] = now
# Write new devices to known devices file
if not self.invalid_known_devices_file and new_devices:
self.states.set(
known_dev[device]['entity_id'], components.STATE_HOME,
known_dev[device]['default_state_attr'])
known_dev_path = self.hass.get_config_path(KNOWN_DEVICES_FILE)
# For all devices we did not find, set state to NH
# But only if they have been gone for longer then the error time span
# Because we do not want to have stuff happening when the device does
# not show up for 1 scan beacuse of reboot etc
for device in temp_tracking_devices:
if now - known_dev[device]['last_seen'] > self.error_scanning:
try:
# If file does not exist we will write the header too
is_new_file = not os.path.isfile(known_dev_path)
self.states.set(known_dev[device]['entity_id'],
components.STATE_NOT_HOME,
known_dev[device]['default_state_attr'])
with open(known_dev_path, 'a') as outp:
_LOGGER.info(
"Found %d new devices, updating %s",
len(new_devices), known_dev_path)
# If we come along any unknown devices we will write them to the
# known devices file but only if we did not encounter an invalid
# known devices file
if not self.invalid_known_devices_file:
writer = csv.writer(outp)
known_dev_path = self.path_known_devices_file
if is_new_file:
writer.writerow((
"device", "name", "track", "picture"))
unknown_devices = [device for device in found_devices
if device not in known_dev]
for device in new_devices:
# See if the device scanner knows the name
# else defaults to unknown device
name = (self.device_scanner.get_device_name(device)
or "unknown_device")
if unknown_devices:
try:
# If file does not exist we will write the header too
is_new_file = not os.path.isfile(known_dev_path)
writer.writerow((device, name, 0, ""))
with open(known_dev_path, 'a') as outp:
_LOGGER.info(
"Found %d new devices, updating %s",
len(unknown_devices), known_dev_path)
writer = csv.writer(outp)
if is_new_file:
writer.writerow((
"device", "name", "track", "picture"))
for device in unknown_devices:
# See if the device scanner knows the name
# else defaults to unknown device
name = (self.device_scanner.get_device_name(device)
or "unknown_device")
writer.writerow((device, name, 0, ""))
known_dev[device] = {'name': name,
'track': False,
'picture': ""}
except IOError:
_LOGGER.exception(
"Error updating %s with %d new devices",
known_dev_path, len(unknown_devices))
except IOError:
_LOGGER.exception(
"Error updating %s with %d new devices",
known_dev_path, len(new_devices))
self.lock.release()
# pylint: disable=too-many-branches
def _read_known_devices_file(self):
""" Parse and process the known devices file. """
known_dev_path = self.hass.get_config_path(KNOWN_DEVICES_FILE)
# Read known devices if file exists
if os.path.isfile(self.path_known_devices_file):
self.lock.acquire()
# Return if no known devices file exists
if not os.path.isfile(known_dev_path):
return
known_devices = {}
self.lock.acquire()
with open(self.path_known_devices_file) as inp:
default_last_seen = datetime(1990, 1, 1)
self.untracked_devices.clear()
# Temp variable to keep track of which entity ids we use
# so we can ensure we have unique entity ids.
used_entity_ids = []
with open(known_dev_path) as inp:
default_last_seen = datetime(1990, 1, 1)
try:
for row in csv.DictReader(inp):
device = row['device']
# To track which devices need an entity_id assigned
need_entity_id = []
row['track'] = True if row['track'] == '1' else False
# All devices that are still in this set after we read the CSV file
# have been removed from the file and thus need to be cleaned up.
removed_devices = set(self.tracked.keys())
try:
for row in csv.DictReader(inp):
device = row['device']
if row['track'] == '1':
if device in self.tracked:
# Device exists
removed_devices.remove(device)
else:
# We found a new device
need_entity_id.append(device)
self.tracked[device] = {
'name': row['name'],
'last_seen': default_last_seen
}
# Update state_attr with latest from file
state_attr = {
ATTR_FRIENDLY_NAME: row['name']
}
if row['picture']:
row['default_state_attr'] = {
components.ATTR_ENTITY_PICTURE: row['picture']}
state_attr[ATTR_ENTITY_PICTURE] = row['picture']
else:
row['default_state_attr'] = None
self.tracked[device]['state_attr'] = state_attr
# If we track this device setup tracking variables
if row['track']:
row['last_seen'] = default_last_seen
else:
self.untracked_devices.add(device)
# Make sure that each device is mapped
# to a unique entity_id name
name = util.slugify(row['name']) if row['name'] \
else "unnamed_device"
# Remove existing devices that we no longer track
for device in removed_devices:
entity_id = self.tracked[device]['entity_id']
entity_id = ENTITY_ID_FORMAT.format(name)
tries = 1
_LOGGER.info("Removing entity %s", entity_id)
while entity_id in used_entity_ids:
tries += 1
self.hass.states.remove(entity_id)
suffix = "_{}".format(tries)
self.tracked.pop(device)
entity_id = ENTITY_ID_FORMAT.format(
name + suffix)
# Setup entity_ids for the new devices
used_entity_ids = [info['entity_id'] for device, info
in self.tracked.items()
if device not in need_entity_id]
row['entity_id'] = entity_id
used_entity_ids.append(entity_id)
for device in need_entity_id:
name = self.tracked[device]['name']
row['picture'] = row['picture']
entity_id = util.ensure_unique_string(
ENTITY_ID_FORMAT.format(util.slugify(name)),
used_entity_ids)
known_devices[device] = row
used_entity_ids.append(entity_id)
if not known_devices:
_LOGGER.warning(
"No devices to track. Please update %s.",
self.path_known_devices_file)
self.tracked[device]['entity_id'] = entity_id
# Remove entities that are no longer maintained
new_entity_ids = set([known_devices[dev]['entity_id']
for dev in known_devices
if known_devices[dev]['track']])
for entity_id in \
self.device_entity_ids - new_entity_ids:
_LOGGER.info("Removing entity %s", entity_id)
self.states.remove(entity_id)
# File parsed, warnings given if necessary
# entities cleaned up, make it available
self.known_devices = known_devices
_LOGGER.info("Loaded devices from %s",
self.path_known_devices_file)
except KeyError:
self.invalid_known_devices_file = True
if not self.tracked:
_LOGGER.warning(
("Invalid known devices file: %s. "
"We won't update it with new found devices."),
self.path_known_devices_file)
"No devices to track. Please update %s.",
known_dev_path)
finally:
self.lock.release()
_LOGGER.info("Loaded devices from %s", known_dev_path)
except KeyError:
self.invalid_known_devices_file = True
_LOGGER.warning(
("Invalid known devices file: %s. "
"We won't update it with new found devices."),
known_dev_path)
finally:
self.lock.release()

View File

@ -7,6 +7,7 @@ Provides functionality to group devices that can be turned on or off.
import logging
import homeassistant as ha
import homeassistant.util as util
from homeassistant.components import (STATE_ON, STATE_OFF,
STATE_HOME, STATE_NOT_HOME,
@ -24,6 +25,8 @@ _GROUP_TYPES = {
"home_not_home": (STATE_HOME, STATE_NOT_HOME)
}
_GROUPS = {}
def _get_group_type(state):
""" Determine the group type based on the given group type. """
@ -105,7 +108,6 @@ def setup(hass, config):
def setup_group(hass, name, entity_ids, user_defined=True):
""" Sets up a group state that is the combined state of
several states. Supports ON/OFF and DEVICE_HOME/DEVICE_NOT_HOME. """
# In case an iterable is passed in
entity_ids = list(entity_ids)
@ -159,35 +161,47 @@ def setup_group(hass, name, entity_ids, user_defined=True):
return False
else:
group_entity_id = ENTITY_ID_FORMAT.format(name)
state_attr = {ATTR_ENTITY_ID: entity_ids, ATTR_AUTO: not user_defined}
group_entity_id = ENTITY_ID_FORMAT.format(util.slugify(name))
state_attr = {ATTR_ENTITY_ID: entity_ids, ATTR_AUTO: not user_defined}
# pylint: disable=unused-argument
def update_group_state(entity_id, old_state, new_state):
""" Updates the group state based on a state change by
a tracked entity. """
# pylint: disable=unused-argument
def update_group_state(entity_id, old_state, new_state):
""" Updates the group state based on a state change by
a tracked entity. """
cur_gr_state = hass.states.get(group_entity_id).state
cur_gr_state = hass.states.get(group_entity_id).state
# if cur_gr_state = OFF and new_state = ON: set ON
# if cur_gr_state = ON and new_state = OFF: research
# else: ignore
# if cur_gr_state = OFF and new_state = ON: set ON
# if cur_gr_state = ON and new_state = OFF: research
# else: ignore
if cur_gr_state == group_off and new_state.state == group_on:
if cur_gr_state == group_off and new_state.state == group_on:
hass.states.set(group_entity_id, group_on, state_attr)
hass.states.set(group_entity_id, group_on, state_attr)
elif cur_gr_state == group_on and new_state.state == group_off:
elif cur_gr_state == group_on and new_state.state == group_off:
# Check if any of the other states is still on
if not any([hass.states.is_state(ent_id, group_on)
for ent_id in entity_ids
if entity_id != ent_id]):
hass.states.set(group_entity_id, group_off, state_attr)
# Check if any of the other states is still on
if not any([hass.states.is_state(ent_id, group_on)
for ent_id in entity_ids
if entity_id != ent_id]):
hass.states.set(group_entity_id, group_off, state_attr)
hass.states.track_change(entity_ids, update_group_state)
_GROUPS[group_entity_id] = hass.states.track_change(
entity_ids, update_group_state)
hass.states.set(group_entity_id, group_state, state_attr)
hass.states.set(group_entity_id, group_state, state_attr)
return True
return True
def remove_group(hass, name):
""" Remove a group and its state listener from Home Assistant. """
group_entity_id = ENTITY_ID_FORMAT.format(util.slugify(name))
if hass.states.get(group_entity_id) is not None:
hass.states.remove(group_entity_id)
if group_entity_id in _GROUPS:
hass.bus.remove_listener(
ha.EVENT_STATE_CHANGED, _GROUPS.pop(group_entity_id))