Add config flow support for RPC device (#56118)

pull/56253/head
Shay Levy 2021-09-11 23:28:33 +03:00 committed by GitHub
parent 8c3c2ad8e3
commit f1a88f0563
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 145 additions and 54 deletions

View File

@ -46,7 +46,11 @@ from .const import (
SLEEP_PERIOD_MULTIPLIER,
UPDATE_PERIOD_MULTIPLIER,
)
from .utils import get_coap_context, get_device_name, get_device_sleep_period
from .utils import (
get_block_device_name,
get_block_device_sleep_period,
get_coap_context,
)
PLATFORMS: Final = ["binary_sensor", "cover", "light", "sensor", "switch"]
SLEEPING_PLATFORMS: Final = ["binary_sensor", "sensor"]
@ -85,6 +89,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
return False
if entry.data.get("gen") == 2:
return True
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id] = {}
hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][DEVICE] = None
@ -124,7 +131,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if sleep_period is None:
data = {**entry.data}
data["sleep_period"] = get_device_sleep_period(device.settings)
data["sleep_period"] = get_block_device_sleep_period(device.settings)
data["model"] = device.settings["device"]["type"]
hass.config_entries.async_update_entry(entry, data=data)
@ -192,7 +199,9 @@ class ShellyDeviceWrapper(update_coordinator.DataUpdateCoordinator):
UPDATE_PERIOD_MULTIPLIER * device.settings["coiot"]["update_period"]
)
device_name = get_device_name(device) if device.initialized else entry.title
device_name = (
get_block_device_name(device) if device.initialized else entry.title
)
super().__init__(
hass,
_LOGGER,
@ -338,7 +347,7 @@ class ShellyDeviceRestWrapper(update_coordinator.DataUpdateCoordinator):
super().__init__(
hass,
_LOGGER,
name=get_device_name(device),
name=get_block_device_name(device),
update_interval=timedelta(seconds=update_interval),
)
self.device = device
@ -360,6 +369,9 @@ class ShellyDeviceRestWrapper(update_coordinator.DataUpdateCoordinator):
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
if entry.data.get("gen") == 2:
return True
device = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id].get(DEVICE)
if device is not None:
# If device is present, device wrapper is not setup yet

View File

