Fix cancellation leaking upward from the timeout util (#129003)

pull/129181/head
J. Nick Koston 2024-10-23 12:00:01 -10:00 committed by Franck Nijhof
parent 9dd8c0cc4f
commit 096d50617f
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
2 changed files with 143 additions and 4 deletions

View File

@ -16,7 +16,7 @@ from .async_ import run_callback_threadsafe
ZONE_GLOBAL = "global"
class _State(str, enum.Enum):
class _State(enum.Enum):
"""States of a task."""
INIT = "INIT"
@ -160,11 +160,16 @@ class _GlobalTaskContext:
self._wait_zone: asyncio.Event = asyncio.Event()
self._state: _State = _State.INIT
self._cool_down: float = cool_down
self._cancelling = 0
async def __aenter__(self) -> Self:
self._manager.global_tasks.append(self)
self._start_timer()
self._state = _State.ACTIVE
# Remember if the task was already cancelling
# so when we __aexit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = self._task.cancelling()
return self
async def __aexit__(
@ -177,7 +182,15 @@ class _GlobalTaskContext:
self._manager.global_tasks.remove(self)
# Timeout on exit
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
# The timeout was hit, and the task was cancelled
# so we need to uncancel the task since the cancellation
# should not leak out of the context manager
if self._task.uncancel() > self._cancelling:
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
return None
raise TimeoutError
self._state = _State.EXIT
@ -266,6 +279,7 @@ class _ZoneTaskContext:
self._time_left: float = timeout
self._expiration_time: float | None = None
self._timeout_handler: asyncio.Handle | None = None
self._cancelling = 0
@property
def state(self) -> _State:
@ -280,6 +294,11 @@ class _ZoneTaskContext:
if self._zone.freezes_done:
self._start_timer()
# Remember if the task was already cancelling
# so when we __aexit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = self._task.cancelling()
return self
async def __aexit__(
@ -292,7 +311,15 @@ class _ZoneTaskContext:
self._stop_timer()
# Timeout on exit
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
# The timeout was hit, and the task was cancelled
# so we need to uncancel the task since the cancellation
# should not leak out of the context manager
if self._task.uncancel() > self._cancelling:
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
return None
raise TimeoutError
self._state = _State.EXIT

View File

@ -146,6 +146,62 @@ async def test_simple_global_timeout_freeze_with_executor_job(
await hass.async_add_executor_job(time.sleep, 0.3)
async def test_simple_global_timeout_does_not_leak_upward(
hass: HomeAssistant,
) -> None:
"""Test a global timeout does not leak upward."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)
assert cancelling_inside_timeout == 0
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0
async def test_simple_global_timeout_does_swallow_cancellation(
hass: HomeAssistant,
) -> None:
"""Test a global timeout does not swallow cancellation."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None
async def task_with_timeout() -> None:
nonlocal cancelling_inside_timeout
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
cancelling_inside_timeout = new_task.cancelling()
async with timeout.async_timeout(0.1):
await asyncio.sleep(0.3)
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0
task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1
assert cancelling_inside_timeout == 0
# Cancellation should not leak into the current task
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1
async def test_simple_global_timeout_freeze_reset() -> None:
"""Test a simple global timeout freeze reset."""
timeout = TimeoutManager()
@ -166,6 +222,62 @@ async def test_simple_zone_timeout() -> None:
await asyncio.sleep(0.3)
async def test_simple_zone_timeout_does_not_leak_upward(
hass: HomeAssistant,
) -> None:
"""Test a zone timeout does not leak upward."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1, "test"):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)
assert cancelling_inside_timeout == 0
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0
async def test_simple_zone_timeout_does_swallow_cancellation(
hass: HomeAssistant,
) -> None:
"""Test a zone timeout does not swallow cancellation."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None
async def task_with_timeout() -> None:
nonlocal cancelling_inside_timeout
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1, "test"):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0
task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1
# Cancellation should not leak into the current task
assert cancelling_inside_timeout == 0
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1
async def test_multiple_zone_timeout() -> None:
"""Test a simple zone timeout."""
timeout = TimeoutManager()
@ -327,7 +439,7 @@ async def test_simple_zone_timeout_freeze_without_timeout_exeption() -> None:
await asyncio.sleep(0.4)
async def test_simple_zone_timeout_zone_with_timeout_exeption() -> None:
async def test_simple_zone_timeout_zone_with_timeout_exception() -> None:
"""Test a simple zone timeout freeze on a zone that does not have a timeout set."""
timeout = TimeoutManager()