Rewrite go2rtc binary handling to be async (#128078)

pull/128391/head
Robert Resch 2024-10-14 15:32:00 +02:00 committed by GitHub
parent cdb1b1df15
commit f5b55d5eb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 115 additions and 81 deletions

View File

@ -50,9 +50,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up WebRTC from a config entry."""
if binary := entry.data.get(CONF_BINARY):
# HA will manage the binary
server = Server(binary)
server = Server(hass, binary)
entry.async_on_unload(server.stop)
server.start()
await server.start()
client = Go2RtcClient(async_get_clientsession(hass), entry.data[CONF_HOST])

View File

@ -1,56 +1,70 @@
"""Go2rtc server."""
from __future__ import annotations
import asyncio
import logging
import subprocess
from tempfile import NamedTemporaryFile
from threading import Thread
from .const import DOMAIN
from homeassistant.core import HomeAssistant
_LOGGER = logging.getLogger(__name__)
_TERMINATE_TIMEOUT = 5
class Server(Thread):
"""Server thread."""
def _create_temp_file() -> str:
"""Create temporary config file."""
# Set delete=False to prevent the file from being deleted when the file is closed
# Linux is clearing tmp folder on reboot, so no need to delete it manually
with NamedTemporaryFile(prefix="go2rtc", suffix=".yaml", delete=False) as file:
return file.name
def __init__(self, binary: str) -> None:
async def _log_output(process: asyncio.subprocess.Process) -> None:
"""Log the output of the process."""
assert process.stdout is not None
async for line in process.stdout:
_LOGGER.debug(line[:-1].decode().strip())
class Server:
"""Go2rtc server."""
def __init__(self, hass: HomeAssistant, binary: str) -> None:
"""Initialize the server."""
super().__init__(name=DOMAIN, daemon=True)
self._hass = hass
self._binary = binary
self._stop_requested = False
self._process: asyncio.subprocess.Process | None = None
def run(self) -> None:
"""Run the server."""
async def start(self) -> None:
"""Start the server."""
_LOGGER.debug("Starting go2rtc server")
self._stop_requested = False
with (
NamedTemporaryFile(prefix="go2rtc", suffix=".yaml") as file,
subprocess.Popen(
[self._binary, "-c", "webrtc.ice_servers=[]", "-c", file.name],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as process,
):
while not self._stop_requested and process.poll() is None:
assert process.stdout
line = process.stdout.readline()
if line == b"":
break
_LOGGER.debug(line[:-1].decode())
config_file = await self._hass.async_add_executor_job(_create_temp_file)
_LOGGER.debug("Terminating go2rtc server")
self._process = await asyncio.create_subprocess_exec(
self._binary,
"-c",
"webrtc.ice_servers=[]",
"-c",
config_file,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
self._hass.async_create_background_task(
_log_output(self._process), "Go2rtc log output"
)
async def stop(self) -> None:
"""Stop the server."""
if self._process:
_LOGGER.debug("Stopping go2rtc server")
process = self._process
self._process = None
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
_LOGGER.warning("Go2rtc server didn't terminate gracefully.Killing it")
await asyncio.wait_for(process.wait(), timeout=_TERMINATE_TIMEOUT)
except TimeoutError:
_LOGGER.warning("Go2rtc server didn't terminate gracefully. Killing it")
process.kill()
_LOGGER.debug("Go2rtc server has been stopped")
def stop(self) -> None:
"""Stop the server."""
self._stop_requested = True
if self.is_alive():
self.join()
else:
_LOGGER.debug("Go2rtc server has been stopped")

View File

@ -7,6 +7,7 @@ from go2rtc_client.client import _StreamClient, _WebRTCClient
import pytest
from homeassistant.components.go2rtc.const import CONF_BINARY, DOMAIN
from homeassistant.components.go2rtc.server import Server
from homeassistant.const import CONF_HOST
from tests.common import MockConfigEntry
@ -41,9 +42,11 @@ def mock_client() -> Generator[AsyncMock]:
@pytest.fixture
def mock_server() -> Generator[Mock]:
def mock_server() -> Generator[AsyncMock]:
"""Mock a go2rtc server."""
with patch("homeassistant.components.go2rtc.Server", autoSpec=True) as mock_server:
with patch(
"homeassistant.components.go2rtc.Server", spec_set=Server
) as mock_server:
yield mock_server

View File

@ -184,13 +184,13 @@ async def _test_setup(
async def test_setup_go_binary(
hass: HomeAssistant,
mock_client: AsyncMock,
mock_server: Mock,
mock_server: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test the go2rtc config entry with binary."""
def after_setup() -> None:
mock_server.assert_called_once_with("/usr/bin/go2rtc")
mock_server.assert_called_once_with(hass, "/usr/bin/go2rtc")
mock_server.return_value.start.assert_called_once()
await _test_setup(hass, mock_client, mock_config_entry, after_setup)

View File

@ -2,20 +2,22 @@
import asyncio
from collections.abc import Generator
import logging
import subprocess
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.components.go2rtc.server import Server
from homeassistant.core import HomeAssistant
TEST_BINARY = "/bin/go2rtc"
@pytest.fixture
def server() -> Server:
def server(hass: HomeAssistant) -> Server:
"""Fixture to initialize the Server."""
return Server(binary=TEST_BINARY)
return Server(hass, binary=TEST_BINARY)
@pytest.fixture
@ -29,63 +31,77 @@ def mock_tempfile() -> Generator[MagicMock]:
@pytest.fixture
def mock_popen() -> Generator[MagicMock]:
def mock_process() -> Generator[MagicMock]:
"""Fixture to mock subprocess.Popen."""
with patch("homeassistant.components.go2rtc.server.subprocess.Popen") as mock_popen:
with patch(
"homeassistant.components.go2rtc.server.asyncio.create_subprocess_exec"
) as mock_popen:
mock_popen.return_value.returncode = None
yield mock_popen
@pytest.mark.usefixtures("mock_tempfile")
async def test_server_run_success(mock_popen: MagicMock, server: Server) -> None:
async def test_server_run_success(
mock_process: MagicMock,
server: Server,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that the server runs successfully."""
mock_process = MagicMock()
mock_process.poll.return_value = None # Simulate process running
# Simulate process output
mock_process.stdout.readline.side_effect = [
b"log line 1\n",
b"log line 2\n",
b"",
]
mock_popen.return_value.__enter__.return_value = mock_process
mock_process.return_value.stdout.__aiter__.return_value = iter(
[
b"log line 1\n",
b"log line 2\n",
]
)
server.start()
await asyncio.sleep(0)
await server.start()
# Check that Popen was called with the right arguments
mock_popen.assert_called_once_with(
[TEST_BINARY, "-c", "webrtc.ice_servers=[]", "-c", "test.yaml"],
mock_process.assert_called_once_with(
TEST_BINARY,
"-c",
"webrtc.ice_servers=[]",
"-c",
"test.yaml",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
# Check that server read the log lines
assert mock_process.stdout.readline.call_count == 3
for entry in ("log line 1", "log line 2"):
assert (
"homeassistant.components.go2rtc.server",
logging.DEBUG,
entry,
) in caplog.record_tuples
server.stop()
mock_process.terminate.assert_called_once()
assert not server.is_alive()
await server.stop()
mock_process.return_value.terminate.assert_called_once()
@pytest.mark.usefixtures("mock_tempfile")
def test_server_run_process_timeout(mock_popen: MagicMock, server: Server) -> None:
async def test_server_run_process_timeout(
mock_process: MagicMock, server: Server
) -> None:
"""Test server run where the process takes too long to terminate."""
mock_process.return_value.stdout.__aiter__.return_value = iter(
[
b"log line 1\n",
]
)
async def sleep() -> None:
await asyncio.sleep(1)
mock_process = MagicMock()
mock_process.poll.return_value = None # Simulate process running
# Simulate process output
mock_process.stdout.readline.side_effect = [
b"log line 1\n",
b"",
]
# Simulate timeout
mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="go2rtc", timeout=5)
mock_popen.return_value.__enter__.return_value = mock_process
mock_process.return_value.wait.side_effect = sleep
# Start server thread
server.start()
server.stop()
with patch("homeassistant.components.go2rtc.server._TERMINATE_TIMEOUT", new=0.1):
# Start server thread
await server.start()
await server.stop()
# Ensure terminate and kill were called due to timeout
mock_process.terminate.assert_called_once()
mock_process.kill.assert_called_once()
assert not server.is_alive()
mock_process.return_value.terminate.assert_called_once()
mock_process.return_value.kill.assert_called_once()