259 lines
6.8 KiB
Python
259 lines
6.8 KiB
Python
"""Block blocking calls being done in asyncio."""
|
|
|
|
import builtins
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from dataclasses import dataclass
|
|
import glob
|
|
from http.client import HTTPConnection
|
|
import importlib
|
|
import os
|
|
from pathlib import Path
|
|
from ssl import SSLContext
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
from .helpers.frame import get_current_frame
|
|
from .util.loop import protect_loop
|
|
|
|
_IN_TESTS = "unittest" in sys.modules
|
|
|
|
ALLOWED_FILE_PREFIXES = ("/proc",)
|
|
|
|
|
|
def _check_import_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|
# If the module is already imported, we can ignore it.
|
|
return bool((args := mapped_args.get("args")) and args[0] in sys.modules)
|
|
|
|
|
|
def _check_file_allowed(mapped_args: dict[str, Any]) -> bool:
|
|
# If the file is in /proc we can ignore it.
|
|
args = mapped_args["args"]
|
|
path = args[0] if type(args[0]) is str else str(args[0])
|
|
return path.startswith(ALLOWED_FILE_PREFIXES)
|
|
|
|
|
|
def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|
#
|
|
# Avoid extracting the stack unless we need to since it
|
|
# will have to access the linecache which can do blocking
|
|
# I/O and we are trying to avoid blocking calls.
|
|
#
|
|
# frame[0] is us
|
|
# frame[1] is raise_for_blocking_call
|
|
# frame[2] is protected_loop_func
|
|
# frame[3] is the offender
|
|
with suppress(ValueError):
|
|
return get_current_frame(4).f_code.co_filename.endswith("pydevd.py")
|
|
return False
|
|
|
|
|
|
def _check_load_verify_locations_call_allowed(mapped_args: dict[str, Any]) -> bool:
|
|
# If only cadata is passed, we can ignore it
|
|
kwargs = mapped_args.get("kwargs")
|
|
return bool(kwargs and len(kwargs) == 1 and "cadata" in kwargs)
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class BlockingCall:
|
|
"""Class to hold information about a blocking call."""
|
|
|
|
original_func: Callable
|
|
object: object
|
|
function: str
|
|
check_allowed: Callable[[dict[str, Any]], bool] | None
|
|
strict: bool
|
|
strict_core: bool
|
|
skip_for_tests: bool
|
|
|
|
|
|
_BLOCKING_CALLS: tuple[BlockingCall, ...] = (
|
|
BlockingCall(
|
|
original_func=HTTPConnection.putrequest,
|
|
object=HTTPConnection,
|
|
function="putrequest",
|
|
check_allowed=None,
|
|
strict=True,
|
|
strict_core=True,
|
|
skip_for_tests=False,
|
|
),
|
|
BlockingCall(
|
|
original_func=time.sleep,
|
|
object=time,
|
|
function="sleep",
|
|
check_allowed=_check_sleep_call_allowed,
|
|
strict=True,
|
|
strict_core=True,
|
|
skip_for_tests=False,
|
|
),
|
|
BlockingCall(
|
|
original_func=glob.glob,
|
|
object=glob,
|
|
function="glob",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=False,
|
|
),
|
|
BlockingCall(
|
|
original_func=glob.iglob,
|
|
object=glob,
|
|
function="iglob",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=False,
|
|
),
|
|
BlockingCall(
|
|
original_func=os.walk,
|
|
object=os,
|
|
function="walk",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=False,
|
|
),
|
|
BlockingCall(
|
|
original_func=os.listdir,
|
|
object=os,
|
|
function="listdir",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=os.scandir,
|
|
object=os,
|
|
function="scandir",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=builtins.open,
|
|
object=builtins,
|
|
function="open",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=importlib.import_module,
|
|
object=importlib,
|
|
function="import_module",
|
|
check_allowed=_check_import_call_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=SSLContext.load_default_certs,
|
|
object=SSLContext,
|
|
function="load_default_certs",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=SSLContext.load_verify_locations,
|
|
object=SSLContext,
|
|
function="load_verify_locations",
|
|
check_allowed=_check_load_verify_locations_call_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=SSLContext.load_cert_chain,
|
|
object=SSLContext,
|
|
function="load_cert_chain",
|
|
check_allowed=None,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=Path.open,
|
|
object=Path,
|
|
function="open",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=Path.read_text,
|
|
object=Path,
|
|
function="read_text",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=Path.read_bytes,
|
|
object=Path,
|
|
function="read_bytes",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=Path.write_text,
|
|
object=Path,
|
|
function="write_text",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
BlockingCall(
|
|
original_func=Path.write_bytes,
|
|
object=Path,
|
|
function="write_bytes",
|
|
check_allowed=_check_file_allowed,
|
|
strict=False,
|
|
strict_core=False,
|
|
skip_for_tests=True,
|
|
),
|
|
)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class BlockedCalls:
|
|
"""Class to track which calls are blocked."""
|
|
|
|
calls: set[BlockingCall]
|
|
|
|
|
|
_BLOCKED_CALLS = BlockedCalls(set())
|
|
|
|
|
|
def enable() -> None:
|
|
"""Enable the detection of blocking calls in the event loop."""
|
|
calls = _BLOCKED_CALLS.calls
|
|
if calls:
|
|
raise RuntimeError("Blocking call detection is already enabled")
|
|
|
|
loop_thread_id = threading.get_ident()
|
|
for blocking_call in _BLOCKING_CALLS:
|
|
if _IN_TESTS and blocking_call.skip_for_tests:
|
|
continue
|
|
|
|
protected_function = protect_loop(
|
|
blocking_call.original_func,
|
|
strict=blocking_call.strict,
|
|
strict_core=blocking_call.strict_core,
|
|
check_allowed=blocking_call.check_allowed,
|
|
loop_thread_id=loop_thread_id,
|
|
)
|
|
setattr(blocking_call.object, blocking_call.function, protected_function)
|
|
calls.add(blocking_call)
|