Skip to content

Commit bd9a310

Browse files
committed
Pass along the received item to the next receiver if the task was cancelled
1 parent 1d548ca commit bd9a310

File tree

2 files changed

+139
-19
lines changed

2 files changed

+139
-19
lines changed

Diff for: src/anyio/streams/memory.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from collections import deque, OrderedDict
22
from dataclasses import dataclass, field
3-
from typing import TypeVar, Generic, List, Deque
3+
from typing import TypeVar, Generic, List, Deque, Tuple
44

5-
import anyio
5+
from .. import get_cancelled_exc_class
6+
from .._core._lowlevel import checkpoint
7+
from .._core._synchronization import create_event
68
from ..abc.synchronization import Event
79
from ..abc.streams import ObjectSendStream, ObjectReceiveStream
810
from ..exceptions import ClosedResourceError, BrokenResourceError, WouldBlock, EndOfStream
@@ -16,8 +18,7 @@ class MemoryObjectStreamState(Generic[T_Item]):
1618
buffer: Deque[T_Item] = field(init=False, default_factory=deque)
1719
open_send_channels: int = field(init=False, default=0)
1820
open_receive_channels: int = field(init=False, default=0)
19-
waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False,
20-
default_factory=OrderedDict)
21+
waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque)
2122
waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict)
2223

2324

@@ -58,20 +59,41 @@ async def receive_nowait(self) -> T_Item:
5859
raise WouldBlock
5960

6061
async def receive(self) -> T_Item:
61-
# anyio.check_cancelled()
62+
await checkpoint()
6263
try:
6364
return await self.receive_nowait()
6465
except WouldBlock:
6566
# Add ourselves in the queue
66-
receive_event = anyio.create_event()
67+
receive_event = create_event()
6768
container: List[T_Item] = []
68-
self._state.waiting_receivers[receive_event] = container
69+
ticket = receive_event, container
70+
self._state.waiting_receivers.append(ticket)
6971

7072
try:
7173
await receive_event.wait()
72-
except BaseException:
73-
self._state.waiting_receivers.pop(receive_event, None)
74+
except get_cancelled_exc_class():
75+
# If we already received an item in the container, pass it to the next receiver in
76+
# line
77+
index = self._state.waiting_receivers.index(ticket) + 1
78+
if container:
79+
item = container[0]
80+
while index < len(self._state.waiting_receivers):
81+
receive_event, container = self._state.waiting_receivers[index]
82+
if container:
83+
item, container[0] = container[0], item
84+
else:
85+
# Found an untriggered receiver
86+
container.append(item)
87+
await receive_event.set()
88+
break
89+
else:
90+
# Could not find an untriggered receiver, so in order to not lose any
91+
# items, put it in the buffer, even if it exceeds the maximum buffer size
92+
self._state.buffer.append(item)
93+
7494
raise
95+
finally:
96+
self._state.waiting_receivers.remove(ticket)
7597

7698
if container:
7799
return container[0]
@@ -129,22 +151,24 @@ async def send_nowait(self, item: T_Item) -> None:
129151
if not self._state.open_receive_channels:
130152
raise BrokenResourceError
131153

132-
if self._state.waiting_receivers:
133-
receive_event, container = self._state.waiting_receivers.popitem(last=False)
134-
container.append(item)
135-
await receive_event.set()
136-
elif len(self._state.buffer) < self._state.max_buffer_size:
154+
for receive_event, container in self._state.waiting_receivers:
155+
if not container:
156+
container.append(item)
157+
await receive_event.set()
158+
return
159+
160+
if len(self._state.buffer) < self._state.max_buffer_size:
137161
self._state.buffer.append(item)
138162
else:
139163
raise WouldBlock
140164

