From f06b00c6d875e3ad0150966a31f56317ea0134ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Maggioni?= Date: Mon, 8 Apr 2024 10:04:59 +0200 Subject: [PATCH] Fix hang in SNMP device_tracker implementation (#112815) Co-authored-by: J. Nick Koston --- CODEOWNERS | 2 + .../components/snmp/device_tracker.py | 154 ++++++++++++------ homeassistant/components/snmp/manifest.json | 2 +- 3 files changed, 110 insertions(+), 48 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 946caef629e..40d7c0f502a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1270,6 +1270,8 @@ build.json @home-assistant/supervisor /homeassistant/components/sms/ @ocalvo /homeassistant/components/snapcast/ @luar123 /tests/components/snapcast/ @luar123 +/homeassistant/components/snmp/ @nmaggioni +/tests/components/snmp/ @nmaggioni /homeassistant/components/snooz/ @AustinBrunkhorst /tests/components/snooz/ @AustinBrunkhorst /homeassistant/components/solaredge/ @frenck diff --git a/homeassistant/components/snmp/device_tracker.py b/homeassistant/components/snmp/device_tracker.py index 4b8ab073b9c..a1a91116f0f 100644 --- a/homeassistant/components/snmp/device_tracker.py +++ b/homeassistant/components/snmp/device_tracker.py @@ -5,8 +5,19 @@ from __future__ import annotations import binascii import logging -from pysnmp.entity import config as cfg -from pysnmp.entity.rfc3413.oneliner import cmdgen +from pysnmp.error import PySnmpError +from pysnmp.hlapi.asyncio import ( + CommunityData, + ContextData, + ObjectIdentity, + ObjectType, + SnmpEngine, + Udp6TransportTarget, + UdpTransportTarget, + UsmUserData, + bulkWalkCmd, + isEndOfMib, +) import voluptuous as vol from homeassistant.components.device_tracker import ( @@ -24,7 +35,13 @@ from .const import ( CONF_BASEOID, CONF_COMMUNITY, CONF_PRIV_KEY, + DEFAULT_AUTH_PROTOCOL, DEFAULT_COMMUNITY, + DEFAULT_PORT, + DEFAULT_PRIV_PROTOCOL, + DEFAULT_TIMEOUT, + DEFAULT_VERSION, + SNMP_VERSIONS, ) _LOGGER = logging.getLogger(__name__) @@ -40,9 +57,12 @@ PLATFORM_SCHEMA = PARENT_PLATFORM_SCHEMA.extend( ) -def get_scanner(hass: HomeAssistant, config: ConfigType) -> SnmpScanner | None: +async def async_get_scanner( + hass: HomeAssistant, config: ConfigType +) -> SnmpScanner | None: """Validate the configuration and return an SNMP scanner.""" scanner = SnmpScanner(config[DOMAIN]) + await scanner.async_init() return scanner if scanner.success_init else None @@ -51,39 +71,75 @@ class SnmpScanner(DeviceScanner): """Queries any SNMP capable Access Point for connected devices.""" def __init__(self, config): - """Initialize the scanner.""" + """Initialize the scanner and test the target device.""" + host = config[CONF_HOST] + community = config[CONF_COMMUNITY] + baseoid = config[CONF_BASEOID] + authkey = config.get(CONF_AUTH_KEY) + authproto = DEFAULT_AUTH_PROTOCOL + privkey = config.get(CONF_PRIV_KEY) + privproto = DEFAULT_PRIV_PROTOCOL - self.snmp = cmdgen.CommandGenerator() + try: + # Try IPv4 first. + target = UdpTransportTarget((host, DEFAULT_PORT), timeout=DEFAULT_TIMEOUT) + except PySnmpError: + # Then try IPv6. + try: + target = Udp6TransportTarget( + (host, DEFAULT_PORT), timeout=DEFAULT_TIMEOUT + ) + except PySnmpError as err: + _LOGGER.error("Invalid SNMP host: %s", err) + return - self.host = cmdgen.UdpTransportTarget((config[CONF_HOST], 161)) - if CONF_AUTH_KEY not in config or CONF_PRIV_KEY not in config: - self.auth = cmdgen.CommunityData(config[CONF_COMMUNITY]) + if authkey is not None or privkey is not None: + if not authkey: + authproto = "none" + if not privkey: + privproto = "none" + + request_args = [ + SnmpEngine(), + UsmUserData( + community, + authKey=authkey or None, + privKey=privkey or None, + authProtocol=authproto, + privProtocol=privproto, + ), + target, + ContextData(), + ] else: - self.auth = cmdgen.UsmUserData( - config[CONF_COMMUNITY], - config[CONF_AUTH_KEY], - config[CONF_PRIV_KEY], - authProtocol=cfg.usmHMACSHAAuthProtocol, - privProtocol=cfg.usmAesCfb128Protocol, - ) - self.baseoid = cmdgen.MibVariable(config[CONF_BASEOID]) - self.last_results = [] + request_args = [ + SnmpEngine(), + CommunityData(community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]), + target, + ContextData(), + ] - # Test the router is accessible - data = self.get_snmp_data() + self.request_args = request_args + self.baseoid = baseoid + self.last_results = [] + self.success_init = False + + async def async_init(self): + """Make a one-off read to check if the target device is reachable and readable.""" + data = await self.async_get_snmp_data() self.success_init = data is not None - def scan_devices(self): + async def async_scan_devices(self): """Scan for new devices and return a list with found device IDs.""" - self._update_info() + await self._async_update_info() return [client["mac"] for client in self.last_results if client.get("mac")] - def get_device_name(self, device): + async def async_get_device_name(self, device): """Return the name of the given device or None if we don't know.""" # We have no names return None - def _update_info(self): + async def _async_update_info(self): """Ensure the information from the device is up to date. Return boolean if scanning successful. @@ -91,38 +147,42 @@ class SnmpScanner(DeviceScanner): if not self.success_init: return False - if not (data := self.get_snmp_data()): + if not (data := await self.async_get_snmp_data()): return False self.last_results = data return True - def get_snmp_data(self): + async def async_get_snmp_data(self): """Fetch MAC addresses from access point via SNMP.""" devices = [] - errindication, errstatus, errindex, restable = self.snmp.nextCmd( - self.auth, self.host, self.baseoid + walker = bulkWalkCmd( + *self.request_args, + 0, + 50, + ObjectType(ObjectIdentity(self.baseoid)), + lexicographicMode=False, ) + async for errindication, errstatus, errindex, res in walker: + if errindication: + _LOGGER.error("SNMPLIB error: %s", errindication) + return + if errstatus: + _LOGGER.error( + "SNMP error: %s at %s", + errstatus.prettyPrint(), + errindex and res[int(errindex) - 1][0] or "?", + ) + return - if errindication: - _LOGGER.error("SNMPLIB error: %s", errindication) - return - if errstatus: - _LOGGER.error( - "SNMP error: %s at %s", - errstatus.prettyPrint(), - errindex and restable[int(errindex) - 1][0] or "?", - ) - return - - for resrow in restable: - for _, val in resrow: - try: - mac = binascii.hexlify(val.asOctets()).decode("utf-8") - except AttributeError: - continue - _LOGGER.debug("Found MAC address: %s", mac) - mac = ":".join([mac[i : i + 2] for i in range(0, len(mac), 2)]) - devices.append({"mac": mac}) + for _oid, value in res: + if not isEndOfMib(res): + try: + mac = binascii.hexlify(value.asOctets()).decode("utf-8") + except AttributeError: + continue + _LOGGER.debug("Found MAC address: %s", mac) + mac = ":".join([mac[i : i + 2] for i in range(0, len(mac), 2)]) + devices.append({"mac": mac}) return devices diff --git a/homeassistant/components/snmp/manifest.json b/homeassistant/components/snmp/manifest.json index c4aa82f2a74..d79910c44cd 100644 --- a/homeassistant/components/snmp/manifest.json +++ b/homeassistant/components/snmp/manifest.json @@ -1,7 +1,7 @@ { "domain": "snmp", "name": "SNMP", - "codeowners": [], + "codeowners": ["@nmaggioni"], "documentation": "https://www.home-assistant.io/integrations/snmp", "iot_class": "local_polling", "loggers": ["pyasn1", "pysmi", "pysnmp"],