@ -3,27 +3,37 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, Final, cast
from typing import Any, Final
import aiohttp
import aioshelly
from aioshelly.block_device import BlockDevice
from aioshelly.rpc_device import RpcDevice
import async_timeout
import voluptuous as vol
from homeassistant import config_entries, core
from homeassistant import config_entries
from homeassistant.const import (
CONF_HOST,
CONF_PASSWORD,
CONF_USERNAME,
HTTP_UNAUTHORIZED,
)
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import aiohttp_client
from homeassistant.helpers.typing import DiscoveryInfoType
from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, DOMAIN
from .utils import get_coap_context, get_device_sleep_period
from .utils import (
get_block_device_name,
get_block_device_sleep_period,
get_coap_context,
get_info_auth,
get_info_gen,
get_model_name,
get_rpc_device_name,
)
_LOGGER: Final = logging.getLogger(__name__)
@ -33,33 +43,48 @@ HTTP_CONNECT_ERRORS: Final = (asyncio.TimeoutError, aiohttp.ClientError)
async def validate_input(
hass: core.HomeAssistant, host: str, data: dict[str, Any]
hass: HomeAssistant,
host: str,
info: dict[str, Any],
data: dict[str, Any],
) -> dict[str, Any]:
"""Validate the user input allows us to connect.
Data has the keys from DATA_SCHEMA with values provided by the user.
Data has the keys from HOST_SCHEMA with values provided by the user.
"""
options = aioshelly.common.ConnectionOptions(
host, data.get(CONF_USERNAME), data.get(CONF_PASSWORD)
host,
data.get(CONF_USERNAME),
data.get(CONF_PASSWORD),
)
coap_context = await get_coap_context(hass)
async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC):
device = await BlockDevice.create(
if get_info_gen(info) == 2:
rpc_device = await RpcDevice.create(
aiohttp_client.async_get_clientsession(hass),
options,
)
await rpc_device.shutdown()
return {
"title": get_rpc_device_name(rpc_device),
"sleep_period": 0,
"model": rpc_device.model,
"gen": 2,
}
# Gen1
coap_context = await get_coap_context(hass)
block_device = await BlockDevice.create(
aiohttp_client.async_get_clientsession(hass),
coap_context,
options,
)
device.shutdown()
# Return info that you want to store in the config entry.
block_device.shutdown()
return {
"title": device.settings["name"],
"hostname": device.settings["device"]["hostname"],
"sleep_period": get_device_sleep_period(device.settings),
"model": device.settings["device"]["type"],
"title": get_block_device_name(block_device),
"sleep_period": get_block_device_sleep_period(block_device.settings),
"model": block_device.model,
"gen": 1,
}
@ -80,7 +105,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None:
host: str = user_input[CONF_HOST]
try:
info = await self._async_get_info(host)
self.info = await self._async_get_info(host)
except HTTP_CONNECT_ERRORS:
errors["base"] = "cannot_connect"
except aioshelly.exceptions.FirmwareUnsupported:
@ -89,14 +114,16 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
await self.async_set_unique_id(info["mac"])
await self.async_set_unique_id(self.info["mac"])
self._abort_if_unique_id_configured({CONF_HOST: host})
self.host = host
if info["auth"]:
if get_info_auth(self.info):
return await self.async_step_credentials()
try:
device_info = await validate_input(self.hass, self.host, {})
device_info = await validate_input(
self.hass, self.host, self.info, {}
)
except HTTP_CONNECT_ERRORS:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
@ -104,11 +131,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors["base"] = "unknown"
else:
return self.async_create_entry(
title=device_info["title"] or device_info["hostname"],
title=device_info["title"],
data={
**user_input,
"sleep_period": device_info["sleep_period"],
"model": device_info["model"],
"gen": device_info["gen"],
},
)
@ -123,7 +151,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input is not None:
try:
device_info = await validate_input(self.hass, self.host, user_input)
device_info = await validate_input(
self.hass, self.host, self.info, user_input
)
except aiohttp.ClientResponseError as error:
if error.status == HTTP_UNAUTHORIZED:
errors["base"] = "invalid_auth"
@ -136,12 +166,13 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors["base"] = "unknown"
else:
return self.async_create_entry(
title=device_info["title"] or device_info["hostname"],
title=device_info["title"],
data={
**user_input,
CONF_HOST: self.host,
"sleep_period": device_info["sleep_period"],
"model": device_info["model"],
"gen": device_info["gen"],
},
)
else:
@ -163,13 +194,13 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
) -> FlowResult:
"""Handle zeroconf discovery."""
try:
self.info = info = await self._async_get_info(discovery_info["host"])
self.info = await self._async_get_info(discovery_info["host"])
except HTTP_CONNECT_ERRORS:
return self.async_abort(reason="cannot_connect")
except aioshelly.exceptions.FirmwareUnsupported:
return self.async_abort(reason="unsupported_firmware")
await self.async_set_unique_id(info["mac"])
await self.async_set_unique_id(self.info["mac"])
self._abort_if_unique_id_configured({CONF_HOST: discovery_info["host"]})
self.host = discovery_info["host"]
@ -177,11 +208,11 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"name": discovery_info.get("name", "").split(".")[0]
}
if info["auth"]:
if get_info_auth(self.info):
return await self.async_step_credentials()
try:
self.device_info = await validate_input(self.hass, self.host, {})
self.device_info = await validate_input(self.hass, self.host, self.info, {})
except HTTP_CONNECT_ERRORS:
return self.async_abort(reason="cannot_connect")
@ -194,11 +225,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input is not None:
return self.async_create_entry(
title=self.device_info["title"] or self.device_info["hostname"],
title=self.device_info["title"],
data={
"host": self.host,
"sleep_period": self.device_info["sleep_period"],
"model": self.device_info["model"],
"gen": self.device_info["gen"],
},
)
@ -207,9 +239,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form(
step_id="confirm_discovery",
description_placeholders={
"model": aioshelly.const.MODEL_NAMES.get(
self.info["type"], self.info["type"]
),
"model": get_model_name(self.info),
"host": self.host,
},
errors=errors,
@ -218,10 +248,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
async def _async_get_info(self, host: str) -> dict[str, Any]:
"""Get info from shelly device."""
async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC):
return cast(
Dict[str, Any],
await aioshelly.common.get_info(
aiohttp_client.async_get_clientsession(self.hass),
host,
),
return await aioshelly.common.get_info(
aiohttp_client.async_get_clientsession(self.hass), host
)

View File

@ -15,7 +15,7 @@ from .const import (
DOMAIN,
EVENT_SHELLY_CLICK,
)
from .utils import get_device_name
from .utils import get_block_device_name
@callback
@ -30,7 +30,7 @@ def async_describe_events(
"""Describe shelly.click logbook event."""
wrapper = get_device_wrapper(hass, event.data[ATTR_DEVICE_ID])
if wrapper and wrapper.device.initialized:
device_name = get_device_name(wrapper.device)
device_name = get_block_device_name(wrapper.device)
else:
device_name = event.data[ATTR_DEVICE]

View File

@ -3,7 +3,7 @@
"name": "Shelly",
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/shelly",
"requirements": ["aioshelly==1.0.0"],
"requirements": ["aioshelly==1.0.1"],
"zeroconf": [
{
"type": "_http._tcp.local.",

View File

@ -6,6 +6,8 @@ import logging
from typing import Any, Final, cast
from aioshelly.block_device import BLOCK_VALUE_UNIT, COAP, Block, BlockDevice
from aioshelly.const import MODEL_NAMES
from aioshelly.rpc_device import RpcDevice
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, TEMP_CELSIUS, TEMP_FAHRENHEIT
from homeassistant.core import HomeAssistant, callback
@ -45,11 +47,18 @@ def temperature_unit(block_info: dict[str, Any]) -> str:
return TEMP_CELSIUS
def get_device_name(device: BlockDevice) -> str:
def get_block_device_name(device: BlockDevice) -> str:
"""Naming for device."""
return cast(str, device.settings["name"] or device.settings["device"]["hostname"])
def get_rpc_device_name(device: RpcDevice) -> str:
"""Naming for device."""
# Gen2 does not support setting device name
# AP SSID name is used as a nicely formatted device name
return cast(str, device.config["wifi"]["ap"]["ssid"] or device.hostname)
def get_number_of_channels(device: BlockDevice, block: Block) -> int:
"""Get number of channels for block type."""
assert isinstance(device.shelly, dict)
@ -88,7 +97,7 @@ def get_entity_name(
def get_device_channel_name(device: BlockDevice, block: Block | None) -> str:
"""Get name based on device and channel name."""
entity_name = get_device_name(device)
entity_name = get_block_device_name(device)
if (
not block
@ -200,7 +209,7 @@ async def get_coap_context(hass: HomeAssistant) -> COAP:
return context
def get_device_sleep_period(settings: dict[str, Any]) -> int:
def get_block_device_sleep_period(settings: dict[str, Any]) -> int:
"""Return the device sleep period in seconds or 0 for non sleeping devices."""
sleep_period = 0
@ -210,3 +219,21 @@ def get_device_sleep_period(settings: dict[str, Any]) -> int:
sleep_period *= 60 # hours to minutes
return sleep_period * 60 # minutes to seconds
def get_info_auth(info: dict[str, Any]) -> bool:
"""Return true if device has authorization enabled."""
return cast(bool, info.get("auth") or info.get("auth_en"))
def get_info_gen(info: dict[str, Any]) -> int:
"""Return the device generation from shelly info."""
return int(info.get("gen", 1))
def get_model_name(info: dict[str, Any]) -> str:
"""Return the device model name."""
if get_info_gen(info) == 2:
return cast(str, MODEL_NAMES.get(info["model"], info["model"]))
return cast(str, MODEL_NAMES.get(info["type"], info["type"]))

View File

@ -240,7 +240,7 @@ aiopylgtv==0.4.0
aiorecollect==1.0.8
# homeassistant.components.shelly
aioshelly==1.0.0
aioshelly==1.0.1
# homeassistant.components.switcher_kis
aioswitcher==2.0.5

View File

@ -161,7 +161,7 @@ aiopylgtv==0.4.0
aiorecollect==1.0.8
# homeassistant.components.shelly
aioshelly==1.0.0
aioshelly==1.0.1
# homeassistant.components.switcher_kis
aioswitcher==2.0.5

View File

@ -20,9 +20,13 @@ DISCOVERY_INFO = {
"name": "shelly1pm-12345",
"properties": {"id": "shelly1pm-12345"},
}
MOCK_CONFIG = {
"wifi": {"ap": {"ssid": "Test name"}},
}
async def test_form(hass):
@pytest.mark.parametrize("gen", [1, 2])
async def test_form(hass, gen):
"""Test we get the form."""
await setup.async_setup_component(hass, "persistent_notification", {})
result = await hass.config_entries.flow.async_init(
@ -33,14 +37,24 @@ async def test_form(hass):
with patch(
"aioshelly.common.get_info",
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": False},
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": False, "gen": gen},
), patch(
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=MOCK_SETTINGS,
)
),
), patch(
"aioshelly.rpc_device.RpcDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
config=MOCK_CONFIG,
shutdown=AsyncMock(),
)
),
), patch(
"homeassistant.components.shelly.async_setup", return_value=True
) as mock_setup, patch(
@ -59,6 +73,7 @@ async def test_form(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": gen,
}
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
@ -84,6 +99,7 @@ async def test_title_without_name(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=settings,
)
),
@ -105,6 +121,7 @@ async def test_title_without_name(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": 1,
}
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
@ -134,6 +151,7 @@ async def test_form_auth(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=MOCK_SETTINGS,
)
),
@ -155,6 +173,7 @@ async def test_form_auth(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": 1,
"username": "test username",
"password": "test password",
}
@ -260,6 +279,7 @@ async def test_user_setup_ignored_device(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=settings,
)
),
@ -350,6 +370,7 @@ async def test_zeroconf(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=MOCK_SETTINGS,
)
),
@ -386,6 +407,7 @@ async def test_zeroconf(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": 1,
}
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
@ -407,6 +429,7 @@ async def test_zeroconf_sleeping_device(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings={
"name": "Test name",
"device": {
@ -450,6 +473,7 @@ async def test_zeroconf_sleeping_device(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 600,
"gen": 1,
}
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
@ -560,6 +584,7 @@ async def test_zeroconf_require_auth(hass):
"aioshelly.block_device.BlockDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
settings=MOCK_SETTINGS,
)
),
@ -581,6 +606,7 @@ async def test_zeroconf_require_auth(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": 1,
"username": "test username",
"password": "test password",
}