95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
"""Executor util helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import contextlib
|
|
import logging
|
|
import sys
|
|
from threading import Thread
|
|
import time
|
|
import traceback
|
|
from typing import Any
|
|
|
|
from .thread import async_raise
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
MAX_LOG_ATTEMPTS = 2
|
|
|
|
_JOIN_ATTEMPTS = 10
|
|
|
|
EXECUTOR_SHUTDOWN_TIMEOUT = 10
|
|
|
|
|
|
def _log_thread_running_at_shutdown(name: str, ident: int) -> None:
|
|
"""Log the stack of a thread that was still running at shutdown."""
|
|
frames = sys._current_frames() # pylint: disable=protected-access
|
|
stack = frames.get(ident)
|
|
formatted_stack = traceback.format_stack(stack)
|
|
_LOGGER.warning(
|
|
"Thread[%s] is still running at shutdown: %s",
|
|
name,
|
|
"".join(formatted_stack).strip(),
|
|
)
|
|
|
|
|
|
def join_or_interrupt_threads(
|
|
threads: set[Thread], timeout: float, log: bool
|
|
) -> set[Thread]:
|
|
"""Attempt to join or interrupt a set of threads."""
|
|
joined = set()
|
|
timeout_per_thread = timeout / len(threads)
|
|
|
|
for thread in threads:
|
|
thread.join(timeout=timeout_per_thread)
|
|
|
|
if not thread.is_alive() or thread.ident is None:
|
|
joined.add(thread)
|
|
continue
|
|
|
|
if log:
|
|
_log_thread_running_at_shutdown(thread.name, thread.ident)
|
|
|
|
with contextlib.suppress(SystemError):
|
|
# SystemError at this stage is usually a race condition
|
|
# where the thread happens to die right before we force
|
|
# it to raise the exception
|
|
async_raise(thread.ident, SystemExit)
|
|
|
|
return joined
|
|
|
|
|
|
class InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
|
|
"""A ThreadPoolExecutor instance that will not deadlock on shutdown."""
|
|
|
|
def shutdown(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Shutdown with interrupt support added."""
|
|
super().shutdown(wait=False, cancel_futures=True)
|
|
self.join_threads_or_timeout()
|
|
|
|
def join_threads_or_timeout(self) -> None:
|
|
"""Join threads or timeout."""
|
|
remaining_threads = set(self._threads)
|
|
start_time = time.monotonic()
|
|
timeout_remaining: float = EXECUTOR_SHUTDOWN_TIMEOUT
|
|
attempt = 0
|
|
|
|
while True:
|
|
if not remaining_threads:
|
|
return
|
|
|
|
attempt += 1
|
|
|
|
remaining_threads -= join_or_interrupt_threads(
|
|
remaining_threads,
|
|
timeout_remaining / _JOIN_ATTEMPTS,
|
|
attempt <= MAX_LOG_ATTEMPTS,
|
|
)
|
|
|
|
timeout_remaining = EXECUTOR_SHUTDOWN_TIMEOUT - (
|
|
time.monotonic() - start_time
|
|
)
|
|
if timeout_remaining <= 0:
|
|
return
|