Skip to content

Commit bcdedc4

Browse files
committed
Fix Event.wait() raising cancelled on asyncio when set() before scope cancelled
Fixes agronholm#536.
1 parent bfdc46a commit bcdedc4

File tree

5 files changed

+120
-15
lines changed

5 files changed

+120
-15
lines changed

Diff for: docs/versionhistory.rst

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
4444
the event loop to be closed
4545
- Fixed ``current_effective_deadline()`` not returning ``-inf`` on asyncio when the
4646
currently active cancel scope has been cancelled (PR by Ganden Schaffner)
47+
- Fixed ``Event.set()`` failing to notify a waiter on asyncio if an ``Event.wait()``'s
48+
scope was cancelled after ``Event.set()`` but before the the scheduler resumed the
49+
waiting task. This also fixed a race condition where ``MemoryObjectSendStream.send()``
50+
could raise a ``CancelledError`` on asyncio after successfully delivering an item to a
51+
receiver (PR by Ganden Schaffner)
4752

4853
**3.6.1**
4954

Diff for: src/anyio/_backends/_asyncio.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,42 @@
8080
from ..lowlevel import RunVar
8181
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8282

83-
if sys.version_info < (3, 11):
83+
if sys.version_info >= (3, 11):
84+
85+
def cancelling(task: asyncio.Task) -> bool:
86+
"""
87+
Return ``True`` if the task is cancelling.
88+
89+
NOTE: If the task finished cancelling and is now done, this function can return
90+
anything. This is because on Python >= 3.11, the task can be uncancelled after
91+
it finishes. (One might think we could avoid this by instead returning
92+
``bool(task.cancelling()) or task.cancelled()``, but on Python < 3.8
93+
``task.cancelled()`` can be ``False`` when it should be ``True``. On Python <
94+
3.8 it appears to be impossible to determine whether a done task was cancelled
95+
or not (see https://github.com/python/cpython/pull/16330).)
96+
97+
"""
98+
return bool(task.cancelling())
99+
100+
else:
101+
102+
def cancelling(task: asyncio.Task) -> bool:
103+
if task.cancelled():
104+
return True
105+
106+
if task._must_cancel: # type: ignore[attr-defined]
107+
return True
108+
109+
waiter = task._fut_waiter # type: ignore[attr-defined]
110+
if waiter is None:
111+
return False
112+
if waiter.cancelled():
113+
return True
114+
elif isinstance(waiter, asyncio.Task):
115+
return cancelling(waiter)
116+
else:
117+
return False
118+
84119
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
85120

86121
if sys.version_info >= (3, 8):
@@ -1474,16 +1509,43 @@ def __new__(cls) -> Event:
14741509

14751510
def __init__(self) -> None:
14761511
self._event = asyncio.Event()
1512+
self._waiter_cancelling_when_set: dict[asyncio.Task, bool | None] = {}
14771513

14781514
def set(self) -> None:
1479-
self._event.set()
1515+
if not self._event.is_set():
1516+
self._event.set()
1517+
for waiter in tuple(self._waiter_cancelling_when_set):
1518+
self._waiter_cancelling_when_set[waiter] = cancelling(waiter)
14801519

14811520
def is_set(self) -> bool:
14821521
return self._event.is_set()
14831522

14841523
async def wait(self) -> None:
1485-
if await self._event.wait():
1524+
if self._event.is_set():
14861525
await AsyncIOBackend.checkpoint()
1526+
else:
1527+
waiter = cast(asyncio.Task, current_task())
1528+
self._waiter_cancelling_when_set[waiter] = None
1529+
try:
1530+
if await self._event.wait():
1531+
await AsyncIOBackend.checkpoint()
1532+
except CancelledError:
1533+
if not self._event.is_set():
1534+
raise
1535+
else:
1536+
# If we are here, then the event was not set before `wait()`. Then,
1537+
# in either order:
1538+
#
1539+
# * the event got set.
1540+
# * the current cancel scope was cancelled.
1541+
#
1542+
# To match trio, `Event.wait()` must raise a cancellation exception
1543+
# now if and only if the current scope was cancelled *before* the
1544+
# event was set.
1545+
if self._waiter_cancelling_when_set[waiter]:
1546+
raise
1547+
finally:
1548+
del self._waiter_cancelling_when_set[waiter]
14871549

14881550
def statistics(self) -> EventStatistics:
14891551
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]

Diff for: src/anyio/streams/memory.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55
from types import TracebackType
66
from typing import Generic, NamedTuple, TypeVar
77

8-
from .. import (
9-
BrokenResourceError,
10-
ClosedResourceError,
11-
EndOfStream,
12-
WouldBlock,
13-
get_cancelled_exc_class,
14-
)
8+
from .. import BrokenResourceError, ClosedResourceError, EndOfStream, WouldBlock
159
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
1610
from ..lowlevel import checkpoint
1711

@@ -104,11 +98,6 @@ async def receive(self) -> T_co:
10498

10599
try:
106100
await receive_event.wait()
107-
except get_cancelled_exc_class():
108-
# Ignore the immediate cancellation if we already received an item, so
109-
# as not to lose it
110-
if not container:
111-
raise
112101
finally:
113102
self._state.waiting_receivers.pop(receive_event, None)
114103

Diff for: tests/_backends/__init__.py

Whitespace-only changes.

Diff for: tests/_backends/test_asyncio.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import Any, cast
5+
6+
import pytest
7+
from _pytest.fixtures import getfixturemarker
8+
9+
from anyio import create_task_group
10+
from anyio._backends._asyncio import cancelling
11+
from anyio.abc import TaskStatus
12+
from anyio.lowlevel import cancel_shielded_checkpoint, checkpoint
13+
from anyio.pytest_plugin import anyio_backend_name
14+
15+
try:
16+
from .conftest import anyio_backend as parent_anyio_backend
17+
except ImportError:
18+
from ..conftest import anyio_backend as parent_anyio_backend
19+
20+
pytestmark = pytest.mark.anyio
21+
22+
# Use the inherited anyio_backend, but filter out non-asyncio
23+
anyio_backend = pytest.fixture(
24+
params=[
25+
param
26+
for param in cast(Any, getfixturemarker(parent_anyio_backend)).params
27+
if any(
28+
"asyncio"
29+
in anyio_backend_name.__wrapped__(backend) # type: ignore[attr-defined]
30+
for backend in param.values
31+
)
32+
]
33+
)(parent_anyio_backend.__wrapped__)
34+
35+
36+
async def test_cancelling() -> None:
37+
async def func(*, task_status: TaskStatus[asyncio.Task]) -> None:
38+
task = cast(asyncio.Task, asyncio.current_task())
39+
task_status.started(task)
40+
try:
41+
await checkpoint()
42+
finally:
43+
await cancel_shielded_checkpoint()
44+
45+
async with create_task_group() as tg:
46+
task = cast(asyncio.Task, await tg.start(func))
47+
assert not cancelling(task)
48+
tg.cancel_scope.cancel()
49+
assert cancelling(task)

0 commit comments

Comments
 (0)