141165
async def send(self, item: T_Item) -> None:
142-
# await check_cancelled()
166+
await checkpoint()
143167
try:
144168
await self.send_nowait(item)
145169
except WouldBlock:
146170
# Wait until there's someone on the receiving end
147-
send_event = anyio.create_event()
171+
send_event = create_event()
148172
self._state.waiting_senders[send_event] = item
149173
try:
150174
await send_event.wait()
@@ -175,7 +199,6 @@ async def aclose(self) -> None:
175199
self._closed = True
176200
self._state.open_send_channels -= 1
177201
if self._state.open_send_channels == 0:
178-
receive_events = list(self._state.waiting_receivers.keys())
179-
self._state.waiting_receivers.clear()
202+
receive_events = [event for event, container in self._state.waiting_receivers]
180203
for event in receive_events:
181204
await event.set()

Diff for: tests/streams/test_memory.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22

33
from anyio import (
4-
create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after)
4+
create_task_group, wait_all_tasks_blocked, create_memory_object_stream, fail_after,
5+
open_cancel_scope)
56
from anyio.exceptions import EndOfStream, ClosedResourceError, BrokenResourceError, WouldBlock
67

78
pytestmark = pytest.mark.anyio
@@ -177,3 +178,99 @@ async def test_receive_after_send_closed():
177178
await send.send('hello')
178179
await send.aclose()
179180
assert await receive.receive() == 'hello'
181+
182+
183+
async def test_receive_when_cancelled():
184+
"""
185+
Test that calling receive() in a cancelled scope prevents it from going through with the
186+
operation.
187+
188+
"""
189+
send, receive = create_memory_object_stream()
190+
async with create_task_group() as tg:
191+
await tg.spawn(send.send, 'hello')
192+
await wait_all_tasks_blocked()
193+
await tg.spawn(send.send, 'world')
194+
await wait_all_tasks_blocked()
195+
196+
async with open_cancel_scope() as scope:
197+
await scope.cancel()
198+
await receive.receive()
199+
200+
assert await receive.receive() == 'hello'
201+
assert await receive.receive() == 'world'
202+
203+
204+
async def test_send_when_cancelled():
205+
"""
206+
Test that calling send() in a cancelled scope prevents it from going through with the
207+
operation.
208+
209+
"""
210+
async def receiver():
211+
received.append(await receive.receive())
212+
213+
received = []
214+
send, receive = create_memory_object_stream()
215+
async with create_task_group() as tg:
216+
await tg.spawn(receiver)
217+
async with open_cancel_scope() as scope:
218+
await scope.cancel()
219+
await send.send('hello')
220+
221+
await send.send('world')
222+
223+
assert received == ['world']
224+
225+
226+
async def test_cancel_during_receive():
227+
"""
228+
Test that cancelling a pending receive() operation does not cause an item in the stream to be
229+
lost.
230+
231+
"""
232+
async def scoped_receiver():
233+
nonlocal receiver_scope
234+
async with open_cancel_scope() as receiver_scope:
235+
await receive.receive()
236+
237+
async def receiver():
238+
received.append(await receive.receive())
239+
240+
receiver_scope = None
241+
received = []
242+
send, receive = create_memory_object_stream()
243+
async with create_task_group() as tg:
244+
await tg.spawn(scoped_receiver)
245+
await wait_all_tasks_blocked()
246+
await tg.spawn(receiver)
247+
await receiver_scope.cancel()
248+
await send.send('hello')
249+
250+
assert received == ['hello']
251+
252+
253+
async def test_cancel_during_receive_last_receiver():
254+
"""
255+
Test that cancelling a pending receive() operation does not cause an item in the stream to be
256+
lost, even if there are no other receivers waiting.
257+
258+
"""
259+
async def scoped_receiver():
260+
nonlocal receiver_scope
261+
async with open_cancel_scope() as receiver_scope:
262+
await receive.receive()
263+
pytest.fail('This point should never be reached')
264+
265+
receiver_scope = None
266+
send, receive = create_memory_object_stream()
267+
async with create_task_group() as tg:
268+
await tg.spawn(scoped_receiver)
269+
await wait_all_tasks_blocked()
270+
await receiver_scope.cancel()
271+
await send.send_nowait('hello')
272+
273+
with pytest.raises(WouldBlock):
274+
await send.send_nowait('world')
275+
276+
assert await receive.receive_nowait() == 'hello'

0 commit comments

Comments
 (0)