Use collection helpers for counter integration (#32295)

* Refactor counter to use config dict.

* Use collection helpers for counter integration.

* Update tests.

* Use callbacks were applicable.
pull/32334/head
Alexei Chetroi 2020-02-28 17:06:39 -05:00 committed by GitHub
parent 4a95eee40f
commit 0670b4f457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 397 additions and 73 deletions

View File

@ -1,12 +1,24 @@
"""Component to count within automations.""" """Component to count within automations."""
import logging import logging
from typing import Dict, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_ICON, CONF_MAXIMUM, CONF_MINIMUM, CONF_NAME from homeassistant.const import (
ATTR_EDITABLE,
CONF_ICON,
CONF_ID,
CONF_MAXIMUM,
CONF_MINIMUM,
CONF_NAME,
)
from homeassistant.core import callback
from homeassistant.helpers import collection
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -31,6 +43,29 @@ SERVICE_INCREMENT = "increment"
SERVICE_RESET = "reset" SERVICE_RESET = "reset"
SERVICE_CONFIGURE = "configure" SERVICE_CONFIGURE = "configure"
STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
CREATE_FIELDS = {
vol.Optional(CONF_ICON): cv.icon,
vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int,
vol.Required(CONF_NAME): vol.All(cv.string, vol.Length(min=1)),
vol.Optional(CONF_MAXIMUM, default=None): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_MINIMUM, default=None): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_RESTORE, default=True): cv.boolean,
vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int,
}
UPDATE_FIELDS = {
vol.Optional(CONF_ICON): cv.icon,
vol.Optional(CONF_INITIAL): cv.positive_int,
vol.Optional(CONF_NAME): cv.string,
vol.Optional(CONF_MAXIMUM): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_MINIMUM): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_RESTORE): cv.boolean,
vol.Optional(CONF_STEP): cv.positive_int,
}
def _none_to_empty_dict(value): def _none_to_empty_dict(value):
if value is None: if value is None:
@ -65,30 +100,38 @@ CONFIG_SCHEMA = vol.Schema(
) )
async def async_setup(hass, config): async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
"""Set up the counters.""" """Set up the counters."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent(_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
entities = [] yaml_collection = collection.YamlCollection(
logging.getLogger(f"{__name__}.yaml_collection"), id_manager
)
collection.attach_entity_component_collection(
component, yaml_collection, Counter.from_yaml
)
for object_id, cfg in config[DOMAIN].items(): storage_collection = CounterStorageCollection(
if not cfg: Store(hass, STORAGE_VERSION, STORAGE_KEY),
cfg = {} logging.getLogger(f"{__name__}.storage_collection"),
id_manager,
)
collection.attach_entity_component_collection(
component, storage_collection, Counter
)
name = cfg.get(CONF_NAME) await yaml_collection.async_load(
initial = cfg[CONF_INITIAL] [{CONF_ID: id_, **(conf or {})} for id_, conf in config.get(DOMAIN, {}).items()]
restore = cfg[CONF_RESTORE] )
step = cfg[CONF_STEP] await storage_collection.async_load()
icon = cfg.get(CONF_ICON)
minimum = cfg[CONF_MINIMUM]
maximum = cfg[CONF_MAXIMUM]
entities.append( collection.StorageCollectionWebsocket(
Counter(object_id, name, initial, minimum, maximum, restore, step, icon) storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
) ).async_setup(hass)
if not entities: collection.attach_entity_registry_cleaner(hass, DOMAIN, DOMAIN, yaml_collection)
return False collection.attach_entity_registry_cleaner(hass, DOMAIN, DOMAIN, storage_collection)
component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment")
component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement") component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement")
@ -105,104 +148,137 @@ async def async_setup(hass, config):
"async_configure", "async_configure",
) )
await component.async_add_entities(entities)
return True return True
class CounterStorageCollection(collection.StorageCollection):
"""Input storage based collection."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
async def _process_create_data(self, data: Dict) -> Dict:
"""Validate the config is valid."""
return self.CREATE_SCHEMA(data)
@callback
def _get_suggested_id(self, info: Dict) -> str:
"""Suggest an ID based on the config."""
return info[CONF_NAME]
async def _update_data(self, data: dict, update_data: Dict) -> Dict:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
return {**data, **update_data}
class Counter(RestoreEntity): class Counter(RestoreEntity):
"""Representation of a counter.""" """Representation of a counter."""
def __init__(self, object_id, name, initial, minimum, maximum, restore, step, icon): def __init__(self, config: Dict):
"""Initialize a counter.""" """Initialize a counter."""
self.entity_id = ENTITY_ID_FORMAT.format(object_id) self._config: Dict = config
self._name = name self._state: Optional[int] = config[CONF_INITIAL]
self._restore = restore self.editable: bool = True
self._step = step
self._state = self._initial = initial @classmethod
self._min = minimum def from_yaml(cls, config: Dict) -> "Counter":
self._max = maximum """Create counter instance from yaml config."""
self._icon = icon counter = cls(config)
counter.editable = False
counter.entity_id = ENTITY_ID_FORMAT.format(config[CONF_ID])
return counter
@property @property
def should_poll(self): def should_poll(self) -> bool:
"""If entity should be polled.""" """If entity should be polled."""
return False return False
@property @property
def name(self): def name(self) -> Optional[str]:
"""Return name of the counter.""" """Return name of the counter."""
return self._name return self._config.get(CONF_NAME)
@property @property
def icon(self): def icon(self) -> Optional[str]:
"""Return the icon to be used for this entity.""" """Return the icon to be used for this entity."""
return self._icon return self._config.get(CONF_ICON)
@property @property
def state(self): def state(self) -> Optional[int]:
"""Return the current value of the counter.""" """Return the current value of the counter."""
return self._state return self._state
@property @property
def state_attributes(self): def state_attributes(self) -> Dict:
"""Return the state attributes.""" """Return the state attributes."""
ret = {ATTR_INITIAL: self._initial, ATTR_STEP: self._step} ret = {
if self._min is not None: ATTR_EDITABLE: self.editable,
ret[CONF_MINIMUM] = self._min ATTR_INITIAL: self._config[CONF_INITIAL],
if self._max is not None: ATTR_STEP: self._config[CONF_STEP],
ret[CONF_MAXIMUM] = self._max }
if self._config[CONF_MINIMUM] is not None:
ret[CONF_MINIMUM] = self._config[CONF_MINIMUM]
if self._config[CONF_MAXIMUM] is not None:
ret[CONF_MAXIMUM] = self._config[CONF_MAXIMUM]
return ret return ret
def compute_next_state(self, state): @property
def unique_id(self) -> Optional[str]:
"""Return unique id of the entity."""
return self._config[CONF_ID]
def compute_next_state(self, state) -> int:
"""Keep the state within the range of min/max values.""" """Keep the state within the range of min/max values."""
if self._min is not None: if self._config[CONF_MINIMUM] is not None:
state = max(self._min, state) state = max(self._config[CONF_MINIMUM], state)
if self._max is not None: if self._config[CONF_MAXIMUM] is not None:
state = min(self._max, state) state = min(self._config[CONF_MAXIMUM], state)
return state return state
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Call when entity about to be added to Home Assistant.""" """Call when entity about to be added to Home Assistant."""
await super().async_added_to_hass() await super().async_added_to_hass()
# __init__ will set self._state to self._initial, only override # __init__ will set self._state to self._initial, only override
# if needed. # if needed.
if self._restore: if self._config[CONF_RESTORE]:
state = await self.async_get_last_state() state = await self.async_get_last_state()
if state is not None: if state is not None:
self._state = self.compute_next_state(int(state.state)) self._state = self.compute_next_state(int(state.state))
self._initial = state.attributes.get(ATTR_INITIAL) self._config[CONF_INITIAL] = state.attributes.get(ATTR_INITIAL)
self._max = state.attributes.get(ATTR_MAXIMUM) self._config[CONF_MAXIMUM] = state.attributes.get(ATTR_MAXIMUM)
self._min = state.attributes.get(ATTR_MINIMUM) self._config[CONF_MINIMUM] = state.attributes.get(ATTR_MINIMUM)
self._step = state.attributes.get(ATTR_STEP) self._config[CONF_STEP] = state.attributes.get(ATTR_STEP)
async def async_decrement(self): @callback
def async_decrement(self) -> None:
"""Decrement the counter.""" """Decrement the counter."""
self._state = self.compute_next_state(self._state - self._step) self._state = self.compute_next_state(self._state - self._config[CONF_STEP])
await self.async_update_ha_state() self.async_write_ha_state()
async def async_increment(self): @callback
def async_increment(self) -> None:
"""Increment a counter.""" """Increment a counter."""
self._state = self.compute_next_state(self._state + self._step) self._state = self.compute_next_state(self._state + self._config[CONF_STEP])
await self.async_update_ha_state() self.async_write_ha_state()
async def async_reset(self): @callback
def async_reset(self) -> None:
"""Reset a counter.""" """Reset a counter."""
self._state = self.compute_next_state(self._initial) self._state = self.compute_next_state(self._config[CONF_INITIAL])
await self.async_update_ha_state() self.async_write_ha_state()
async def async_configure(self, **kwargs): @callback
def async_configure(self, **kwargs) -> None:
"""Change the counter's settings with a service.""" """Change the counter's settings with a service."""
if CONF_MINIMUM in kwargs: new_state = kwargs.pop(VALUE, self._state)
self._min = kwargs[CONF_MINIMUM] self._config = {**self._config, **kwargs}
if CONF_MAXIMUM in kwargs: self._state = self.compute_next_state(new_state)
self._max = kwargs[CONF_MAXIMUM] self.async_write_ha_state()
if CONF_STEP in kwargs:
self._step = kwargs[CONF_STEP]
if CONF_INITIAL in kwargs:
self._initial = kwargs[CONF_INITIAL]
if VALUE in kwargs:
self._state = kwargs[VALUE]
async def async_update_config(self, config: Dict) -> None:
"""Change the counter's settings WS CRUD."""
self._config = config
self._state = self.compute_next_state(self._state) self._state = self.compute_next_state(self._state)
await self.async_update_ha_state() self.async_write_ha_state()

