core/homeassistant/components/yeelight/scanner.py

201 lines
7.1 KiB
Python

"""Support for Xiaomi Yeelight WiFi color bulb."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, ValuesView
import contextlib
from datetime import datetime
from ipaddress import IPv4Address
import logging
from urllib.parse import urlparse
import async_timeout
from async_upnp_client.search import SsdpSearchListener
from async_upnp_client.utils import CaseInsensitiveDict
from typing_extensions import Self
from homeassistant import config_entries
from homeassistant.components import network, ssdp
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import discovery_flow
from homeassistant.helpers.event import async_call_later, async_track_time_interval
from .const import (
DISCOVERY_ATTEMPTS,
DISCOVERY_INTERVAL,
DISCOVERY_SEARCH_INTERVAL,
DISCOVERY_TIMEOUT,
DOMAIN,
SSDP_ST,
SSDP_TARGET,
)
_LOGGER = logging.getLogger(__name__)
class YeelightScanner:
"""Scan for Yeelight devices."""
_scanner: Self | None = None
@classmethod
@callback
def async_get(cls, hass: HomeAssistant) -> YeelightScanner:
"""Get scanner instance."""
if cls._scanner is None:
cls._scanner = cls(hass)
return cls._scanner
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize class."""
self._hass = hass
self._host_discovered_events: dict[str, list[asyncio.Event]] = {}
self._unique_id_capabilities: dict[str, CaseInsensitiveDict] = {}
self._host_capabilities: dict[str, CaseInsensitiveDict] = {}
self._track_interval: CALLBACK_TYPE | None = None
self._listeners: list[SsdpSearchListener] = []
self._connected_events: list[asyncio.Event] = []
async def async_setup(self) -> None:
"""Set up the scanner."""
if self._connected_events:
await self._async_wait_connected()
return
for idx, source_ip in enumerate(await self._async_build_source_set()):
self._connected_events.append(asyncio.Event())
def _wrap_async_connected_idx(idx) -> Callable[[], None]:
"""Create a function to capture the idx cell variable."""
@callback
def _async_connected() -> None:
self._connected_events[idx].set()
return _async_connected
source = (str(source_ip), 0)
self._listeners.append(
SsdpSearchListener(
callback=self._async_process_entry,
search_target=SSDP_ST,
target=SSDP_TARGET,
source=source,
connect_callback=_wrap_async_connected_idx(idx),
)
)
results = await asyncio.gather(
*(listener.async_start() for listener in self._listeners),
return_exceptions=True,
)
failed_listeners = []
for idx, result in enumerate(results):
if not isinstance(result, Exception):
continue
_LOGGER.warning(
"Failed to setup listener for %s: %s",
self._listeners[idx].source,
result,
)
failed_listeners.append(self._listeners[idx])
self._connected_events[idx].set()
for listener in failed_listeners:
self._listeners.remove(listener)
await self._async_wait_connected()
self._track_interval = async_track_time_interval(
self._hass, self.async_scan, DISCOVERY_INTERVAL
)
self.async_scan()
async def _async_wait_connected(self):
"""Wait for the listeners to be up and connected."""
await asyncio.gather(*(event.wait() for event in self._connected_events))
async def _async_build_source_set(self) -> set[IPv4Address]:
"""Build the list of ssdp sources."""
adapters = await network.async_get_adapters(self._hass)
sources: set[IPv4Address] = set()
if network.async_only_default_interface_enabled(adapters):
sources.add(IPv4Address("0.0.0.0"))
return sources
return {
source_ip
for source_ip in await network.async_get_enabled_source_ips(self._hass)
if isinstance(source_ip, IPv4Address) and not source_ip.is_loopback
}
async def async_discover(self) -> ValuesView[CaseInsensitiveDict]:
"""Discover bulbs."""
_LOGGER.debug("Yeelight discover with interval %s", DISCOVERY_SEARCH_INTERVAL)
await self.async_setup()
for _ in range(DISCOVERY_ATTEMPTS):
self.async_scan()
await asyncio.sleep(DISCOVERY_SEARCH_INTERVAL.total_seconds())
return self._unique_id_capabilities.values()
@callback
def async_scan(self, _: datetime | None = None) -> None:
"""Send discovery packets."""
_LOGGER.debug("Yeelight scanning")
for listener in self._listeners:
listener.async_search()
async def async_get_capabilities(self, host: str) -> CaseInsensitiveDict | None:
"""Get capabilities via SSDP."""
if host in self._host_capabilities:
return self._host_capabilities[host]
host_event = asyncio.Event()
self._host_discovered_events.setdefault(host, []).append(host_event)
await self.async_setup()
for listener in self._listeners:
listener.async_search((host, SSDP_TARGET[1]))
with contextlib.suppress(asyncio.TimeoutError):
async with async_timeout.timeout(DISCOVERY_TIMEOUT):
await host_event.wait()
self._host_discovered_events[host].remove(host_event)
return self._host_capabilities.get(host)
def _async_discovered_by_ssdp(self, response: CaseInsensitiveDict) -> None:
@callback
def _async_start_flow(*_) -> None:
discovery_flow.async_create_flow(
self._hass,
DOMAIN,
context={"source": config_entries.SOURCE_SSDP},
data=ssdp.SsdpServiceInfo(
ssdp_usn="",
ssdp_st=SSDP_ST,
ssdp_headers=response,
upnp={},
),
)
# Delay starting the flow in case the discovery is the result
# of another discovery
async_call_later(self._hass, 1, _async_start_flow)
@callback
def _async_process_entry(self, headers: CaseInsensitiveDict) -> None:
"""Process a discovery."""
_LOGGER.debug("Discovered via SSDP: %s", headers)
unique_id = headers["id"]
host = urlparse(headers["location"]).hostname
assert host
current_entry = self._unique_id_capabilities.get(unique_id)
# Make sure we handle ip changes
if not current_entry or host != urlparse(current_entry["location"]).hostname:
_LOGGER.debug("Yeelight discovered with %s", headers)
self._async_discovered_by_ssdp(headers)
self._host_capabilities[host] = headers
self._unique_id_capabilities[unique_id] = headers
for event in self._host_discovered_events.get(host, []):
event.set()