diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index 77625ffa7f8..a4438020886 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -3,16 +3,18 @@ import logging import voluptuous as vol -import homeassistant.helpers.config_validation as cv from homeassistant.const import ( - ATTR_UNIT_OF_MEASUREMENT, ATTR_MODE, + ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, - CONF_NAME, CONF_MODE, + CONF_NAME, + SERVICE_RELOAD, ) +import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity +import homeassistant.helpers.service _LOGGER = logging.getLogger(__name__) @@ -77,12 +79,49 @@ CONFIG_SCHEMA = vol.Schema( required=True, extra=vol.ALLOW_EXTRA, ) +RELOAD_SERVICE_SCHEMA = vol.Schema({}) async def async_setup(hass, config): """Set up an input slider.""" component = EntityComponent(_LOGGER, DOMAIN, hass) + entities = await _async_process_config(config) + + async def reload_service_handler(service_call): + """Remove all entities and load new ones from config.""" + conf = await component.async_prepare_reload() + if conf is None: + return + new_entities = await _async_process_config(conf) + if new_entities: + await component.async_add_entities(new_entities) + + homeassistant.helpers.service.async_register_admin_service( + hass, + DOMAIN, + SERVICE_RELOAD, + reload_service_handler, + schema=RELOAD_SERVICE_SCHEMA, + ) + + component.async_register_entity_service( + SERVICE_SET_VALUE, + {vol.Required(ATTR_VALUE): vol.Coerce(float)}, + "async_set_value", + ) + + component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") + + component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement") + + if entities: + await component.async_add_entities(entities) + return True + + +async def _async_process_config(config): + """Process config and create list of entities.""" entities = [] for object_id, cfg in config[DOMAIN].items(): @@ -101,21 +140,7 @@ async def async_setup(hass, config): ) ) - if not entities: - return False - - component.async_register_entity_service( - SERVICE_SET_VALUE, - {vol.Required(ATTR_VALUE): vol.Coerce(float)}, - "async_set_value", - ) - - component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") - - component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement") - - await component.async_add_entities(entities) - return True + return entities class InputNumber(RestoreEntity): diff --git a/homeassistant/components/input_number/services.yaml b/homeassistant/components/input_number/services.yaml index 650abc056a9..9cd1b913ccd 100644 --- a/homeassistant/components/input_number/services.yaml +++ b/homeassistant/components/input_number/services.yaml @@ -14,3 +14,5 @@ set_value: entity_id: {description: Entity id of the input number to set the new value., example: input_number.threshold} value: {description: The target value the entity should be set to., example: 42} +reload: + description: Reload the input_number configuration. diff --git a/tests/components/input_number/test_init.py b/tests/components/input_number/test_init.py index 02d59c367c9..a3b46212daf 100644 --- a/tests/components/input_number/test_init.py +++ b/tests/components/input_number/test_init.py @@ -1,16 +1,21 @@ """The tests for the Input number component.""" # pylint: disable=protected-access import asyncio +from unittest.mock import patch + +import pytest -from homeassistant.core import CoreState, State, Context from homeassistant.components.input_number import ( ATTR_VALUE, DOMAIN, SERVICE_DECREMENT, SERVICE_INCREMENT, + SERVICE_RELOAD, SERVICE_SET_VALUE, ) from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import Context, CoreState, State +from homeassistant.exceptions import Unauthorized from homeassistant.loader import bind_hass from homeassistant.setup import async_setup_component @@ -254,3 +259,57 @@ async def test_input_number_context(hass, hass_admin_user): assert state2 is not None assert state.state != state2.state assert state2.context.user_id == hass_admin_user.id + + +async def test_reload(hass, hass_admin_user, hass_read_only_user): + """Test reload service.""" + count_start = len(hass.states.async_entity_ids()) + + assert await async_setup_component( + hass, DOMAIN, {DOMAIN: {"test_1": {"initial": 50, "min": 0, "max": 51}}} + ) + + assert count_start + 1 == len(hass.states.async_entity_ids()) + + state_1 = hass.states.get("input_number.test_1") + state_2 = hass.states.get("input_number.test_2") + + assert state_1 is not None + assert state_2 is None + assert 50 == float(state_1.state) + + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value={ + DOMAIN: { + "test_1": {"initial": 40, "min": 0, "max": 51}, + "test_2": {"initial": 20, "min": 10, "max": 30}, + } + }, + ): + with patch("homeassistant.config.find_config_file", return_value=""): + with pytest.raises(Unauthorized): + await hass.services.async_call( + DOMAIN, + SERVICE_RELOAD, + blocking=True, + context=Context(user_id=hass_read_only_user.id), + ) + await hass.services.async_call( + DOMAIN, + SERVICE_RELOAD, + blocking=True, + context=Context(user_id=hass_admin_user.id), + ) + await hass.async_block_till_done() + + assert count_start + 2 == len(hass.states.async_entity_ids()) + + state_1 = hass.states.get("input_number.test_1") + state_2 = hass.states.get("input_number.test_2") + + assert state_1 is not None + assert state_2 is not None + assert 40 == float(state_1.state) + assert 20 == float(state_2.state)