diff --git a/homeassistant/components/shelly/__init__.py b/homeassistant/components/shelly/__init__.py index 6009f8613fe..c0d0016392f 100644 --- a/homeassistant/components/shelly/__init__.py +++ b/homeassistant/components/shelly/__init__.py @@ -46,7 +46,11 @@ from .const import ( SLEEP_PERIOD_MULTIPLIER, UPDATE_PERIOD_MULTIPLIER, ) -from .utils import get_coap_context, get_device_name, get_device_sleep_period +from .utils import ( + get_block_device_name, + get_block_device_sleep_period, + get_coap_context, +) PLATFORMS: Final = ["binary_sensor", "cover", "light", "sensor", "switch"] SLEEPING_PLATFORMS: Final = ["binary_sensor", "sensor"] @@ -85,6 +89,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) return False + if entry.data.get("gen") == 2: + return True + hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id] = {} hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][DEVICE] = None @@ -124,7 +131,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if sleep_period is None: data = {**entry.data} - data["sleep_period"] = get_device_sleep_period(device.settings) + data["sleep_period"] = get_block_device_sleep_period(device.settings) data["model"] = device.settings["device"]["type"] hass.config_entries.async_update_entry(entry, data=data) @@ -192,7 +199,9 @@ class ShellyDeviceWrapper(update_coordinator.DataUpdateCoordinator): UPDATE_PERIOD_MULTIPLIER * device.settings["coiot"]["update_period"] ) - device_name = get_device_name(device) if device.initialized else entry.title + device_name = ( + get_block_device_name(device) if device.initialized else entry.title + ) super().__init__( hass, _LOGGER, @@ -338,7 +347,7 @@ class ShellyDeviceRestWrapper(update_coordinator.DataUpdateCoordinator): super().__init__( hass, _LOGGER, - name=get_device_name(device), + name=get_block_device_name(device), update_interval=timedelta(seconds=update_interval), ) self.device = device @@ -360,6 +369,9 @@ class ShellyDeviceRestWrapper(update_coordinator.DataUpdateCoordinator): async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" + if entry.data.get("gen") == 2: + return True + device = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id].get(DEVICE) if device is not None: # If device is present, device wrapper is not setup yet diff --git a/homeassistant/components/shelly/config_flow.py b/homeassistant/components/shelly/config_flow.py index da4413e16b7..31f99b2b1fb 100644 --- a/homeassistant/components/shelly/config_flow.py +++ b/homeassistant/components/shelly/config_flow.py @@ -3,27 +3,37 @@ from __future__ import annotations import asyncio import logging -from typing import Any, Dict, Final, cast +from typing import Any, Final import aiohttp import aioshelly from aioshelly.block_device import BlockDevice +from aioshelly.rpc_device import RpcDevice import async_timeout import voluptuous as vol -from homeassistant import config_entries, core +from homeassistant import config_entries from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, CONF_USERNAME, HTTP_UNAUTHORIZED, ) +from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import aiohttp_client from homeassistant.helpers.typing import DiscoveryInfoType from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, DOMAIN -from .utils import get_coap_context, get_device_sleep_period +from .utils import ( + get_block_device_name, + get_block_device_sleep_period, + get_coap_context, + get_info_auth, + get_info_gen, + get_model_name, + get_rpc_device_name, +) _LOGGER: Final = logging.getLogger(__name__) @@ -33,34 +43,49 @@ HTTP_CONNECT_ERRORS: Final = (asyncio.TimeoutError, aiohttp.ClientError) async def validate_input( - hass: core.HomeAssistant, host: str, data: dict[str, Any] + hass: HomeAssistant, + host: str, + info: dict[str, Any], + data: dict[str, Any], ) -> dict[str, Any]: """Validate the user input allows us to connect. - Data has the keys from DATA_SCHEMA with values provided by the user. + Data has the keys from HOST_SCHEMA with values provided by the user. """ - options = aioshelly.common.ConnectionOptions( - host, data.get(CONF_USERNAME), data.get(CONF_PASSWORD) + host, + data.get(CONF_USERNAME), + data.get(CONF_PASSWORD), ) - coap_context = await get_coap_context(hass) async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): - device = await BlockDevice.create( + if get_info_gen(info) == 2: + rpc_device = await RpcDevice.create( + aiohttp_client.async_get_clientsession(hass), + options, + ) + await rpc_device.shutdown() + return { + "title": get_rpc_device_name(rpc_device), + "sleep_period": 0, + "model": rpc_device.model, + "gen": 2, + } + + # Gen1 + coap_context = await get_coap_context(hass) + block_device = await BlockDevice.create( aiohttp_client.async_get_clientsession(hass), coap_context, options, ) - - device.shutdown() - - # Return info that you want to store in the config entry. - return { - "title": device.settings["name"], - "hostname": device.settings["device"]["hostname"], - "sleep_period": get_device_sleep_period(device.settings), - "model": device.settings["device"]["type"], - } + block_device.shutdown() + return { + "title": get_block_device_name(block_device), + "sleep_period": get_block_device_sleep_period(block_device.settings), + "model": block_device.model, + "gen": 1, + } class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): @@ -80,7 +105,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if user_input is not None: host: str = user_input[CONF_HOST] try: - info = await self._async_get_info(host) + self.info = await self._async_get_info(host) except HTTP_CONNECT_ERRORS: errors["base"] = "cannot_connect" except aioshelly.exceptions.FirmwareUnsupported: @@ -89,14 +114,16 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" else: - await self.async_set_unique_id(info["mac"]) + await self.async_set_unique_id(self.info["mac"]) self._abort_if_unique_id_configured({CONF_HOST: host}) self.host = host - if info["auth"]: + if get_info_auth(self.info): return await self.async_step_credentials() try: - device_info = await validate_input(self.hass, self.host, {}) + device_info = await validate_input( + self.hass, self.host, self.info, {} + ) except HTTP_CONNECT_ERRORS: errors["base"] = "cannot_connect" except Exception: # pylint: disable=broad-except @@ -104,11 +131,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors["base"] = "unknown" else: return self.async_create_entry( - title=device_info["title"] or device_info["hostname"], + title=device_info["title"], data={ **user_input, "sleep_period": device_info["sleep_period"], "model": device_info["model"], + "gen": device_info["gen"], }, ) @@ -123,7 +151,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input is not None: try: - device_info = await validate_input(self.hass, self.host, user_input) + device_info = await validate_input( + self.hass, self.host, self.info, user_input + ) except aiohttp.ClientResponseError as error: if error.status == HTTP_UNAUTHORIZED: errors["base"] = "invalid_auth" @@ -136,12 +166,13 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors["base"] = "unknown" else: return self.async_create_entry( - title=device_info["title"] or device_info["hostname"], + title=device_info["title"], data={ **user_input, CONF_HOST: self.host, "sleep_period": device_info["sleep_period"], "model": device_info["model"], + "gen": device_info["gen"], }, ) else: @@ -163,13 +194,13 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ) -> FlowResult: """Handle zeroconf discovery.""" try: - self.info = info = await self._async_get_info(discovery_info["host"]) + self.info = await self._async_get_info(discovery_info["host"]) except HTTP_CONNECT_ERRORS: return self.async_abort(reason="cannot_connect") except aioshelly.exceptions.FirmwareUnsupported: return self.async_abort(reason="unsupported_firmware") - await self.async_set_unique_id(info["mac"]) + await self.async_set_unique_id(self.info["mac"]) self._abort_if_unique_id_configured({CONF_HOST: discovery_info["host"]}) self.host = discovery_info["host"] @@ -177,11 +208,11 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): "name": discovery_info.get("name", "").split(".")[0] } - if info["auth"]: + if get_info_auth(self.info): return await self.async_step_credentials() try: - self.device_info = await validate_input(self.hass, self.host, {}) + self.device_info = await validate_input(self.hass, self.host, self.info, {}) except HTTP_CONNECT_ERRORS: return self.async_abort(reason="cannot_connect") @@ -194,11 +225,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input is not None: return self.async_create_entry( - title=self.device_info["title"] or self.device_info["hostname"], + title=self.device_info["title"], data={ "host": self.host, "sleep_period": self.device_info["sleep_period"], "model": self.device_info["model"], + "gen": self.device_info["gen"], }, ) @@ -207,9 +239,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form( step_id="confirm_discovery", description_placeholders={ - "model": aioshelly.const.MODEL_NAMES.get( - self.info["type"], self.info["type"] - ), + "model": get_model_name(self.info), "host": self.host, }, errors=errors, @@ -218,10 +248,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): async def _async_get_info(self, host: str) -> dict[str, Any]: """Get info from shelly device.""" async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): - return cast( - Dict[str, Any], - await aioshelly.common.get_info( - aiohttp_client.async_get_clientsession(self.hass), - host, - ), + return await aioshelly.common.get_info( + aiohttp_client.async_get_clientsession(self.hass), host ) diff --git a/homeassistant/components/shelly/logbook.py b/homeassistant/components/shelly/logbook.py index deac3b5c05b..ca4818085d0 100644 --- a/homeassistant/components/shelly/logbook.py +++ b/homeassistant/components/shelly/logbook.py @@ -15,7 +15,7 @@ from .const import ( DOMAIN, EVENT_SHELLY_CLICK, ) -from .utils import get_device_name +from .utils import get_block_device_name @callback @@ -30,7 +30,7 @@ def async_describe_events( """Describe shelly.click logbook event.""" wrapper = get_device_wrapper(hass, event.data[ATTR_DEVICE_ID]) if wrapper and wrapper.device.initialized: - device_name = get_device_name(wrapper.device) + device_name = get_block_device_name(wrapper.device) else: device_name = event.data[ATTR_DEVICE] diff --git a/homeassistant/components/shelly/manifest.json b/homeassistant/components/shelly/manifest.json index 0c1e90eaaee..ca092295473 100644 --- a/homeassistant/components/shelly/manifest.json +++ b/homeassistant/components/shelly/manifest.json @@ -3,7 +3,7 @@ "name": "Shelly", "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/shelly", - "requirements": ["aioshelly==1.0.0"], + "requirements": ["aioshelly==1.0.1"], "zeroconf": [ { "type": "_http._tcp.local.", diff --git a/homeassistant/components/shelly/utils.py b/homeassistant/components/shelly/utils.py index dfd4b1dc78a..405c34e6eb9 100644 --- a/homeassistant/components/shelly/utils.py +++ b/homeassistant/components/shelly/utils.py @@ -6,6 +6,8 @@ import logging from typing import Any, Final, cast from aioshelly.block_device import BLOCK_VALUE_UNIT, COAP, Block, BlockDevice +from aioshelly.const import MODEL_NAMES +from aioshelly.rpc_device import RpcDevice from homeassistant.const import EVENT_HOMEASSISTANT_STOP, TEMP_CELSIUS, TEMP_FAHRENHEIT from homeassistant.core import HomeAssistant, callback @@ -45,11 +47,18 @@ def temperature_unit(block_info: dict[str, Any]) -> str: return TEMP_CELSIUS -def get_device_name(device: BlockDevice) -> str: +def get_block_device_name(device: BlockDevice) -> str: """Naming for device.""" return cast(str, device.settings["name"] or device.settings["device"]["hostname"]) +def get_rpc_device_name(device: RpcDevice) -> str: + """Naming for device.""" + # Gen2 does not support setting device name + # AP SSID name is used as a nicely formatted device name + return cast(str, device.config["wifi"]["ap"]["ssid"] or device.hostname) + + def get_number_of_channels(device: BlockDevice, block: Block) -> int: """Get number of channels for block type.""" assert isinstance(device.shelly, dict) @@ -88,7 +97,7 @@ def get_entity_name( def get_device_channel_name(device: BlockDevice, block: Block | None) -> str: """Get name based on device and channel name.""" - entity_name = get_device_name(device) + entity_name = get_block_device_name(device) if ( not block @@ -200,7 +209,7 @@ async def get_coap_context(hass: HomeAssistant) -> COAP: return context -def get_device_sleep_period(settings: dict[str, Any]) -> int: +def get_block_device_sleep_period(settings: dict[str, Any]) -> int: """Return the device sleep period in seconds or 0 for non sleeping devices.""" sleep_period = 0 @@ -210,3 +219,21 @@ def get_device_sleep_period(settings: dict[str, Any]) -> int: sleep_period *= 60 # hours to minutes return sleep_period * 60 # minutes to seconds + + +def get_info_auth(info: dict[str, Any]) -> bool: + """Return true if device has authorization enabled.""" + return cast(bool, info.get("auth") or info.get("auth_en")) + + +def get_info_gen(info: dict[str, Any]) -> int: + """Return the device generation from shelly info.""" + return int(info.get("gen", 1)) + + +def get_model_name(info: dict[str, Any]) -> str: + """Return the device model name.""" + if get_info_gen(info) == 2: + return cast(str, MODEL_NAMES.get(info["model"], info["model"])) + + return cast(str, MODEL_NAMES.get(info["type"], info["type"])) diff --git a/requirements_all.txt b/requirements_all.txt index d96a8792be4..1d4c6981105 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -240,7 +240,7 @@ aiopylgtv==0.4.0 aiorecollect==1.0.8 # homeassistant.components.shelly -aioshelly==1.0.0 +aioshelly==1.0.1 # homeassistant.components.switcher_kis aioswitcher==2.0.5 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 2f483571e3f..ce70997fd12 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -161,7 +161,7 @@ aiopylgtv==0.4.0 aiorecollect==1.0.8 # homeassistant.components.shelly -aioshelly==1.0.0 +aioshelly==1.0.1 # homeassistant.components.switcher_kis aioswitcher==2.0.5 diff --git a/tests/components/shelly/test_config_flow.py b/tests/components/shelly/test_config_flow.py index 81118f928d3..1cc102715c5 100644 --- a/tests/components/shelly/test_config_flow.py +++ b/tests/components/shelly/test_config_flow.py @@ -20,9 +20,13 @@ DISCOVERY_INFO = { "name": "shelly1pm-12345", "properties": {"id": "shelly1pm-12345"}, } +MOCK_CONFIG = { + "wifi": {"ap": {"ssid": "Test name"}}, +} -async def test_form(hass): +@pytest.mark.parametrize("gen", [1, 2]) +async def test_form(hass, gen): """Test we get the form.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( @@ -33,14 +37,24 @@ async def test_form(hass): with patch( "aioshelly.common.get_info", - return_value={"mac": "test-mac", "type": "SHSW-1", "auth": False}, + return_value={"mac": "test-mac", "type": "SHSW-1", "auth": False, "gen": gen}, ), patch( "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=MOCK_SETTINGS, ) ), + ), patch( + "aioshelly.rpc_device.RpcDevice.create", + new=AsyncMock( + return_value=Mock( + model="SHSW-1", + config=MOCK_CONFIG, + shutdown=AsyncMock(), + ) + ), ), patch( "homeassistant.components.shelly.async_setup", return_value=True ) as mock_setup, patch( @@ -59,6 +73,7 @@ async def test_form(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, + "gen": gen, } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -84,6 +99,7 @@ async def test_title_without_name(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=settings, ) ), @@ -105,6 +121,7 @@ async def test_title_without_name(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, + "gen": 1, } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -134,6 +151,7 @@ async def test_form_auth(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=MOCK_SETTINGS, ) ), @@ -155,6 +173,7 @@ async def test_form_auth(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, + "gen": 1, "username": "test username", "password": "test password", } @@ -260,6 +279,7 @@ async def test_user_setup_ignored_device(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=settings, ) ), @@ -350,6 +370,7 @@ async def test_zeroconf(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=MOCK_SETTINGS, ) ), @@ -386,6 +407,7 @@ async def test_zeroconf(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, + "gen": 1, } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -407,6 +429,7 @@ async def test_zeroconf_sleeping_device(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings={ "name": "Test name", "device": { @@ -450,6 +473,7 @@ async def test_zeroconf_sleeping_device(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 600, + "gen": 1, } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -560,6 +584,7 @@ async def test_zeroconf_require_auth(hass): "aioshelly.block_device.BlockDevice.create", new=AsyncMock( return_value=Mock( + model="SHSW-1", settings=MOCK_SETTINGS, ) ), @@ -581,6 +606,7 @@ async def test_zeroconf_require_auth(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, + "gen": 1, "username": "test username", "password": "test password", }