Refactor yeelight scanner to avoid creating tasks to wait for scanner start (#113919)

pull/114285/head
J. Nick Koston 2024-03-26 23:17:35 -10:00 committed by GitHub
parent 13d6ebaabf
commit 2421b42f10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 25 deletions

View File

@ -3,9 +3,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, ValuesView
from collections.abc import ValuesView
import contextlib
from datetime import datetime
from functools import partial
from ipaddress import IPv4Address
import logging
from typing import Self
@ -19,6 +20,7 @@ from homeassistant.components import network, ssdp
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import discovery_flow
from homeassistant.helpers.event import async_call_later, async_track_time_interval
from homeassistant.util.async_ import create_eager_task
from .const import (
DISCOVERY_ATTEMPTS,
@ -33,6 +35,12 @@ from .const import (
_LOGGER = logging.getLogger(__name__)
@callback
def _set_future_if_not_done(future: asyncio.Future[None]) -> None:
if not future.done():
future.set_result(None)
class YeelightScanner:
"""Scan for Yeelight devices."""
@ -54,26 +62,18 @@ class YeelightScanner:
self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
self._track_interval: CALLBACK_TYPE | None = None
self._listeners: list[SsdpSearchListener] = []
self._connected_events: list[asyncio.Event] = []
self._setup_future: asyncio.Future[None] | None = None
async def async_setup(self) -> None:
"""Set up the scanner."""
if self._connected_events:
await self._async_wait_connected()
return
for idx, source_ip in enumerate(await self._async_build_source_set()):
self._connected_events.append(asyncio.Event())
def _wrap_async_connected_idx(idx) -> Callable[[], None]:
"""Create a function to capture the idx cell variable."""
@callback
def _async_connected() -> None:
self._connected_events[idx].set()
return _async_connected
if self._setup_future is not None:
return await self._setup_future
self._setup_future = self._hass.loop.create_future()
connected_futures: list[asyncio.Future[None]] = []
for source_ip in await self._async_build_source_set():
future = self._hass.loop.create_future()
connected_futures.append(future)
source = (str(source_ip), 0)
self._listeners.append(
SsdpSearchListener(
@ -81,12 +81,15 @@ class YeelightScanner:
search_target=SSDP_ST,
target=SSDP_TARGET,
source=source,
connect_callback=_wrap_async_connected_idx(idx),
connect_callback=partial(_set_future_if_not_done, future),
)
)
results = await asyncio.gather(
*(listener.async_start() for listener in self._listeners),
*(
create_eager_task(listener.async_start())
for listener in self._listeners
),
return_exceptions=True,
)
failed_listeners = []
@ -99,20 +102,17 @@ class YeelightScanner:
result,
)
failed_listeners.append(self._listeners[idx])
self._connected_events[idx].set()
_set_future_if_not_done(connected_futures[idx])
for listener in failed_listeners:
self._listeners.remove(listener)
await self._async_wait_connected()
await asyncio.wait(connected_futures)
self._track_interval = async_track_time_interval(
self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True
)
self.async_scan()
async def _async_wait_connected(self):
"""Wait for the listeners to be up and connected."""
await asyncio.gather(*(event.wait() for event in self._connected_events))
_set_future_if_not_done(self._setup_future)
async def _async_build_source_set(self) -> set[IPv4Address]:
"""Build the list of ssdp sources."""

View File

@ -1413,6 +1413,7 @@ async def test_effects(hass: HomeAssistant) -> None:
}
},
)
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
config_entry.add_to_hass(hass)