diff --git a/homeassistant/components/yeelight/__init__.py b/homeassistant/components/yeelight/__init__.py index fb908775d1b..64fa7b01f28 100644 --- a/homeassistant/components/yeelight/__init__.py +++ b/homeassistant/components/yeelight/__init__.py @@ -26,10 +26,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.dispatcher import ( - async_dispatcher_connect, - async_dispatcher_send, -) +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import DeviceInfo, Entity from homeassistant.helpers.event import async_call_later, async_track_time_interval from homeassistant.helpers.typing import ConfigType @@ -42,7 +39,6 @@ POWER_STATE_CHANGE_TIME = 1 # seconds DOMAIN = "yeelight" DATA_YEELIGHT = DOMAIN DATA_UPDATED = "yeelight_{}_data_updated" -DEVICE_INITIALIZED = "yeelight_{}_device_initialized" DEFAULT_NAME = "Yeelight" DEFAULT_TRANSITION = 350 @@ -203,24 +199,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def _async_initialize( hass: HomeAssistant, entry: ConfigEntry, - host: str, - device: YeelightDevice | None = None, + device: YeelightDevice, ) -> None: - entry_data = hass.data[DOMAIN][DATA_CONFIG_ENTRIES][entry.entry_id] = { - DATA_PLATFORMS_LOADED: False - } - - @callback - def _async_load_platforms(): - if entry_data[DATA_PLATFORMS_LOADED]: - return - entry_data[DATA_PLATFORMS_LOADED] = True - hass.config_entries.async_setup_platforms(entry, PLATFORMS) - - if not device: - # get device and start listening for local pushes - device = await _async_get_device(hass, host, entry) - + entry_data = hass.data[DOMAIN][DATA_CONFIG_ENTRIES][entry.entry_id] = {} await device.async_setup() entry_data[DATA_DEVICE] = device @@ -232,15 +213,9 @@ async def _async_initialize( entry, options={**entry.options, CONF_MODEL: device.capabilities["model"]} ) - entry.async_on_unload(entry.add_update_listener(_async_update_listener)) - entry.async_on_unload( - async_dispatcher_connect( - hass, DEVICE_INITIALIZED.format(host), _async_load_platforms - ) - ) - # fetch initial state - asyncio.create_task(device.async_update()) + await device.async_update() + entry.async_on_unload(entry.add_update_listener(_async_update_listener)) @callback @@ -256,7 +231,7 @@ def _async_normalize_config_entry(hass: HomeAssistant, entry: ConfigEntry) -> No entry, data={ CONF_HOST: entry.data.get(CONF_HOST), - CONF_ID: entry.data.get(CONF_ID, entry.unique_id), + CONF_ID: entry.data.get(CONF_ID) or entry.unique_id, }, options={ CONF_NAME: entry.data.get(CONF_NAME, ""), @@ -270,68 +245,44 @@ def _async_normalize_config_entry(hass: HomeAssistant, entry: ConfigEntry) -> No CONF_NIGHTLIGHT_SWITCH, DEFAULT_NIGHTLIGHT_SWITCH ), }, + unique_id=entry.unique_id or entry.data.get(CONF_ID), ) elif entry.unique_id and not entry.data.get(CONF_ID): hass.config_entries.async_update_entry( entry, data={CONF_HOST: entry.data.get(CONF_HOST), CONF_ID: entry.unique_id}, ) + elif entry.data.get(CONF_ID) and not entry.unique_id: + hass.config_entries.async_update_entry( + entry, + unique_id=entry.data[CONF_ID], + ) async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Yeelight from a config entry.""" _async_normalize_config_entry(hass, entry) - if entry.data.get(CONF_HOST): - try: - device = await _async_get_device(hass, entry.data[CONF_HOST], entry) - except BULB_EXCEPTIONS as ex: - # Always retry later since bulbs can stop responding to SSDP - # sometimes even though they are online. If it has changed - # IP we will update it via discovery to the config flow - raise ConfigEntryNotReady from ex - else: - # Since device is passed this cannot throw an exception anymore - await _async_initialize(hass, entry, entry.data[CONF_HOST], device=device) - return True + if not entry.data.get(CONF_HOST): + bulb_id = async_format_id(entry.data.get(CONF_ID, entry.unique_id)) + raise ConfigEntryNotReady(f"Waiting for {bulb_id} to be discovered") - async def _async_from_discovery(capabilities: dict[str, str]) -> None: - host = urlparse(capabilities["location"]).hostname - try: - await _async_initialize(hass, entry, host) - except BULB_EXCEPTIONS: - _LOGGER.exception("Failed to connect to bulb at %s", host) + try: + device = await _async_get_device(hass, entry.data[CONF_HOST], entry) + await _async_initialize(hass, entry, device) + except BULB_EXCEPTIONS as ex: + raise ConfigEntryNotReady from ex + + hass.config_entries.async_setup_platforms(entry, PLATFORMS) - scanner = YeelightScanner.async_get(hass) - await scanner.async_register_callback(entry.data[CONF_ID], _async_from_discovery) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - if entry.data.get(CONF_ID): - # discovery - scanner = YeelightScanner.async_get(hass) - scanner.async_unregister_callback(entry.data[CONF_ID]) - data_config_entries = hass.data[DOMAIN][DATA_CONFIG_ENTRIES] - if entry.entry_id not in data_config_entries: - # Device not online - return True - - entry_data = data_config_entries[entry.entry_id] - unload_ok = True - if entry_data[DATA_PLATFORMS_LOADED]: - unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - - if DATA_DEVICE in entry_data: - device = entry_data[DATA_DEVICE] - _LOGGER.debug("Shutting down Yeelight Listener") - await device.bulb.async_stop_listening() - _LOGGER.debug("Yeelight Listener stopped") - data_config_entries.pop(entry.entry_id) - return unload_ok + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) @callback @@ -380,7 +331,6 @@ class YeelightScanner: def __init__(self, hass: HomeAssistant) -> None: """Initialize class.""" self._hass = hass - self._callbacks = {} self._host_discovered_events = {} self._unique_id_capabilities = {} self._host_capabilities = {} @@ -391,7 +341,7 @@ class YeelightScanner: async def async_setup(self): """Set up the scanner.""" if self._connected_events: - await asyncio.gather(*(event.wait() for event in self._connected_events)) + await self._async_wait_connected() return for idx, source_ip in enumerate(await self._async_build_source_set()): @@ -434,9 +384,16 @@ class YeelightScanner: for listener in failed_listeners: self._listeners.remove(listener) - await asyncio.gather(*(event.wait() for event in self._connected_events)) + await self._async_wait_connected() + self._track_interval = async_track_time_interval( + self._hass, self.async_scan, DISCOVERY_INTERVAL + ) self.async_scan() + async def _async_wait_connected(self): + """Wait for the listeners to be up and connected.""" + await asyncio.gather(*(event.wait() for event in self._connected_events)) + async def _async_build_source_set(self) -> set[IPv4Address]: """Build the list of ssdp sources.""" adapters = await network.async_get_adapters(self._hass) @@ -453,6 +410,7 @@ class YeelightScanner: async def async_discover(self): """Discover bulbs.""" + _LOGGER.debug("Yeelight discover with interval %s", DISCOVERY_SEARCH_INTERVAL) await self.async_setup() for _ in range(DISCOVERY_ATTEMPTS): self.async_scan() @@ -513,45 +471,6 @@ class YeelightScanner: self._unique_id_capabilities[unique_id] = response for event in self._host_discovered_events.get(host, []): event.set() - if unique_id in self._callbacks: - self._hass.async_create_task(self._callbacks[unique_id](response)) - self._callbacks.pop(unique_id) - if not self._callbacks: - self._async_stop_scan() - - async def _async_start_scan(self): - """Start scanning for Yeelight devices.""" - _LOGGER.debug("Start scanning") - await self.async_setup() - if not self._track_interval: - self._track_interval = async_track_time_interval( - self._hass, self.async_scan, DISCOVERY_INTERVAL - ) - self.async_scan() - - @callback - def _async_stop_scan(self): - """Stop scanning.""" - if self._track_interval is None: - return - _LOGGER.debug("Stop scanning interval") - self._track_interval() - self._track_interval = None - - async def async_register_callback(self, unique_id, callback_func): - """Register callback function.""" - if capabilities := self._unique_id_capabilities.get(unique_id): - self._hass.async_create_task(callback_func(capabilities)) - return - self._callbacks[unique_id] = callback_func - await self._async_start_scan() - - @callback - def async_unregister_callback(self, unique_id): - """Unregister callback function.""" - self._callbacks.pop(unique_id, None) - if not self._callbacks: - self._async_stop_scan() def update_needs_bg_power_workaround(data): @@ -675,7 +594,6 @@ class YeelightDevice: self._available = True if not self._initialized: self._initialized = True - async_dispatcher_send(self._hass, DEVICE_INITIALIZED.format(self._host)) except BULB_NETWORK_EXCEPTIONS as ex: if self._available: # just inform once _LOGGER.error( @@ -725,9 +643,6 @@ class YeelightDevice: ): # On reconnect the properties may be out of sync # - # We need to make sure the DEVICE_INITIALIZED dispatcher is setup - # before we can update on reconnect by checking self._did_first_update - # # If the device drops the connection right away, we do not want to # do a property resync via async_update since its about # to be called when async_setup_entry reaches the end of the @@ -743,10 +658,7 @@ class YeelightEntity(Entity): def __init__(self, device: YeelightDevice, entry: ConfigEntry) -> None: """Initialize the entity.""" self._device = device - self._unique_id = entry.entry_id - if entry.unique_id is not None: - # Use entry unique id (device id) whenever possible - self._unique_id = entry.unique_id + self._unique_id = entry.unique_id or entry.entry_id @property def unique_id(self) -> str: @@ -794,12 +706,19 @@ async def _async_get_device( # register stop callback to shutdown listening for local pushes async def async_stop_listen_task(event): - """Stop listen thread.""" - _LOGGER.debug("Shutting down Yeelight Listener") + """Stop listen task.""" + _LOGGER.debug("Shutting down Yeelight Listener (stop event)") await device.bulb.async_stop_listening() + @callback + def _async_stop_listen_on_unload(): + """Stop listen task.""" + _LOGGER.debug("Shutting down Yeelight Listener (unload)") + hass.async_create_task(device.bulb.async_stop_listening()) + entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_listen_task) ) + entry.async_on_unload(_async_stop_listen_on_unload) return device diff --git a/tests/components/yeelight/test_init.py b/tests/components/yeelight/test_init.py index 3ad99fa34ac..7ddb2845ac8 100644 --- a/tests/components/yeelight/test_init.py +++ b/tests/components/yeelight/test_init.py @@ -111,7 +111,9 @@ async def test_ip_changes_id_missing_cannot_fallback(hass: HomeAssistant): async def test_setup_discovery(hass: HomeAssistant): """Test setting up Yeelight by discovery.""" - config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA) + config_entry = MockConfigEntry( + domain=DOMAIN, data={CONF_HOST: IP_ADDRESS, **CONFIG_ENTRY_DATA} + ) config_entry.add_to_hass(hass) mocked_bulb = _mocked_bulb() @@ -151,7 +153,9 @@ async def test_setup_discovery_with_manually_configured_network_adapter( hass: HomeAssistant, ): """Test setting up Yeelight by discovery with a manually configured network adapter.""" - config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA) + config_entry = MockConfigEntry( + domain=DOMAIN, data={CONF_HOST: IP_ADDRESS, **CONFIG_ENTRY_DATA} + ) config_entry.add_to_hass(hass) mocked_bulb = _mocked_bulb() @@ -205,7 +209,9 @@ async def test_setup_discovery_with_manually_configured_network_adapter_one_fail hass: HomeAssistant, caplog ): """Test setting up Yeelight by discovery with a manually configured network adapter with one that fails to bind.""" - config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA) + config_entry = MockConfigEntry( + domain=DOMAIN, data={CONF_HOST: IP_ADDRESS, **CONFIG_ENTRY_DATA} + ) config_entry.add_to_hass(hass) mocked_bulb = _mocked_bulb() @@ -268,7 +274,7 @@ async def test_unique_ids_device(hass: HomeAssistant): """Test Yeelight unique IDs from yeelight device IDs.""" config_entry = MockConfigEntry( domain=DOMAIN, - data={**CONFIG_ENTRY_DATA, CONF_NIGHTLIGHT_SWITCH: True}, + data={CONF_HOST: IP_ADDRESS, **CONFIG_ENTRY_DATA, CONF_NIGHTLIGHT_SWITCH: True}, unique_id=ID, ) config_entry.add_to_hass(hass) @@ -292,7 +298,8 @@ async def test_unique_ids_device(hass: HomeAssistant): async def test_unique_ids_entry(hass: HomeAssistant): """Test Yeelight unique IDs from entry IDs.""" config_entry = MockConfigEntry( - domain=DOMAIN, data={**CONFIG_ENTRY_DATA, CONF_NIGHTLIGHT_SWITCH: True} + domain=DOMAIN, + data={CONF_HOST: IP_ADDRESS, CONF_NIGHTLIGHT_SWITCH: True}, ) config_entry.add_to_hass(hass) @@ -357,18 +364,16 @@ async def test_async_listen_error_late_discovery(hass, caplog): await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() - assert config_entry.state is ConfigEntryState.LOADED - assert "Failed to connect to bulb at" in caplog.text - await hass.config_entries.async_unload(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY await hass.async_block_till_done() - - caplog.clear() + assert "Waiting for 0x15243f to be discovered" in caplog.text with _patch_discovery(), patch(f"{MODULE}.AsyncBulb", return_value=_mocked_bulb()): - await hass.config_entries.async_setup(config_entry.entry_id) + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=5)) + await hass.async_block_till_done() + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10)) await hass.async_block_till_done() - assert "Failed to connect to bulb at" not in caplog.text assert config_entry.state is ConfigEntryState.LOADED assert config_entry.options[CONF_MODEL] == MODEL @@ -386,7 +391,7 @@ async def test_unload_before_discovery(hass, caplog): await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() - assert config_entry.state is ConfigEntryState.LOADED + assert config_entry.state is ConfigEntryState.SETUP_RETRY await hass.config_entries.async_unload(config_entry.entry_id) await hass.async_block_till_done() @@ -451,6 +456,31 @@ async def test_async_setup_with_missing_id(hass: HomeAssistant): assert config_entry.state is ConfigEntryState.LOADED +async def test_async_setup_with_missing_unique_id(hass: HomeAssistant): + """Test that setting adds the missing unique_id from CONF_ID.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_HOST: "127.0.0.1", CONF_ID: ID}, + options={CONF_NAME: "Test name"}, + ) + config_entry.add_to_hass(hass) + + with _patch_discovery(), _patch_discovery_timeout(), _patch_discovery_interval(), patch( + f"{MODULE}.AsyncBulb", return_value=_mocked_bulb(cannot_connect=True) + ): + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + assert config_entry.state is ConfigEntryState.SETUP_RETRY + assert config_entry.unique_id == ID + + with _patch_discovery(), _patch_discovery_timeout(), _patch_discovery_interval(), patch( + f"{MODULE}.AsyncBulb", return_value=_mocked_bulb() + ): + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=2)) + await hass.async_block_till_done() + assert config_entry.state is ConfigEntryState.LOADED + + async def test_connection_dropped_resyncs_properties(hass: HomeAssistant): """Test handling a connection drop results in a property resync.""" config_entry = MockConfigEntry(