Prevent entities running multiple updates simultaneously (#6511)
* Protect entity for multible updates on same time. * Address all comments / make update more robust * fix unittest * fix lint * address commentspull/6626/head
parent
c4e151f621
commit
5529d77c62
|
@ -19,6 +19,7 @@ from homeassistant.util.async import (
|
|||
run_coroutine_threadsafe, run_callback_threadsafe)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SLOW_UPDATE_WARNING = 10
|
||||
|
||||
|
||||
def generate_entity_id(entity_id_format: str, name: Optional[str],
|
||||
|
@ -70,6 +71,9 @@ class Entity(object):
|
|||
# If we reported if this entity was slow
|
||||
_slow_reported = False
|
||||
|
||||
# protect for multible updates
|
||||
_update_warn = None
|
||||
|
||||
@property
|
||||
def should_poll(self) -> bool:
|
||||
"""Return True if entity has to be polled for state.
|
||||
|
@ -199,12 +203,32 @@ class Entity(object):
|
|||
raise NoEntitySpecifiedError(
|
||||
"No entity id specified for entity {}".format(self.name))
|
||||
|
||||
# update entity data
|
||||
if force_refresh:
|
||||
if hasattr(self, 'async_update'):
|
||||
# pylint: disable=no-member
|
||||
yield from self.async_update()
|
||||
else:
|
||||
yield from self.hass.loop.run_in_executor(None, self.update)
|
||||
if self._update_warn:
|
||||
_LOGGER.warning('Update for %s is already in progress',
|
||||
self.entity_id)
|
||||
return
|
||||
|
||||
self._update_warn = self.hass.loop.call_later(
|
||||
SLOW_UPDATE_WARNING, _LOGGER.warning,
|
||||
'Update of %s is taking over %s seconds.', self.entity_id,
|
||||
SLOW_UPDATE_WARNING
|
||||
)
|
||||
|
||||
try:
|
||||
if hasattr(self, 'async_update'):
|
||||
# pylint: disable=no-member
|
||||
yield from self.async_update()
|
||||
else:
|
||||
yield from self.hass.loop.run_in_executor(
|
||||
None, self.update)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Update for %s fails', self.entity_id)
|
||||
return
|
||||
finally:
|
||||
self._update_warn.cancel()
|
||||
self._update_warn = None
|
||||
|
||||
start = timer()
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
"""Test the entity helper."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -132,3 +131,68 @@ class TestHelpersEntity(object):
|
|||
self.hass.block_till_done()
|
||||
state = self.hass.states.get(self.entity.entity_id)
|
||||
assert state.attributes.get(ATTR_DEVICE_CLASS) == 'test_class'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_warn_slow_update(hass):
|
||||
"""Warn we log when entity update takes a long time."""
|
||||
update_call = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update():
|
||||
"""Mock async update."""
|
||||
nonlocal update_call
|
||||
update_call = True
|
||||
|
||||
mock_entity = entity.Entity()
|
||||
mock_entity.hass = hass
|
||||
mock_entity.entity_id = 'comp_test.test_entity'
|
||||
mock_entity.async_update = async_update
|
||||
|
||||
with patch.object(hass.loop, 'call_later', MagicMock()) \
|
||||
as mock_call:
|
||||
yield from mock_entity.async_update_ha_state(True)
|
||||
assert mock_call.called
|
||||
assert len(mock_call.mock_calls) == 2
|
||||
|
||||
timeout, logger_method = mock_call.mock_calls[0][1][:2]
|
||||
|
||||
assert timeout == entity.SLOW_UPDATE_WARNING
|
||||
assert logger_method == entity._LOGGER.warning
|
||||
|
||||
assert mock_call().cancel.called
|
||||
|
||||
assert update_call
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_warn_slow_update_with_exception(hass):
|
||||
"""Warn we log when entity update takes a long time and trow exception."""
|
||||
update_call = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update():
|
||||
"""Mock async update."""
|
||||
nonlocal update_call
|
||||
update_call = True
|
||||
raise AssertionError("Fake update error")
|
||||
|
||||
mock_entity = entity.Entity()
|
||||
mock_entity.hass = hass
|
||||
mock_entity.entity_id = 'comp_test.test_entity'
|
||||
mock_entity.async_update = async_update
|
||||
|
||||
with patch.object(hass.loop, 'call_later', MagicMock()) \
|
||||
as mock_call:
|
||||
yield from mock_entity.async_update_ha_state(True)
|
||||
assert mock_call.called
|
||||
assert len(mock_call.mock_calls) == 2
|
||||
|
||||
timeout, logger_method = mock_call.mock_calls[0][1][:2]
|
||||
|
||||
assert timeout == entity.SLOW_UPDATE_WARNING
|
||||
assert logger_method == entity._LOGGER.warning
|
||||
|
||||
assert mock_call().cancel.called
|
||||
|
||||
assert update_call
|
||||
|
|
|
@ -116,6 +116,43 @@ class TestHelpersEntityComponent(unittest.TestCase):
|
|||
assert not no_poll_ent.async_update.called
|
||||
assert poll_ent.async_update.called
|
||||
|
||||
def test_polling_updates_entities_with_exception(self):
|
||||
"""Test the updated entities that not brake with a exception."""
|
||||
component = EntityComponent(
|
||||
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
|
||||
|
||||
update_ok = []
|
||||
update_err = []
|
||||
|
||||
def update_mock():
|
||||
"""Mock normal update."""
|
||||
update_ok.append(None)
|
||||
|
||||
def update_mock_err():
|
||||
"""Mock error update."""
|
||||
update_err.append(None)
|
||||
raise AssertionError("Fake error update")
|
||||
|
||||
ent1 = EntityTest(should_poll=True)
|
||||
ent1.update = update_mock_err
|
||||
ent2 = EntityTest(should_poll=True)
|
||||
ent2.update = update_mock
|
||||
ent3 = EntityTest(should_poll=True)
|
||||
ent3.update = update_mock
|
||||
ent4 = EntityTest(should_poll=True)
|
||||
ent4.update = update_mock
|
||||
|
||||
component.add_entities([ent1, ent2, ent3, ent4])
|
||||
|
||||
update_ok.clear()
|
||||
update_err.clear()
|
||||
|
||||
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(update_ok) == 3
|
||||
assert len(update_err) == 1
|
||||
|
||||
def test_update_state_adds_entities(self):
|
||||
"""Test if updating poll entities cause an entity to be added works."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
|
||||
|
|
Loading…
Reference in New Issue