Prevent socket leak on SSDP when finding available port (#150999)

Co-authored-by: Abílio Costa <abmantis@users.noreply.github.com>
pull/151990/head
skbeh 2025-09-09 15:50:52 +00:00 committed by GitHub
parent 36edfd8c04
commit 9e73ff06d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 41 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from contextlib import ExitStack
import logging
import socket
from time import time
@ -89,22 +90,29 @@ class HassUpnpServiceDevice(UpnpServerDevice):
SERVICES: list[type[UpnpServerService]] = []
async def _async_find_next_available_port(source: AddressTupleVXType) -> int:
async def _async_find_next_available_port(
source: AddressTupleVXType,
) -> tuple[int, socket.socket]:
"""Get a free TCP port."""
family = socket.AF_INET if is_ipv4_address(source) else socket.AF_INET6
test_socket = socket.socket(family, socket.SOCK_STREAM)
test_socket.setblocking(False)
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# We use an ExitStack to ensure the socket is closed if we fail to find a port.
with ExitStack() as stack:
test_socket = stack.enter_context(socket.socket(family, socket.SOCK_STREAM))
test_socket.setblocking(False)
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT):
addr = (source[0], port, *source[2:])
try:
test_socket.bind(addr)
except OSError:
if port == UPNP_SERVER_MAX_PORT - 1:
raise
else:
return port
for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT):
addr = (source[0], port, *source[2:])
try:
test_socket.bind(addr)
except OSError:
if port == UPNP_SERVER_MAX_PORT - 1:
raise
else:
# The socket will be dealt by the caller, so we detach it from the stack
# before returning it to prevent it from being closed.
stack.pop_all()
return port, test_socket
raise RuntimeError("unreachable")
@ -167,35 +175,43 @@ class Server:
# Start a server on all source IPs.
boot_id = int(time())
for source_ip in await async_build_source_set(self.hass):
source_ip_str = str(source_ip)
if source_ip.version == 6:
assert source_ip.scope_id is not None
source_tuple: AddressTupleVXType = (
source_ip_str,
0,
0,
int(source_ip.scope_id),
# We use an ExitStack to ensure that all sockets are closed.
# The socket is created in _async_find_next_available_port,
# and should be kept open until UpnpServer is started to
# keep the kernel from reassigning the port.
with ExitStack() as stack:
for source_ip in await async_build_source_set(self.hass):
source_ip_str = str(source_ip)
if source_ip.version == 6:
assert source_ip.scope_id is not None
source_tuple: AddressTupleVXType = (
source_ip_str,
0,
0,
int(source_ip.scope_id),
)
else:
source_tuple = (source_ip_str, 0)
source, target = determine_source_target(source_tuple)
source = fix_ipv6_address_scope_id(source) or source
http_port, http_socket = await _async_find_next_available_port(source)
stack.enter_context(http_socket)
_LOGGER.debug(
"Binding UPnP HTTP server to: %s:%s", source_ip, http_port
)
else:
source_tuple = (source_ip_str, 0)
source, target = determine_source_target(source_tuple)
source = fix_ipv6_address_scope_id(source) or source
http_port = await _async_find_next_available_port(source)
_LOGGER.debug("Binding UPnP HTTP server to: %s:%s", source_ip, http_port)
self._upnp_servers.append(
UpnpServer(
source=source,
target=target,
http_port=http_port,
server_device=HassUpnpServiceDevice,
boot_id=boot_id,
self._upnp_servers.append(
UpnpServer(
source=source,
target=target,
http_port=http_port,
server_device=HassUpnpServiceDevice,
boot_id=boot_id,
)
)
results = await asyncio.gather(
*(upnp_server.async_start() for upnp_server in self._upnp_servers),
return_exceptions=True,
)
results = await asyncio.gather(
*(upnp_server.async_start() for upnp_server in self._upnp_servers),
return_exceptions=True,
)
failed_servers = []
for idx, result in enumerate(results):
if isinstance(result, Exception):

View File

@ -1,7 +1,8 @@
"""Configuration for SSDP tests."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import socket
from unittest.mock import AsyncMock, MagicMock, patch
from async_upnp_client.server import UpnpServer
from async_upnp_client.ssdp_listener import SsdpListener
@ -29,7 +30,10 @@ async def disabled_upnp_server():
with (
patch("homeassistant.components.ssdp.server.UpnpServer.async_start"),
patch("homeassistant.components.ssdp.server.UpnpServer.async_stop"),
patch("homeassistant.components.ssdp.server._async_find_next_available_port"),
patch(
"homeassistant.components.ssdp.server._async_find_next_available_port",
return_value=(40000, MagicMock(spec_set=socket.socket)),
),
):
yield UpnpServer