Strict typing for dhcp (#67361)
parent
21ce441a97
commit
076fe97110
|
@ -68,6 +68,7 @@ homeassistant.components.device_automation.*
|
|||
homeassistant.components.device_tracker.*
|
||||
homeassistant.components.devolo_home_control.*
|
||||
homeassistant.components.devolo_home_network.*
|
||||
homeassistant.components.dhcp.*
|
||||
homeassistant.components.dlna_dmr.*
|
||||
homeassistant.components.dnsip.*
|
||||
homeassistant.components.dsmr.*
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Iterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import fnmatch
|
||||
|
@ -9,7 +12,7 @@ from ipaddress import ip_address as make_ip_address
|
|||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
|
||||
from aiodiscover import DiscoverHosts
|
||||
from aiodiscover.discovery import (
|
||||
|
@ -51,12 +54,16 @@ from homeassistant.helpers.event import (
|
|||
)
|
||||
from homeassistant.helpers.frame import report
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import async_get_dhcp
|
||||
from homeassistant.loader import DHCPMatcher, async_get_dhcp
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.network import is_invalid, is_link_local, is_loopback
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from scapy.packet import Packet
|
||||
from scapy.sendrecv import AsyncSniffer
|
||||
|
||||
FILTER = "udp and (port 67 or 68)"
|
||||
REQUESTED_ADDR = "requested_addr"
|
||||
MESSAGE_TYPE = "message-type"
|
||||
|
@ -115,7 +122,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
watchers: list[WatcherBase] = []
|
||||
address_data: dict[str, dict[str, str]] = {}
|
||||
integration_matchers = await async_get_dhcp(hass)
|
||||
|
||||
# For the passive classes we need to start listening
|
||||
# for state changes and connect the dispatchers before
|
||||
# everything else starts up or we will miss events
|
||||
|
@ -124,13 +130,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
await passive_watcher.async_start()
|
||||
watchers.append(passive_watcher)
|
||||
|
||||
async def _initialize(_):
|
||||
async def _initialize(event: Event) -> None:
|
||||
for active_cls in (DHCPWatcher, NetworkWatcher):
|
||||
active_watcher = active_cls(hass, address_data, integration_matchers)
|
||||
await active_watcher.async_start()
|
||||
watchers.append(active_watcher)
|
||||
|
||||
async def _async_stop(*_):
|
||||
async def _async_stop(event: Event) -> None:
|
||||
for watcher in watchers:
|
||||
await watcher.async_stop()
|
||||
|
||||
|
@ -143,7 +149,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
class WatcherBase:
|
||||
"""Base class for dhcp and device tracker watching."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__()
|
||||
|
||||
|
@ -152,11 +163,11 @@ class WatcherBase:
|
|||
self._address_data = address_data
|
||||
|
||||
@abstractmethod
|
||||
async def async_stop(self):
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop the watcher."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_start(self):
|
||||
async def async_start(self) -> None:
|
||||
"""Start the watcher."""
|
||||
|
||||
def process_client(self, ip_address: str, hostname: str, mac_address: str) -> None:
|
||||
|
@ -197,8 +208,8 @@ class WatcherBase:
|
|||
data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
|
||||
self._address_data[ip_address] = data
|
||||
|
||||
lowercase_hostname = data[HOSTNAME].lower()
|
||||
uppercase_mac = data[MAC_ADDRESS].upper()
|
||||
lowercase_hostname = hostname.lower()
|
||||
uppercase_mac = mac_address.upper()
|
||||
|
||||
_LOGGER.debug(
|
||||
"Processing updated address data for %s: mac=%s hostname=%s",
|
||||
|
@ -218,22 +229,24 @@ class WatcherBase:
|
|||
if entry := self.hass.config_entries.async_get_entry(entry_id):
|
||||
device_domains.add(entry.domain)
|
||||
|
||||
for entry in self._integration_matchers:
|
||||
if entry.get(REGISTERED_DEVICES) and not entry["domain"] in device_domains:
|
||||
for matcher in self._integration_matchers:
|
||||
domain = matcher["domain"]
|
||||
|
||||
if matcher.get(REGISTERED_DEVICES) and domain not in device_domains:
|
||||
continue
|
||||
|
||||
if MAC_ADDRESS in entry and not fnmatch.fnmatch(
|
||||
uppercase_mac, entry[MAC_ADDRESS]
|
||||
):
|
||||
if (
|
||||
matcher_mac := matcher.get(MAC_ADDRESS)
|
||||
) is not None and not fnmatch.fnmatch(uppercase_mac, matcher_mac):
|
||||
continue
|
||||
|
||||
if HOSTNAME in entry and not fnmatch.fnmatch(
|
||||
lowercase_hostname, entry[HOSTNAME]
|
||||
):
|
||||
if (
|
||||
matcher_hostname := matcher.get(HOSTNAME)
|
||||
) is not None and not fnmatch.fnmatch(lowercase_hostname, matcher_hostname):
|
||||
continue
|
||||
|
||||
_LOGGER.debug("Matched %s against %s", data, entry)
|
||||
matched_domains.add(entry["domain"])
|
||||
_LOGGER.debug("Matched %s against %s", data, matcher)
|
||||
matched_domains.add(domain)
|
||||
|
||||
for domain in matched_domains:
|
||||
discovery_flow.async_create_flow(
|
||||
|
@ -243,7 +256,7 @@ class WatcherBase:
|
|||
DhcpServiceInfo(
|
||||
ip=ip_address,
|
||||
hostname=lowercase_hostname,
|
||||
macaddress=data[MAC_ADDRESS],
|
||||
macaddress=mac_address,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -251,14 +264,19 @@ class WatcherBase:
|
|||
class NetworkWatcher(WatcherBase):
|
||||
"""Class to query ptr records routers."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
self._discover_hosts = None
|
||||
self._discover_task = None
|
||||
self._unsub: Callable[[], None] | None = None
|
||||
self._discover_hosts: DiscoverHosts | None = None
|
||||
self._discover_task: asyncio.Task | None = None
|
||||
|
||||
async def async_stop(self):
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop scanning for new devices on the network."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
|
@ -267,7 +285,7 @@ class NetworkWatcher(WatcherBase):
|
|||
self._discover_task.cancel()
|
||||
self._discover_task = None
|
||||
|
||||
async def async_start(self):
|
||||
async def async_start(self) -> None:
|
||||
"""Start scanning for new devices on the network."""
|
||||
self._discover_hosts = DiscoverHosts()
|
||||
self._unsub = async_track_time_interval(
|
||||
|
@ -276,14 +294,15 @@ class NetworkWatcher(WatcherBase):
|
|||
self.async_start_discover()
|
||||
|
||||
@callback
|
||||
def async_start_discover(self, *_):
|
||||
def async_start_discover(self, *_: Any) -> None:
|
||||
"""Start a new discovery task if one is not running."""
|
||||
if self._discover_task and not self._discover_task.done():
|
||||
return
|
||||
self._discover_task = self.hass.async_create_task(self.async_discover())
|
||||
|
||||
async def async_discover(self):
|
||||
async def async_discover(self) -> None:
|
||||
"""Process discovery."""
|
||||
assert self._discover_hosts is not None
|
||||
for host in await self._discover_hosts.async_discover():
|
||||
self.async_process_client(
|
||||
host[DISCOVERY_IP_ADDRESS],
|
||||
|
@ -295,18 +314,23 @@ class NetworkWatcher(WatcherBase):
|
|||
class DeviceTrackerWatcher(WatcherBase):
|
||||
"""Class to watch dhcp data from routers."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
self._unsub: Callable[[], None] | None = None
|
||||
|
||||
async def async_stop(self):
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop watching for new device trackers."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
self._unsub = None
|
||||
|
||||
async def async_start(self):
|
||||
async def async_start(self) -> None:
|
||||
"""Stop watching for new device trackers."""
|
||||
self._unsub = async_track_state_added_domain(
|
||||
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
|
||||
|
@ -315,12 +339,12 @@ class DeviceTrackerWatcher(WatcherBase):
|
|||
self._async_process_device_state(state)
|
||||
|
||||
@callback
|
||||
def _async_process_device_event(self, event: Event):
|
||||
def _async_process_device_event(self, event: Event) -> None:
|
||||
"""Process a device tracker state change event."""
|
||||
self._async_process_device_state(event.data["new_state"])
|
||||
|
||||
@callback
|
||||
def _async_process_device_state(self, state: State):
|
||||
def _async_process_device_state(self, state: State) -> None:
|
||||
"""Process a device tracker state."""
|
||||
if state.state != STATE_HOME:
|
||||
return
|
||||
|
@ -343,18 +367,23 @@ class DeviceTrackerWatcher(WatcherBase):
|
|||
class DeviceTrackerRegisteredWatcher(WatcherBase):
|
||||
"""Class to watch data from device tracker registrations."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
self._unsub: Callable[[], None] | None = None
|
||||
|
||||
async def async_stop(self):
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop watching for device tracker registrations."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
self._unsub = None
|
||||
|
||||
async def async_start(self):
|
||||
async def async_start(self) -> None:
|
||||
"""Stop watching for device tracker registrations."""
|
||||
self._unsub = async_dispatcher_connect(
|
||||
self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_data
|
||||
|
@ -376,26 +405,32 @@ class DeviceTrackerRegisteredWatcher(WatcherBase):
|
|||
class DHCPWatcher(WatcherBase):
|
||||
"""Class to watch dhcp requests."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._sniffer = None
|
||||
self._sniffer: AsyncSniffer | None = None
|
||||
self._started = threading.Event()
|
||||
|
||||
async def async_stop(self):
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop watching for new device trackers."""
|
||||
await self.hass.async_add_executor_job(self._stop)
|
||||
|
||||
def _stop(self):
|
||||
def _stop(self) -> None:
|
||||
"""Stop the thread."""
|
||||
if self._started.is_set():
|
||||
assert self._sniffer is not None
|
||||
self._sniffer.stop()
|
||||
|
||||
async def async_start(self):
|
||||
async def async_start(self) -> None:
|
||||
"""Start watching for dhcp packets."""
|
||||
await self.hass.async_add_executor_job(self._start)
|
||||
|
||||
def _start(self):
|
||||
def _start(self) -> None:
|
||||
"""Start watching for dhcp packets."""
|
||||
# Local import because importing from scapy has side effects such as opening
|
||||
# sockets
|
||||
|
@ -417,20 +452,25 @@ class DHCPWatcher(WatcherBase):
|
|||
AsyncSniffer,
|
||||
)
|
||||
|
||||
def _handle_dhcp_packet(packet):
|
||||
def _handle_dhcp_packet(packet: Packet) -> None:
|
||||
"""Process a dhcp packet."""
|
||||
if DHCP not in packet:
|
||||
return
|
||||
|
||||
options = packet[DHCP].options
|
||||
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
|
||||
if request_type != DHCP_REQUEST:
|
||||
options_dict = _dhcp_options_as_dict(packet[DHCP].options)
|
||||
if options_dict.get(MESSAGE_TYPE) != DHCP_REQUEST:
|
||||
# Not a DHCP request
|
||||
return
|
||||
|
||||
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
|
||||
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
|
||||
mac_address = _format_mac(packet[Ether].src)
|
||||
ip_address = options_dict.get(REQUESTED_ADDR) or cast(str, packet[IP].src)
|
||||
assert isinstance(ip_address, str)
|
||||
hostname = ""
|
||||
if (hostname_bytes := options_dict.get(HOSTNAME)) and isinstance(
|
||||
hostname_bytes, bytes
|
||||
):
|
||||
with contextlib.suppress(AttributeError, UnicodeDecodeError):
|
||||
hostname = hostname_bytes.decode()
|
||||
mac_address = _format_mac(cast(str, packet[Ether].src))
|
||||
|
||||
if ip_address is not None and mac_address is not None:
|
||||
self.process_client(ip_address, hostname, mac_address)
|
||||
|
@ -470,29 +510,19 @@ class DHCPWatcher(WatcherBase):
|
|||
self._sniffer.thread.name = self.__class__.__name__
|
||||
|
||||
|
||||
def _decode_dhcp_option(dhcp_options, key):
|
||||
"""Extract and decode data from a packet option."""
|
||||
for option in dhcp_options:
|
||||
if len(option) < 2 or option[0] != key:
|
||||
continue
|
||||
|
||||
value = option[1]
|
||||
if value is None or key != HOSTNAME:
|
||||
return value
|
||||
|
||||
# hostname is unicode
|
||||
try:
|
||||
return value.decode()
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
return None
|
||||
def _dhcp_options_as_dict(
|
||||
dhcp_options: Iterable[tuple[str, int | bytes | None]]
|
||||
) -> dict[str, str | int | bytes | None]:
|
||||
"""Extract data from packet options as a dict."""
|
||||
return {option[0]: option[1] for option in dhcp_options if len(option) >= 2}
|
||||
|
||||
|
||||
def _format_mac(mac_address):
|
||||
def _format_mac(mac_address: str) -> str:
|
||||
"""Format a mac address for matching."""
|
||||
return format_mac(mac_address).replace(":", "")
|
||||
|
||||
|
||||
def _verify_l2socket_setup(cap_filter):
|
||||
def _verify_l2socket_setup(cap_filter: str) -> None:
|
||||
"""Create a socket using the scapy configured l2socket.
|
||||
|
||||
Try to create the socket
|
||||
|
@ -504,7 +534,7 @@ def _verify_l2socket_setup(cap_filter):
|
|||
conf.L2socket(filter=cap_filter)
|
||||
|
||||
|
||||
def _verify_working_pcap(cap_filter):
|
||||
def _verify_working_pcap(cap_filter: str) -> None:
|
||||
"""Verify we can create a packet filter.
|
||||
|
||||
If we cannot create a filter we will be listening for
|
||||
|
|
|
@ -60,6 +60,24 @@ MAX_LOAD_CONCURRENTLY = 4
|
|||
MOVED_ZEROCONF_PROPS = ("macaddress", "model", "manufacturer")
|
||||
|
||||
|
||||
class DHCPMatcherRequired(TypedDict, total=True):
|
||||
"""Matcher for the dhcp integration for required fields."""
|
||||
|
||||
domain: str
|
||||
|
||||
|
||||
class DHCPMatcherOptional(TypedDict, total=False):
|
||||
"""Matcher for the dhcp integration for optional fields."""
|
||||
|
||||
macaddress: str
|
||||
hostname: str
|
||||
registered_devices: bool
|
||||
|
||||
|
||||
class DHCPMatcher(DHCPMatcherRequired, DHCPMatcherOptional):
|
||||
"""Matcher for the dhcp integration."""
|
||||
|
||||
|
||||
class Manifest(TypedDict, total=False):
|
||||
"""
|
||||
Integration manifest.
|
||||
|
@ -228,16 +246,16 @@ async def async_get_zeroconf(
|
|||
return zeroconf
|
||||
|
||||
|
||||
async def async_get_dhcp(hass: HomeAssistant) -> list[dict[str, str | bool]]:
|
||||
async def async_get_dhcp(hass: HomeAssistant) -> list[DHCPMatcher]:
|
||||
"""Return cached list of dhcp types."""
|
||||
dhcp: list[dict[str, str | bool]] = DHCP.copy()
|
||||
dhcp = cast(list[DHCPMatcher], DHCP.copy())
|
||||
|
||||
integrations = await async_get_custom_components(hass)
|
||||
for integration in integrations.values():
|
||||
if not integration.dhcp:
|
||||
continue
|
||||
for entry in integration.dhcp:
|
||||
dhcp.append({"domain": integration.domain, **entry})
|
||||
dhcp.append(cast(DHCPMatcher, {"domain": integration.domain, **entry}))
|
||||
|
||||
return dhcp
|
||||
|
||||
|
|
11
mypy.ini
11
mypy.ini
|
@ -549,6 +549,17 @@ no_implicit_optional = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.dhcp.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.dlna_dmr.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
|
Loading…
Reference in New Issue