Coronavirus updates (#32417)
* Sort countries alphabetically * Update sensor name * Add migration to stable unique IDs * Update sensor.pypull/32424/head
parent
08f5b49dc4
commit
815502044e
|
@ -8,8 +8,8 @@ import async_timeout
|
|||
import coronavirus
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import aiohttp_client, update_coordinator
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import aiohttp_client, entity_registry, update_coordinator
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
@ -25,6 +25,23 @@ async def async_setup(hass: HomeAssistant, config: dict):
|
|||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Set up Coronavirus from a config entry."""
|
||||
if isinstance(entry.data["country"], int):
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, data={**entry.data, "country": entry.title}
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_migrator(entity_entry: entity_registry.RegistryEntry):
|
||||
"""Migrate away from unstable ID."""
|
||||
country, info_type = entity_entry.unique_id.rsplit("-", 1)
|
||||
if not country.isnumeric():
|
||||
return None
|
||||
return {"new_unique_id": f"{entry.title}-{info_type}"}
|
||||
|
||||
await entity_registry.async_migrate_entries(
|
||||
hass, entry.entry_id, _async_migrator
|
||||
)
|
||||
|
||||
for component in PLATFORMS:
|
||||
hass.async_create_task(
|
||||
hass.config_entries.async_forward_entry_setup(entry, component)
|
||||
|
@ -56,7 +73,7 @@ async def get_coordinator(hass):
|
|||
try:
|
||||
with async_timeout.timeout(10):
|
||||
return {
|
||||
case.id: case
|
||||
case.country: case
|
||||
for case in await coronavirus.get_cases(
|
||||
aiohttp_client.async_get_clientsession(hass)
|
||||
)
|
||||
|
|
|
@ -26,8 +26,10 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if self._options is None:
|
||||
self._options = {OPTION_WORLDWIDE: "Worldwide"}
|
||||
coordinator = await get_coordinator(self.hass)
|
||||
for case_id in sorted(coordinator.data):
|
||||
self._options[case_id] = coordinator.data[case_id].country
|
||||
for case in sorted(
|
||||
coordinator.data.values(), key=lambda case: case.country
|
||||
):
|
||||
self._options[case.country] = case.country
|
||||
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
|
|
|
@ -25,9 +25,9 @@ class CoronavirusSensor(Entity):
|
|||
def __init__(self, coordinator, country, info_type):
|
||||
"""Initialize coronavirus sensor."""
|
||||
if country == OPTION_WORLDWIDE:
|
||||
self.name = f"Worldwide {info_type}"
|
||||
self.name = f"Worldwide Coronavirus {info_type}"
|
||||
else:
|
||||
self.name = f"{coordinator.data[country].country} {info_type}"
|
||||
self.name = f"{coordinator.data[country].country} Coronavirus {info_type}"
|
||||
self.unique_id = f"{country}-{info_type}"
|
||||
self.coordinator = coordinator
|
||||
self.country = country
|
||||
|
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -560,3 +560,21 @@ def async_setup_entity_restore(
|
|||
states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs)
|
||||
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states)
|
||||
|
||||
|
||||
async def async_migrate_entries(
|
||||
hass: HomeAssistantType,
|
||||
config_entry_id: str,
|
||||
entry_callback: Callable[[RegistryEntry], Optional[dict]],
|
||||
) -> None:
|
||||
"""Migrator of unique IDs."""
|
||||
ent_reg = await async_get_registry(hass)
|
||||
|
||||
for entry in ent_reg.entities.values():
|
||||
if entry.config_entry_id != config_entry_id:
|
||||
continue
|
||||
|
||||
updates = entry_callback(entry)
|
||||
|
||||
if updates is not None:
|
||||
ent_reg.async_update_entity(entry.entity_id, **updates) # type: ignore
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
"""Test init of Coronavirus integration."""
|
||||
from asynctest import Mock, patch
|
||||
|
||||
from homeassistant.components.coronavirus.const import DOMAIN, OPTION_WORLDWIDE
|
||||
from homeassistant.helpers import entity_registry
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, mock_registry
|
||||
|
||||
|
||||
async def test_migration(hass):
|
||||
"""Test that we can migrate coronavirus to stable unique ID."""
|
||||
nl_entry = MockConfigEntry(domain=DOMAIN, title="Netherlands", data={"country": 34})
|
||||
nl_entry.add_to_hass(hass)
|
||||
worldwide_entry = MockConfigEntry(
|
||||
domain=DOMAIN, title="Worldwide", data={"country": OPTION_WORLDWIDE}
|
||||
)
|
||||
worldwide_entry.add_to_hass(hass)
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
"sensor.netherlands_confirmed": entity_registry.RegistryEntry(
|
||||
entity_id="sensor.netherlands_confirmed",
|
||||
unique_id="34-confirmed",
|
||||
platform="coronavirus",
|
||||
config_entry_id=nl_entry.entry_id,
|
||||
),
|
||||
"sensor.worldwide_confirmed": entity_registry.RegistryEntry(
|
||||
entity_id="sensor.worldwide_confirmed",
|
||||
unique_id="__worldwide-confirmed",
|
||||
platform="coronavirus",
|
||||
config_entry_id=worldwide_entry.entry_id,
|
||||
),
|
||||
},
|
||||
)
|
||||
with patch(
|
||||
"coronavirus.get_cases",
|
||||
return_value=[
|
||||
Mock(country="Netherlands", confirmed=10, recovered=8, deaths=1, current=1),
|
||||
Mock(country="Germany", confirmed=1, recovered=0, deaths=0, current=0),
|
||||
],
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
ent_reg = await entity_registry.async_get_registry(hass)
|
||||
|
||||
sensor_nl = ent_reg.async_get("sensor.netherlands_confirmed")
|
||||
assert sensor_nl.unique_id == "Netherlands-confirmed"
|
||||
|
||||
sensor_worldwide = ent_reg.async_get("sensor.worldwide_confirmed")
|
||||
assert sensor_worldwide.unique_id == "__worldwide-confirmed"
|
||||
|
||||
assert hass.states.get("sensor.netherlands_confirmed").state == "10"
|
||||
assert hass.states.get("sensor.worldwide_confirmed").state == "11"
|
Loading…
Reference in New Issue