Add locks to entity component

pull/679/head
Paulus Schoutsen 2015-11-28 15:55:01 -08:00
parent ef394b8af7
commit 99aa4307ef
1 changed files with 31 additions and 26 deletions

View File

@ -4,6 +4,8 @@ homeassistant.helpers.entity_component
Provides helpers for components that manage entities.
"""
from threading import Lock
from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.helpers import (
generate_entity_id, config_per_platform, extract_entity_ids)
@ -37,6 +39,7 @@ class EntityComponent(object):
self.is_polling = False
self.config = None
self.lock = Lock()
def setup(self, config):
"""
@ -61,8 +64,11 @@ class EntityComponent(object):
Takes in a list of new entities. For each entity will see if it already
exists. If not, will add it, set it up and push the first state.
"""
for entity in new_entities:
if entity is not None and entity not in self.entities.values():
with self.lock:
for entity in new_entities:
if entity is None or entity in self.entities.values():
continue
entity.hass = self.hass
if getattr(entity, 'entity_id', None) is None:
@ -74,23 +80,33 @@ class EntityComponent(object):
entity.update_ha_state()
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is None and self.group_name is not None:
self.group = group.Group(self.hass, self.group_name,
user_defined=False)
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())
self._start_polling()
if self.is_polling or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def extract_from_service(self, service):
"""
Takes a service and extracts all known entities.
Will return all if no entity IDs given in service.
"""
if ATTR_ENTITY_ID not in service.data:
return self.entities.values()
else:
with self.lock:
if ATTR_ENTITY_ID not in service.data:
return list(self.entities.values())
return [self.entities[entity_id] for entity_id
in extract_entity_ids(self.hass, service)
if entity_id in self.entities]
@ -99,9 +115,10 @@ class EntityComponent(object):
""" Update the states of all the entities. """
self.logger.info("Updating %s entities", self.domain)
for entity in self.entities.values():
if entity.should_poll:
entity.update_ha_state(True)
with self.lock:
for entity in self.entities.values():
if entity.should_poll:
entity.update_ha_state(True)
def _entity_discovered(self, service, info):
""" Called when a entity is discovered. """
@ -110,18 +127,6 @@ class EntityComponent(object):
self._setup_platform(self.discovery_platforms[service], {}, info)
def _start_polling(self):
""" Start polling entities if necessary. """
if self.is_polling or \
not any(entity.should_poll for entity in self.entities.values()):
return
self.is_polling = True
track_utc_time_change(
self.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))
def _setup_platform(self, platform_type, platform_config,
discovery_info=None):
""" Tries to setup a platform for this component. """