13
13
# limitations under the License.
14
14
15
15
import logging
16
+ import uuid
17
+ import zlib
18
+ import pyarrow
16
19
17
20
from .actors import new_client
21
+ from .config import options
18
22
from .errors import GraphNotExists
19
23
from .scheduler import SessionActor , GraphActor , GraphMetaActor , ResourceActor , \
20
24
SessionManagerActor , ChunkMetaClient
21
25
from .scheduler .graph import ResultReceiverActor
22
26
from .scheduler .node_info import NodeInfoActor
23
27
from .scheduler .utils import SchedulerClusterInfoActor
24
28
from .serialize import dataserializer
29
+ from .utils import tokenize
25
30
26
31
logger = logging .getLogger (__name__ )
27
32
@@ -51,6 +56,13 @@ def get_schedulers_info(self):
51
56
infos [scheduler ] = info_ref .get_info ()
52
57
return infos
53
58
59
+ def _get_receiver_ref (self , chunk_key ):
60
+ from .worker .dispatcher import DispatchActor
61
+ ep = self .cluster_info .get_scheduler (chunk_key )
62
+ dispatch_ref = self .actor_client .actor_ref (DispatchActor .default_uid (), address = ep )
63
+ uid = dispatch_ref .get_hash_slot ('receiver' , chunk_key )
64
+ return self .actor_client .actor_ref (uid , address = ep )
65
+
54
66
def count_workers (self ):
55
67
try :
56
68
uid = ResourceActor .default_uid ()
@@ -71,6 +83,66 @@ def submit_graph(self, session_id, serialized_graph, graph_key, target,
71
83
session_ref .submit_tileable_graph (
72
84
serialized_graph , graph_key , target , compose = compose , _tell = not wait )
73
85
86
+ def create_mutable_tensor (self , session_id , name , shape , dtype , * args , ** kwargs ):
87
+ session_uid = SessionActor .gen_uid (session_id )
88
+ session_ref = self .get_actor_ref (session_uid )
89
+ return session_ref .create_mutable_tensor (name , shape , dtype , * args , ** kwargs )
90
+
91
+ def get_mutable_tensor (self , session_id , name ):
92
+ session_uid = SessionActor .gen_uid (session_id )
93
+ session_ref = self .get_actor_ref (session_uid )
94
+ return session_ref .get_mutable_tensor (name )
95
+
96
+ def send_chunk_records (self , session_id , name , chunk_records_to_send , directly = True ):
97
+ from .worker .dataio import ArrowBufferIO
98
+ from .worker .quota import MemQuotaActor
99
+ session_uid = SessionActor .gen_uid (session_id )
100
+ session_ref = self .get_actor_ref (session_uid )
101
+
102
+ chunk_records = []
103
+ for chunk_key , records in chunk_records_to_send .items ():
104
+ record_chunk_key = tokenize (chunk_key , uuid .uuid4 ().hex )
105
+ ep = self .cluster_info .get_scheduler (chunk_key )
106
+ # register quota
107
+ quota_ref = self .actor_client .actor_ref (MemQuotaActor .default_uid (), address = ep )
108
+ quota_ref .request_batch_quota ({record_chunk_key : records .nbytes })
109
+ # send record chunk
110
+ buf = pyarrow .serialize (records ).to_buffer ()
111
+ receiver_ref = self ._get_receiver_ref (chunk_key )
112
+ receiver_ref .create_data_writer (session_id , record_chunk_key , buf .size , None ,
113
+ ensure_cached = False , use_promise = False )
114
+
115
+ block_size = options .worker .transfer_block_size
116
+
117
+ try :
118
+ reader = ArrowBufferIO (buf , 'r' , block_size = block_size )
119
+ checksum = 0
120
+ while True :
121
+ next_chunk = reader .read (block_size )
122
+ if not next_chunk :
123
+ reader .close ()
124
+ receiver_ref .finish_receive (session_id , record_chunk_key , checksum )
125
+ break
126
+ checksum = zlib .crc32 (next_chunk , checksum )
127
+ receiver_ref .receive_data_part (session_id , record_chunk_key , next_chunk , checksum )
128
+ except :
129
+ receiver_ref .cancel_receive (session_id , chunk_key )
130
+ raise
131
+ finally :
132
+ if reader :
133
+ reader .close ()
134
+ del reader
135
+
136
+ chunk_records .append ((chunk_key , record_chunk_key ))
137
+
138
+ # register the record chunk to MutableTensorActor
139
+ session_ref .append_chunk_records (name , chunk_records )
140
+
141
+ def seal (self , session_id , name ):
142
+ session_uid = SessionActor .gen_uid (session_id )
143
+ session_ref = self .get_actor_ref (session_uid )
144
+ return session_ref .seal (name )
145
+
74
146
def delete_graph (self , session_id , graph_key ):
75
147
graph_uid = GraphActor .gen_uid (session_id , graph_key )
76
148
graph_ref = self .get_actor_ref (graph_uid )
0 commit comments