Skip to content

Commit 535bcb5

Browse files
committed
Fixed memory object stream sometimes dropping sent items
Check if the receiving task has a pending cancellation before sending an item. Fixes #728.
1 parent 4b3de97 commit 535bcb5

File tree

8 files changed

+121
-57
lines changed

8 files changed

+121
-57
lines changed

Diff for: docs/versionhistory.rst

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
2525
variable when setting the ``debug`` flag in ``anyio.run()``
2626
- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
2727
the read buffer (`#701 <https://github.com/agronholm/anyio/issues/701>`_)
28+
- Fixed ``MemoryObjectStream`` dropping an item if the item is delivered to a recipient
29+
that is waiting to receive an item but has a cancellation pending
30+
(`#728 <https://github.com/agronholm/anyio/issues/728>`_)
2831
- Emit a ``ResourceWarning`` for ``MemoryObjectReceiveStream`` and
2932
``MemoryObjectSendStream`` that were garbage collected without being closed (PR by
3033
Andrey Kazantcev)

Diff for: src/anyio/_backends/_asyncio.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import socket
88
import sys
99
import threading
10+
import weakref
1011
from asyncio import (
1112
AbstractEventLoop,
1213
CancelledError,
@@ -596,14 +597,14 @@ class TaskState:
596597
itself because there are no guarantees about its implementation.
597598
"""
598599

599-
__slots__ = "parent_id", "cancel_scope"
600+
__slots__ = "parent_id", "cancel_scope", "__weakref__"
600601

601602
def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
602603
self.parent_id = parent_id
603604
self.cancel_scope = cancel_scope
604605

605606

606-
_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
607+
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
607608

608609

609610
#
@@ -1827,14 +1828,33 @@ async def __anext__(self) -> Signals:
18271828
#
18281829

18291830

1830-
def _create_task_info(task: asyncio.Task) -> TaskInfo:
1831-
task_state = _task_states.get(task)
1832-
if task_state is None:
1833-
parent_id = None
1834-
else:
1835-
parent_id = task_state.parent_id
1831+
class AsyncIOTaskInfo(TaskInfo):
1832+
def __init__(self, task: asyncio.Task):
1833+
task_state = _task_states.get(task)
1834+
if task_state is None:
1835+
parent_id = None
1836+
else:
1837+
parent_id = task_state.parent_id
1838+
1839+
super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
1840+
self._task = weakref.ref(task)
1841+
1842+
def has_pending_cancellation(self) -> bool:
1843+
if not (task := self._task()):
1844+
# If the task isn't around anymore, it won't have a pending cancellation
1845+
return False
1846+
1847+
if sys.version_info >= (3, 11):
1848+
if task.cancelling():
1849+
return True
1850+
elif task._must_cancel:
1851+
return True
18361852

1837-
return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
1853+
if task_state := _task_states.get(task):
1854+
if cancel_scope := task_state.cancel_scope:
1855+
return cancel_scope.cancel_called or cancel_scope._parent_cancelled()
1856+
1857+
return False
18381858

18391859

18401860
class TestRunner(abc.TestRunner):
@@ -2452,11 +2472,11 @@ def open_signal_receiver(
24522472

24532473
@classmethod
24542474
def get_current_task(cls) -> TaskInfo:
2455-
return _create_task_info(current_task()) # type: ignore[arg-type]
2475+
return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
24562476

24572477
@classmethod
2458-
def get_running_tasks(cls) -> list[TaskInfo]:
2459-
return [_create_task_info(task) for task in all_tasks() if not task.done()]
2478+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
2479+
return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
24602480

24612481
@classmethod
24622482
async def wait_all_tasks_blocked(cls) -> None:

Diff for: src/anyio/_backends/_trio.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66
import sys
77
import types
8+
import weakref
89
from collections.abc import AsyncIterator, Iterable
910
from concurrent.futures import Future
1011
from dataclasses import dataclass
@@ -839,6 +840,24 @@ def run_test(
839840
self._call_in_runner_task(test_func, **kwargs)
840841

841842

843+
class TrioTaskInfo(TaskInfo):
844+
def __init__(self, task: trio.lowlevel.Task):
845+
parent_id = None
846+
if task.parent_nursery and task.parent_nursery.parent_task:
847+
parent_id = id(task.parent_nursery.parent_task)
848+
849+
super().__init__(id(task), parent_id, task.name, task.coro)
850+
self._task = weakref.proxy(task)
851+
852+
def has_pending_cancellation(self) -> bool:
853+
try:
854+
return self._task._cancel_status.effectively_cancelled
855+
except ReferenceError:
856+
# If the task is no longer around, it surely doesn't have a cancellation
857+
# pending
858+
return False
859+
860+
842861
class TrioBackend(AsyncBackend):
843862
@classmethod
844863
def run(
@@ -1125,28 +1144,19 @@ def open_signal_receiver(
11251144
@classmethod
11261145
def get_current_task(cls) -> TaskInfo:
11271146
task = current_task()
1128-
1129-
parent_id = None
1130-
if task.parent_nursery and task.parent_nursery.parent_task:
1131-
parent_id = id(task.parent_nursery.parent_task)
1132-
1133-
return TaskInfo(id(task), parent_id, task.name, task.coro)
1147+
return TrioTaskInfo(task)
11341148

11351149
@classmethod
1136-
def get_running_tasks(cls) -> list[TaskInfo]:
1150+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
11371151
root_task = current_root_task()
11381152
assert root_task
1139-
task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
1153+
task_infos = [TrioTaskInfo(root_task)]
11401154
nurseries = root_task.child_nurseries
11411155
while nurseries:
11421156
new_nurseries: list[trio.Nursery] = []
11431157
for nursery in nurseries:
11441158
for task in nursery.child_tasks:
1145-
task_infos.append(
1146-
TaskInfo(
1147-
id(task), id(nursery.parent_task), task.name, task.coro
1148-
)
1149-
)
1159+
task_infos.append(TrioTaskInfo(task))
11501160
new_nurseries.extend(task.child_nurseries)
11511161

11521162
nurseries = new_nurseries

Diff for: src/anyio/_core/_testing.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Awaitable, Generator
4-
from typing import Any
4+
from typing import Any, cast
55

66
from ._eventloop import get_async_backend
77

@@ -45,8 +45,8 @@ def __hash__(self) -> int:
4545
def __repr__(self) -> str:
4646
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
4747

48-
def _unwrap(self) -> TaskInfo:
49-
return self
48+
def has_pending_cancellation(self) -> bool:
49+
return False
5050

5151

5252
def get_current_task() -> TaskInfo:
@@ -63,10 +63,10 @@ def get_running_tasks() -> list[TaskInfo]:
6363
"""
6464
Return a list of running tasks in the current event loop.
6565
66-
:return: a list of task info objects
66+
:return: a sequence of task info objects
6767
6868
"""
69-
return get_async_backend().get_running_tasks()
69+
return cast("list[TaskInfo]", get_async_backend().get_running_tasks())
7070

7171

7272
async def wait_all_tasks_blocked() -> None:

Diff for: src/anyio/abc/_eventloop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def get_current_task(cls) -> TaskInfo:
376376

377377
@classmethod
378378
@abstractmethod
379-
def get_running_tasks(cls) -> list[TaskInfo]:
379+
def get_running_tasks(cls) -> Sequence[TaskInfo]:
380380
pass
381381

382382
@classmethod

Diff for: src/anyio/streams/memory.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
EndOfStream,
1313
WouldBlock,
1414
)
15+
from .._core._testing import TaskInfo, get_current_task
1516
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
1617
from ..lowlevel import checkpoint
1718

@@ -32,13 +33,19 @@ class MemoryObjectStreamStatistics(NamedTuple):
3233
tasks_waiting_receive: int
3334

3435

36+
@dataclass(eq=False)
37+
class MemoryObjectItemReceiver(Generic[T_Item]):
38+
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
39+
item: T_Item = field(init=False)
40+
41+
3542
@dataclass(eq=False)
3643
class MemoryObjectStreamState(Generic[T_Item]):
3744
max_buffer_size: float = field()
3845
buffer: deque[T_Item] = field(init=False, default_factory=deque)
3946
open_send_channels: int = field(init=False, default=0)
4047
open_receive_channels: int = field(init=False, default=0)
41-
waiting_receivers: OrderedDict[Event, list[T_Item]] = field(
48+
waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver] = field(
4249
init=False, default_factory=OrderedDict
4350
)
4451
waiting_senders: OrderedDict[Event, T_Item] = field(
@@ -99,17 +106,17 @@ async def receive(self) -> T_co:
99106
except WouldBlock:
100107
# Add ourselves in the queue
101108
receive_event = Event()
102-
container: list[T_co] = []
103-
self._state.waiting_receivers[receive_event] = container
109+
receiver = MemoryObjectItemReceiver[T_co]()
110+
self._state.waiting_receivers[receive_event] = receiver
104111

105112
try:
106113
await receive_event.wait()
107114
finally:
108115
self._state.waiting_receivers.pop(receive_event, None)
109116

110-
if container:
111-
return container[0]
112-
else:
117+
try:
118+
return receiver.item
119+
except AttributeError:
113120
raise EndOfStream
114121

115122
def clone(self) -> MemoryObjectReceiveStream[T_co]:
@@ -199,11 +206,14 @@ def send_nowait(self, item: T_contra) -> None:
199206
if not self._state.open_receive_channels:
200207
raise BrokenResourceError
201208

202-
if self._state.waiting_receivers:
203-
receive_event, container = self._state.waiting_receivers.popitem(last=False)
204-
container.append(item)
205-
receive_event.set()
206-
elif len(self._state.buffer) < self._state.max_buffer_size:
209+
while self._state.waiting_receivers:
210+
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
211+
if not receiver.task_info.has_pending_cancellation():
212+
receiver.item = item
213+
receive_event.set()
214+
return
215+
216+
if len(self._state.buffer) < self._state.max_buffer_size:
207217
self._state.buffer.append(item)
208218
else:
209219
raise WouldBlock

Diff for: tests/streams/test_memory.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
fail_after,
1818
wait_all_tasks_blocked,
1919
)
20-
from anyio.abc import ObjectReceiveStream, ObjectSendStream
20+
from anyio.abc import ObjectReceiveStream, ObjectSendStream, TaskStatus
2121
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2222

2323
if sys.version_info < (3, 11):
@@ -304,28 +304,49 @@ async def test_cancel_during_receive() -> None:
304304
stream to be lost.
305305
306306
"""
307-
receiver_scope = None
308307

309-
async def scoped_receiver() -> None:
310-
nonlocal receiver_scope
311-
with CancelScope() as receiver_scope:
308+
async def scoped_receiver(task_status: TaskStatus[CancelScope]) -> None:
309+
with CancelScope() as cancel_scope:
310+
task_status.started(cancel_scope)
312311
received.append(await receive.receive())
313312

314-
assert receiver_scope.cancel_called
313+
assert cancel_scope.cancel_called
315314

316315
received: list[str] = []
317316
send, receive = create_memory_object_stream[str]()
318-
async with create_task_group() as tg:
319-
tg.start_soon(scoped_receiver)
320-
await wait_all_tasks_blocked()
321-
send.send_nowait("hello")
322-
assert receiver_scope is not None
323-
receiver_scope.cancel()
317+
with send, receive:
318+
async with create_task_group() as tg:
319+
receiver_scope = await tg.start(scoped_receiver)
320+
await wait_all_tasks_blocked()
321+
send.send_nowait("hello")
322+
receiver_scope.cancel()
324323

325324
assert received == ["hello"]
326325

327-
send.close()
328-
receive.close()
326+
327+
async def test_cancel_during_receive_buffered() -> None:
328+
"""
329+
Test that sending an item to a memory object stream when the receiver that is next
330+
in line has been cancelled will not result in the item being lost.
331+
"""
332+
333+
async def scoped_receiver(
334+
receive: MemoryObjectReceiveStream[str], task_status: TaskStatus[CancelScope]
335+
) -> None:
336+
with CancelScope() as cancel_scope:
337+
task_status.started(cancel_scope)
338+
await receive.receive()
339+
340+
send, receive = create_memory_object_stream[str](1)
341+
with send, receive:
342+
async with create_task_group() as tg:
343+
cancel_scope = await tg.start(scoped_receiver, receive)
344+
await wait_all_tasks_blocked()
345+
cancel_scope.cancel()
346+
send.send_nowait("item")
347+
348+
# Since the item was not sent to the cancelled task, it should be available here
349+
assert receive.receive_nowait() == "item"
329350

330351

331352
async def test_close_receive_after_send() -> None:

Diff for: tests/test_debugging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ async def inspect() -> None:
9696
for task, expected_name in zip(task_infos, expected_names):
9797
assert task.parent_id == host_task.id
9898
assert task.name == expected_name
99-
assert repr(task) == f"TaskInfo(id={task.id}, name={expected_name!r})"
99+
assert repr(task).endswith(f"TaskInfo(id={task.id}, name={expected_name!r})")
100100

101101

102102
@pytest.mark.skipif(

0 commit comments

Comments
 (0)