diff --git a/homeassistant/components/input_boolean/__init__.py b/homeassistant/components/input_boolean/__init__.py index 7dee3614ad5..a43b132a0e2 100644 --- a/homeassistant/components/input_boolean/__init__.py +++ b/homeassistant/components/input_boolean/__init__.py @@ -37,20 +37,25 @@ _LOGGER = logging.getLogger(__name__) CONF_INITIAL = "initial" -CREATE_FIELDS = { +STORAGE_FIELDS = { vol.Required(CONF_NAME): vol.All(str, vol.Length(min=1)), vol.Optional(CONF_INITIAL): cv.boolean, vol.Optional(CONF_ICON): cv.icon, } -UPDATE_FIELDS = { - vol.Optional(CONF_NAME): cv.string, - vol.Optional(CONF_INITIAL): cv.boolean, - vol.Optional(CONF_ICON): cv.icon, -} - CONFIG_SCHEMA = vol.Schema( - {DOMAIN: cv.schema_with_slug_keys(vol.Any(UPDATE_FIELDS, None))}, + { + DOMAIN: cv.schema_with_slug_keys( + vol.Any( + { + vol.Optional(CONF_NAME): cv.string, + vol.Optional(CONF_INITIAL): cv.boolean, + vol.Optional(CONF_ICON): cv.icon, + }, + None, + ) + ) + }, extra=vol.ALLOW_EXTRA, ) @@ -62,12 +67,11 @@ STORAGE_VERSION = 1 class InputBooleanStorageCollection(collection.StorageCollection): """Input boolean collection stored in storage.""" - CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) - UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) + CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) async def _process_create_data(self, data: dict) -> dict: """Validate the config is valid.""" - return self.CREATE_SCHEMA(data) + return self.CREATE_UPDATE_SCHEMA(data) @callback def _get_suggested_id(self, info: dict) -> str: @@ -76,8 +80,8 @@ class InputBooleanStorageCollection(collection.StorageCollection): async def _update_data(self, data: dict, update_data: dict) -> dict: """Return a new updated data object.""" - update_data = self.UPDATE_SCHEMA(update_data) - return {**data, **update_data} + update_data = self.CREATE_UPDATE_SCHEMA(update_data) + return {CONF_ID: data[CONF_ID]} | update_data @bind_hass @@ -118,7 +122,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await storage_collection.async_load() collection.StorageCollectionWebsocket( - storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) async def reload_service_handler(service_call: ServiceCall) -> None: diff --git a/tests/components/input_boolean/test_init.py b/tests/components/input_boolean/test_init.py index 2b7a1f88ef1..2e044c7a90f 100644 --- a/tests/components/input_boolean/test_init.py +++ b/tests/components/input_boolean/test_init.py @@ -40,7 +40,11 @@ def storage_setup(hass, hass_storage): "data": {"items": [{"id": "from_storage", "name": "from storage"}]}, } else: - hass_storage[DOMAIN] = items + hass_storage[DOMAIN] = { + "key": DOMAIN, + "version": 1, + "data": {"items": items}, + } if config is None: config = {DOMAIN: {}} return await async_setup_component(hass, DOMAIN, config) @@ -332,6 +336,89 @@ async def test_ws_delete(hass, hass_ws_client, storage_setup): assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None +async def test_ws_update(hass, hass_ws_client, storage_setup): + """Test update WS.""" + + settings = { + "name": "from storage", + } + items = [{"id": "from_storage"} | settings] + assert await storage_setup(items) + + input_id = "from_storage" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = er.async_get(hass) + + state = hass.states.get(input_entity_id) + assert state is not None + assert state.state + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None + + client = await hass_ws_client(hass) + + updated_settings = settings | {"name": "new_name", "icon": "mdi:blah"} + await client.send_json( + { + "id": 6, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + **updated_settings, + } + ) + resp = await client.receive_json() + assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings + + state = hass.states.get(input_entity_id) + assert state.attributes["icon"] == "mdi:blah" + assert state.attributes["friendly_name"] == "new_name" + + updated_settings = settings | {"name": "new_name_2"} + await client.send_json( + { + "id": 7, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + **updated_settings, + } + ) + resp = await client.receive_json() + assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings + + state = hass.states.get(input_entity_id) + assert "icon" not in state.attributes + assert state.attributes["friendly_name"] == "new_name_2" + + +async def test_ws_create(hass, hass_ws_client, storage_setup): + """Test create WS.""" + assert await storage_setup(items=[]) + + input_id = "new_input" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = er.async_get(hass) + + state = hass.states.get(input_entity_id) + assert state is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None + + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 6, + "type": f"{DOMAIN}/create", + "name": "New Input", + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert state.state + + async def test_setup_no_config(hass, hass_admin_user): """Test component setup with no config.""" count_start = len(hass.states.async_entity_ids())