-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
596 lines (487 loc) · 26.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
import os
import time
import math
from argparse import ArgumentParser
from contextlib import nullcontext
from textwrap import dedent
from arguments import BaseArgs
import ray
import wandb
import numpy as np
import torch
import torch.distributed as dist
from ray import serve
from tqdm import tqdm
from dotenv import load_dotenv
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_scheduler
from accelerate import Accelerator
from accelerate.utils import DummyOptim, DummyScheduler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from rewards import format_reward_func, math_reward_func
from utils import \
prepare_deepspeed, get_all_inference_actors, \
call_func_using_actor_handle, stateless_init_process_group, \
get_per_token_logps, extract_answer, extract_numbers, compare_numbers
load_dotenv()
def main():
parser = ArgumentParser(description="Train for R1")
parser.add_argument(
"--exp-config-path",
type=str, default="./exps/exp_debug/exp_config.yaml",
help="Path to the experiment config file"
)
args = parser.parse_args()
exp_args = BaseArgs.from_yaml(args.exp_config_path)
accelerator = Accelerator(gradient_accumulation_steps=exp_args.gradient_accumulation_steps)
is_deepspeed = accelerator.state.deepspeed_plugin is not None
accelerator.print("Using DeepSpeed:", is_deepspeed)
tb_writer = None
is_wandb_logging = "wandb" in exp_args.logging_methods
is_tb_logging = "tensorboard" in exp_args.logging_methods
if accelerator.is_main_process:
if is_wandb_logging:
wandb.init(
project=exp_args.wandb_project,
entity=exp_args.wandb_entity,
config=exp_args.to_dict(),
name=exp_args.exp_name
)
if is_tb_logging:
tb_writer = SummaryWriter(f"tb/{exp_args.exp_name}")
###############################################################
# Prepare Tokenizer
tokenizer = AutoTokenizer.from_pretrained(exp_args.model_name_or_path)
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
stop_token = "<im_end>"
chat_template = dedent("""
{{- eos_token }}
{%- for message in messages %}
{{- "<im_start>" + message["role"] + "\n" + message["content"] + "<im_end>" + "\n" }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- "<im_start>assistant\n" }}
{%- endif %}""").strip()
tokenizer.chat_template = chat_template
def tokenize_function(examples):
system_prompt = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>\n<answer> answer here </answer>"
fewshot_question_1 = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
fewshot_answer_1 = dedent("""
<think>
Natalia sold 48/2 = 24 clips in May.
Natalia sold 48+24 = 72 clips altogether in April and May.
</think>
<answer>
72
</answer>""").strip()
gold_answer_list = []
new_messages_list = []
for q, a in zip(examples["question"], examples["answer"]):
new_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": fewshot_question_1},
{"role": "assistant", "content": fewshot_answer_1},
{"role": "user", "content": q}
]
new_messages_list.append(new_messages)
gold_answer_list.append(a.split("####")[-1].strip())
batch = {}
batch["gold_answer"] = gold_answer_list
batch["solution"] = gold_answer_list
batch["user_input_ids"] = tokenizer.apply_chat_template(
new_messages_list,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
truncation=True,
max_length=exp_args.max_length
).tolist()
batch["user_input_text"] = tokenizer.apply_chat_template(
new_messages_list,
tokenize=False,
add_generation_prompt=True
)
return batch
###############################################################
# Prepare Dataset
is_cache_exist = os.path.exists(exp_args.tokenized_dataset_path)
if accelerator.is_main_process and (not is_cache_exist or exp_args.overwrite_preprocess):
dataset = load_dataset(exp_args.dataset_name_or_path, "main")
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
batch_size=exp_args.batch_size_for_preproc,
num_proc=8
)
tokenized_datasets.save_to_disk(exp_args.tokenized_dataset_path)
accelerator.wait_for_everyone()
tokenized_datasets = load_from_disk(exp_args.tokenized_dataset_path)
if dist.is_available() and dist.is_initialized():
train_sampler = DistributedSampler(tokenized_datasets["train"], shuffle=True)
else:
train_sampler = RandomSampler(tokenized_datasets["train"])
valid_sampler = SequentialSampler(tokenized_datasets["test"])
def collate_fn_all(batch):
keys = [key for key in batch[0].keys()]
data = {key: [] for key in keys}
for item in batch:
for key in keys:
data[key].append(item[key])
if "user_input_ids" in data:
user_input = tokenizer.pad({"input_ids": data["user_input_ids"]},
return_tensors="pt",
padding=True,
padding_side="left")
data["user_input_ids"] = user_input.input_ids
return data
train_dataloader = DataLoader(tokenized_datasets["train"],
sampler=train_sampler,
batch_size=exp_args.train_batch_size_per_proc,
collate_fn=collate_fn_all,
drop_last=True)
valid_dataloader = DataLoader(tokenized_datasets["test"],
sampler=valid_sampler,
batch_size=exp_args.eval_batch_size_per_proc,
collate_fn=collate_fn_all)
###############################################################
# Prepare Model
model = AutoModelForCausalLM.from_pretrained(exp_args.model_name_or_path)
ref_model = AutoModelForCausalLM.from_pretrained(exp_args.model_name_or_path)
ref_model.eval()
###############################################################
# Prepare Optimizer and Scheduler
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
optimizer = optimizer_cls(model.parameters(), lr=exp_args.learning_rate)
num_processes = accelerator.num_processes
accelerator.print("Number of processes (GPUs):", num_processes)
num_training_steps = math.ceil(len(tokenized_datasets["train"]) / (exp_args.train_batch_size_per_proc * num_processes)) * exp_args.num_train_epochs
accelerator.print("Number of training steps:", num_training_steps)
# Creates Dummy Scheduler if `scheduler` was specified in the config file else creates `args.lr_scheduler_type` Scheduler
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
lr_scheduler = get_scheduler(
name=exp_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=exp_args.num_warmup_steps,
num_training_steps=num_training_steps,
)
else:
lr_scheduler = DummyScheduler(
optimizer, total_num_steps=num_training_steps, warmup_num_steps=exp_args.num_warmup_steps
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if accelerator.state.deepspeed_plugin is None:
ref_model = accelerator.prepare_model(ref_model, evaluation_mode=True)
else:
ref_model = prepare_deepspeed(ref_model, accelerator)
###############################################################
# Prepare Reward function
reward_func_list = [format_reward_func, math_reward_func]
###############################################################
# Prepare Inference Workers
ray_master_address = os.environ["RAY_MASTER_ADDRESS"]
ray_client_server_port = int(os.environ["RAY_CLIENT_SERVER_PORT"])
ray_master_pg_port = int(os.environ["RAY_MASTER_PG_PORT"])
if accelerator.is_main_process:
ray.init(address="auto")
else:
ray.init(address=f"ray://{ray_master_address}:{ray_client_server_port}")
handle = serve.get_deployment_handle("InferenceWorker",
app_name="default")
print(f"result: {handle.generate_text.remote(['hello']).result()}")
num_infer_workers = -1
model_update_group = None
# init weight update group
if accelerator.is_main_process:
actor_handle_list = get_all_inference_actors()
num_infer_workers = len(actor_handle_list)
accelerator.print(actor_handle_list)
worker_weight_init_handle_list = []
for i, actor_handle in enumerate(actor_handle_list):
worker_weight_init_handle = call_func_using_actor_handle(actor_handle,
"init_weight_update_group",
master_address=ray_master_address,
master_port=ray_master_pg_port,
rank=i+1,
world_size=num_infer_workers + 1)
worker_weight_init_handle_list.append(worker_weight_init_handle)
model_update_group = stateless_init_process_group(
ray_master_address,
ray_master_pg_port,
rank=0,
world_size=num_infer_workers + 1,
device=torch.device("cuda:0")
)
ray.get(worker_weight_init_handle_list)
accelerator.wait_for_everyone()
global_i = 0
os.makedirs(exp_args.save_dir, exist_ok=True)
model.train()
pbar = tqdm(range(num_training_steps), total=num_training_steps)
accelerator.print("Start training")
for epoch in range(exp_args.num_train_epochs):
for batch in train_dataloader:
context = nullcontext()
with context:
###############################################################
# Rollout
sample_params = {"temperature": exp_args.rollout_temperature,
"max_tokens": exp_args.rollout_max_tokens,
"n": exp_args.rollout_per_sample,
"include_stop_str_in_output": True,
"stop": [stop_token]}
future_policy_rollout_batch = handle.generate_text.remote(
batch["user_input_text"],
sample_params=sample_params
)
policy_rollout_batch = future_policy_rollout_batch.result()
text_compl_sample_list_batch = policy_rollout_batch["text"] # [batch_size, rollout_per_sample]
reward_list = [] # [batch_size, rollout_per_sample, num_reward_func]
###############################################################
# Calc Reward
for j, (text_compl_sample_list, solution) in enumerate(zip(text_compl_sample_list_batch, batch["solution"])):
curr_compl_reward_list = [] # [rollout_per_sample, num_reward_func]
for k, text_compl_sample in enumerate(text_compl_sample_list): # [rollout_per_sample]
curr_sample_reward_list = []
for l, reward_func in enumerate(reward_func_list):
reward = reward_func(text_compl_sample, solution=solution)
curr_sample_reward_list.append(reward)
curr_compl_reward_list.append(curr_sample_reward_list)
reward_list.append(curr_compl_reward_list)
rewards = torch.tensor(reward_list)
total_reward_by_each_compl = torch.sum(rewards, dim=2) # [batch_size, rollout_per_sample]
reward_mean = torch.mean(total_reward_by_each_compl, dim=1) # [batch_size]
reward_std = torch.std(total_reward_by_each_compl, dim=1) # [batch_size]
###############################################################
# Calc Advantages
# [batch_size, rollout_per_sample]
advantages = (total_reward_by_each_compl - reward_mean.unsqueeze(1)) / (reward_std.unsqueeze(1) + 1e-4)
advantages = advantages.to(model.device)
# [batch_size, rollout_per_sample, not fixed length ]
raw_completion_ids_batch = policy_rollout_batch["token_ids"]
###############################################################
# Calc KL divergence
batch_size = len(raw_completion_ids_batch)
rollout_per_sample = len(raw_completion_ids_batch[0])
# [batch_size * rollout_per_sample, length]
completion_ids_list = []
for raw_completion_ids in raw_completion_ids_batch:
completion_ids_list.extend(raw_completion_ids)
# [batch_size * rollout_per_sample, max_length]
completion_padded= tokenizer.pad({
"input_ids": completion_ids_list},
return_tensors="pt",
padding=True,
padding_side="right"
)
completion_ids = completion_padded.input_ids
# [batch_size, rollout_per_sample, max_length]
completion_ids = completion_ids.view(batch_size, rollout_per_sample, -1)
completion_ids = completion_ids.to(model.device)
# [batch_size, max_length]
user_input_ids = batch["user_input_ids"]
# [batch_size, rollout_per_sample, max_length]
user_input_ids_expanded = user_input_ids.unsqueeze(1).expand(-1, rollout_per_sample, -1)
# [batch_size, rollout_per_sample, max_length]
prompt_completion_ids = torch.cat([user_input_ids_expanded,
completion_ids], dim=-1)
logits_to_keep = completion_ids[0].size(1)
# [batch_size * rollout_per_sample, max_length]
flatten_prompt_completion_ids = prompt_completion_ids.view(batch_size * rollout_per_sample, -1)
# [batch_size, rollout_per_sample, max_length]
flatten_prompt_completion_attention_mask = (flatten_prompt_completion_ids != pad_token_id).view(batch_size* rollout_per_sample, -1).long()
# Calc KLD
with torch.no_grad():
ref_per_token_logps = get_per_token_logps(
ref_model,
flatten_prompt_completion_ids,
flatten_prompt_completion_attention_mask,
logits_to_keep
)
policy_per_token_logps = get_per_token_logps(
model,
flatten_prompt_completion_ids,
flatten_prompt_completion_attention_mask,
logits_to_keep
)
# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - policy_per_token_logps) - (ref_per_token_logps - policy_per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
# It is equivalent to updating the old policy model at every step.
# [batch_size * rollout_per_sample, max_length]
per_token_loss = torch.exp(policy_per_token_logps - policy_per_token_logps.detach()) * advantages.view(-1, 1)
# Working version... However, I have no idea why it works
# I think I need to multiply -1. to per_token_loss. Weird...
per_token_loss = per_token_loss + exp_args.kl_coef * per_token_kl
# per_token_loss = -(per_token_loss - exp_args.kl_coef * per_token_kl)
# [batch_size * rollout_per_sample, max_length]
completion_attention_mask = (completion_ids != pad_token_id).view(batch_size* rollout_per_sample, -1).long()
train_loss = ((per_token_loss * completion_attention_mask).sum(dim=1) / completion_attention_mask.sum(dim=1)).mean()
accelerator.backward(train_loss)
if not is_deepspeed and accelerator.sync_gradients:
accelerator.clip_grad_value_(model.parameters(), exp_args.max_grad_value)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
###############################################################
# Update Policy model
actor_handle_list = get_all_inference_actors(class_name="InferenceWorker", state="ALIVE")
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
start_time = time.time()
for name, p in unwrapped_model.named_parameters():
worker_update_weight_handle_list = []
for i, actor_handle in enumerate(actor_handle_list):
worker_update_weight_handle = call_func_using_actor_handle(actor_handle,
"update_weight",
name=name,
dtype=p.dtype,
shape=p.shape)
worker_update_weight_handle_list.append(worker_update_weight_handle)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(worker_update_weight_handle_list)
print(f"Time for weight update: {time.time() - start_time}")
accelerator.wait_for_everyone()
print(f"{accelerator.process_index} Train loss:", train_loss.item())
###############################################################
# Logging
if global_i % exp_args.log_interval == 0:
# Collect metrics
global_train_loss = accelerator.reduce(train_loss.detach(), reduction="mean").item()
# [batch_size * world_size, rollout_per_sample, num_reward_func]
global_rewards = accelerator.gather_for_metrics(rewards.to(model.device)).detach()
global_reward_mean = torch.mean(global_rewards).item()
if accelerator.is_main_process:
length_list = []
for text_compl_sample_list in text_compl_sample_list_batch:
length_list.append([len(text_compl_sample)
for text_compl_sample in text_compl_sample_list])
length_mean = np.mean(length_list)
length_std = np.std(length_list)
reward_func_to_reward_map = {}
for i, reward_func in enumerate(reward_func_list):
reward_func_name = reward_func.__name__
all_rewards = global_rewards[:, :, i]
curr_reward_mean = torch.sum(all_rewards) / torch.numel(all_rewards)
reward_func_to_reward_map[reward_func_name] = curr_reward_mean.item()
metrics = {"epoch": epoch,
"global_step": global_i,
"reward_mean": global_reward_mean,
"train_loss": global_train_loss,
"lr": lr_scheduler.get_last_lr()[0],
"length_mean": length_mean,
"length_std": length_std,
**reward_func_to_reward_map
}
print(metrics)
if is_wandb_logging:
wandb.log(metrics)
if is_tb_logging:
for k, v in metrics.items():
tb_writer.add_scalar(f"train/{k}", v, global_i)
print("="*60)
for item_list in completion_ids:
for item in item_list:
sample_completion = tokenizer.decode(item.cpu().tolist(),
skip_special_tokens=True)
print(sample_completion)
print("-"*30)
if accelerator.is_main_process and global_i % exp_args.eval_interval == 0:
pred_raw_list = []
pred_list = []
gold_list = []
batch_result_list = []
eval_sample = 30
for batch in valid_dataloader:
if len(gold_list) > eval_sample:
break
# inference
gold_list.extend(batch["gold_answer"])
sample_params = {"temperature": 0.1,
"max_tokens": exp_args.rollout_max_tokens,
"n": 1,
"include_stop_str_in_output": True,
"stop": [stop_token]}
future_policy_rollout_batch = handle.generate_text.remote(
batch["user_input_text"],
sample_params=sample_params
)
batch_result_list.append(future_policy_rollout_batch)
if len(batch_result_list) >= num_infer_workers:
continue
for future_policy_rollout_batch in batch_result_list:
policy_rollout_batch = future_policy_rollout_batch.result()
for preds in policy_rollout_batch["text"]:
pred_raw_list.append(preds[0])
batch_result_list = []
if batch_result_list:
for future_policy_rollout_batch in batch_result_list:
policy_rollout_batch = future_policy_rollout_batch.result()
for preds in policy_rollout_batch["text"]:
pred_raw_list.append(preds[0])
gold_list = gold_list[:eval_sample]
pred_raw_list = pred_raw_list[:eval_sample]
for pred_raw in pred_raw_list:
# extract answer from <answer> </answer> tag
answer_block = extract_answer(pred_raw)
answer_number = extract_numbers(answer_block)
pred = answer_number[0] if answer_number else None
pred_list.append(pred)
n_exact_correct = 0
n_within_tolerance_correct = 0
n_total = len(pred_list)
for pred, gold in zip(pred_list, gold_list):
result = compare_numbers(pred, gold)
if result["exact_match"]:
n_exact_correct += 1
if result["within_tolerance"]:
n_within_tolerance_correct += 1
# Calc Accuracy
exact_accuracy = n_exact_correct / n_total
within_tolerance_accuracy = n_within_tolerance_correct / n_total
metrics = {
f"gsm8k_accuracy_exact_{eval_sample}s": exact_accuracy,
f"gsm8k_accuracy_within_tolerance_{eval_sample}s": within_tolerance_accuracy,
}
if is_wandb_logging:
wandb.log(metrics)
if is_tb_logging:
for k, v in metrics.items():
tb_writer.add_scalar(f"valid/{k}", v, global_i)
accelerator.print(
f"global_step: {global_i}, epoch: {epoch}, "
f"gsm8k_accuracy_exact_{eval_sample}s: {exact_accuracy:0.4f}, "
f"gsm8k_accuracy_within_tolerance_{eval_sample}s: {within_tolerance_accuracy:0.4f}"
)
accelerator.wait_for_everyone()
if global_i % exp_args.save_interval == 0:
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{exp_args.save_dir}/ckpt_{global_i}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model)
)
tokenizer.save_pretrained(f"{exp_args.save_dir}/ckpt_{global_i}")
torch.save({"epoch": epoch, "global_step": global_i}, f"{exp_args.save_dir}/ckpt_{global_i}/training_state.pt")
accelerator.wait_for_everyone()
pbar.update(1)
global_i += 1
pbar.close()
if __name__ == "__main__":
main()