diff --git a/homeassistant/components/ssdp/server.py b/homeassistant/components/ssdp/server.py index b6e105b9560..a9cea01a517 100644 --- a/homeassistant/components/ssdp/server.py +++ b/homeassistant/components/ssdp/server.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from contextlib import ExitStack import logging import socket from time import time @@ -89,22 +90,29 @@ class HassUpnpServiceDevice(UpnpServerDevice): SERVICES: list[type[UpnpServerService]] = [] -async def _async_find_next_available_port(source: AddressTupleVXType) -> int: +async def _async_find_next_available_port( + source: AddressTupleVXType, +) -> tuple[int, socket.socket]: """Get a free TCP port.""" family = socket.AF_INET if is_ipv4_address(source) else socket.AF_INET6 - test_socket = socket.socket(family, socket.SOCK_STREAM) - test_socket.setblocking(False) - test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # We use an ExitStack to ensure the socket is closed if we fail to find a port. + with ExitStack() as stack: + test_socket = stack.enter_context(socket.socket(family, socket.SOCK_STREAM)) + test_socket.setblocking(False) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT): - addr = (source[0], port, *source[2:]) - try: - test_socket.bind(addr) - except OSError: - if port == UPNP_SERVER_MAX_PORT - 1: - raise - else: - return port + for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT): + addr = (source[0], port, *source[2:]) + try: + test_socket.bind(addr) + except OSError: + if port == UPNP_SERVER_MAX_PORT - 1: + raise + else: + # The socket will be dealt by the caller, so we detach it from the stack + # before returning it to prevent it from being closed. + stack.pop_all() + return port, test_socket raise RuntimeError("unreachable") @@ -167,35 +175,43 @@ class Server: # Start a server on all source IPs. boot_id = int(time()) - for source_ip in await async_build_source_set(self.hass): - source_ip_str = str(source_ip) - if source_ip.version == 6: - assert source_ip.scope_id is not None - source_tuple: AddressTupleVXType = ( - source_ip_str, - 0, - 0, - int(source_ip.scope_id), + # We use an ExitStack to ensure that all sockets are closed. + # The socket is created in _async_find_next_available_port, + # and should be kept open until UpnpServer is started to + # keep the kernel from reassigning the port. + with ExitStack() as stack: + for source_ip in await async_build_source_set(self.hass): + source_ip_str = str(source_ip) + if source_ip.version == 6: + assert source_ip.scope_id is not None + source_tuple: AddressTupleVXType = ( + source_ip_str, + 0, + 0, + int(source_ip.scope_id), + ) + else: + source_tuple = (source_ip_str, 0) + source, target = determine_source_target(source_tuple) + source = fix_ipv6_address_scope_id(source) or source + http_port, http_socket = await _async_find_next_available_port(source) + stack.enter_context(http_socket) + _LOGGER.debug( + "Binding UPnP HTTP server to: %s:%s", source_ip, http_port ) - else: - source_tuple = (source_ip_str, 0) - source, target = determine_source_target(source_tuple) - source = fix_ipv6_address_scope_id(source) or source - http_port = await _async_find_next_available_port(source) - _LOGGER.debug("Binding UPnP HTTP server to: %s:%s", source_ip, http_port) - self._upnp_servers.append( - UpnpServer( - source=source, - target=target, - http_port=http_port, - server_device=HassUpnpServiceDevice, - boot_id=boot_id, + self._upnp_servers.append( + UpnpServer( + source=source, + target=target, + http_port=http_port, + server_device=HassUpnpServiceDevice, + boot_id=boot_id, + ) ) + results = await asyncio.gather( + *(upnp_server.async_start() for upnp_server in self._upnp_servers), + return_exceptions=True, ) - results = await asyncio.gather( - *(upnp_server.async_start() for upnp_server in self._upnp_servers), - return_exceptions=True, - ) failed_servers = [] for idx, result in enumerate(results): if isinstance(result, Exception): diff --git a/tests/components/ssdp/conftest.py b/tests/components/ssdp/conftest.py index 61c763ce7d4..644f449fe38 100644 --- a/tests/components/ssdp/conftest.py +++ b/tests/components/ssdp/conftest.py @@ -1,7 +1,8 @@ """Configuration for SSDP tests.""" from collections.abc import Generator -from unittest.mock import AsyncMock, patch +import socket +from unittest.mock import AsyncMock, MagicMock, patch from async_upnp_client.server import UpnpServer from async_upnp_client.ssdp_listener import SsdpListener @@ -29,7 +30,10 @@ async def disabled_upnp_server(): with ( patch("homeassistant.components.ssdp.server.UpnpServer.async_start"), patch("homeassistant.components.ssdp.server.UpnpServer.async_stop"), - patch("homeassistant.components.ssdp.server._async_find_next_available_port"), + patch( + "homeassistant.components.ssdp.server._async_find_next_available_port", + return_value=(40000, MagicMock(spec_set=socket.socket)), + ), ): yield UpnpServer