Skip to content

Commit e7f750b

Browse files
authored
Fixed memory object stream sometimes dropping sent items (#735)
Check if the receiving task has a pending cancellation before sending an item. Fixes #728.
1 parent 9f5f14b commit e7f750b

File tree

8 files changed

+154
-56
lines changed

8 files changed

+154
-56
lines changed

Diff for: docs/versionhistory.rst

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
1010
portals
1111
- Added ``__slots__`` to ``AsyncResource`` so that child classes can use ``__slots__``
1212
(`#733 <https://github.com/agronholm/anyio/pull/733>`_; PR by Justin Su)
13+
- Added the ``TaskInfo.has_pending_cancellation()`` method
14+
- Fixed erroneous ``RuntimeError: called 'started' twice on the same task status``
15+
when cancelling a task in a TaskGroup created with the ``start()`` method before
16+
the first checkpoint is reached after calling ``task_status.started()``
17+
(`#706 <https://github.com/agronholm/anyio/issues/706>`_; PR by Dominik Schwabe)
1318
- Fixed two bugs with ``TaskGroup.start()`` on asyncio:
1419

1520
* Fixed erroneous ``RuntimeError: called 'started' twice on the same task status``
@@ -32,6 +37,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
3237
variable when setting the ``debug`` flag in ``anyio.run()``
3338
- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
3439
the read buffer (`#701 <https://github.com/agronholm/anyio/issues/701>`_)
40+
- Fixed ``MemoryObjectStream`` dropping an item if the item is delivered to a recipient
41+
that is waiting to receive an item but has a cancellation pending
42+
(`#728 <https://github.com/agronholm/anyio/issues/728>`_)
3543
- Emit a ``ResourceWarning`` for ``MemoryObjectReceiveStream`` and
3644
``MemoryObjectSendStream`` that were garbage collected without being closed (PR by
3745
Andrey Kazantcev)

Diff for: src/anyio/_backends/_asyncio.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import socket
88
import sys
99
import threading
10+
import weakref
1011
from asyncio import (
1112
AbstractEventLoop,
1213
CancelledError,
@@ -596,14 +597,14 @@ class TaskState:
596597
itself because there are no guarantees about its implementation.
597598
"""
598599

599-
__slots__ = "parent_id", "cancel_scope"
600+
__slots__ = "parent_id", "cancel_scope", "__weakref__"
600601

601602
def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
602603
self.parent_id = parent_id
603604
self.cancel_scope = cancel_scope
604605

605606

606-
_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
607+
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
607608

608609

609610
#
@@ -1833,14 +1834,36 @@ async def __anext__(self) -> Signals:
18331834
#
18341835

18351836

1836-
def _create_task_info(task: asyncio.Task) -> TaskInfo:
1837-
task_state = _task_states.get(task)
1838-
if task_state is None:
1839-
parent_id = None
1840-
else:
1841-
parent_id = task_state.parent_id
1837+
class AsyncIOTaskInfo(TaskInfo):
1838+
def __init__(self, task: asyncio.Task):
1839+
task_state = _task_states.get(task)
1840+
if task_state is None:
1841+
parent_id = None
1842+
else:
1843+
parent_id = task_state.parent_id
1844+
1845+
super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
1846+
self._task = weakref.ref(task)
18421847

1843-
return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
1848+
def has_pending_cancellation(self) -> bool:
1849+
if not (task := self._task()):
1850+
# If the task isn't around anymore, it won't have a pending cancellation
1851+
return False
1852+
1853+
if sys.version_info >= (3, 11):
1854+
if task.cancelling():
1855+
return True
1856+
elif (
1857+
isinstance(task._fut_waiter, asyncio.Future)
1858+
and task._fut_waiter.cancelled()
1859+
):
1860+
return True
1861+
1862+
if task_state := _task_states.get(task):
1863+
if cancel_scope := task_state.cancel_scope:
1864+
return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
1865+
1866+
return False
18441867

18451868

18461869
class TestRunner(abc.TestRunner):
@@ -2458,11 +2481,11 @@ def open_signal_receiver(
24582481

24592482
@classmethod
24602483
def get_current_task(cls) -> TaskInfo:
2461-
return _create_task_info(current_task()) # type: ignore[arg-type]
2484+
return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
24622485

24632486
@classmethod
2464-
def get_running_tasks(cls) -> list[TaskInfo]:
2465-
return [_create_task_info(task) for task in all_tasks() if not task.done()]
2487+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
2488+
return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
24662489

24672490
@classmethod
24682491
async def wait_all_tasks_blocked(cls) -> None:

Diff for: src/anyio/_backends/_trio.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66
import sys
77
import types
8+
import weakref
89
from collections.abc import AsyncIterator, Iterable
910
from concurrent.futures import Future
1011
from dataclasses import dataclass
@@ -839,6 +840,24 @@ def run_test(
839840
self._call_in_runner_task(test_func, **kwargs)
840841

841842

843+
class TrioTaskInfo(TaskInfo):
844+
def __init__(self, task: trio.lowlevel.Task):
845+
parent_id = None
846+
if task.parent_nursery and task.parent_nursery.parent_task:
847+
parent_id = id(task.parent_nursery.parent_task)
848+
849+
super().__init__(id(task), parent_id, task.name, task.coro)
850+
self._task = weakref.proxy(task)
851+
852+
def has_pending_cancellation(self) -> bool:
853+
try:
854+
return self._task._cancel_status.effectively_cancelled
855+
except ReferenceError:
856+
# If the task is no longer around, it surely doesn't have a cancellation
857+
# pending
858+
return False
859+
860+
842861
class TrioBackend(AsyncBackend):
843862
@classmethod
844863
def run(
@@ -1125,28 +1144,19 @@ def open_signal_receiver(
11251144
@classmethod
11261145
def get_current_task(cls) -> TaskInfo:
11271146
task = current_task()
1128-
1129-
parent_id = None
1130-
if task.parent_nursery and task.parent_nursery.parent_task:
1131-
parent_id = id(task.parent_nursery.parent_task)
1132-
1133-
return TaskInfo(id(task), parent_id, task.name, task.coro)
1147+
return TrioTaskInfo(task)
11341148

11351149
@classmethod
1136-
def get_running_tasks(cls) -> list[TaskInfo]:
1150+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
11371151
root_task = current_root_task()
11381152
assert root_task
1139-
task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
1153+
task_infos = [TrioTaskInfo(root_task)]
11401154
nurseries = root_task.child_nurseries
11411155
while nurseries:
11421156
new_nurseries: list[trio.Nursery] = []
11431157
for nursery in nurseries:
11441158
for task in nursery.child_tasks:
1145-
task_infos.append(
1146-
TaskInfo(
1147-
id(task), id(nursery.parent_task), task.name, task.coro
1148-
)
1149-
)
1159+
task_infos.append(TrioTaskInfo(task))
11501160
new_nurseries.extend(task.child_nurseries)
11511161

11521162
nurseries = new_nurseries

Diff for: src/anyio/_core/_testing.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Generator
4-
from typing import Any
4+
from typing import Any, cast
55

66
from ._eventloop import get_async_backend
77

@@ -45,8 +45,12 @@ def __hash__(self) -> int:
4545
def __repr__(self) -> str:
4646
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
4747

48-
def _unwrap(self) -> TaskInfo:
49-
return self
48+
def has_pending_cancellation(self) -> bool:
49+
"""
50+
Return ``True`` if the task has a cancellation pending, ``False`` otherwise.
51+
52+
"""
53+
return False
5054

5155

5256
def get_current_task() -> TaskInfo:
@@ -66,7 +70,7 @@ def get_running_tasks() -> list[TaskInfo]:
6670
:return: a list of task info objects
6771
6872
"""
69-
return get_async_backend().get_running_tasks()
73+
return cast("list[TaskInfo]", get_async_backend().get_running_tasks())
7074

7175

7276
async def wait_all_tasks_blocked() -> None:

Diff for: src/anyio/abc/_eventloop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def get_current_task(cls) -> TaskInfo:
376376

377377
@classmethod
378378
@abstractmethod
379-
def get_running_tasks(cls) -> list[TaskInfo]:
379+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
380380
pass
381381

382382
@classmethod

Diff for: src/anyio/streams/memory.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
EndOfStream,
1313
WouldBlock,
1414
)
15+
from .._core._testing import TaskInfo, get_current_task
1516
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
1617
from ..lowlevel import checkpoint
1718

@@ -32,13 +33,19 @@ class MemoryObjectStreamStatistics(NamedTuple):
3233
tasks_waiting_receive: int
3334

3435

36+
@dataclass(eq=False)
37+
class MemoryObjectItemReceiver(Generic[T_Item]):
38+
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
39+
item: T_Item = field(init=False)
40+
41+
3542
@dataclass(eq=False)
3643
class MemoryObjectStreamState(Generic[T_Item]):
3744
max_buffer_size: float = field()
3845
buffer: deque[T_Item] = field(init=False, default_factory=deque)
3946
open_send_channels: int = field(init=False, default=0)
4047
open_receive_channels: int = field(init=False, default=0)
41-
waiting_receivers: OrderedDict[Event, list[T_Item]] = field(
48+
waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field(
4249
init=False, default_factory=OrderedDict
4350
)
4451
waiting_senders: OrderedDict[Event, T_Item] = field(
@@ -99,17 +106,17 @@ async def receive(self) -> T_co:
99106
except WouldBlock:
100107
# Add ourselves in the queue
101108
receive_event = Event()
102-
container: list[T_co] = []
103-
self._state.waiting_receivers[receive_event] = container
109+
receiver = MemoryObjectItemReceiver[T_co]()
110+
self._state.waiting_receivers[receive_event] = receiver
104111

105112
try:
106113
await receive_event.wait()
107114
finally:
108115
self._state.waiting_receivers.pop(receive_event, None)
109116

110-
if container:
111-
return container[0]
112-
else:
117+
try:
118+
return receiver.item
119+
except AttributeError:
113120
raise EndOfStream
114121

115122
def clone(self) -> MemoryObjectReceiveStream[T_co]:
@@ -199,11 +206,14 @@ def send_nowait(self, item: T_contra) -> None:
199206
if not self._state.open_receive_channels:
200207
raise BrokenResourceError
201208

202-
if self._state.waiting_receivers:
203-
receive_event, container = self._state.waiting_receivers.popitem(last=False)
204-
container.append(item)
205-
receive_event.set()
206-
elif len(self._state.buffer) < self._state.max_buffer_size:
209+
while self._state.waiting_receivers:
210+
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
211+
if not receiver.task_info.has_pending_cancellation():
212+
receiver.item = item
213+
receive_event.set()
214+
return
215+
216+
if len(self._state.buffer) < self._state.max_buffer_size:
207217
self._state.buffer.append(item)
208218
else:
209219
raise WouldBlock

Diff for: tests/streams/test_memory.py

+57-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
fail_after,
1818
wait_all_tasks_blocked,
1919
)
20-
from anyio.abc import ObjectReceiveStream, ObjectSendStream
20+
from anyio.abc import ObjectReceiveStream, ObjectSendStream, TaskStatus
2121
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2222

2323
if sys.version_info < (3, 11):
@@ -305,28 +305,49 @@ async def test_cancel_during_receive() -> None:
305305
stream to be lost.
306306
307307
"""
308-
receiver_scope = None
309308

310-
async def scoped_receiver() -> None:
311-
nonlocal receiver_scope
312-
with CancelScope() as receiver_scope:
309+
async def scoped_receiver(task_status: TaskStatus[CancelScope]) -> None:
310+
with CancelScope() as cancel_scope:
311+
task_status.started(cancel_scope)
313312
received.append(await receive.receive())
314313

315-
assert receiver_scope.cancel_called
314+
assert cancel_scope.cancel_called
316315

317316
received: list[str] = []
318317
send, receive = create_memory_object_stream[str]()
319-
async with create_task_group() as tg:
320-
tg.start_soon(scoped_receiver)
321-
await wait_all_tasks_blocked()
322-
send.send_nowait("hello")
323-
assert receiver_scope is not None
324-
receiver_scope.cancel()
318+
with send, receive:
319+
async with create_task_group() as tg:
320+
receiver_scope = await tg.start(scoped_receiver)
321+
await wait_all_tasks_blocked()
322+
send.send_nowait("hello")
323+
receiver_scope.cancel()
325324

326325
assert received == ["hello"]
327326

328-
send.close()
329-
receive.close()
327+
328+
async def test_cancel_during_receive_buffered() -> None:
329+
"""
330+
Test that sending an item to a memory object stream when the receiver that is next
331+
in line has been cancelled will not result in the item being lost.
332+
"""
333+
334+
async def scoped_receiver(
335+
receive: MemoryObjectReceiveStream[str], task_status: TaskStatus[CancelScope]
336+
) -> None:
337+
with CancelScope() as cancel_scope:
338+
task_status.started(cancel_scope)
339+
await receive.receive()
340+
341+
send, receive = create_memory_object_stream[str](1)
342+
with send, receive:
343+
async with create_task_group() as tg:
344+
cancel_scope = await tg.start(scoped_receiver, receive)
345+
await wait_all_tasks_blocked()
346+
cancel_scope.cancel()
347+
send.send_nowait("item")
348+
349+
# Since the item was not sent to the cancelled task, it should be available here
350+
assert receive.receive_nowait() == "item"
330351

331352

332353
async def test_close_receive_after_send() -> None:
@@ -455,3 +476,25 @@ async def test_not_closed_warning() -> None:
455476
with pytest.warns(ResourceWarning, match="Unclosed <MemoryObjectReceiveStream>"):
456477
del receive
457478
gc.collect()
479+
480+
481+
@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True)
482+
async def test_send_to_natively_cancelled_receiver() -> None:
483+
"""
484+
Test that if a task waiting on receive.receive() is cancelled and then another
485+
task sends an item, said item is not delivered to the task with a pending
486+
cancellation, but rather to the next one in line.
487+
488+
"""
489+
from asyncio import CancelledError, create_task
490+
491+
send, receive = create_memory_object_stream[str](1)
492+
with send, receive:
493+
receive_task = create_task(receive.receive())
494+
await wait_all_tasks_blocked() # ensure that the task is waiting to receive
495+
receive_task.cancel()
496+
send.send_nowait("hello")
497+
with pytest.raises(CancelledError):
498+
await receive_task
499+
500+
assert receive.receive_nowait() == "hello"

Diff for: tests/test_debugging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def inspect() -> None:
9696
for task, expected_name in zip(task_infos, expected_names):
9797
assert task.parent_id == host_task.id
9898
assert task.name == expected_name
99-
assert repr(task) == f"TaskInfo(id={task.id}, name={expected_name!r})"
99+
assert repr(task).endswith(f"TaskInfo(id={task.id}, name={expected_name!r})")
100100

101101

102102
@pytest.mark.skipif(

0 commit comments

Comments
 (0)