From 815502044e4d7d6634af5651dd4dcc42c141a040 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 2 Mar 2020 17:59:32 -0800 Subject: [PATCH] Coronavirus updates (#32417) * Sort countries alphabetically * Update sensor name * Add migration to stable unique IDs * Update sensor.py --- .../components/coronavirus/__init__.py | 23 +++++++- .../components/coronavirus/config_flow.py | 6 +- .../components/coronavirus/sensor.py | 4 +- homeassistant/helpers/entity_registry.py | 20 ++++++- tests/components/coronavirus/test_init.py | 55 +++++++++++++++++++ 5 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 tests/components/coronavirus/test_init.py diff --git a/homeassistant/components/coronavirus/__init__.py b/homeassistant/components/coronavirus/__init__.py index 95c3cd1c024..d5dbcd9f3f4 100644 --- a/homeassistant/components/coronavirus/__init__.py +++ b/homeassistant/components/coronavirus/__init__.py @@ -8,8 +8,8 @@ import async_timeout import coronavirus from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant -from homeassistant.helpers import aiohttp_client, update_coordinator +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import aiohttp_client, entity_registry, update_coordinator from .const import DOMAIN @@ -25,6 +25,23 @@ async def async_setup(hass: HomeAssistant, config: dict): async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): """Set up Coronavirus from a config entry.""" + if isinstance(entry.data["country"], int): + hass.config_entries.async_update_entry( + entry, data={**entry.data, "country": entry.title} + ) + + @callback + def _async_migrator(entity_entry: entity_registry.RegistryEntry): + """Migrate away from unstable ID.""" + country, info_type = entity_entry.unique_id.rsplit("-", 1) + if not country.isnumeric(): + return None + return {"new_unique_id": f"{entry.title}-{info_type}"} + + await entity_registry.async_migrate_entries( + hass, entry.entry_id, _async_migrator + ) + for component in PLATFORMS: hass.async_create_task( hass.config_entries.async_forward_entry_setup(entry, component) @@ -56,7 +73,7 @@ async def get_coordinator(hass): try: with async_timeout.timeout(10): return { - case.id: case + case.country: case for case in await coronavirus.get_cases( aiohttp_client.async_get_clientsession(hass) ) diff --git a/homeassistant/components/coronavirus/config_flow.py b/homeassistant/components/coronavirus/config_flow.py index 59d25e16709..4a313a6837f 100644 --- a/homeassistant/components/coronavirus/config_flow.py +++ b/homeassistant/components/coronavirus/config_flow.py @@ -26,8 +26,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if self._options is None: self._options = {OPTION_WORLDWIDE: "Worldwide"} coordinator = await get_coordinator(self.hass) - for case_id in sorted(coordinator.data): - self._options[case_id] = coordinator.data[case_id].country + for case in sorted( + coordinator.data.values(), key=lambda case: case.country + ): + self._options[case.country] = case.country if user_input is not None: return self.async_create_entry( diff --git a/homeassistant/components/coronavirus/sensor.py b/homeassistant/components/coronavirus/sensor.py index 770ab78b43e..20f18896431 100644 --- a/homeassistant/components/coronavirus/sensor.py +++ b/homeassistant/components/coronavirus/sensor.py @@ -25,9 +25,9 @@ class CoronavirusSensor(Entity): def __init__(self, coordinator, country, info_type): """Initialize coronavirus sensor.""" if country == OPTION_WORLDWIDE: - self.name = f"Worldwide {info_type}" + self.name = f"Worldwide Coronavirus {info_type}" else: - self.name = f"{coordinator.data[country].country} {info_type}" + self.name = f"{coordinator.data[country].country} Coronavirus {info_type}" self.unique_id = f"{country}-{info_type}" self.coordinator = coordinator self.country = country diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 5996fb6eaf7..87383d45635 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -11,7 +11,7 @@ import asyncio from collections import OrderedDict from itertools import chain import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, cast import attr @@ -560,3 +560,21 @@ def async_setup_entity_restore( states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs) hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states) + + +async def async_migrate_entries( + hass: HomeAssistantType, + config_entry_id: str, + entry_callback: Callable[[RegistryEntry], Optional[dict]], +) -> None: + """Migrator of unique IDs.""" + ent_reg = await async_get_registry(hass) + + for entry in ent_reg.entities.values(): + if entry.config_entry_id != config_entry_id: + continue + + updates = entry_callback(entry) + + if updates is not None: + ent_reg.async_update_entity(entry.entity_id, **updates) # type: ignore diff --git a/tests/components/coronavirus/test_init.py b/tests/components/coronavirus/test_init.py new file mode 100644 index 00000000000..05a14f2f296 --- /dev/null +++ b/tests/components/coronavirus/test_init.py @@ -0,0 +1,55 @@ +"""Test init of Coronavirus integration.""" +from asynctest import Mock, patch + +from homeassistant.components.coronavirus.const import DOMAIN, OPTION_WORLDWIDE +from homeassistant.helpers import entity_registry +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry, mock_registry + + +async def test_migration(hass): + """Test that we can migrate coronavirus to stable unique ID.""" + nl_entry = MockConfigEntry(domain=DOMAIN, title="Netherlands", data={"country": 34}) + nl_entry.add_to_hass(hass) + worldwide_entry = MockConfigEntry( + domain=DOMAIN, title="Worldwide", data={"country": OPTION_WORLDWIDE} + ) + worldwide_entry.add_to_hass(hass) + mock_registry( + hass, + { + "sensor.netherlands_confirmed": entity_registry.RegistryEntry( + entity_id="sensor.netherlands_confirmed", + unique_id="34-confirmed", + platform="coronavirus", + config_entry_id=nl_entry.entry_id, + ), + "sensor.worldwide_confirmed": entity_registry.RegistryEntry( + entity_id="sensor.worldwide_confirmed", + unique_id="__worldwide-confirmed", + platform="coronavirus", + config_entry_id=worldwide_entry.entry_id, + ), + }, + ) + with patch( + "coronavirus.get_cases", + return_value=[ + Mock(country="Netherlands", confirmed=10, recovered=8, deaths=1, current=1), + Mock(country="Germany", confirmed=1, recovered=0, deaths=0, current=0), + ], + ): + assert await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + ent_reg = await entity_registry.async_get_registry(hass) + + sensor_nl = ent_reg.async_get("sensor.netherlands_confirmed") + assert sensor_nl.unique_id == "Netherlands-confirmed" + + sensor_worldwide = ent_reg.async_get("sensor.worldwide_confirmed") + assert sensor_worldwide.unique_id == "__worldwide-confirmed" + + assert hass.states.get("sensor.netherlands_confirmed").state == "10" + assert hass.states.get("sensor.worldwide_confirmed").state == "11"