View File

@ -2,8 +2,13 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import logging import logging
import pytest
from homeassistant.components.counter import ( from homeassistant.components.counter import (
ATTR_EDITABLE,
ATTR_INITIAL, ATTR_INITIAL,
ATTR_MAXIMUM,
ATTR_MINIMUM,
ATTR_STEP, ATTR_STEP,
CONF_ICON, CONF_ICON,
CONF_INITIAL, CONF_INITIAL,
@ -14,8 +19,9 @@ from homeassistant.components.counter import (
DEFAULT_STEP, DEFAULT_STEP,
DOMAIN, DOMAIN,
) )
from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON from homeassistant.const import ATTR_FRIENDLY_NAME, ATTR_ICON, ATTR_NAME
from homeassistant.core import Context, CoreState, State from homeassistant.core import Context, CoreState, State
from homeassistant.helpers import entity_registry
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import mock_restore_cache from tests.common import mock_restore_cache
@ -28,6 +34,42 @@ from tests.components.counter.common import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@pytest.fixture
def storage_setup(hass, hass_storage):
"""Storage setup."""
async def _storage(items=None, config=None):
if items is None:
hass_storage[DOMAIN] = {
"key": DOMAIN,
"version": 1,
"data": {
"items": [
{
"id": "from_storage",
"initial": 10,
"name": "from storage",
"maximum": 100,
"minimum": 3,
"step": 2,
"restore": False,
}
]
},
}
else:
hass_storage[DOMAIN] = {
"key": DOMAIN,
"version": 1,
"data": {"items": items},
}
if config is None:
config = {DOMAIN: {}}
return await async_setup_component(hass, DOMAIN, config)
return _storage
async def test_config(hass): async def test_config(hass):
"""Test config.""" """Test config."""
invalid_configs = [None, 1, {}, {"name with space": None}] invalid_configs = [None, 1, {}, {"name with space": None}]
@ -452,3 +494,209 @@ async def test_configure(hass, hass_admin_user):
assert 0 == state.attributes.get("minimum") assert 0 == state.attributes.get("minimum")
assert 9 == state.attributes.get("maximum") assert 9 == state.attributes.get("maximum")
assert 6 == state.attributes.get("initial") assert 6 == state.attributes.get("initial")
async def test_load_from_storage(hass, storage_setup):
"""Test set up from storage."""
assert await storage_setup()
state = hass.states.get(f"{DOMAIN}.from_storage")
assert int(state.state) == 10
assert state.attributes.get(ATTR_FRIENDLY_NAME) == "from storage"
assert state.attributes.get(ATTR_EDITABLE)
async def test_editable_state_attribute(hass, storage_setup):
"""Test editable attribute."""
assert await storage_setup(
config={
DOMAIN: {
"from_yaml": {
"minimum": 1,
"maximum": 10,
"initial": 5,
"step": 1,
"restore": False,
}
}
}
)
state = hass.states.get(f"{DOMAIN}.from_storage")
assert int(state.state) == 10
assert state.attributes[ATTR_FRIENDLY_NAME] == "from storage"
assert state.attributes[ATTR_EDITABLE] is True
state = hass.states.get(f"{DOMAIN}.from_yaml")
assert int(state.state) == 5
assert state.attributes[ATTR_EDITABLE] is False
async def test_ws_list(hass, hass_ws_client, storage_setup):
"""Test listing via WS."""
assert await storage_setup(
config={
DOMAIN: {
"from_yaml": {
"minimum": 1,
"maximum": 10,
"initial": 5,
"step": 1,
"restore": False,
}
}
}
)
client = await hass_ws_client(hass)
await client.send_json({"id": 6, "type": f"{DOMAIN}/list"})
resp = await client.receive_json()
assert resp["success"]
storage_ent = "from_storage"
yaml_ent = "from_yaml"
result = {item["id"]: item for item in resp["result"]}
assert len(result) == 1
assert storage_ent in result
assert yaml_ent not in result
assert result[storage_ent][ATTR_NAME] == "from storage"
async def test_ws_delete(hass, hass_ws_client, storage_setup):
"""Test WS delete cleans up entity registry."""
assert await storage_setup()
input_id = "from_storage"
input_entity_id = f"{DOMAIN}.{input_id}"
ent_reg = await entity_registry.async_get_registry(hass)
state = hass.states.get(input_entity_id)
assert state is not None
assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None
client = await hass_ws_client(hass)
await client.send_json(
{"id": 6, "type": f"{DOMAIN}/delete", f"{DOMAIN}_id": f"{input_id}"}
)
resp = await client.receive_json()
assert resp["success"]
state = hass.states.get(input_entity_id)
assert state is None
assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None
async def test_update_min_max(hass, hass_ws_client, storage_setup):
"""Test updating min/max updates the state."""
items = [
{
"id": "from_storage",
"initial": 15,
"name": "from storage",
"maximum": 100,
"minimum": 10,
"step": 3,
"restore": True,
}
]
assert await storage_setup(items)
input_id = "from_storage"
input_entity_id = f"{DOMAIN}.{input_id}"
ent_reg = await entity_registry.async_get_registry(hass)
state = hass.states.get(input_entity_id)
assert state is not None
assert int(state.state) == 15
assert state.attributes[ATTR_MAXIMUM] == 100
assert state.attributes[ATTR_MINIMUM] == 10
assert state.attributes[ATTR_STEP] == 3
assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 6,
"type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}",
"minimum": 19,
}
)
resp = await client.receive_json()
assert resp["success"]
state = hass.states.get(input_entity_id)
assert int(state.state) == 19
assert state.attributes[ATTR_MINIMUM] == 19
assert state.attributes[ATTR_MAXIMUM] == 100
assert state.attributes[ATTR_STEP] == 3
await client.send_json(
{
"id": 7,
"type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}",
"maximum": 5,
"minimum": 2,
"step": 5,
}
)
resp = await client.receive_json()
assert resp["success"]
state = hass.states.get(input_entity_id)
assert int(state.state) == 5
assert state.attributes[ATTR_MINIMUM] == 2
assert state.attributes[ATTR_MAXIMUM] == 5
assert state.attributes[ATTR_STEP] == 5
await client.send_json(
{
"id": 8,
"type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}",
"maximum": None,
"minimum": None,
"step": 6,
}
)
resp = await client.receive_json()
assert resp["success"]
state = hass.states.get(input_entity_id)
assert int(state.state) == 5
assert ATTR_MINIMUM not in state.attributes
assert ATTR_MAXIMUM not in state.attributes
assert state.attributes[ATTR_STEP] == 6
async def test_create(hass, hass_ws_client, storage_setup):
"""Test creating counter using WS."""
items = []
assert await storage_setup(items)
counter_id = "new_counter"
input_entity_id = f"{DOMAIN}.{counter_id}"
ent_reg = await entity_registry.async_get_registry(hass)
state = hass.states.get(input_entity_id)
assert state is None
assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, counter_id) is None
client = await hass_ws_client(hass)
await client.send_json({"id": 6, "type": f"{DOMAIN}/create", "name": "new counter"})
resp = await client.receive_json()
assert resp["success"]
state = hass.states.get(input_entity_id)
assert int(state.state) == 0
assert ATTR_MINIMUM not in state.attributes
assert ATTR_MAXIMUM not in state.attributes
assert state.attributes[ATTR_STEP] == 1