Initial state over restore state (#6924)

* Input Boolean: initial state > restore state

* Input select: initial state overrules restored state

* Input slider: initial state overrule restore state

* Lint

* Lint
pull/6940/head
Paulus Schoutsen 2017-04-04 09:29:49 -07:00 committed by GitHub
parent c5574c2684
commit c4e1255a84
7 changed files with 147 additions and 63 deletions

View File

@ -30,13 +30,11 @@ SERVICE_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
}) })
DEFAULT_CONFIG = {CONF_INITIAL: DEFAULT_INITIAL}
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({ DOMAIN: vol.Schema({
cv.slug: vol.Any({ cv.slug: vol.Any({
vol.Optional(CONF_NAME): cv.string, vol.Optional(CONF_NAME): cv.string,
vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.boolean, vol.Optional(CONF_INITIAL): cv.boolean,
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_ICON): cv.icon,
}, None) }, None)
}) })
@ -72,13 +70,13 @@ def async_setup(hass, config):
for object_id, cfg in config[DOMAIN].items(): for object_id, cfg in config[DOMAIN].items():
if not cfg: if not cfg:
cfg = DEFAULT_CONFIG cfg = {}
name = cfg.get(CONF_NAME) name = cfg.get(CONF_NAME)
state = cfg.get(CONF_INITIAL) initial = cfg.get(CONF_INITIAL)
icon = cfg.get(CONF_ICON) icon = cfg.get(CONF_ICON)
entities.append(InputBoolean(object_id, name, state, icon)) entities.append(InputBoolean(object_id, name, initial, icon))
if not entities: if not entities:
return False return False
@ -113,11 +111,11 @@ def async_setup(hass, config):
class InputBoolean(ToggleEntity): class InputBoolean(ToggleEntity):
"""Representation of a boolean input.""" """Representation of a boolean input."""
def __init__(self, object_id, name, state, icon): def __init__(self, object_id, name, initial, icon):
"""Initialize a boolean input.""" """Initialize a boolean input."""
self.entity_id = ENTITY_ID_FORMAT.format(object_id) self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self._name = name self._name = name
self._state = state self._state = initial
self._icon = icon self._icon = icon
@property @property
@ -143,10 +141,12 @@ class InputBoolean(ToggleEntity):
@asyncio.coroutine @asyncio.coroutine
def async_added_to_hass(self): def async_added_to_hass(self):
"""Called when entity about to be added to hass.""" """Called when entity about to be added to hass."""
state = yield from async_get_last_state(self.hass, self.entity_id) # If not None, we got an initial value.
if not state: if self._state is not None:
return return
self._state = state.state == STATE_ON
state = yield from async_get_last_state(self.hass, self.entity_id)
self._state = state and state.state == STATE_ON
@asyncio.coroutine @asyncio.coroutine
def async_turn_on(self, **kwargs): def async_turn_on(self, **kwargs):

View File

