diff --git a/homeassistant/components/coronavirus/__init__.py b/homeassistant/components/coronavirus/__init__.py index d05c4cef862..4bda4edcd37 100644 --- a/homeassistant/components/coronavirus/__init__.py +++ b/homeassistant/components/coronavirus/__init__.py @@ -15,14 +15,14 @@ from .const import DOMAIN PLATFORMS = ["sensor"] -async def async_setup(hass: HomeAssistant, config: dict): +async def async_setup(hass: HomeAssistant, config: dict) -> bool: """Set up the Coronavirus component.""" # Make sure coordinator is initialized. await get_coordinator(hass) return True -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Coronavirus from a config entry.""" if isinstance(entry.data["country"], int): hass.config_entries.async_update_entry( @@ -44,6 +44,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): if not entry.unique_id: hass.config_entries.async_update_entry(entry, unique_id=entry.data["country"]) + coordinator = await get_coordinator(hass) + if not coordinator.last_update_success: + await coordinator.async_config_entry_first_refresh() + for platform in PLATFORMS: hass.async_create_task( hass.config_entries.async_forward_entry_setup(entry, platform) @@ -52,9 +56,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - unload_ok = all( + return all( await asyncio.gather( *[ hass.config_entries.async_forward_entry_unload(entry, platform) @@ -63,10 +67,10 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): ) ) - return unload_ok - -async def get_coordinator(hass): +async def get_coordinator( + hass: HomeAssistant, +) -> update_coordinator.DataUpdateCoordinator: """Get the data update coordinator.""" if DOMAIN in hass.data: return hass.data[DOMAIN] diff --git a/homeassistant/components/coronavirus/config_flow.py b/homeassistant/components/coronavirus/config_flow.py index 6d2776c7ecc..4f6e865fa37 100644 --- a/homeassistant/components/coronavirus/config_flow.py +++ b/homeassistant/components/coronavirus/config_flow.py @@ -1,4 +1,8 @@ """Config flow for Coronavirus integration.""" +from __future__ import annotations + +from typing import Any + import voluptuous as vol from homeassistant import config_entries @@ -15,13 +19,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): _options = None - async def async_step_user(self, user_input=None): + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> dict[str, Any]: """Handle the initial step.""" errors = {} if self._options is None: - self._options = {OPTION_WORLDWIDE: "Worldwide"} coordinator = await get_coordinator(self.hass) + if not coordinator.last_update_success: + return self.async_abort(reason="cannot_connect") + + self._options = {OPTION_WORLDWIDE: "Worldwide"} for case in sorted( coordinator.data.values(), key=lambda case: case.country ): diff --git a/homeassistant/components/coronavirus/strings.json b/homeassistant/components/coronavirus/strings.json index 6a5b2626003..e0b29d6c8db 100644 --- a/homeassistant/components/coronavirus/strings.json +++ b/homeassistant/components/coronavirus/strings.json @@ -7,6 +7,7 @@ } }, "abort": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" } } diff --git a/homeassistant/components/coronavirus/translations/en.json b/homeassistant/components/coronavirus/translations/en.json index cbd057bfce1..ea7ba1f6f9d 100644 --- a/homeassistant/components/coronavirus/translations/en.json +++ b/homeassistant/components/coronavirus/translations/en.json @@ -1,7 +1,8 @@ { "config": { "abort": { - "already_configured": "Service is already configured" + "already_configured": "Service is already configured", + "cannot_connect": "Failed to connect" }, "step": { "user": { diff --git a/tests/components/coronavirus/test_config_flow.py b/tests/components/coronavirus/test_config_flow.py index 06d586ba2a5..bfc69200893 100644 --- a/tests/components/coronavirus/test_config_flow.py +++ b/tests/components/coronavirus/test_config_flow.py @@ -1,9 +1,14 @@ """Test the Coronavirus config flow.""" +from unittest.mock import MagicMock, patch + +from aiohttp import ClientError + from homeassistant import config_entries, setup from homeassistant.components.coronavirus.const import DOMAIN, OPTION_WORLDWIDE +from homeassistant.core import HomeAssistant -async def test_form(hass): +async def test_form(hass: HomeAssistant) -> None: """Test we get the form.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( @@ -24,3 +29,22 @@ async def test_form(hass): } await hass.async_block_till_done() assert len(hass.states.async_all()) == 4 + + +@patch( + "coronavirus.get_cases", + side_effect=ClientError, +) +async def test_abort_on_connection_error( + mock_get_cases: MagicMock, hass: HomeAssistant +) -> None: + """Test we abort on connection error.""" + await setup.async_setup_component(hass, "persistent_notification", {}) + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + assert "type" in result + assert result["type"] == "abort" + assert "reason" in result + assert result["reason"] == "cannot_connect" diff --git a/tests/components/coronavirus/test_init.py b/tests/components/coronavirus/test_init.py index cc49bf7d4b6..c36255db9d1 100644 --- a/tests/components/coronavirus/test_init.py +++ b/tests/components/coronavirus/test_init.py @@ -1,12 +1,18 @@ """Test init of Coronavirus integration.""" +from unittest.mock import MagicMock, patch + +from aiohttp import ClientError + from homeassistant.components.coronavirus.const import DOMAIN, OPTION_WORLDWIDE +from homeassistant.config_entries import ENTRY_STATE_SETUP_RETRY +from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry, mock_registry -async def test_migration(hass): +async def test_migration(hass: HomeAssistant) -> None: """Test that we can migrate coronavirus to stable unique ID.""" nl_entry = MockConfigEntry(domain=DOMAIN, title="Netherlands", data={"country": 34}) nl_entry.add_to_hass(hass) @@ -47,3 +53,20 @@ async def test_migration(hass): assert nl_entry.unique_id == "Netherlands" assert worldwide_entry.unique_id == OPTION_WORLDWIDE + + +@patch( + "coronavirus.get_cases", + side_effect=ClientError, +) +async def test_config_entry_not_ready( + mock_get_cases: MagicMock, hass: HomeAssistant +) -> None: + """Test the configuration entry not ready.""" + entry = MockConfigEntry(domain=DOMAIN, title="Netherlands", data={"country": 34}) + entry.add_to_hass(hass) + + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + assert entry.state == ENTRY_STATE_SETUP_RETRY