Skip to content

Commit 5b6fd9f

Browse files
authored
distributed_fused_lamb.py replace c_allreduce_sum (#70737)
1 parent 4f7808b commit 5b6fd9f

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

python/paddle/incubate/optimizer/distributed_fused_lamb.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,13 @@ def init_communicator(block, rank, ranks, ring_id):
8989
type='fill_constant', outputs={'Out': tmp_var}, attrs={'value': 1}
9090
)
9191
block.append_op(
92-
type='c_allreduce_sum',
93-
inputs={'X': tmp_var},
94-
outputs={'Out': tmp_var},
95-
attrs={'ring_id': ring_id, 'use_calc_stream': True},
92+
type='all_reduce',
93+
inputs={'x': tmp_var},
94+
outputs={'out': tmp_var},
95+
attrs={
96+
'ring_id': ring_id,
97+
'reduce_type': paddle.distributed.ReduceOp.SUM,
98+
},
9699
)
97100
block.append_op(
98101
type='c_sync_calc_stream',

python/paddle/incubate/optimizer/pipeline.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,17 @@ def _insert_allreduce_op(self, op_idx, block):
165165
offset += 1
166166
block._insert_op(
167167
op_idx + 1 + offset,
168-
type=(
169-
'c_allreduce_max'
170-
if op.type == "reduce_any"
171-
else 'c_allreduce_sum'
172-
),
173-
inputs={'X': temp_var if op.type == "reduce_any" else out_var},
174-
outputs={'Out': temp_var if op.type == "reduce_any" else out_var},
168+
type='all_reduce',
169+
inputs={'x': temp_var if op.type == "reduce_any" else out_var},
170+
outputs={'out': temp_var if op.type == "reduce_any" else out_var},
175171
attrs={
176172
'ring_id': self.global_ring_id,
177173
self._op_role_key: self._op_role.Optimize,
178-
'use_calc_stream': True,
174+
'reduce_type': (
175+
paddle.distributed.ReduceOp.MAX
176+
if op.type == "reduce_any"
177+
else paddle.distributed.ReduceOp.SUM
178+
),
179179
},
180180
)
181181
offset += 1

0 commit comments

Comments
 (0)