Refactor yeelight scanner to avoid creating tasks to wait for scanner start (#113919)
parent
13d6ebaabf
commit
2421b42f10
|
@ -3,9 +3,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, ValuesView
|
from collections.abc import ValuesView
|
||||||
import contextlib
|
import contextlib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
import logging
|
import logging
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
@ -19,6 +20,7 @@ from homeassistant.components import network, ssdp
|
||||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||||
from homeassistant.helpers import discovery_flow
|
from homeassistant.helpers import discovery_flow
|
||||||
from homeassistant.helpers.event import async_call_later, async_track_time_interval
|
from homeassistant.helpers.event import async_call_later, async_track_time_interval
|
||||||
|
from homeassistant.util.async_ import create_eager_task
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
DISCOVERY_ATTEMPTS,
|
DISCOVERY_ATTEMPTS,
|
||||||
|
@ -33,6 +35,12 @@ from .const import (
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _set_future_if_not_done(future: asyncio.Future[None]) -> None:
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
class YeelightScanner:
|
class YeelightScanner:
|
||||||
"""Scan for Yeelight devices."""
|
"""Scan for Yeelight devices."""
|
||||||
|
|
||||||
|
@ -54,26 +62,18 @@ class YeelightScanner:
|
||||||
self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
|
self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
|
||||||
self._track_interval: CALLBACK_TYPE | None = None
|
self._track_interval: CALLBACK_TYPE | None = None
|
||||||
self._listeners: list[SsdpSearchListener] = []
|
self._listeners: list[SsdpSearchListener] = []
|
||||||
self._connected_events: list[asyncio.Event] = []
|
self._setup_future: asyncio.Future[None] | None = None
|
||||||
|
|
||||||
async def async_setup(self) -> None:
|
async def async_setup(self) -> None:
|
||||||
"""Set up the scanner."""
|
"""Set up the scanner."""
|
||||||
if self._connected_events:
|
if self._setup_future is not None:
|
||||||
await self._async_wait_connected()
|
return await self._setup_future
|
||||||
return
|
|
||||||
|
|
||||||
for idx, source_ip in enumerate(await self._async_build_source_set()):
|
|
||||||
self._connected_events.append(asyncio.Event())
|
|
||||||
|
|
||||||
def _wrap_async_connected_idx(idx) -> Callable[[], None]:
|
|
||||||
"""Create a function to capture the idx cell variable."""
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def _async_connected() -> None:
|
|
||||||
self._connected_events[idx].set()
|
|
||||||
|
|
||||||
return _async_connected
|
|
||||||
|
|
||||||
|
self._setup_future = self._hass.loop.create_future()
|
||||||
|
connected_futures: list[asyncio.Future[None]] = []
|
||||||
|
for source_ip in await self._async_build_source_set():
|
||||||
|
future = self._hass.loop.create_future()
|
||||||
|
connected_futures.append(future)
|
||||||
source = (str(source_ip), 0)
|
source = (str(source_ip), 0)
|
||||||
self._listeners.append(
|
self._listeners.append(
|
||||||
SsdpSearchListener(
|
SsdpSearchListener(
|
||||||
|
@ -81,12 +81,15 @@ class YeelightScanner:
|
||||||
search_target=SSDP_ST,
|
search_target=SSDP_ST,
|
||||||
target=SSDP_TARGET,
|
target=SSDP_TARGET,
|
||||||
source=source,
|
source=source,
|
||||||
connect_callback=_wrap_async_connected_idx(idx),
|
connect_callback=partial(_set_future_if_not_done, future),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(listener.async_start() for listener in self._listeners),
|
*(
|
||||||
|
create_eager_task(listener.async_start())
|
||||||
|
for listener in self._listeners
|
||||||
|
),
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
failed_listeners = []
|
failed_listeners = []
|
||||||
|
@ -99,20 +102,17 @@ class YeelightScanner:
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
failed_listeners.append(self._listeners[idx])
|
failed_listeners.append(self._listeners[idx])
|
||||||
self._connected_events[idx].set()
|
_set_future_if_not_done(connected_futures[idx])
|
||||||
|
|
||||||
for listener in failed_listeners:
|
for listener in failed_listeners:
|
||||||
self._listeners.remove(listener)
|
self._listeners.remove(listener)
|
||||||
|
|
||||||
await self._async_wait_connected()
|
await asyncio.wait(connected_futures)
|
||||||
self._track_interval = async_track_time_interval(
|
self._track_interval = async_track_time_interval(
|
||||||
self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True
|
self._hass, self.async_scan, DISCOVERY_INTERVAL, cancel_on_shutdown=True
|
||||||
)
|
)
|
||||||
self.async_scan()
|
self.async_scan()
|
||||||
|
_set_future_if_not_done(self._setup_future)
|
||||||
async def _async_wait_connected(self):
|
|
||||||
"""Wait for the listeners to be up and connected."""
|
|
||||||
await asyncio.gather(*(event.wait() for event in self._connected_events))
|
|
||||||
|
|
||||||
async def _async_build_source_set(self) -> set[IPv4Address]:
|
async def _async_build_source_set(self) -> set[IPv4Address]:
|
||||||
"""Build the list of ssdp sources."""
|
"""Build the list of ssdp sources."""
|
||||||
|
|
|
@ -1413,6 +1413,7 @@ async def test_effects(hass: HomeAssistant) -> None:
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
|
config_entry = MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA)
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
|
|
Loading…
Reference in New Issue