Skip to content

Commit b5ac619

Browse files
authored
Merge pull request #97 from bytedance/add_barrier_to_micro_perf
add barrier for ccl ops in micro_perf.
2 parents 572c388 + 2a49ac9 commit b5ac619

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

byte_micro_perf/backends/AMD/backend_amd.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -296,4 +296,7 @@ def setup_2d_group(self):
296296
torch.distributed.barrier()
297297

298298
def destroy_process_group(self):
299-
dist.destroy_process_group()
299+
dist.destroy_process_group()
300+
301+
def barier(self):
302+
dist.barrier(self.group)

byte_micro_perf/backends/GPU/backend_gpu.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -304,4 +304,7 @@ def setup_2d_group(self):
304304
torch.distributed.barrier()
305305

306306
def destroy_process_group(self):
307-
dist.destroy_process_group()
307+
dist.destroy_process_group()
308+
309+
def barier(self):
310+
dist.barrier(self.group)

byte_micro_perf/backends/backend.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def setup_2d_group(self):
9292
def destroy_process_group(self):
9393
pass
9494

95+
@abstractmethod
96+
def barier(self):
97+
pass
9598

9699

97100
# communication ops
@@ -229,6 +232,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
229232
for _ in range(num_warm_up):
230233
self._run_operation(self.op, tensor_list[0])
231234

235+
236+
# ccl ops need barrier
237+
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
238+
self.barier()
239+
232240
# test perf
233241
num_test_perf = 5
234242
self.device_synchronize()
@@ -241,7 +249,6 @@ def perf(self, input_shapes: List[List[int]], dtype):
241249
self.device_synchronize()
242250
end_time = time.perf_counter_ns()
243251

244-
245252
prefer_iterations = self.iterations
246253
max_perf_seconds = 10.0
247254
op_duration = (end_time - start_time) / num_test_perf / 1e9
@@ -250,6 +257,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
250257
else:
251258
prefer_iterations = min(max(int(max_perf_seconds // op_duration), 10), self.iterations)
252259

260+
261+
# ccl ops need barrier
262+
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
263+
self.barier()
264+
253265
# perf
254266
self.device_synchronize()
255267
start_time = time.perf_counter_ns()

0 commit comments

Comments
 (0)