Refactor yeelight scanner to avoid creating tasks to wait for scanner start (#113919)

pull/114285/head
J. Nick Koston 2024-03-26 23:17:35 -10:00 committed by GitHub
parent 13d6ebaabf
commit 2421b42f10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 25 deletions

View File

@ -3,9 +3,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, ValuesView from collections.abc import ValuesView
import contextlib import contextlib
from datetime import datetime from datetime import datetime
from functools import partial
from ipaddress import IPv4Address from ipaddress import IPv4Address
import logging import logging
from typing import Self from typing import Self
@ -19,6 +20,7 @@ from homeassistant.components import network, ssdp
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import discovery_flow from homeassistant.helpers import discovery_flow
from homeassistant.helpers.event import async_call_later, async_track_time_interval from homeassistant.helpers.event import async_call_later, async_track_time_interval
from homeassistant.util.async_ import create_eager_task
from .const import ( from .const import (
DISCOVERY_ATTEMPTS, DISCOVERY_ATTEMPTS,
@ -33,6 +35,12 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@callback
def _set_future_if_not_done(future: asyncio.Future[None]) -> None:
if not future.done():
future.set_result(None)
class YeelightScanner: class YeelightScanner:
"""Scan for Yeelight devices.""" """Scan for Yeelight devices."""
@ -54,26 +62,18 @@ class YeelightScanner:
self._host_capabilities: dict[str, CaseInsensitiveDict] = {} self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
self._track_interval: CALLBACK_TYPE | None = None self._track_interval: CALLBACK_TYPE | None = None
self._listeners: list[SsdpSearchListener] = [] self._listeners: list[SsdpSearchListener] = []
self._connected_events: list[asyncio.Event] = [] self._setup_future: asyncio.Future[None] | None = None
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Set up the scanner.""" """Set up the scanner."""
if self._connected_events: if self._setup_future is not None:
await self._async_wait_connected() return await self._setup_future
return
for idx, source_ip in enumerate(await self._async_build_source_set()):
self._connected_events.append(asyncio.Event())
def _wrap_async_connected_idx(idx) -> Callable[[], None]:
"""Create a function to capture the idx cell variable."""
@callback
def _async_connected() -> None:
self._connected_events[idx].set()
return _async_connected
self._setup_future = self._hass.loop.create_future()
connected_futures: list[asyncio.Future[None]] = []
for source_ip in await self._async_build_source_set():
future = self._hass.loop.create_future()
connected_futures.append(future)
source = (str(source_ip), 0) source = (str(source_ip), 0)
self._listeners.append( self._listeners.append(
SsdpSearchListener( SsdpSearchListener(
@ -81,12 +81,15 @@ class YeelightScanner:
search_target=SSDP_ST, search_target=SSDP_ST,
target=SSDP_TARGET, target=SSDP_TARGET,
source=source, source=source,
connect_callback=_wrap_async_connected_idx(idx), connect_callback=partial(_set_future_if_not_done, future),
) )
) )
results = await asyncio.gather( results = await asyncio.gather(
*(listener.async_start() for listener in self._listeners), *(
create_eager_task(listener.async_start())
for listener in self._listeners
),
return_exceptions=True, return_exceptions=True,
) )
failed_listeners = [] failed_listeners = []
@ -99,20 +102,17 @@ class YeelightScanner:
result, result,
) )
failed_listeners.append(self._listeners[idx]) failed_listeners.append(self._listeners[idx])
self._connected_events[idx].set() _set_future_if_not_done(connected_futures[idx])
for listener in failed_listeners: for listener in failed_listeners:
self._listeners.remove(listener) self._listeners.remove(listener)
await self._async_wait_connected() await asyncio.wait(connected_futures)
self._track_interval = async_track_time_interval( self._track_interval = async_track_time_interval(
self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True
) )
self.async_scan() self.async_scan()
_set_future_if_not_done(self._setup_future)
async def _async_wait_connected(self):
"""Wait for the listeners to be up and connected."""
await asyncio.gather(*(event.wait() for event in self._connected_events))
async def _async_build_source_set(self) -> set[IPv4Address]: async def _async_build_source_set(self) -> set[IPv4Address]:
"""Build the list of ssdp sources.""" """Build the list of ssdp sources."""

View File

@ -1413,6 +1413,7 @@ async def test_effects(hass: HomeAssistant) -> None:
} }
}, },
) )
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA) config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)