diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index 5580518a9a3..ad5e4000116 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -1,12 +1,24 @@ """Component to count within automations.""" import logging +from typing import Dict, Optional import voluptuous as vol -from homeassistant.const import CONF_ICON, CONF_MAXIMUM, CONF_MINIMUM, CONF_NAME +from homeassistant.const import ( + ATTR_EDITABLE, + CONF_ICON, + CONF_ID, + CONF_MAXIMUM, + CONF_MINIMUM, + CONF_NAME, +) +from homeassistant.core import callback +from homeassistant.helpers import collection import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers.storage import Store +from homeassistant.helpers.typing import ConfigType, HomeAssistantType _LOGGER = logging.getLogger(__name__) @@ -31,6 +43,29 @@ SERVICE_INCREMENT = "increment" SERVICE_RESET = "reset" SERVICE_CONFIGURE = "configure" +STORAGE_KEY = DOMAIN +STORAGE_VERSION = 1 + +CREATE_FIELDS = { + vol.Optional(CONF_ICON): cv.icon, + vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int, + vol.Required(CONF_NAME): vol.All(cv.string, vol.Length(min=1)), + vol.Optional(CONF_MAXIMUM, default=None): vol.Any(None, vol.Coerce(int)), + vol.Optional(CONF_MINIMUM, default=None): vol.Any(None, vol.Coerce(int)), + vol.Optional(CONF_RESTORE, default=True): cv.boolean, + vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int, +} + +UPDATE_FIELDS = { + vol.Optional(CONF_ICON): cv.icon, + vol.Optional(CONF_INITIAL): cv.positive_int, + vol.Optional(CONF_NAME): cv.string, + vol.Optional(CONF_MAXIMUM): vol.Any(None, vol.Coerce(int)), + vol.Optional(CONF_MINIMUM): vol.Any(None, vol.Coerce(int)), + vol.Optional(CONF_RESTORE): cv.boolean, + vol.Optional(CONF_STEP): cv.positive_int, +} + def _none_to_empty_dict(value): if value is None: @@ -65,30 +100,38 @@ CONFIG_SCHEMA = vol.Schema( ) -async def async_setup(hass, config): +async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: """Set up the counters.""" component = EntityComponent(_LOGGER, DOMAIN, hass) + id_manager = collection.IDManager() - entities = [] + yaml_collection = collection.YamlCollection( + logging.getLogger(f"{__name__}.yaml_collection"), id_manager + ) + collection.attach_entity_component_collection( + component, yaml_collection, Counter.from_yaml + ) - for object_id, cfg in config[DOMAIN].items(): - if not cfg: - cfg = {} + storage_collection = CounterStorageCollection( + Store(hass, STORAGE_VERSION, STORAGE_KEY), + logging.getLogger(f"{__name__}.storage_collection"), + id_manager, + ) + collection.attach_entity_component_collection( + component, storage_collection, Counter + ) - name = cfg.get(CONF_NAME) - initial = cfg[CONF_INITIAL] - restore = cfg[CONF_RESTORE] - step = cfg[CONF_STEP] - icon = cfg.get(CONF_ICON) - minimum = cfg[CONF_MINIMUM] - maximum = cfg[CONF_MAXIMUM] + await yaml_collection.async_load( + [{CONF_ID: id_, **(conf or {})} for id_, conf in config.get(DOMAIN, {}).items()] + ) + await storage_collection.async_load() - entities.append( - Counter(object_id, name, initial, minimum, maximum, restore, step, icon) - ) + collection.StorageCollectionWebsocket( + storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + ).async_setup(hass) - if not entities: - return False + collection.attach_entity_registry_cleaner(hass, DOMAIN, DOMAIN, yaml_collection) + collection.attach_entity_registry_cleaner(hass, DOMAIN, DOMAIN, storage_collection) component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement") @@ -105,104 +148,137 @@ async def async_setup(hass, config): "async_configure", ) - await component.async_add_entities(entities) return True +class CounterStorageCollection(collection.StorageCollection): + """Input storage based collection.""" + + CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) + UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) + + async def _process_create_data(self, data: Dict) -> Dict: + """Validate the config is valid.""" + return self.CREATE_SCHEMA(data) + + @callback + def _get_suggested_id(self, info: Dict) -> str: + """Suggest an ID based on the config.""" + return info[CONF_NAME] + + async def _update_data(self, data: dict, update_data: Dict) -> Dict: + """Return a new updated data object.""" + update_data = self.UPDATE_SCHEMA(update_data) + return {**data, **update_data} + + class Counter(RestoreEntity): """Representation of a counter.""" - def __init__(self, object_id, name, initial, minimum, maximum, restore, step, icon): + def __init__(self, config: Dict): """Initialize a counter.""" - self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self._name = name - self._restore = restore - self._step = step - self._state = self._initial = initial - self._min = minimum - self._max = maximum - self._icon = icon + self._config: Dict = config + self._state: Optional[int] = config[CONF_INITIAL] + self.editable: bool = True + + @classmethod + def from_yaml(cls, config: Dict) -> "Counter": + """Create counter instance from yaml config.""" + counter = cls(config) + counter.editable = False + counter.entity_id = ENTITY_ID_FORMAT.format(config[CONF_ID]) + return counter @property - def should_poll(self): + def should_poll(self) -> bool: """If entity should be polled.""" return False @property - def name(self): + def name(self) -> Optional[str]: """Return name of the counter.""" - return self._name + return self._config.get(CONF_NAME) @property - def icon(self): + def icon(self) -> Optional[str]: """Return the icon to be used for this entity.""" - return self._icon + return self._config.get(CONF_ICON) @property - def state(self): + def state(self) -> Optional[int]: """Return the current value of the counter.""" return self._state @property - def state_attributes(self): + def state_attributes(self) -> Dict: """Return the state attributes.""" - ret = {ATTR_INITIAL: self._initial, ATTR_STEP: self._step} - if self._min is not None: - ret[CONF_MINIMUM] = self._min - if self._max is not None: - ret[CONF_MAXIMUM] = self._max + ret = { + ATTR_EDITABLE: self.editable, + ATTR_INITIAL: self._config[CONF_INITIAL], + ATTR_STEP: self._config[CONF_STEP], + } + if self._config[CONF_MINIMUM] is not None: + ret[CONF_MINIMUM] = self._config[CONF_MINIMUM] + if self._config[CONF_MAXIMUM] is not None: + ret[CONF_MAXIMUM] = self._config[CONF_MAXIMUM] return ret - def compute_next_state(self, state): + @property + def unique_id(self) -> Optional[str]: + """Return unique id of the entity.""" + return self._config[CONF_ID] + + def compute_next_state(self, state) -> int: """Keep the state within the range of min/max values.""" - if self._min is not None: - state = max(self._min, state) - if self._max is not None: - state = min(self._max, state) + if self._config[CONF_MINIMUM] is not None: + state = max(self._config[CONF_MINIMUM], state) + if self._config[CONF_MAXIMUM] is not None: + state = min(self._config[CONF_MAXIMUM], state) return state - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Call when entity about to be added to Home Assistant.""" await super().async_added_to_hass() # __init__ will set self._state to self._initial, only override # if needed. - if self._restore: + if self._config[CONF_RESTORE]: state = await self.async_get_last_state() if state is not None: self._state = self.compute_next_state(int(state.state)) - self._initial = state.attributes.get(ATTR_INITIAL) - self._max = state.attributes.get(ATTR_MAXIMUM) - self._min = state.attributes.get(ATTR_MINIMUM) - self._step = state.attributes.get(ATTR_STEP) + self._config[CONF_INITIAL] = state.attributes.get(ATTR_INITIAL) + self._config[CONF_MAXIMUM] = state.attributes.get(ATTR_MAXIMUM) + self._config[CONF_MINIMUM] = state.attributes.get(ATTR_MINIMUM) + self._config[CONF_STEP] = state.attributes.get(ATTR_STEP) - async def async_decrement(self): + @callback + def async_decrement(self) -> None: """Decrement the counter.""" - self._state = self.compute_next_state(self._state - self._step) - await self.async_update_ha_state() + self._state = self.compute_next_state(self._state - self._config[CONF_STEP]) + self.async_write_ha_state() - async def async_increment(self): + @callback + def async_increment(self) -> None: """Increment a counter.""" - self._state = self.compute_next_state(self._state + self._step) - await self.async_update_ha_state() + self._state = self.compute_next_state(self._state + self._config[CONF_STEP]) + self.async_write_ha_state() - async def async_reset(self): + @callback + def async_reset(self) -> None: """Reset a counter.""" - self._state = self.compute_next_state(self._initial) - await self.async_update_ha_state() + self._state = self.compute_next_state(self._config[CONF_INITIAL]) + self.async_write_ha_state() - async def async_configure(self, **kwargs): + @callback + def async_configure(self, **kwargs) -> None: """Change the counter's settings with a service.""" - if CONF_MINIMUM in kwargs: - self._min = kwargs[CONF_MINIMUM] - if CONF_MAXIMUM in kwargs: - self._max = kwargs[CONF_MAXIMUM] - if CONF_STEP in kwargs: - self._step = kwargs[CONF_STEP] - if CONF_INITIAL in kwargs: - self._initial = kwargs[CONF_INITIAL] - if VALUE in kwargs: - self._state = kwargs[VALUE] + new_state = kwargs.pop(VALUE, self._state) + self._config = {**self._config, **kwargs} + self._state = self.compute_next_state(new_state) + self.async_write_ha_state() + async def async_update_config(self, config: Dict) -> None: + """Change the counter's settings WS CRUD.""" + self._config = config self._state = self.compute_next_state(self._state) - await self.async_update_ha_state() + self.async_write_ha_state() diff --git a/tests/components/counter/test_init.py b/tests/components/counter/test_init.py index f5ff825e7fb..d6a41af6deb 100644 --- a/tests/components/counter/test_init.py +++ b/tests/components/counter/test_init.py @@ -2,8 +2,13 @@ # pylint: disable=protected-access import logging +import pytest + from homeassistant.components.counter import ( + ATTR_EDITABLE, ATTR_INITIAL, + ATTR_MAXIMUM, + ATTR_MINIMUM, ATTR_STEP, CONF_ICON, CONF_INITIAL, @@ -14,8 +19,9 @@ from homeassistant.components.counter import ( DEFAULT_STEP, DOMAIN, ) -from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON +from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON, ATTR_NAME from homeassistant.core import Context, CoreState, State +from homeassistant.helpers import entity_registry from homeassistant.setup import async_setup_component from tests.common import mock_restore_cache @@ -28,6 +34,42 @@ from tests.components.counter.common import ( _LOGGER = logging.getLogger(__name__) +@pytest.fixture +def storage_setup(hass, hass_storage): + """Storage setup.""" + + async def _storage(items=None, config=None): + if items is None: + hass_storage[DOMAIN] = { + "key": DOMAIN, + "version": 1, + "data": { + "items": [ + { + "id": "from_storage", + "initial": 10, + "name": "from storage", + "maximum": 100, + "minimum": 3, + "step": 2, + "restore": False, + } + ] + }, + } + else: + hass_storage[DOMAIN] = { + "key": DOMAIN, + "version": 1, + "data": {"items": items}, + } + if config is None: + config = {DOMAIN: {}} + return await async_setup_component(hass, DOMAIN, config) + + return _storage + + async def test_config(hass): """Test config.""" invalid_configs = [None, 1, {}, {"name with space": None}] @@ -452,3 +494,209 @@ async def test_configure(hass, hass_admin_user): assert 0 == state.attributes.get("minimum") assert 9 == state.attributes.get("maximum") assert 6 == state.attributes.get("initial") + + +async def test_load_from_storage(hass, storage_setup): + """Test set up from storage.""" + assert await storage_setup() + state = hass.states.get(f"{DOMAIN}.from_storage") + assert int(state.state) == 10 + assert state.attributes.get(ATTR_FRIENDLY_NAME) == "from storage" + assert state.attributes.get(ATTR_EDITABLE) + + +async def test_editable_state_attribute(hass, storage_setup): + """Test editable attribute.""" + assert await storage_setup( + config={ + DOMAIN: { + "from_yaml": { + "minimum": 1, + "maximum": 10, + "initial": 5, + "step": 1, + "restore": False, + } + } + } + ) + + state = hass.states.get(f"{DOMAIN}.from_storage") + assert int(state.state) == 10 + assert state.attributes[ATTR_FRIENDLY_NAME] == "from storage" + assert state.attributes[ATTR_EDITABLE] is True + + state = hass.states.get(f"{DOMAIN}.from_yaml") + assert int(state.state) == 5 + assert state.attributes[ATTR_EDITABLE] is False + + +async def test_ws_list(hass, hass_ws_client, storage_setup): + """Test listing via WS.""" + assert await storage_setup( + config={ + DOMAIN: { + "from_yaml": { + "minimum": 1, + "maximum": 10, + "initial": 5, + "step": 1, + "restore": False, + } + } + } + ) + + client = await hass_ws_client(hass) + + await client.send_json({"id": 6, "type": f"{DOMAIN}/list"}) + resp = await client.receive_json() + assert resp["success"] + + storage_ent = "from_storage" + yaml_ent = "from_yaml" + result = {item["id"]: item for item in resp["result"]} + + assert len(result) == 1 + assert storage_ent in result + assert yaml_ent not in result + assert result[storage_ent][ATTR_NAME] == "from storage" + + +async def test_ws_delete(hass, hass_ws_client, storage_setup): + """Test WS delete cleans up entity registry.""" + assert await storage_setup() + + input_id = "from_storage" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is not None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None + + client = await hass_ws_client(hass) + + await client.send_json( + {"id": 6, "type": f"{DOMAIN}/delete", f"{DOMAIN}_id": f"{input_id}"} + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert state is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None + + +async def test_update_min_max(hass, hass_ws_client, storage_setup): + """Test updating min/max updates the state.""" + + items = [ + { + "id": "from_storage", + "initial": 15, + "name": "from storage", + "maximum": 100, + "minimum": 10, + "step": 3, + "restore": True, + } + ] + assert await storage_setup(items) + + input_id = "from_storage" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is not None + assert int(state.state) == 15 + assert state.attributes[ATTR_MAXIMUM] == 100 + assert state.attributes[ATTR_MINIMUM] == 10 + assert state.attributes[ATTR_STEP] == 3 + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None + + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 6, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + "minimum": 19, + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert int(state.state) == 19 + assert state.attributes[ATTR_MINIMUM] == 19 + assert state.attributes[ATTR_MAXIMUM] == 100 + assert state.attributes[ATTR_STEP] == 3 + + await client.send_json( + { + "id": 7, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + "maximum": 5, + "minimum": 2, + "step": 5, + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert int(state.state) == 5 + assert state.attributes[ATTR_MINIMUM] == 2 + assert state.attributes[ATTR_MAXIMUM] == 5 + assert state.attributes[ATTR_STEP] == 5 + + await client.send_json( + { + "id": 8, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + "maximum": None, + "minimum": None, + "step": 6, + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert int(state.state) == 5 + assert ATTR_MINIMUM not in state.attributes + assert ATTR_MAXIMUM not in state.attributes + assert state.attributes[ATTR_STEP] == 6 + + +async def test_create(hass, hass_ws_client, storage_setup): + """Test creating counter using WS.""" + + items = [] + + assert await storage_setup(items) + + counter_id = "new_counter" + input_entity_id = f"{DOMAIN}.{counter_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, counter_id) is None + + client = await hass_ws_client(hass) + + await client.send_json({"id": 6, "type": f"{DOMAIN}/create", "name": "new counter"}) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert int(state.state) == 0 + assert ATTR_MINIMUM not in state.attributes + assert ATTR_MAXIMUM not in state.attributes + assert state.attributes[ATTR_STEP] == 1