Fix cancellation leaking upward from the timeout util (#129003)
parent
9dd8c0cc4f
commit
096d50617f
|
@ -16,7 +16,7 @@ from .async_ import run_callback_threadsafe
|
|||
ZONE_GLOBAL = "global"
|
||||
|
||||
|
||||
class _State(str, enum.Enum):
|
||||
class _State(enum.Enum):
|
||||
"""States of a task."""
|
||||
|
||||
INIT = "INIT"
|
||||
|
@ -160,11 +160,16 @@ class _GlobalTaskContext:
|
|||
self._wait_zone: asyncio.Event = asyncio.Event()
|
||||
self._state: _State = _State.INIT
|
||||
self._cool_down: float = cool_down
|
||||
self._cancelling = 0
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self._manager.global_tasks.append(self)
|
||||
self._start_timer()
|
||||
self._state = _State.ACTIVE
|
||||
# Remember if the task was already cancelling
|
||||
# so when we __aexit__ we can decide if we should
|
||||
# raise asyncio.TimeoutError or let the cancellation propagate
|
||||
self._cancelling = self._task.cancelling()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
|
@ -177,7 +182,15 @@ class _GlobalTaskContext:
|
|||
self._manager.global_tasks.remove(self)
|
||||
|
||||
# Timeout on exit
|
||||
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
|
||||
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
||||
# The timeout was hit, and the task was cancelled
|
||||
# so we need to uncancel the task since the cancellation
|
||||
# should not leak out of the context manager
|
||||
if self._task.uncancel() > self._cancelling:
|
||||
# If the task was already cancelling don't raise
|
||||
# asyncio.TimeoutError and instead return None
|
||||
# to allow the cancellation to propagate
|
||||
return None
|
||||
raise TimeoutError
|
||||
|
||||
self._state = _State.EXIT
|
||||
|
@ -266,6 +279,7 @@ class _ZoneTaskContext:
|
|||
self._time_left: float = timeout
|
||||
self._expiration_time: float | None = None
|
||||
self._timeout_handler: asyncio.Handle | None = None
|
||||
self._cancelling = 0
|
||||
|
||||
@property
|
||||
def state(self) -> _State:
|
||||
|
@ -280,6 +294,11 @@ class _ZoneTaskContext:
|
|||
if self._zone.freezes_done:
|
||||
self._start_timer()
|
||||
|
||||
# Remember if the task was already cancelling
|
||||
# so when we __aexit__ we can decide if we should
|
||||
# raise asyncio.TimeoutError or let the cancellation propagate
|
||||
self._cancelling = self._task.cancelling()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
|
@ -292,7 +311,15 @@ class _ZoneTaskContext:
|
|||
self._stop_timer()
|
||||
|
||||
# Timeout on exit
|
||||
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
|
||||
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
||||
# The timeout was hit, and the task was cancelled
|
||||
# so we need to uncancel the task since the cancellation
|
||||
# should not leak out of the context manager
|
||||
if self._task.uncancel() > self._cancelling:
|
||||
# If the task was already cancelling don't raise
|
||||
# asyncio.TimeoutError and instead return None
|
||||
# to allow the cancellation to propagate
|
||||
return None
|
||||
raise TimeoutError
|
||||
|
||||
self._state = _State.EXIT
|
||||
|
|
|
@ -146,6 +146,62 @@ async def test_simple_global_timeout_freeze_with_executor_job(
|
|||
await hass.async_add_executor_job(time.sleep, 0.3)
|
||||
|
||||
|
||||
async def test_simple_global_timeout_does_not_leak_upward(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test a global timeout does not leak upward."""
|
||||
timeout = TimeoutManager()
|
||||
current_task = asyncio.current_task()
|
||||
assert current_task is not None
|
||||
cancelling_inside_timeout = None
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
async with timeout.async_timeout(0.1):
|
||||
cancelling_inside_timeout = current_task.cancelling()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert cancelling_inside_timeout == 0
|
||||
# After the context manager exits, the task should no longer be cancelling
|
||||
assert current_task.cancelling() == 0
|
||||
|
||||
|
||||
async def test_simple_global_timeout_does_swallow_cancellation(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test a global timeout does not swallow cancellation."""
|
||||
timeout = TimeoutManager()
|
||||
current_task = asyncio.current_task()
|
||||
assert current_task is not None
|
||||
cancelling_inside_timeout = None
|
||||
|
||||
async def task_with_timeout() -> None:
|
||||
nonlocal cancelling_inside_timeout
|
||||
new_task = asyncio.current_task()
|
||||
assert new_task is not None
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
cancelling_inside_timeout = new_task.cancelling()
|
||||
async with timeout.async_timeout(0.1):
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# After the context manager exits, the task should no longer be cancelling
|
||||
assert current_task.cancelling() == 0
|
||||
|
||||
task = asyncio.create_task(task_with_timeout())
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
assert task.cancelling() == 1
|
||||
|
||||
assert cancelling_inside_timeout == 0
|
||||
# Cancellation should not leak into the current task
|
||||
assert current_task.cancelling() == 0
|
||||
# Cancellation should not be swallowed if the task is cancelled
|
||||
# and it also times out
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert task.cancelling() == 1
|
||||
|
||||
|
||||
async def test_simple_global_timeout_freeze_reset() -> None:
|
||||
"""Test a simple global timeout freeze reset."""
|
||||
timeout = TimeoutManager()
|
||||
|
@ -166,6 +222,62 @@ async def test_simple_zone_timeout() -> None:
|
|||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
async def test_simple_zone_timeout_does_not_leak_upward(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test a zone timeout does not leak upward."""
|
||||
timeout = TimeoutManager()
|
||||
current_task = asyncio.current_task()
|
||||
assert current_task is not None
|
||||
cancelling_inside_timeout = None
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
async with timeout.async_timeout(0.1, "test"):
|
||||
cancelling_inside_timeout = current_task.cancelling()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert cancelling_inside_timeout == 0
|
||||
# After the context manager exits, the task should no longer be cancelling
|
||||
assert current_task.cancelling() == 0
|
||||
|
||||
|
||||
async def test_simple_zone_timeout_does_swallow_cancellation(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test a zone timeout does not swallow cancellation."""
|
||||
timeout = TimeoutManager()
|
||||
current_task = asyncio.current_task()
|
||||
assert current_task is not None
|
||||
cancelling_inside_timeout = None
|
||||
|
||||
async def task_with_timeout() -> None:
|
||||
nonlocal cancelling_inside_timeout
|
||||
new_task = asyncio.current_task()
|
||||
assert new_task is not None
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
async with timeout.async_timeout(0.1, "test"):
|
||||
cancelling_inside_timeout = current_task.cancelling()
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# After the context manager exits, the task should no longer be cancelling
|
||||
assert current_task.cancelling() == 0
|
||||
|
||||
task = asyncio.create_task(task_with_timeout())
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
assert task.cancelling() == 1
|
||||
|
||||
# Cancellation should not leak into the current task
|
||||
assert cancelling_inside_timeout == 0
|
||||
assert current_task.cancelling() == 0
|
||||
# Cancellation should not be swallowed if the task is cancelled
|
||||
# and it also times out
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert task.cancelling() == 1
|
||||
|
||||
|
||||
async def test_multiple_zone_timeout() -> None:
|
||||
"""Test a simple zone timeout."""
|
||||
timeout = TimeoutManager()
|
||||
|
@ -327,7 +439,7 @@ async def test_simple_zone_timeout_freeze_without_timeout_exeption() -> None:
|
|||
await asyncio.sleep(0.4)
|
||||
|
||||
|
||||
async def test_simple_zone_timeout_zone_with_timeout_exeption() -> None:
|
||||
async def test_simple_zone_timeout_zone_with_timeout_exception() -> None:
|
||||
"""Test a simple zone timeout freeze on a zone that does not have a timeout set."""
|
||||
timeout = TimeoutManager()
|
||||
|
||||
|
|
Loading…
Reference in New Issue