Skip to content

Commit 21f9bff

Browse files
authored
Merge pull request #225 from otaviogood/grad_accum
Fix for gradient_accumulation_steps training slow
2 parents d9f4735 + a6a708c commit 21f9bff

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

Diff for: config/train_gpt2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
1111
batch_size = 12
1212
block_size = 1024
13-
gradient_accumulation_steps = 5
13+
gradient_accumulation_steps = 5 * 8
1414

1515
# this makes total number of tokens be 300B
1616
max_iters = 600000

Diff for: config/train_shakespeare_char.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
wandb_run_name = 'mini-gpt'
1515

1616
dataset = 'shakespeare_char'
17+
gradient_accumulation_steps = 1
1718
batch_size = 64
1819
block_size = 256 # context of up to 256 previous characters
1920

Diff for: train.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
wandb_run_name = 'gpt2' # 'run' + str(time.time())
4646
# data
4747
dataset = 'openwebtext'
48-
gradient_accumulation_steps = 5 # used to simulate larger batch sizes
48+
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
4949
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
5050
block_size = 1024
5151
# model
@@ -84,16 +84,20 @@
8484
init_process_group(backend=backend)
8585
ddp_rank = int(os.environ['RANK'])
8686
ddp_local_rank = int(os.environ['LOCAL_RANK'])
87+
ddp_world_size = int(os.environ['WORLD_SIZE'])
8788
device = f'cuda:{ddp_local_rank}'
8889
torch.cuda.set_device(device)
8990
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
9091
seed_offset = ddp_rank # each process gets a different seed
92+
assert gradient_accumulation_steps % torch.cuda.device_count() == 0
93+
gradient_accumulation_steps //= torch.cuda.device_count()
9194
else:
9295
# if not ddp, we are running on a single gpu, and one process
9396
master_process = True
9497
seed_offset = 0
95-
gradient_accumulation_steps *= 8 # simulate 8 gpus
96-
print("total number of tokens per iteration:", batch_size * block_size * gradient_accumulation_steps)
98+
ddp_world_size = 1
99+
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
100+
print(f"tokens per iteration will be: {tokens_per_iter:,}")
97101

98102
if master_process:
99103
os.makedirs(out_dir, exist_ok=True)

0 commit comments

Comments
 (0)