@ -58,10 +58,10 @@ SERVICE_SET_OPTIONS_SCHEMA = vol.Schema({
def _cv_input_select(cfg): def _cv_input_select(cfg):
"""Config validation helper for input select (Voluptuous).""" """Config validation helper for input select (Voluptuous)."""
options = cfg[CONF_OPTIONS] options = cfg[CONF_OPTIONS]
state = cfg.get(CONF_INITIAL, options[0]) initial = cfg.get(CONF_INITIAL)
if state not in options: if initial is not None and initial not in options:
raise vol.Invalid('initial state "{}" is not part of the options: {}' raise vol.Invalid('initial state "{}" is not part of the options: {}'
.format(state, ','.join(options))) .format(initial, ','.join(options)))
return cfg return cfg
@ -117,9 +117,9 @@ def async_setup(hass, config):
for object_id, cfg in config[DOMAIN].items(): for object_id, cfg in config[DOMAIN].items():
name = cfg.get(CONF_NAME) name = cfg.get(CONF_NAME)
options = cfg.get(CONF_OPTIONS) options = cfg.get(CONF_OPTIONS)
state = cfg.get(CONF_INITIAL, options[0]) initial = cfg.get(CONF_INITIAL)
icon = cfg.get(CONF_ICON) icon = cfg.get(CONF_ICON)
entities.append(InputSelect(object_id, name, state, options, icon)) entities.append(InputSelect(object_id, name, initial, options, icon))
if not entities: if not entities:
return False return False
@ -187,23 +187,25 @@ def async_setup(hass, config):
class InputSelect(Entity): class InputSelect(Entity):
"""Representation of a select input.""" """Representation of a select input."""
def __init__(self, object_id, name, state, options, icon): def __init__(self, object_id, name, initial, options, icon):
"""Initialize a select input.""" """Initialize a select input."""
self.entity_id = ENTITY_ID_FORMAT.format(object_id) self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self._name = name self._name = name
self._current_option = state self._current_option = initial
self._options = options self._options = options
self._icon = icon self._icon = icon
@asyncio.coroutine @asyncio.coroutine
def async_added_to_hass(self): def async_added_to_hass(self):
"""Called when entity about to be added to hass.""" """Called when entity about to be added to hass."""
if self._current_option is not None:
return
state = yield from async_get_last_state(self.hass, self.entity_id) state = yield from async_get_last_state(self.hass, self.entity_id)
if not state: if not state or state.state not in self._options:
return self._current_option = self._options[0]
if state.state not in self._options: else:
return self._current_option = state.state
self._current_option = state.state
@property @property
def should_poll(self): def should_poll(self):

View File

@ -45,11 +45,10 @@ def _cv_input_slider(cfg):
if minimum >= maximum: if minimum >= maximum:
raise vol.Invalid('Maximum ({}) is not greater than minimum ({})' raise vol.Invalid('Maximum ({}) is not greater than minimum ({})'
.format(minimum, maximum)) .format(minimum, maximum))
state = cfg.get(CONF_INITIAL, minimum) state = cfg.get(CONF_INITIAL)
if state < minimum or state > maximum: if state is not None and (state < minimum or state > maximum):
raise vol.Invalid('Initial value {} not in range {}-{}' raise vol.Invalid('Initial value {} not in range {}-{}'
.format(state, minimum, maximum)) .format(state, minimum, maximum))
cfg[CONF_INITIAL] = state
return cfg return cfg
@ -88,12 +87,12 @@ def async_setup(hass, config):
name = cfg.get(CONF_NAME) name = cfg.get(CONF_NAME)
minimum = cfg.get(CONF_MIN) minimum = cfg.get(CONF_MIN)
maximum = cfg.get(CONF_MAX) maximum = cfg.get(CONF_MAX)
state = cfg.get(CONF_INITIAL, minimum) initial = cfg.get(CONF_INITIAL)
step = cfg.get(CONF_STEP) step = cfg.get(CONF_STEP)
icon = cfg.get(CONF_ICON) icon = cfg.get(CONF_ICON)
unit = cfg.get(ATTR_UNIT_OF_MEASUREMENT) unit = cfg.get(ATTR_UNIT_OF_MEASUREMENT)
entities.append(InputSlider(object_id, name, state, minimum, maximum, entities.append(InputSlider(object_id, name, initial, minimum, maximum,
step, icon, unit)) step, icon, unit))
if not entities: if not entities:
@ -120,12 +119,12 @@ def async_setup(hass, config):
class InputSlider(Entity): class InputSlider(Entity):
"""Represent an slider.""" """Represent an slider."""
def __init__(self, object_id, name, state, minimum, maximum, step, icon, def __init__(self, object_id, name, initial, minimum, maximum, step, icon,
unit): unit):
"""Initialize a select input.""" """Initialize a select input."""
self.entity_id = ENTITY_ID_FORMAT.format(object_id) self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self._name = name self._name = name
self._current_value = state self._current_value = initial
self._minimum = minimum self._minimum = minimum
self._maximum = maximum self._maximum = maximum
self._step = step self._step = step
@ -169,14 +168,17 @@ class InputSlider(Entity):
@asyncio.coroutine @asyncio.coroutine
def async_added_to_hass(self): def async_added_to_hass(self):
"""Called when entity about to be added to hass.""" """Called when entity about to be added to hass."""
state = yield from async_get_last_state(self.hass, self.entity_id) if self._current_value is not None:
if not state:
return return
num_value = float(state.state) state = yield from async_get_last_state(self.hass, self.entity_id)
if num_value < self._minimum or num_value > self._maximum: value = state and float(state.state)
return
self._current_value = num_value # Check against False because value can be 0
if value is not False and self._minimum < value < self._maximum:
self._current_value = value
else:
self._current_value = self._minimum
@asyncio.coroutine @asyncio.coroutine
def async_select_value(self, value): def async_select_value(self, value):

View File

@ -12,7 +12,7 @@ from contextlib import contextmanager
from aiohttp import web from aiohttp import web
from homeassistant import core as ha, loader from homeassistant import core as ha, loader
from homeassistant.setup import setup_component, DATA_SETUP from homeassistant.setup import setup_component
from homeassistant.config import async_process_component_config from homeassistant.config import async_process_component_config
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
@ -271,15 +271,10 @@ def mock_mqtt_component(hass):
def mock_component(hass, component): def mock_component(hass, component):
"""Mock a component is setup.""" """Mock a component is setup."""
setup_tasks = hass.data.get(DATA_SETUP) if component in hass.config.components:
if setup_tasks is None:
setup_tasks = hass.data[DATA_SETUP] = {}
if component not in setup_tasks:
AssertionError("Component {} is already setup".format(component)) AssertionError("Component {} is already setup".format(component))
hass.config.components.add(component) hass.config.components.add(component)
setup_tasks[component] = asyncio.Task(mock_coro(True), loop=hass.loop)
class MockModule(object): class MockModule(object):
@ -499,4 +494,4 @@ def mock_restore_cache(hass, states):
assert len(hass.data[DATA_RESTORE_CACHE]) == len(states), \ assert len(hass.data[DATA_RESTORE_CACHE]) == len(states), \
"Duplicate entity_id? {}".format(states) "Duplicate entity_id? {}".format(states)
hass.state = ha.CoreState.starting hass.state = ha.CoreState.starting
hass.config.components.add(recorder.DOMAIN) mock_component(hass, recorder.DOMAIN)

View File

@ -4,15 +4,15 @@ import asyncio
import unittest import unittest
import logging import logging
from tests.common import get_test_home_assistant, mock_component
from homeassistant.core import CoreState, State from homeassistant.core import CoreState, State
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.components.input_boolean import ( from homeassistant.components.input_boolean import (
DOMAIN, is_on, toggle, turn_off, turn_on) DOMAIN, is_on, toggle, turn_off, turn_on, CONF_INITIAL)
from homeassistant.const import ( from homeassistant.const import (
STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME) STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME)
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
from tests.common import (
get_test_home_assistant, mock_component, mock_restore_cache)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -111,11 +111,11 @@ class TestInputBoolean(unittest.TestCase):
@asyncio.coroutine @asyncio.coroutine
def test_restore_state(hass): def test_restore_state(hass):
"""Ensure states are restored on startup.""" """Ensure states are restored on startup."""
hass.data[DATA_RESTORE_CACHE] = { mock_restore_cache(hass, (
'input_boolean.b1': State('input_boolean.b1', 'on'), State('input_boolean.b1', 'on'),
'input_boolean.b2': State('input_boolean.b2', 'off'), State('input_boolean.b2', 'off'),
'input_boolean.b3': State('input_boolean.b3', 'on'), State('input_boolean.b3', 'on'),
} ))
hass.state = CoreState.starting hass.state = CoreState.starting
mock_component(hass, 'recorder') mock_component(hass, 'recorder')
@ -133,3 +133,28 @@ def test_restore_state(hass):
state = hass.states.get('input_boolean.b2') state = hass.states.get('input_boolean.b2')
assert state assert state
assert state.state == 'off' assert state.state == 'off'
@asyncio.coroutine
def test_initial_state_overrules_restore_state(hass):
"""Ensure states are restored on startup."""
mock_restore_cache(hass, (
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'off'),
))
hass.state = CoreState.starting
yield from async_setup_component(hass, DOMAIN, {
DOMAIN: {
'b1': {CONF_INITIAL: False},
'b2': {CONF_INITIAL: True},
}})
state = hass.states.get('input_boolean.b1')
assert state
assert state.state == 'off'
state = hass.states.get('input_boolean.b2')
assert state
assert state.state == 'on'

View File

@ -229,7 +229,6 @@ def test_restore_state(hass):
'middle option', 'middle option',
'last option', 'last option',
], ],
'initial': 'middle option',
} }
yield from async_setup_component(hass, DOMAIN, { yield from async_setup_component(hass, DOMAIN, {
@ -242,6 +241,38 @@ def test_restore_state(hass):
assert state assert state
assert state.state == 'last option' assert state.state == 'last option'
state = hass.states.get('input_select.s2')
assert state
assert state.state == 'first option'
@asyncio.coroutine
def test_initial_state_overrules_restore_state(hass):
"""Ensure states are restored on startup."""
mock_restore_cache(hass, (
State('input_select.s1', 'last option'),
State('input_select.s2', 'bad option'),
))
options = {
'options': [
'first option',
'middle option',
'last option',
],
'initial': 'middle option',
}
yield from async_setup_component(hass, DOMAIN, {
DOMAIN: {
's1': options,
's2': options,
}})
state = hass.states.get('input_select.s1')
assert state
assert state.state == 'middle option'
state = hass.states.get('input_select.s2') state = hass.states.get('input_select.s2')
assert state assert state
assert state.state == 'middle option' assert state.state == 'middle option'

View File

@ -3,12 +3,11 @@
import asyncio import asyncio
import unittest import unittest
from tests.common import get_test_home_assistant, mock_component
from homeassistant.core import CoreState, State from homeassistant.core import CoreState, State
from homeassistant.setup import setup_component, async_setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant.components.input_slider import (DOMAIN, select_value) from homeassistant.components.input_slider import (DOMAIN, select_value)
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
from tests.common import get_test_home_assistant, mock_restore_cache
class TestInputSlider(unittest.TestCase): class TestInputSlider(unittest.TestCase):
@ -75,13 +74,43 @@ class TestInputSlider(unittest.TestCase):
@asyncio.coroutine @asyncio.coroutine
def test_restore_state(hass): def test_restore_state(hass):
"""Ensure states are restored on startup.""" """Ensure states are restored on startup."""
hass.data[DATA_RESTORE_CACHE] = { mock_restore_cache(hass, (
'input_slider.b1': State('input_slider.b1', '70'), State('input_slider.b1', '70'),
'input_slider.b2': State('input_slider.b2', '200'), State('input_slider.b2', '200'),
} ))
hass.state = CoreState.starting
yield from async_setup_component(hass, DOMAIN, {
DOMAIN: {
'b1': {
'min': 0,
'max': 100,
},
'b2': {
'min': 10,
'max': 100,
},
}})
state = hass.states.get('input_slider.b1')
assert state
assert float(state.state) == 70
state = hass.states.get('input_slider.b2')
assert state
assert float(state.state) == 10
@asyncio.coroutine
def test_initial_state_overrules_restore_state(hass):
"""Ensure states are restored on startup."""
mock_restore_cache(hass, (
State('input_slider.b1', '70'),
State('input_slider.b2', '200'),
))
hass.state = CoreState.starting hass.state = CoreState.starting
mock_component(hass, 'recorder')
yield from async_setup_component(hass, DOMAIN, { yield from async_setup_component(hass, DOMAIN, {
DOMAIN: { DOMAIN: {
@ -99,7 +128,7 @@ def test_restore_state(hass):
state = hass.states.get('input_slider.b1') state = hass.states.get('input_slider.b1')
assert state assert state
assert float(state.state) == 70 assert float(state.state) == 50
state = hass.states.get('input_slider.b2') state = hass.states.get('input_slider.b2')
assert state assert state