@@ -447,7 +447,40 @@ def main(job_config: JobConfig):
447
447
448
448
if __name__ == "__main__" :
449
449
init_logger ()
450
+ warmup = False
451
+ from torch ._C ._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT , PrefixStore
452
+ # The first one is just for warm up.
453
+ local_rank = int (os .environ ["LOCAL_RANK" ])
454
+ world_size = int (os .environ ["WORLD_SIZE" ])
455
+ global_rank = int (os .environ ["RANK" ])
456
+ index = 0
457
+ rendezvous_iterator = torch .distributed .rendezvous (
458
+ "env://" , global_rank , world_size , timeout = _DEFAULT_PG_NCCL_TIMEOUT
459
+ )
460
+ tcp_store , rank , world_size = next (rendezvous_iterator )
461
+ tcp_store .set_timeout (_DEFAULT_PG_NCCL_TIMEOUT )
450
462
config = JobConfig ()
451
463
config .parse_args ()
452
- main (config )
453
- torch .distributed .destroy_process_group ()
464
+ with maybe_enable_profiling (
465
+ config , global_step = train_state .step
466
+ ) as torch_profiler :
467
+ for root_size in [128 , 128 ]:
468
+ os .environ ["TORCH_NCCL_RANKS_PER_ROOT" ] = str (root_size )
469
+ iter_size = 5 if warmup else 1
470
+ delta = 0.0
471
+ for i in range (iter_size ):
472
+ start = time .perf_counter ()
473
+ store = PrefixStore (f"default_pg_{ index } " , tcp_store )
474
+ index += 1
475
+ torch .distributed .init_process_group (store = store , backend = "nccl" , world_size = world_size , rank = global_rank )
476
+ torch .cuda .set_device (local_rank )
477
+ torch .distributed .barrier ()
478
+ end = time .perf_counter ()
479
+ torch .distributed .destroy_process_group ()
480
+ delta += (end - start )
481
+ print (f"Time to init process group: { end - start :.6f} seconds for { root_size } ranks per roots" )
482
+ if warmup :
483
+ print (f"Average time to init process group: { delta / float (iter_size ):.6f} seconds for { root_size } ranks per roots" )
484
+ warmup = True
485
+ # main(config)
486
+ # torch.distributed.destroy_process_group()
0 commit comments