core/homeassistant/components/dhcp/__init__.py

269 lines
8.1 KiB
Python
Raw Normal View History

"""The dhcp integration."""
from abc import abstractmethod
import fnmatch
import logging
import os
import threading
from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether
from scapy.sendrecv import sniff
from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME,
ATTR_IP,
ATTR_MAC,
ATTR_SOURCE_TYPE,
DOMAIN as DEVICE_TRACKER_DOMAIN,
SOURCE_TYPE_ROUTER,
)
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
STATE_HOME,
)
from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.loader import async_get_dhcp
from .const import DOMAIN
FILTER = "udp and (port 67 or 68)"
REQUESTED_ADDR = "requested_addr"
MESSAGE_TYPE = "message-type"
HOSTNAME = "hostname"
MAC_ADDRESS = "macaddress"
IP_ADDRESS = "ip"
DHCP_REQUEST = 3
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass: HomeAssistant, config: dict) -> bool:
"""Set up the dhcp component."""
async def _initialize(_):
address_data = {}
integration_matchers = await async_get_dhcp(hass)
watchers = []
for cls in (DHCPWatcher, DeviceTrackerWatcher):
watcher = cls(hass, address_data, integration_matchers)
watcher.async_start()
watchers.append(watcher)
async def _async_stop(*_):
for watcher in watchers:
if hasattr(watcher, "async_stop"):
watcher.async_stop()
else:
await hass.async_add_executor_job(watcher.stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _initialize)
return True
class WatcherBase:
"""Base class for dhcp and device tracker watching."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__()
self.hass = hass
self._integration_matchers = integration_matchers
self._address_data = address_data
def process_client(self, ip_address, hostname, mac_address):
"""Process a client."""
data = self._address_data.get(ip_address)
if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname:
# If the address data is the same no need
# to process it
return
self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self.process_updated_address_data(ip_address, self._address_data[ip_address])
def process_updated_address_data(self, ip_address, data):
"""Process the address data update."""
lowercase_hostname = data[HOSTNAME].lower()
uppercase_mac = data[MAC_ADDRESS].upper()
_LOGGER.debug(
"Processing updated address data for %s: mac=%s hostname=%s",
ip_address,
uppercase_mac,
lowercase_hostname,
)
for entry in self._integration_matchers:
if MAC_ADDRESS in entry and not fnmatch.fnmatch(
uppercase_mac, entry[MAC_ADDRESS]
):
continue
if HOSTNAME in entry and not fnmatch.fnmatch(
lowercase_hostname, entry[HOSTNAME]
):
continue
_LOGGER.debug("Matched %s against %s", data, entry)
self.create_task(
self.hass.config_entries.flow.async_init(
entry["domain"],
context={"source": DOMAIN},
data={IP_ADDRESS: ip_address, **data},
)
)
@abstractmethod
def create_task(self, task):
"""Pass a task to async_add_task based on which context we are in."""
class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
@callback
def async_stop(self):
"""Stop watching for new device trackers."""
if self._unsub:
self._unsub()
self._unsub = None
@callback
def async_start(self):
"""Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
)
for state in self.hass.states.async_all(DEVICE_TRACKER_DOMAIN):
self._async_process_device_state(state)
@callback
def _async_process_device_event(self, event: Event):
"""Process a device tracker state change event."""
self._async_process_device_state(event.data.get("new_state"))
@callback
def _async_process_device_state(self, state: State):
"""Process a device tracker state."""
if state.state != STATE_HOME:
return
attributes = state.attributes
if attributes.get(ATTR_SOURCE_TYPE) != SOURCE_TYPE_ROUTER:
return
ip_address = attributes.get(ATTR_IP)
hostname = attributes.get(ATTR_HOST_NAME)
mac_address = attributes.get(ATTR_MAC)
if ip_address is None or hostname is None or mac_address is None:
return
self.process_client(ip_address, hostname, _format_mac(mac_address))
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
self.hass.async_create_task(task)
class DHCPWatcher(WatcherBase, threading.Thread):
"""Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self.name = "dhcp-discovery"
self._stop_event = threading.Event()
def stop(self):
"""Stop the thread."""
self._stop_event.set()
self.join()
@callback
def async_start(self):
"""Start the thread."""
self.start()
def run(self):
"""Start watching for dhcp packets."""
try:
sniff(
filter=FILTER,
prn=self.handle_dhcp_packet,
stop_filter=lambda _: self._stop_event.is_set(),
)
except (Scapy_Exception, OSError) as ex:
if os.geteuid() == 0:
_LOGGER.error("Cannot watch for dhcp packets: %s", ex)
else:
_LOGGER.debug(
"Cannot watch for dhcp packets without root or CAP_NET_RAW: %s", ex
)
return
def handle_dhcp_packet(self, packet):
"""Process a dhcp packet."""
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
# DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR)
hostname = _decode_dhcp_option(options, HOSTNAME)
mac_address = _format_mac(packet[Ether].src)
if ip_address is None or hostname is None or mac_address is None:
return
self.process_client(ip_address, hostname, mac_address)
def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread."""
self.hass.add_job(task)
def _decode_dhcp_option(dhcp_options, key):
"""Extract and decode data from a packet option."""
for option in dhcp_options:
if len(option) < 2 or option[0] != key:
continue
value = option[1]
if value is None or key != HOSTNAME:
return value
# hostname is unicode
try:
return value.decode()
except (AttributeError, UnicodeDecodeError):
return None
def _format_mac(mac_address):
"""Format a mac address for matching."""
return format_mac(mac_address).replace(":", "")