Refactor input_select (#53334)

pull/58451/head
Franck Nijhof 2021-10-26 05:38:06 +02:00 committed by GitHub
parent 3732ae738e
commit 65b19da3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 175 deletions

View File

@ -58,6 +58,7 @@ homeassistant.components.http.*
homeassistant.components.huawei_lte.*
homeassistant.components.hyperion.*
homeassistant.components.image_processing.*
homeassistant.components.input_select.*
homeassistant.components.integration.*
homeassistant.components.iqvia.*
homeassistant.components.jewish_calendar.*

View File

@ -2,9 +2,11 @@
from __future__ import annotations
import logging
from typing import Any, Dict, cast
import voluptuous as vol
from homeassistant.components.select import SelectEntity
from homeassistant.const import (
ATTR_EDITABLE,
ATTR_OPTION,
@ -55,7 +57,7 @@ UPDATE_FIELDS = {
}
def _cv_input_select(cfg):
def _cv_input_select(cfg: dict[str, Any]) -> dict[str, Any]:
"""Configure validation helper for input select (voluptuous)."""
options = cfg[CONF_OPTIONS]
initial = cfg.get(CONF_INITIAL)
@ -183,138 +185,137 @@ class InputSelectStorageCollection(collection.StorageCollection):
CREATE_SCHEMA = vol.Schema(vol.All(CREATE_FIELDS, _cv_input_select))
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
async def _process_create_data(self, data: dict) -> dict:
async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the config is valid."""
return self.CREATE_SCHEMA(data)
return cast(Dict[str, Any], self.CREATE_SCHEMA(data))
@callback
def _get_suggested_id(self, info: dict) -> str:
def _get_suggested_id(self, info: dict[str, Any]) -> str:
"""Suggest an ID based on the config."""
return info[CONF_NAME]
return cast(str, info[CONF_NAME])
async def _update_data(self, data: dict, update_data: dict) -> dict:
async def _update_data(
self, data: dict[str, Any], update_data: dict[str, Any]
) -> dict[str, Any]:
"""Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data)
return _cv_input_select({**data, **update_data})
class InputSelect(RestoreEntity):
class InputSelect(SelectEntity, RestoreEntity):
"""Representation of a select input."""
def __init__(self, config: dict) -> None:
_attr_should_poll = False
editable = True
def __init__(self, config: ConfigType) -> None:
"""Initialize a select input."""
self._config = config
self.editable = True
self._current_option = config.get(CONF_INITIAL)
self._attr_current_option = config.get(CONF_INITIAL)
self._attr_icon = config.get(CONF_ICON)
self._attr_name = config.get(CONF_NAME)
self._attr_options = config[CONF_OPTIONS]
self._attr_unique_id = config[CONF_ID]
@classmethod
def from_yaml(cls, config: dict) -> InputSelect:
def from_yaml(cls, config: ConfigType) -> InputSelect:
"""Return entity instance initialized from yaml storage."""
input_select = cls(config)
input_select.entity_id = f"{DOMAIN}.{config[CONF_ID]}"
input_select.editable = False
return input_select
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added."""
await super().async_added_to_hass()
if self._current_option is not None:
if self.current_option is not None:
return
state = await self.async_get_last_state()
if not state or state.state not in self._options:
self._current_option = self._options[0]
if not state or state.state not in self.options:
self._attr_current_option = self.options[0]
else:
self._current_option = state.state
self._attr_current_option = state.state
@property
def should_poll(self):
"""If entity should be polled."""
return False
@property
def name(self):
"""Return the name of the select input."""
return self._config.get(CONF_NAME)
@property
def icon(self):
"""Return the icon to be used for this entity."""
return self._config.get(CONF_ICON)
@property
def _options(self) -> list[str]:
"""Return a list of selection options."""
return self._config[CONF_OPTIONS]
@property
def state(self):
"""Return the state of the component."""
return self._current_option
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, bool]:
"""Return the state attributes."""
return {ATTR_OPTIONS: self._config[ATTR_OPTIONS], ATTR_EDITABLE: self.editable}
return {ATTR_EDITABLE: self.editable}
@property
def unique_id(self) -> str | None:
"""Return unique id for the entity."""
return self._config[CONF_ID]
@callback
def async_select_option(self, option):
async def async_select_option(self, option: str) -> None:
"""Select new option."""
if option not in self._options:
if option not in self.options:
_LOGGER.warning(
"Invalid option: %s (possible options: %s)",
option,
", ".join(self._options),
", ".join(self.options),
)
return
self._current_option = option
self._attr_current_option = option
self.async_write_ha_state()
@callback
def async_select_index(self, idx):
def async_select_index(self, idx: int) -> None:
"""Select new option by index."""
new_index = idx % len(self._options)
self._current_option = self._options[new_index]
new_index = idx % len(self.options)
self._attr_current_option = self.options[new_index]
self.async_write_ha_state()
@callback
def async_offset_index(self, offset, cycle):
def async_offset_index(self, offset: int, cycle: bool) -> None:
"""Offset current index."""
current_index = self._options.index(self._current_option)
current_index = (
self.options.index(self.current_option)
if self.current_option is not None
else 0
)
new_index = current_index + offset
if cycle:
new_index = new_index % len(self._options)
else:
if new_index < 0:
new_index = 0
elif new_index >= len(self._options):
new_index = len(self._options) - 1
self._current_option = self._options[new_index]
new_index = new_index % len(self.options)
elif new_index < 0:
new_index = 0
elif new_index >= len(self.options):
new_index = len(self.options) - 1
self._attr_current_option = self.options[new_index]
self.async_write_ha_state()
@callback
def async_next(self, cycle):
def async_next(self, cycle: bool) -> None:
"""Select next option."""
# If there is no current option, first item is the next
if self.current_option is None:
self.async_select_index(0)
return
self.async_offset_index(1, cycle)
@callback
def async_previous(self, cycle):
def async_previous(self, cycle: bool) -> None:
"""Select previous option."""
# If there is no current option, last item is the previous
if self.current_option is None:
self.async_select_index(-1)
return
self.async_offset_index(-1, cycle)
@callback
def async_set_options(self, options):
async def async_set_options(self, options: list[str]) -> None:
"""Set options."""
self._current_option = options[0]
self._config[CONF_OPTIONS] = options
self._attr_options = options
if self.current_option not in self.options:
_LOGGER.warning(
"Current option: %s no longer valid (possible options: %s)",
self.current_option,
", ".join(self.options),
)
self._attr_current_option = options[0]
self.async_write_ha_state()
async def async_update_config(self, config: dict) -> None:
async def async_update_config(self, config: ConfigType) -> None:
"""Handle when the config is updated."""
self._config = config
self._attr_icon = config.get(CONF_ICON)
self._attr_name = config.get(CONF_NAME)
self._attr_options = config[CONF_OPTIONS]
self.async_write_ha_state()

View File

@ -649,6 +649,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.input_select.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.integration.*]
check_untyped_defs = true
disallow_incomplete_defs = true

View File

@ -27,7 +27,6 @@ from homeassistant.const import (
from homeassistant.core import Context, State
from homeassistant.exceptions import Unauthorized
from homeassistant.helpers import entity_registry as er
from homeassistant.loader import bind_hass
from homeassistant.setup import async_setup_component
from tests.common import mock_restore_cache
@ -65,80 +64,12 @@ def storage_setup(hass, hass_storage):
return _storage
@bind_hass
def select_option(hass, entity_id, option):
"""Set value of input_select.
This is a legacy helper method. Do not use it for new tests.
"""
hass.async_create_task(
hass.services.async_call(
DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: option},
)
)
@bind_hass
def select_next(hass, entity_id):
"""Set next value of input_select.
This is a legacy helper method. Do not use it for new tests.
"""
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_SELECT_NEXT, {ATTR_ENTITY_ID: entity_id}
)
)
@bind_hass
def select_previous(hass, entity_id):
"""Set previous value of input_select.
This is a legacy helper method. Do not use it for new tests.
"""
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_SELECT_PREVIOUS, {ATTR_ENTITY_ID: entity_id}
)
)
@bind_hass
def select_first(hass, entity_id):
"""Set first value of input_select.
This is a legacy helper method. Do not use it for new tests.
"""
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_SELECT_FIRST, {ATTR_ENTITY_ID: entity_id}
)
)
@bind_hass
def select_last(hass, entity_id):
"""Set last value of input_select.
This is a legacy helper method. Do not use it for new tests.
"""
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_SELECT_LAST, {ATTR_ENTITY_ID: entity_id}
)
)
async def test_config(hass):
"""Test config."""
invalid_configs = [
None,
{},
{"name with space": None},
# {'bad_options': {'options': None}},
{"bad_initial": {"options": [1, 2], "initial": 3}},
]
@ -158,15 +89,21 @@ async def test_select_option(hass):
state = hass.states.get(entity_id)
assert state.state == "some option"
select_option(hass, entity_id, "another option")
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "another option"},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "another option"
select_option(hass, entity_id, "non existing option")
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "non existing option"},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "another option"
@ -190,15 +127,21 @@ async def test_select_next(hass):
state = hass.states.get(entity_id)
assert state.state == "middle option"
select_next(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_NEXT,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "last option"
select_next(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_NEXT,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "first option"
@ -222,15 +165,21 @@ async def test_select_previous(hass):
state = hass.states.get(entity_id)
assert state.state == "middle option"
select_previous(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_PREVIOUS,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "first option"
select_previous(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_PREVIOUS,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "last option"
@ -254,14 +203,22 @@ async def test_select_first_last(hass):
state = hass.states.get(entity_id)
assert state.state == "middle option"
select_first(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_FIRST,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "first option"
select_last(hass, entity_id)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_LAST,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "last option"
@ -326,20 +283,39 @@ async def test_set_options_service(hass):
state = hass.states.get(entity_id)
assert state.state == "middle option"
data = {ATTR_OPTIONS: ["test1", "test2"], "entity_id": entity_id}
await hass.services.async_call(DOMAIN, SERVICE_SET_OPTIONS, data)
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SET_OPTIONS,
{ATTR_OPTIONS: ["first option", "middle option"], ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "middle option"
await hass.services.async_call(
DOMAIN,
SERVICE_SET_OPTIONS,
{ATTR_OPTIONS: ["test1", "test2"], ATTR_ENTITY_ID: entity_id},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "test1"
select_option(hass, entity_id, "first option")
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "first option"},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "test1"
select_option(hass, entity_id, "test2")
await hass.async_block_till_done()
await hass.services.async_call(
DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "test2"},
blocking=True,
)
state = hass.states.get(entity_id)
assert state.state == "test2"
@ -488,7 +464,6 @@ async def test_reload(hass, hass_admin_user, hass_read_only_user):
blocking=True,
context=Context(user_id=hass_admin_user.id),
)
await hass.async_block_till_done()
assert count_start + 2 == len(hass.states.async_entity_ids())
@ -671,6 +646,5 @@ async def test_setup_no_config(hass, hass_admin_user):
blocking=True,
context=Context(user_id=hass_admin_user.id),
)
await hass.async_block_till_done()
assert count_start == len(hass.states.async_entity_ids())

View File

@ -214,7 +214,7 @@ async def test_templates_with_entities(hass, calls):
blocking=True,
)
await hass.async_block_till_done()
_verify(hass, "a", ["a", "b", "c"])
_verify(hass, "b", ["a", "b", "c"])
await hass.services.async_call(
SELECT_DOMAIN,