Skip to content

Commit f9d2b2c

Browse files
committed
[Not for landing] piggy back on titan for scale init test
ghstack-source-id: c6a5b0d8f912999a10d7bf43e17ebf7bd11348d4 Pull Request resolved: #841
1 parent b291ad6 commit f9d2b2c

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

torchtitan/models/llama/train_configs/llama3_405b.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ description = "Llama 3 405B training"
88
[profiling]
99
enable_profiling = true
1010
save_traces_folder = "profile_trace"
11-
profile_freq = 100
11+
profile_freq = 1
1212

1313
[metrics]
1414
log_freq = 10

torchtitan/train.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,40 @@ def main(job_config: JobConfig):
447447

448448
if __name__ == "__main__":
449449
init_logger()
450+
warmup = False
451+
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT, PrefixStore
452+
# The first one is just for warm up.
453+
local_rank = int(os.environ["LOCAL_RANK"])
454+
world_size = int(os.environ["WORLD_SIZE"])
455+
global_rank = int(os.environ["RANK"])
456+
index = 0
457+
rendezvous_iterator = torch.distributed.rendezvous(
458+
"env://", global_rank, world_size, timeout=_DEFAULT_PG_NCCL_TIMEOUT
459+
)
460+
tcp_store, rank, world_size = next(rendezvous_iterator)
461+
tcp_store.set_timeout(_DEFAULT_PG_NCCL_TIMEOUT)
450462
config = JobConfig()
451463
config.parse_args()
452-
main(config)
453-
torch.distributed.destroy_process_group()
464+
with maybe_enable_profiling(
465+
config, global_step=train_state.step
466+
) as torch_profiler:
467+
for root_size in [128, 128]:
468+
os.environ["TORCH_NCCL_RANKS_PER_ROOT"] = str(root_size)
469+
iter_size = 5 if warmup else 1
470+
delta = 0.0
471+
for i in range(iter_size):
472+
start = time.perf_counter()
473+
store = PrefixStore(f"default_pg_{index}", tcp_store)
474+
index += 1
475+
torch.distributed.init_process_group(store=store, backend="nccl", world_size=world_size, rank=global_rank)
476+
torch.cuda.set_device(local_rank)
477+
torch.distributed.barrier()
478+
end = time.perf_counter()
479+
torch.distributed.destroy_process_group()
480+
delta += (end - start)
481+
print(f"Time to init process group: {end - start:.6f} seconds for {root_size} ranks per roots")
482+
if warmup:
483+
print(f"Average time to init process group: {delta / float(iter_size):.6f} seconds for {root_size} ranks per roots")
484+
warmup = True
485+
# main(config)
486+
# torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)