Refactor input_select (#53334)
parent
3732ae738e
commit
65b19da3ff
|
@ -58,6 +58,7 @@ homeassistant.components.http.*
|
|||
homeassistant.components.huawei_lte.*
|
||||
homeassistant.components.hyperion.*
|
||||
homeassistant.components.image_processing.*
|
||||
homeassistant.components.input_select.*
|
||||
homeassistant.components.integration.*
|
||||
homeassistant.components.iqvia.*
|
||||
homeassistant.components.jewish_calendar.*
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.select import SelectEntity
|
||||
from homeassistant.const import (
|
||||
ATTR_EDITABLE,
|
||||
ATTR_OPTION,
|
||||
|
@ -55,7 +57,7 @@ UPDATE_FIELDS = {
|
|||
}
|
||||
|
||||
|
||||
def _cv_input_select(cfg):
|
||||
def _cv_input_select(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Configure validation helper for input select (voluptuous)."""
|
||||
options = cfg[CONF_OPTIONS]
|
||||
initial = cfg.get(CONF_INITIAL)
|
||||
|
@ -183,138 +185,137 @@ class InputSelectStorageCollection(collection.StorageCollection):
|
|||
CREATE_SCHEMA = vol.Schema(vol.All(CREATE_FIELDS, _cv_input_select))
|
||||
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
|
||||
|
||||
async def _process_create_data(self, data: dict) -> dict:
|
||||
async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the config is valid."""
|
||||
return self.CREATE_SCHEMA(data)
|
||||
return cast(Dict[str, Any], self.CREATE_SCHEMA(data))
|
||||
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict) -> str:
|
||||
def _get_suggested_id(self, info: dict[str, Any]) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
return info[CONF_NAME]
|
||||
return cast(str, info[CONF_NAME])
|
||||
|
||||
async def _update_data(self, data: dict, update_data: dict) -> dict:
|
||||
async def _update_data(
|
||||
self, data: dict[str, Any], update_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Return a new updated data object."""
|
||||
update_data = self.UPDATE_SCHEMA(update_data)
|
||||
return _cv_input_select({**data, **update_data})
|
||||
|
||||
|
||||
class InputSelect(RestoreEntity):
|
||||
class InputSelect(SelectEntity, RestoreEntity):
|
||||
"""Representation of a select input."""
|
||||
|
||||
def __init__(self, config: dict) -> None:
|
||||
_attr_should_poll = False
|
||||
editable = True
|
||||
|
||||
def __init__(self, config: ConfigType) -> None:
|
||||
"""Initialize a select input."""
|
||||
self._config = config
|
||||
self.editable = True
|
||||
self._current_option = config.get(CONF_INITIAL)
|
||||
self._attr_current_option = config.get(CONF_INITIAL)
|
||||
self._attr_icon = config.get(CONF_ICON)
|
||||
self._attr_name = config.get(CONF_NAME)
|
||||
self._attr_options = config[CONF_OPTIONS]
|
||||
self._attr_unique_id = config[CONF_ID]
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config: dict) -> InputSelect:
|
||||
def from_yaml(cls, config: ConfigType) -> InputSelect:
|
||||
"""Return entity instance initialized from yaml storage."""
|
||||
input_select = cls(config)
|
||||
input_select.entity_id = f"{DOMAIN}.{config[CONF_ID]}"
|
||||
input_select.editable = False
|
||||
return input_select
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added."""
|
||||
await super().async_added_to_hass()
|
||||
if self._current_option is not None:
|
||||
if self.current_option is not None:
|
||||
return
|
||||
|
||||
state = await self.async_get_last_state()
|
||||
if not state or state.state not in self._options:
|
||||
self._current_option = self._options[0]
|
||||
if not state or state.state not in self.options:
|
||||
self._attr_current_option = self.options[0]
|
||||
else:
|
||||
self._current_option = state.state
|
||||
self._attr_current_option = state.state
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""If entity should be polled."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the select input."""
|
||||
return self._config.get(CONF_NAME)
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
"""Return the icon to be used for this entity."""
|
||||
return self._config.get(CONF_ICON)
|
||||
|
||||
@property
|
||||
def _options(self) -> list[str]:
|
||||
"""Return a list of selection options."""
|
||||
return self._config[CONF_OPTIONS]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state of the component."""
|
||||
return self._current_option
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self):
|
||||
def extra_state_attributes(self) -> dict[str, bool]:
|
||||
"""Return the state attributes."""
|
||||
return {ATTR_OPTIONS: self._config[ATTR_OPTIONS], ATTR_EDITABLE: self.editable}
|
||||
return {ATTR_EDITABLE: self.editable}
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str | None:
|
||||
"""Return unique id for the entity."""
|
||||
return self._config[CONF_ID]
|
||||
|
||||
@callback
|
||||
def async_select_option(self, option):
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select new option."""
|
||||
if option not in self._options:
|
||||
if option not in self.options:
|
||||
_LOGGER.warning(
|
||||
"Invalid option: %s (possible options: %s)",
|
||||
option,
|
||||
", ".join(self._options),
|
||||
", ".join(self.options),
|
||||
)
|
||||
return
|
||||
self._current_option = option
|
||||
self._attr_current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def async_select_index(self, idx):
|
||||
def async_select_index(self, idx: int) -> None:
|
||||
"""Select new option by index."""
|
||||
new_index = idx % len(self._options)
|
||||
self._current_option = self._options[new_index]
|
||||
new_index = idx % len(self.options)
|
||||
self._attr_current_option = self.options[new_index]
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def async_offset_index(self, offset, cycle):
|
||||
def async_offset_index(self, offset: int, cycle: bool) -> None:
|
||||
"""Offset current index."""
|
||||
current_index = self._options.index(self._current_option)
|
||||
|
||||
current_index = (
|
||||
self.options.index(self.current_option)
|
||||
if self.current_option is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
new_index = current_index + offset
|
||||
if cycle:
|
||||
new_index = new_index % len(self._options)
|
||||
else:
|
||||
if new_index < 0:
|
||||
new_index = 0
|
||||
elif new_index >= len(self._options):
|
||||
new_index = len(self._options) - 1
|
||||
self._current_option = self._options[new_index]
|
||||
new_index = new_index % len(self.options)
|
||||
elif new_index < 0:
|
||||
new_index = 0
|
||||
elif new_index >= len(self.options):
|
||||
new_index = len(self.options) - 1
|
||||
|
||||
self._attr_current_option = self.options[new_index]
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def async_next(self, cycle):
|
||||
def async_next(self, cycle: bool) -> None:
|
||||
"""Select next option."""
|
||||
# If there is no current option, first item is the next
|
||||
if self.current_option is None:
|
||||
self.async_select_index(0)
|
||||
return
|
||||
self.async_offset_index(1, cycle)
|
||||
|
||||
@callback
|
||||
def async_previous(self, cycle):
|
||||
def async_previous(self, cycle: bool) -> None:
|
||||
"""Select previous option."""
|
||||
# If there is no current option, last item is the previous
|
||||
if self.current_option is None:
|
||||
self.async_select_index(-1)
|
||||
return
|
||||
self.async_offset_index(-1, cycle)
|
||||
|
||||
@callback
|
||||
def async_set_options(self, options):
|
||||
async def async_set_options(self, options: list[str]) -> None:
|
||||
"""Set options."""
|
||||
self._current_option = options[0]
|
||||
self._config[CONF_OPTIONS] = options
|
||||
self._attr_options = options
|
||||
|
||||
if self.current_option not in self.options:
|
||||
_LOGGER.warning(
|
||||
"Current option: %s no longer valid (possible options: %s)",
|
||||
self.current_option,
|
||||
", ".join(self.options),
|
||||
)
|
||||
self._attr_current_option = options[0]
|
||||
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_update_config(self, config: dict) -> None:
|
||||
async def async_update_config(self, config: ConfigType) -> None:
|
||||
"""Handle when the config is updated."""
|
||||
self._config = config
|
||||
self._attr_icon = config.get(CONF_ICON)
|
||||
self._attr_name = config.get(CONF_NAME)
|
||||
self._attr_options = config[CONF_OPTIONS]
|
||||
self.async_write_ha_state()
|
||||
|
|
11
mypy.ini
11
mypy.ini
|
@ -649,6 +649,17 @@ no_implicit_optional = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.input_select.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.integration.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
|
|
@ -27,7 +27,6 @@ from homeassistant.const import (
|
|||
from homeassistant.core import Context, State
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_restore_cache
|
||||
|
@ -65,80 +64,12 @@ def storage_setup(hass, hass_storage):
|
|||
return _storage
|
||||
|
||||
|
||||
@bind_hass
|
||||
def select_option(hass, entity_id, option):
|
||||
"""Set value of input_select.
|
||||
|
||||
This is a legacy helper method. Do not use it for new tests.
|
||||
"""
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: option},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
def select_next(hass, entity_id):
|
||||
"""Set next value of input_select.
|
||||
|
||||
This is a legacy helper method. Do not use it for new tests.
|
||||
"""
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_SELECT_NEXT, {ATTR_ENTITY_ID: entity_id}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
def select_previous(hass, entity_id):
|
||||
"""Set previous value of input_select.
|
||||
|
||||
This is a legacy helper method. Do not use it for new tests.
|
||||
"""
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_SELECT_PREVIOUS, {ATTR_ENTITY_ID: entity_id}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
def select_first(hass, entity_id):
|
||||
"""Set first value of input_select.
|
||||
|
||||
This is a legacy helper method. Do not use it for new tests.
|
||||
"""
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_SELECT_FIRST, {ATTR_ENTITY_ID: entity_id}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bind_hass
|
||||
def select_last(hass, entity_id):
|
||||
"""Set last value of input_select.
|
||||
|
||||
This is a legacy helper method. Do not use it for new tests.
|
||||
"""
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_SELECT_LAST, {ATTR_ENTITY_ID: entity_id}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_config(hass):
|
||||
"""Test config."""
|
||||
invalid_configs = [
|
||||
None,
|
||||
{},
|
||||
{"name with space": None},
|
||||
# {'bad_options': {'options': None}},
|
||||
{"bad_initial": {"options": [1, 2], "initial": 3}},
|
||||
]
|
||||
|
||||
|
@ -158,15 +89,21 @@ async def test_select_option(hass):
|
|||
state = hass.states.get(entity_id)
|
||||
assert state.state == "some option"
|
||||
|
||||
select_option(hass, entity_id, "another option")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "another option"},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "another option"
|
||||
|
||||
select_option(hass, entity_id, "non existing option")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "non existing option"},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "another option"
|
||||
|
||||
|
@ -190,15 +127,21 @@ async def test_select_next(hass):
|
|||
state = hass.states.get(entity_id)
|
||||
assert state.state == "middle option"
|
||||
|
||||
select_next(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_NEXT,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "last option"
|
||||
|
||||
select_next(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_NEXT,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "first option"
|
||||
|
||||
|
@ -222,15 +165,21 @@ async def test_select_previous(hass):
|
|||
state = hass.states.get(entity_id)
|
||||
assert state.state == "middle option"
|
||||
|
||||
select_previous(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_PREVIOUS,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "first option"
|
||||
|
||||
select_previous(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_PREVIOUS,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "last option"
|
||||
|
||||
|
@ -254,14 +203,22 @@ async def test_select_first_last(hass):
|
|||
state = hass.states.get(entity_id)
|
||||
assert state.state == "middle option"
|
||||
|
||||
select_first(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_FIRST,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "first option"
|
||||
|
||||
select_last(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_LAST,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "last option"
|
||||
|
@ -326,20 +283,39 @@ async def test_set_options_service(hass):
|
|||
state = hass.states.get(entity_id)
|
||||
assert state.state == "middle option"
|
||||
|
||||
data = {ATTR_OPTIONS: ["test1", "test2"], "entity_id": entity_id}
|
||||
await hass.services.async_call(DOMAIN, SERVICE_SET_OPTIONS, data)
|
||||
await hass.async_block_till_done()
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SET_OPTIONS,
|
||||
{ATTR_OPTIONS: ["first option", "middle option"], ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "middle option"
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SET_OPTIONS,
|
||||
{ATTR_OPTIONS: ["test1", "test2"], ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "test1"
|
||||
|
||||
select_option(hass, entity_id, "first option")
|
||||
await hass.async_block_till_done()
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "first option"},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "test1"
|
||||
|
||||
select_option(hass, entity_id, "test2")
|
||||
await hass.async_block_till_done()
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{ATTR_ENTITY_ID: entity_id, ATTR_OPTION: "test2"},
|
||||
blocking=True,
|
||||
)
|
||||
state = hass.states.get(entity_id)
|
||||
assert state.state == "test2"
|
||||
|
||||
|
@ -488,7 +464,6 @@ async def test_reload(hass, hass_admin_user, hass_read_only_user):
|
|||
blocking=True,
|
||||
context=Context(user_id=hass_admin_user.id),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert count_start + 2 == len(hass.states.async_entity_ids())
|
||||
|
||||
|
@ -671,6 +646,5 @@ async def test_setup_no_config(hass, hass_admin_user):
|
|||
blocking=True,
|
||||
context=Context(user_id=hass_admin_user.id),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert count_start == len(hass.states.async_entity_ids())
|
||||
|
|
|
@ -214,7 +214,7 @@ async def test_templates_with_entities(hass, calls):
|
|||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
_verify(hass, "a", ["a", "b", "c"])
|
||||
_verify(hass, "b", ["a", "b", "c"])
|
||||
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
|
|
Loading…
Reference in New Issue