@@ -128,6 +128,7 @@ def __init__(
128
128
if graph_assigner_cls is None :
129
129
graph_assigner_cls = GraphAssigner
130
130
self ._graph_assigner_cls = graph_assigner_cls
131
+ self ._chunk_to_copied = dict ()
131
132
self ._logic_key_generator = LogicKeyGenerator ()
132
133
133
134
@classmethod
@@ -226,6 +227,7 @@ def _gen_subtask_info(
226
227
result_chunks_set = set ()
227
228
chunk_graph = ChunkGraph (result_chunks )
228
229
out_of_scope_chunks = []
230
+ chunk_to_copied = self ._chunk_to_copied
229
231
update_meta_chunks = []
230
232
# subtask properties
231
233
band = None
@@ -271,11 +273,13 @@ def _gen_subtask_info(
271
273
chunk_priority = chunk .op .priority
272
274
# process input chunks
273
275
inp_chunks = []
276
+ input_changed = False
274
277
build_fetch_index_to_chunks = dict ()
275
278
for i , inp_chunk in enumerate (chunk .inputs ):
276
279
if inp_chunk in chunks_set :
277
- inp_chunks .append (inp_chunk )
280
+ inp_chunks .append (chunk_to_copied [ inp_chunk ] )
278
281
else :
282
+ input_changed = True
279
283
build_fetch_index_to_chunks [i ] = inp_chunk
280
284
inp_chunks .append (None )
281
285
if not isinstance (inp_chunk .op , Fetch ):
@@ -285,14 +289,31 @@ def _gen_subtask_info(
285
289
)
286
290
for i , fetch_chunk in zip (build_fetch_index_to_chunks , fetch_chunks ):
287
291
inp_chunks [i ] = fetch_chunk
288
- for out_chunk in chunk .op .outputs :
292
+
293
+ if input_changed :
294
+ copied_op = chunk .op .copy ()
295
+ copied_op ._key = chunk .op .key
296
+ out_chunks = [
297
+ c .data
298
+ for c in copied_op .new_chunks (
299
+ inp_chunks , kws = [c .params .copy () for c in chunk .op .outputs ]
300
+ )
301
+ ]
302
+ else :
303
+ out_chunks = chunk .op .outputs
289
304
# Note: `dtypes`, `index_value`, and `columns_value` are lazily
290
305
# initialized, so we should call property `params` to initialize
291
306
# these fields.
292
- out_chunk .params
293
- processed .add (out_chunk )
307
+ [o .params for o in out_chunks ]
308
+
309
+ for src_chunk , out_chunk in zip (chunk .op .outputs , out_chunks ):
310
+ processed .add (src_chunk )
311
+ out_chunk ._key = src_chunk .key
294
312
chunk_graph .add_node (out_chunk )
295
- if out_chunk in self ._final_result_chunks_set :
313
+ # cannot be copied twice
314
+ assert src_chunk not in chunk_to_copied
315
+ chunk_to_copied [src_chunk ] = out_chunk
316
+ if src_chunk in self ._final_result_chunks_set :
296
317
if out_chunk not in result_chunks_set :
297
318
# add to result chunks
298
319
result_chunks .append (out_chunk )
@@ -320,12 +341,18 @@ def _gen_subtask_info(
320
341
if out_of_scope_chunks :
321
342
inp_subtasks = []
322
343
for out_of_scope_chunk in out_of_scope_chunks :
344
+ copied_out_of_scope_chunk = chunk_to_copied [out_of_scope_chunk ]
323
345
inp_subtask = chunk_to_subtask [out_of_scope_chunk ]
324
- if out_of_scope_chunk not in inp_subtask .chunk_graph .result_chunks :
346
+ if (
347
+ copied_out_of_scope_chunk
348
+ not in inp_subtask .chunk_graph .result_chunks
349
+ ):
325
350
# make sure the chunk that out of scope
326
351
# is in the input subtask's results,
327
352
# or the meta may be lost
328
- inp_subtask .chunk_graph .result_chunks .append (out_of_scope_chunk )
353
+ inp_subtask .chunk_graph .result_chunks .append (
354
+ copied_out_of_scope_chunk
355
+ )
329
356
inp_subtasks .append (inp_subtask )
330
357
depth = max (st .priority [0 ] for st in inp_subtasks ) + 1
331
358
else :
@@ -383,9 +410,10 @@ def _gen_map_reduce_info(
383
410
# record analyzer map reduce id for mapper op
384
411
# copied chunk exists because map chunk must have
385
412
# been processed before shuffle proxy
386
- if not hasattr (map_chunk , "extra_params" ): # pragma: no cover
387
- map_chunk .extra_params = dict ()
388
- map_chunk .extra_params ["analyzer_map_reduce_id" ] = map_reduce_id
413
+ copied_map_chunk = self ._chunk_to_copied [map_chunk ]
414
+ if not hasattr (copied_map_chunk , "extra_params" ): # pragma: no cover
415
+ copied_map_chunk .extra_params = dict ()
416
+ copied_map_chunk .extra_params ["analyzer_map_reduce_id" ] = map_reduce_id
389
417
reducer_bands = [assign_results [r .outputs [0 ]] for r in reducer_ops ]
390
418
map_reduce_info = MapReduceInfo (
391
419
map_reduce_id = map_reduce_id ,
0 commit comments