Refactor yeelight scanner to avoid creating tasks to wait for scanner start (#113919)
parent
13d6ebaabf
commit
2421b42f10
|
@ -3,9 +3,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, ValuesView
|
||||
from collections.abc import ValuesView
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from ipaddress import IPv4Address
|
||||
import logging
|
||||
from typing import Self
|
||||
|
@ -19,6 +20,7 @@ from homeassistant.components import network, ssdp
|
|||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
from homeassistant.helpers import discovery_flow
|
||||
from homeassistant.helpers.event import async_call_later, async_track_time_interval
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
|
||||
from .const import (
|
||||
DISCOVERY_ATTEMPTS,
|
||||
|
@ -33,6 +35,12 @@ from .const import (
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def _set_future_if_not_done(future: asyncio.Future[None]) -> None:
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
|
||||
|
||||
class YeelightScanner:
|
||||
"""Scan for Yeelight devices."""
|
||||
|
||||
|
@ -54,26 +62,18 @@ class YeelightScanner:
|
|||
self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
|
||||
self._track_interval: CALLBACK_TYPE | None = None
|
||||
self._listeners: list[SsdpSearchListener] = []
|
||||
self._connected_events: list[asyncio.Event] = []
|
||||
self._setup_future: asyncio.Future[None] | None = None
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up the scanner."""
|
||||
if self._connected_events:
|
||||
await self._async_wait_connected()
|
||||
return
|
||||
|
||||
for idx, source_ip in enumerate(await self._async_build_source_set()):
|
||||
self._connected_events.append(asyncio.Event())
|
||||
|
||||
def _wrap_async_connected_idx(idx) -> Callable[[], None]:
|
||||
"""Create a function to capture the idx cell variable."""
|
||||
|
||||
@callback
|
||||
def _async_connected() -> None:
|
||||
self._connected_events[idx].set()
|
||||
|
||||
return _async_connected
|
||||
if self._setup_future is not None:
|
||||
return await self._setup_future
|
||||
|
||||
self._setup_future = self._hass.loop.create_future()
|
||||
connected_futures: list[asyncio.Future[None]] = []
|
||||
for source_ip in await self._async_build_source_set():
|
||||
future = self._hass.loop.create_future()
|
||||
connected_futures.append(future)
|
||||
source = (str(source_ip), 0)
|
||||
self._listeners.append(
|
||||
SsdpSearchListener(
|
||||
|
@ -81,12 +81,15 @@ class YeelightScanner:
|
|||
search_target=SSDP_ST,
|
||||
target=SSDP_TARGET,
|
||||
source=source,
|
||||
connect_callback=_wrap_async_connected_idx(idx),
|
||||
connect_callback=partial(_set_future_if_not_done, future),
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(listener.async_start() for listener in self._listeners),
|
||||
*(
|
||||
create_eager_task(listener.async_start())
|
||||
for listener in self._listeners
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
failed_listeners = []
|
||||
|
@ -99,20 +102,17 @@ class YeelightScanner:
|
|||
result,
|
||||
)
|
||||
failed_listeners.append(self._listeners[idx])
|
||||
self._connected_events[idx].set()
|
||||
_set_future_if_not_done(connected_futures[idx])
|
||||
|
||||
for listener in failed_listeners:
|
||||
self._listeners.remove(listener)
|
||||
|
||||
await self._async_wait_connected()
|
||||
await asyncio.wait(connected_futures)
|
||||
self._track_interval = async_track_time_interval(
|
||||
self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True
|
||||
)
|
||||
self.async_scan()
|
||||
|
||||
async def _async_wait_connected(self):
|
||||
"""Wait for the listeners to be up and connected."""
|
||||
await asyncio.gather(*(event.wait() for event in self._connected_events))
|
||||
_set_future_if_not_done(self._setup_future)
|
||||
|
||||
async def _async_build_source_set(self) -> set[IPv4Address]:
|
||||
"""Build the list of ssdp sources."""
|
||||
|
|
|
@ -1413,6 +1413,7 @@ async def test_effects(hass: HomeAssistant) -> None:
|
|||
}
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
|
||||
config_entry.add_to_hass(hass)
|
||||
|
|
Loading…
Reference in New Issue