1
1
from collections import deque , OrderedDict
2
2
from dataclasses import dataclass , field
3
- from typing import TypeVar , Generic , List , Deque
3
+ from typing import TypeVar , Generic , List , Deque , Tuple
4
4
5
- import anyio
5
+ from .. import get_cancelled_exc_class
6
+ from .._core ._lowlevel import checkpoint
7
+ from .._core ._synchronization import create_event
6
8
from ..abc .synchronization import Event
7
9
from ..abc .streams import ObjectSendStream , ObjectReceiveStream
8
10
from ..exceptions import ClosedResourceError , BrokenResourceError , WouldBlock , EndOfStream
@@ -16,8 +18,7 @@ class MemoryObjectStreamState(Generic[T_Item]):
16
18
buffer : Deque [T_Item ] = field (init = False , default_factory = deque )
17
19
open_send_channels : int = field (init = False , default = 0 )
18
20
open_receive_channels : int = field (init = False , default = 0 )
19
- waiting_receivers : 'OrderedDict[Event, List[T_Item]]' = field (init = False ,
20
- default_factory = OrderedDict )
21
+ waiting_receivers : Deque [Tuple [Event , List [T_Item ]]] = field (init = False , default_factory = deque )
21
22
waiting_senders : 'OrderedDict[Event, T_Item]' = field (init = False , default_factory = OrderedDict )
22
23
23
24
@@ -58,20 +59,41 @@ async def receive_nowait(self) -> T_Item:
58
59
raise WouldBlock
59
60
60
61
async def receive (self ) -> T_Item :
61
- # anyio.check_cancelled ()
62
+ await checkpoint ()
62
63
try :
63
64
return await self .receive_nowait ()
64
65
except WouldBlock :
65
66
# Add ourselves in the queue
66
- receive_event = anyio . create_event ()
67
+ receive_event = create_event ()
67
68
container : List [T_Item ] = []
68
- self ._state .waiting_receivers [receive_event ] = container
69
+ ticket = receive_event , container
70
+ self ._state .waiting_receivers .append (ticket )
69
71
70
72
try :
71
73
await receive_event .wait ()
72
- except BaseException :
73
- self ._state .waiting_receivers .pop (receive_event , None )
74
+ except get_cancelled_exc_class ():
75
+ # If we already received an item in the container, pass it to the next receiver in
76
+ # line
77
+ index = self ._state .waiting_receivers .index (ticket ) + 1
78
+ if container :
79
+ item = container [0 ]
80
+ while index < len (self ._state .waiting_receivers ):
81
+ receive_event , container = self ._state .waiting_receivers [index ]
82
+ if container :
83
+ item , container [0 ] = container [0 ], item
84
+ else :
85
+ # Found an untriggered receiver
86
+ container .append (item )
87
+ await receive_event .set ()
88
+ break
89
+ else :
90
+ # Could not find an untriggered receiver, so in order to not lose any
91
+ # items, put it in the buffer, even if it exceeds the maximum buffer size
92
+ self ._state .buffer .append (item )
93
+
74
94
raise
95
+ finally :
96
+ self ._state .waiting_receivers .remove (ticket )
75
97
76
98
if container :
77
99
return container [0 ]
@@ -129,22 +151,24 @@ async def send_nowait(self, item: T_Item) -> None:
129
151
if not self ._state .open_receive_channels :
130
152
raise BrokenResourceError
131
153
132
- if self ._state .waiting_receivers :
133
- receive_event , container = self ._state .waiting_receivers .popitem (last = False )
134
- container .append (item )
135
- await receive_event .set ()
136
- elif len (self ._state .buffer ) < self ._state .max_buffer_size :
154
+ for receive_event , container in self ._state .waiting_receivers :
155
+ if not container :
156
+ container .append (item )
157
+ await receive_event .set ()
158
+ return
159
+
160
+ if len (self ._state .buffer ) < self ._state .max_buffer_size :
137
161
self ._state .buffer .append (item )
138
162
else :
139
163
raise WouldBlock
140
164
141
165
async def send (self , item : T_Item ) -> None :
142
- # await check_cancelled ()
166
+ await checkpoint ()
143
167
try :
144
168
await self .send_nowait (item )
145
169
except WouldBlock :
146
170
# Wait until there's someone on the receiving end
147
- send_event = anyio . create_event ()
171
+ send_event = create_event ()
148
172
self ._state .waiting_senders [send_event ] = item
149
173
try :
150
174
await send_event .wait ()
@@ -175,7 +199,6 @@ async def aclose(self) -> None:
175
199
self ._closed = True
176
200
self ._state .open_send_channels -= 1
177
201
if self ._state .open_send_channels == 0 :
178
- receive_events = list (self ._state .waiting_receivers .keys ())
179
- self ._state .waiting_receivers .clear ()
202
+ receive_events = [event for event , container in self ._state .waiting_receivers ]
180
203
for event in receive_events :
181
204
await event .set ()
0 commit comments