Skip to content

add barrier for ccl ops in micro_perf. #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion byte_micro_perf/backends/AMD/backend_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,7 @@ def setup_2d_group(self):
torch.distributed.barrier()

def destroy_process_group(self):
dist.destroy_process_group()
dist.destroy_process_group()

def barier(self):
dist.barrier(self.group)
5 changes: 4 additions & 1 deletion byte_micro_perf/backends/GPU/backend_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,7 @@ def setup_2d_group(self):
torch.distributed.barrier()

def destroy_process_group(self):
dist.destroy_process_group()
dist.destroy_process_group()

def barier(self):
dist.barrier(self.group)
14 changes: 13 additions & 1 deletion byte_micro_perf/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def setup_2d_group(self):
def destroy_process_group(self):
pass

@abstractmethod
def barier(self):
pass


# communication ops
Expand Down Expand Up @@ -229,6 +232,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
for _ in range(num_warm_up):
self._run_operation(self.op, tensor_list[0])


# ccl ops need barrier
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
self.barier()

# test perf
num_test_perf = 5
self.device_synchronize()
Expand All @@ -241,7 +249,6 @@ def perf(self, input_shapes: List[List[int]], dtype):
self.device_synchronize()
end_time = time.perf_counter_ns()


prefer_iterations = self.iterations
max_perf_seconds = 10.0
op_duration = (end_time - start_time) / num_test_perf / 1e9
Expand All @@ -250,6 +257,11 @@ def perf(self, input_shapes: List[List[int]], dtype):
else:
prefer_iterations = min(max(int(max_perf_seconds // op_duration), 10), self.iterations)


# ccl ops need barrier
if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]:
self.barier()

# perf
self.device_synchronize()
start_time = time.perf_counter_ns()
Expand Down