Skip to content

Commit a3af1da

Browse files
committed
Ignore the cancellation error on memory stream receive
As per the discussion on #147, it's better to ignore the cancellation exception now and have it triggered at the next checkpoint than to push the item to the buffer, potentially going over the buffer's limit.
1 parent b5a2f08 commit a3af1da

File tree

2 files changed

+19
-64
lines changed

2 files changed

+19
-64
lines changed

Diff for: src/anyio/streams/memory.py

+16-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import deque, OrderedDict
22
from dataclasses import dataclass, field
3-
from typing import TypeVar, Generic, List, Deque, Tuple
3+
from typing import TypeVar, Generic, List, Deque
44

55
from .. import get_cancelled_exc_class
66
from .._core._lowlevel import checkpoint
@@ -18,7 +18,8 @@ class MemoryObjectStreamState(Generic[T_Item]):
1818
buffer: Deque[T_Item] = field(init=False, default_factory=deque)
1919
open_send_channels: int = field(init=False, default=0)
2020
open_receive_channels: int = field(init=False, default=0)
21-
waiting_receivers: Deque[Tuple[Event, List[T_Item]]] = field(init=False, default_factory=deque)
21+
waiting_receivers: 'OrderedDict[Event, List[T_Item]]' = field(init=False,
22+
default_factory=OrderedDict)
2223
waiting_senders: 'OrderedDict[Event, T_Item]' = field(init=False, default_factory=OrderedDict)
2324

2425

@@ -66,34 +67,17 @@ async def receive(self) -> T_Item:
6667
# Add ourselves in the queue
6768
receive_event = create_event()
6869
container: List[T_Item] = []
69-
ticket = receive_event, container
70-
self._state.waiting_receivers.append(ticket)
70+
self._state.waiting_receivers[receive_event] = container
7171

7272
try:
7373
await receive_event.wait()
7474
except get_cancelled_exc_class():
75-
# If we already received an item in the container, pass it to the next receiver in
76-
# line
77-
index = self._state.waiting_receivers.index(ticket) + 1
78-
if container:
79-
item = container[0]
80-
while index < len(self._state.waiting_receivers):
81-
receive_event, container = self._state.waiting_receivers[index]
82-
if container:
83-
item, container[0] = container[0], item
84-
else:
85-
# Found an untriggered receiver
86-
container.append(item)
87-
await receive_event.set()
88-
break
89-
else:
90-
# Could not find an untriggered receiver, so in order to not lose any
91-
# items, put it in the buffer, even if it exceeds the maximum buffer size
92-
self._state.buffer.append(item)
93-
94-
raise
75+
# Ignore the immediate cancellation if we already received an item, so as not to
76+
# lose it
77+
if not container:
78+
raise
9579
finally:
96-
self._state.waiting_receivers.remove(ticket)
80+
self._state.waiting_receivers.pop(receive_event, None)
9781

9882
if container:
9983
return container[0]
@@ -151,13 +135,11 @@ async def send_nowait(self, item: T_Item) -> None:
151135
if not self._state.open_receive_channels:
152136
raise BrokenResourceError
153137

154-
for receive_event, container in self._state.waiting_receivers:
155-
if not container:
156-
container.append(item)
157-
await receive_event.set()
158-
return
159-
160-
if len(self._state.buffer) < self._state.max_buffer_size:
138+
if self._state.waiting_receivers:
139+
receive_event, container = self._state.waiting_receivers.popitem(last=False)
140+
container.append(item)
141+
await receive_event.set()
142+
elif len(self._state.buffer) < self._state.max_buffer_size:
161143
self._state.buffer.append(item)
162144
else:
163145
raise WouldBlock
@@ -199,6 +181,7 @@ async def aclose(self) -> None:
199181
self._closed = True
200182
self._state.open_send_channels -= 1
201183
if self._state.open_send_channels == 0:
202-
receive_events = [event for event, container in self._state.waiting_receivers]
184+
receive_events = list(self._state.waiting_receivers.keys())
185+
self._state.waiting_receivers.clear()
203186
for event in receive_events:
204187
await event.set()

Diff for: tests/streams/test_memory.py

+3-31
Original file line numberDiff line numberDiff line change
@@ -231,45 +231,17 @@ async def test_cancel_during_receive():
231231
async def scoped_receiver():
232232
nonlocal receiver_scope
233233
async with open_cancel_scope() as receiver_scope:
234-
await receive.receive()
234+
received.append(await receive.receive())
235235

236-
async def receiver():
237-
received.append(await receive.receive())
236+
assert receiver_scope.cancel_called
238237

239238
receiver_scope = None
240239
received = []
241240
send, receive = create_memory_object_stream()
242241
async with create_task_group() as tg:
243242
await tg.spawn(scoped_receiver)
244243
await wait_all_tasks_blocked()
245-
await tg.spawn(receiver)
244+
await send.send_nowait('hello')
246245
await receiver_scope.cancel()
247-
await send.send('hello')
248246

249247
assert received == ['hello']
250-
251-
252-
async def test_cancel_during_receive_last_receiver():
253-
"""
254-
Test that cancelling a pending receive() operation does not cause an item in the stream to be
255-
lost, even if there are no other receivers waiting.
256-
257-
"""
258-
async def scoped_receiver():
259-
nonlocal receiver_scope
260-
async with open_cancel_scope() as receiver_scope:
261-
await receive.receive()
262-
pytest.fail('This point should never be reached')
263-
264-
receiver_scope = None
265-
send, receive = create_memory_object_stream()
266-
async with create_task_group() as tg:
267-
await tg.spawn(scoped_receiver)
268-
await wait_all_tasks_blocked()
269-
await receiver_scope.cancel()
270-
await send.send_nowait('hello')
271-
272-
with pytest.raises(WouldBlock):
273-
await send.send_nowait('world')
274-
275-
assert await receive.receive_nowait() == 'hello'

0 commit comments

Comments
 (0)