Prevent socket leak on SSDP when finding available port (#150999)
Co-authored-by: Abílio Costa <abmantis@users.noreply.github.com>pull/151990/head
parent
36edfd8c04
commit
9e73ff06d2
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import ExitStack
|
||||
import logging
|
||||
import socket
|
||||
from time import time
|
||||
|
@ -89,22 +90,29 @@ class HassUpnpServiceDevice(UpnpServerDevice):
|
|||
SERVICES: list[type[UpnpServerService]] = []
|
||||
|
||||
|
||||
async def _async_find_next_available_port(source: AddressTupleVXType) -> int:
|
||||
async def _async_find_next_available_port(
|
||||
source: AddressTupleVXType,
|
||||
) -> tuple[int, socket.socket]:
|
||||
"""Get a free TCP port."""
|
||||
family = socket.AF_INET if is_ipv4_address(source) else socket.AF_INET6
|
||||
test_socket = socket.socket(family, socket.SOCK_STREAM)
|
||||
test_socket.setblocking(False)
|
||||
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
# We use an ExitStack to ensure the socket is closed if we fail to find a port.
|
||||
with ExitStack() as stack:
|
||||
test_socket = stack.enter_context(socket.socket(family, socket.SOCK_STREAM))
|
||||
test_socket.setblocking(False)
|
||||
test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT):
|
||||
addr = (source[0], port, *source[2:])
|
||||
try:
|
||||
test_socket.bind(addr)
|
||||
except OSError:
|
||||
if port == UPNP_SERVER_MAX_PORT - 1:
|
||||
raise
|
||||
else:
|
||||
return port
|
||||
for port in range(UPNP_SERVER_MIN_PORT, UPNP_SERVER_MAX_PORT):
|
||||
addr = (source[0], port, *source[2:])
|
||||
try:
|
||||
test_socket.bind(addr)
|
||||
except OSError:
|
||||
if port == UPNP_SERVER_MAX_PORT - 1:
|
||||
raise
|
||||
else:
|
||||
# The socket will be dealt by the caller, so we detach it from the stack
|
||||
# before returning it to prevent it from being closed.
|
||||
stack.pop_all()
|
||||
return port, test_socket
|
||||
|
||||
raise RuntimeError("unreachable")
|
||||
|
||||
|
@ -167,35 +175,43 @@ class Server:
|
|||
|
||||
# Start a server on all source IPs.
|
||||
boot_id = int(time())
|
||||
for source_ip in await async_build_source_set(self.hass):
|
||||
source_ip_str = str(source_ip)
|
||||
if source_ip.version == 6:
|
||||
assert source_ip.scope_id is not None
|
||||
source_tuple: AddressTupleVXType = (
|
||||
source_ip_str,
|
||||
0,
|
||||
0,
|
||||
int(source_ip.scope_id),
|
||||
# We use an ExitStack to ensure that all sockets are closed.
|
||||
# The socket is created in _async_find_next_available_port,
|
||||
# and should be kept open until UpnpServer is started to
|
||||
# keep the kernel from reassigning the port.
|
||||
with ExitStack() as stack:
|
||||
for source_ip in await async_build_source_set(self.hass):
|
||||
source_ip_str = str(source_ip)
|
||||
if source_ip.version == 6:
|
||||
assert source_ip.scope_id is not None
|
||||
source_tuple: AddressTupleVXType = (
|
||||
source_ip_str,
|
||||
0,
|
||||
0,
|
||||
int(source_ip.scope_id),
|
||||
)
|
||||
else:
|
||||
source_tuple = (source_ip_str, 0)
|
||||
source, target = determine_source_target(source_tuple)
|
||||
source = fix_ipv6_address_scope_id(source) or source
|
||||
http_port, http_socket = await _async_find_next_available_port(source)
|
||||
stack.enter_context(http_socket)
|
||||
_LOGGER.debug(
|
||||
"Binding UPnP HTTP server to: %s:%s", source_ip, http_port
|
||||
)
|
||||
else:
|
||||
source_tuple = (source_ip_str, 0)
|
||||
source, target = determine_source_target(source_tuple)
|
||||
source = fix_ipv6_address_scope_id(source) or source
|
||||
http_port = await _async_find_next_available_port(source)
|
||||
_LOGGER.debug("Binding UPnP HTTP server to: %s:%s", source_ip, http_port)
|
||||
self._upnp_servers.append(
|
||||
UpnpServer(
|
||||
source=source,
|
||||
target=target,
|
||||
http_port=http_port,
|
||||
server_device=HassUpnpServiceDevice,
|
||||
boot_id=boot_id,
|
||||
self._upnp_servers.append(
|
||||
UpnpServer(
|
||||
source=source,
|
||||
target=target,
|
||||
http_port=http_port,
|
||||
server_device=HassUpnpServiceDevice,
|
||||
boot_id=boot_id,
|
||||
)
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
*(upnp_server.async_start() for upnp_server in self._upnp_servers),
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
*(upnp_server.async_start() for upnp_server in self._upnp_servers),
|
||||
return_exceptions=True,
|
||||
)
|
||||
failed_servers = []
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Configuration for SSDP tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import socket
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from async_upnp_client.server import UpnpServer
|
||||
from async_upnp_client.ssdp_listener import SsdpListener
|
||||
|
@ -29,7 +30,10 @@ async def disabled_upnp_server():
|
|||
with (
|
||||
patch("homeassistant.components.ssdp.server.UpnpServer.async_start"),
|
||||
patch("homeassistant.components.ssdp.server.UpnpServer.async_stop"),
|
||||
patch("homeassistant.components.ssdp.server._async_find_next_available_port"),
|
||||
patch(
|
||||
"homeassistant.components.ssdp.server._async_find_next_available_port",
|
||||
return_value=(40000, MagicMock(spec_set=socket.socket)),
|
||||
),
|
||||
):
|
||||
yield UpnpServer
|
||||
|
||||
|
|
Loading…
Reference in New Issue