Skip to content

Commit 8fb4545

Browse files
committed
Fix chunk creation of subtask graph generation
1 parent 40373a8 commit 8fb4545

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

mars/services/task/analyzer/analyzer.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
if graph_assigner_cls is None:
129129
graph_assigner_cls = GraphAssigner
130130
self._graph_assigner_cls = graph_assigner_cls
131+
self._chunk_to_copied = dict()
131132
self._logic_key_generator = LogicKeyGenerator()
132133

133134
@classmethod
@@ -226,6 +227,7 @@ def _gen_subtask_info(
226227
result_chunks_set = set()
227228
chunk_graph = ChunkGraph(result_chunks)
228229
out_of_scope_chunks = []
230+
chunk_to_copied = self._chunk_to_copied
229231
update_meta_chunks = []
230232
# subtask properties
231233
band = None
@@ -271,11 +273,13 @@ def _gen_subtask_info(
271273
chunk_priority = chunk.op.priority
272274
# process input chunks
273275
inp_chunks = []
276+
input_changed = False
274277
build_fetch_index_to_chunks = dict()
275278
for i, inp_chunk in enumerate(chunk.inputs):
276279
if inp_chunk in chunks_set:
277-
inp_chunks.append(inp_chunk)
280+
inp_chunks.append(chunk_to_copied[inp_chunk])
278281
else:
282+
input_changed = True
279283
build_fetch_index_to_chunks[i] = inp_chunk
280284
inp_chunks.append(None)
281285
if not isinstance(inp_chunk.op, Fetch):
@@ -285,14 +289,31 @@ def _gen_subtask_info(
285289
)
286290
for i, fetch_chunk in zip(build_fetch_index_to_chunks, fetch_chunks):
287291
inp_chunks[i] = fetch_chunk
288-
for out_chunk in chunk.op.outputs:
292+
293+
if input_changed:
294+
copied_op = chunk.op.copy()
295+
copied_op._key = chunk.op.key
296+
out_chunks = [
297+
c.data
298+
for c in copied_op.new_chunks(
299+
inp_chunks, kws=[c.params.copy() for c in chunk.op.outputs]
300+
)
301+
]
302+
else:
303+
out_chunks = chunk.op.outputs
289304
# Note: `dtypes`, `index_value`, and `columns_value` are lazily
290305
# initialized, so we should call property `params` to initialize
291306
# these fields.
292-
out_chunk.params
293-
processed.add(out_chunk)
307+
[o.params for o in out_chunks]
308+
309+
for src_chunk, out_chunk in zip(chunk.op.outputs, out_chunks):
310+
processed.add(src_chunk)
311+
out_chunk._key = src_chunk.key
294312
chunk_graph.add_node(out_chunk)
295-
if out_chunk in self._final_result_chunks_set:
313+
# cannot be copied twice
314+
assert src_chunk not in chunk_to_copied
315+
chunk_to_copied[src_chunk] = out_chunk
316+
if src_chunk in self._final_result_chunks_set:
296317
if out_chunk not in result_chunks_set:
297318
# add to result chunks
298319
result_chunks.append(out_chunk)
@@ -320,12 +341,18 @@ def _gen_subtask_info(
320341
if out_of_scope_chunks:
321342
inp_subtasks = []
322343
for out_of_scope_chunk in out_of_scope_chunks:
344+
copied_out_of_scope_chunk = chunk_to_copied[out_of_scope_chunk]
323345
inp_subtask = chunk_to_subtask[out_of_scope_chunk]
324-
if out_of_scope_chunk not in inp_subtask.chunk_graph.result_chunks:
346+
if (
347+
copied_out_of_scope_chunk
348+
not in inp_subtask.chunk_graph.result_chunks
349+
):
325350
# make sure the chunk that out of scope
326351
# is in the input subtask's results,
327352
# or the meta may be lost
328-
inp_subtask.chunk_graph.result_chunks.append(out_of_scope_chunk)
353+
inp_subtask.chunk_graph.result_chunks.append(
354+
copied_out_of_scope_chunk
355+
)
329356
inp_subtasks.append(inp_subtask)
330357
depth = max(st.priority[0] for st in inp_subtasks) + 1
331358
else:
@@ -383,9 +410,10 @@ def _gen_map_reduce_info(
383410
# record analyzer map reduce id for mapper op
384411
# copied chunk exists because map chunk must have
385412
# been processed before shuffle proxy
386-
if not hasattr(map_chunk, "extra_params"): # pragma: no cover
387-
map_chunk.extra_params = dict()
388-
map_chunk.extra_params["analyzer_map_reduce_id"] = map_reduce_id
413+
copied_map_chunk = self._chunk_to_copied[map_chunk]
414+
if not hasattr(copied_map_chunk, "extra_params"): # pragma: no cover
415+
copied_map_chunk.extra_params = dict()
416+
copied_map_chunk.extra_params["analyzer_map_reduce_id"] = map_reduce_id
389417
reducer_bands = [assign_results[r.outputs[0]] for r in reducer_ops]
390418
map_reduce_info = MapReduceInfo(
391419
map_reduce_id=map_reduce_id,

mars/services/task/supervisor/tests/task_preprocessor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ def analyze(
180180
map_reduce_id_to_infos=self.map_reduce_id_to_infos,
181181
)
182182
subtask_graph = analyzer.gen_subtask_graph()
183-
results = set(c for c in chunk_graph.results if not isinstance(c.op, Fetch))
183+
results = set(
184+
analyzer._chunk_to_copied[c]
185+
for c in chunk_graph.results
186+
if not isinstance(c.op, Fetch)
187+
)
184188
for subtask in subtask_graph:
185189
if subtask.extra_config is None:
186190
subtask.extra_config = dict()

0 commit comments

Comments
 (0)