Support custom interval for platforms

pull/1059/head
Paulus Schoutsen 2016-01-31 00:55:46 -08:00
parent 90e17fc77f
commit fce8815ab4
3 changed files with 133 additions and 69 deletions

View File

@ -26,7 +26,7 @@ CONF_PASSWORD = "password"
CONF_API_KEY = "api_key"
CONF_ACCESS_TOKEN = "access_token"
CONF_FILENAME = "filename"
CONF_SCAN_INTERVAL = "scan_interval"
CONF_VALUE_TEMPLATE = "value_template"
# #### EVENTS ####

View File

@ -1,6 +1,7 @@
"""Provides helpers for components that manage entities."""
from threading import Lock
from homeassistant.const import CONF_SCAN_INTERVAL
from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity import generate_entity_id
@ -37,6 +38,9 @@ class EntityComponent(object):
self.config = None
self.lock = Lock()
self.add_entities = EntityPlatform(self,
self.scan_interval).add_entities
def setup(self, config):
"""
Set up a full entity component.
@ -59,47 +63,6 @@ class EntityComponent(object):
self._setup_platform(self.discovery_platforms[service], {},
info))
def add_entities(self, new_entities):
"""
Add new entities to this component.
For each entity will see if it already exists. If not, will add it,
set it up and push the first state.
"""
with self.lock:
for entity in new_entities:
if entity is None or entity in self.entities.values():
continue
entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None:
entity.entity_id = generate_entity_id(
self.entity_id_format, entity.name,
self.entities.keys())
self.entities[entity.entity_id] = entity
entity.update_ha_state()
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
if self.is_polling or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def extract_from_service(self, service):
"""
Extract all known entities from a service call.
@ -115,19 +78,6 @@ class EntityComponent(object):
in extract_entity_ids(self.hass, service)
if entity_id in self.entities]
def _update_entity_states(self, now):
"""Update the states of all the polling entities."""
with self.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
entities = list(entity for entity in self.entities.values()
if entity.should_poll)
self.logger.info("Updating %s entities", self.domain)
for entity in entities:
entity.update_ha_state(True)
def _setup_platform(self, platform_type, platform_config,
discovery_info=None):
"""Setup a platform for this component."""
@ -138,12 +88,85 @@ class EntityComponent(object):
return
try:
# Config > Platform > Component
scan_interval = platform_config.get(
CONF_SCAN_INTERVAL,
getattr(platform, 'SCAN_INTERVAL', self.scan_interval))
platform.setup_platform(
self.hass, platform_config, self.add_entities, discovery_info)
self.hass, platform_config,
EntityPlatform(self, scan_interval).add_entities,
discovery_info)
platform_name = '{}.{}'.format(self.domain, platform_type)
self.hass.config.components.append(platform_name)
except Exception: # pylint: disable=broad-except
self.logger.exception(
'Error while setting up platform %s', platform_type)
return
platform_name = '{}.{}'.format(self.domain, platform_type)
self.hass.config.components.append(platform_name)
def add_entity(self, entity):
"""Add entity to component."""
if entity is None or entity in self.entities.values():
return False
entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None:
entity.entity_id = generate_entity_id(
self.entity_id_format, entity.name,
self.entities.keys())
self.entities[entity.entity_id] = entity
entity.update_ha_state()
return True
def update_group(self):
"""Set up and/or update component group."""
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
class EntityPlatform(object):
"""Keep track of entities for a single platform."""
# pylint: disable=too-few-public-methods
def __init__(self, component, scan_interval):
self.component = component
self.scan_interval = scan_interval
self.platform_entities = []
self.is_polling = False
def add_entities(self, new_entities):
"""Add entities for a single platform."""
with self.component.lock:
for entity in new_entities:
if self.component.add_entity(entity):
self.platform_entities.append(entity)
self.component.update_group()
if self.is_polling or \
not any(entity.should_poll for entity
in self.platform_entities):
return
self.is_polling = True
track_utc_time_change(
self.component.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def _update_entity_states(self, now):
"""Update the states of all the polling entities."""
with self.component.lock:
# We copy the entities because new entities might be detected
# during state update causing deadlocks.
entities = list(entity for entity in self.platform_entities
if entity.should_poll)
for entity in entities:
entity.update_ha_state(True)

View File

@ -15,8 +15,10 @@ import homeassistant.loader as loader
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.components import discovery
import homeassistant.util.dt as dt_util
from tests.common import get_test_home_assistant, MockPlatform, MockModule
from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, fire_time_changed)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
@ -84,8 +86,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert ['test_domain.hello', 'test_domain.hello2'] == \
sorted(group.attributes.get('entity_id'))
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_polling_only_updates_entities_it_should_poll(self, mock_track):
def test_polling_only_updates_entities_it_should_poll(self):
component = EntityComponent(_LOGGER, DOMAIN, self.hass, 20)
no_poll_ent = EntityTest(should_poll=False)
@ -93,17 +94,13 @@ class TestHelpersEntityComponent(unittest.TestCase):
poll_ent = EntityTest(should_poll=True)
poll_ent.update_ha_state = Mock()
component.add_entities([no_poll_ent])
assert not mock_track.called
component.add_entities([poll_ent])
assert mock_track.called
assert [0, 20, 40] == list(mock_track.call_args[1].get('second'))
component.add_entities([no_poll_ent, poll_ent])
no_poll_ent.update_ha_state.reset_mock()
poll_ent.update_ha_state.reset_mock()
component._update_entity_states(None)
fire_time_changed(self.hass, dt_util.utcnow().replace(second=0))
self.hass.pool.block_till_done()
assert not no_poll_ent.update_ha_state.called
assert poll_ent.update_ha_state.called
@ -118,7 +115,10 @@ class TestHelpersEntityComponent(unittest.TestCase):
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update_ha_state = lambda *_: component.add_entities([ent1])
component._update_entity_states(None)
fire_time_changed(self.hass, dt_util.utcnow().replace(second=0))
self.hass.pool.block_till_done()
assert 2 == len(self.hass.states.entity_ids())
def test_not_adding_duplicate_entities(self):
@ -234,3 +234,44 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert mock_setup.called
assert ('platform_test', {}, 'discovery_info') == \
mock_setup.call_args[0]
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_set_scan_interval_via_config(self, mock_track):
def platform_setup(hass, config, add_devices, discovery_info=None):
add_devices([EntityTest(should_poll=True)])
loader.set_component('test_domain.platform',
MockPlatform(platform_setup))
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
'scan_interval': 30,
}
})
assert mock_track.called
assert [0, 30] == list(mock_track.call_args[1]['second'])
@patch('homeassistant.helpers.entity_component.track_utc_time_change')
def test_set_scan_interval_via_platform(self, mock_track):
def platform_setup(hass, config, add_devices, discovery_info=None):
add_devices([EntityTest(should_poll=True)])
platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = 30
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
}
})
assert mock_track.called
assert [0, 30] == list(mock_track.call_args[1]['second'])