Skip to content

Commit b592727

Browse files
authoredDec 30, 2023
Add files via upload
1 parent ff9a785 commit b592727

14 files changed

+53350
-2
lines changed
 

‎README.md

+54-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,54 @@
1-
# RankingGPT
2-
code for paper 《RankingGPT: Empowering Large Language Models in Text Ranking with Progressive Enhancement》
1+
# Description
2+
This is the official code for paper [RankingGPT: Empowering Large Language Models in Text Ranking with Progressive Enhancement](https://arxiv.org/abs/2311.16720).
3+
4+
# Requirements
5+
```
6+
transformers==4.28.1
7+
datasets
8+
pyserini
9+
torch==1.13.1
10+
```
11+
12+
# Data
13+
14+
- ./datasets/text_pairs.json: Weakly supervised text pairs
15+
16+
- ./datasets/msmarco.json: Supervised fine-tuning data
17+
18+
- ./rankdata/trec19: Top-1000 query-document pairs recalled by BM25
19+
20+
21+
# Two-stage Training
22+
23+
## Pretrain
24+
```
25+
bash pretrain.sh bigscience/bloom-560m bloom-560m BloomBlock
26+
```
27+
28+
## SFT
29+
```
30+
bash sft.sh ./outputs_pretrain_bloom-560m bloom-560m 16 BloomBlock
31+
```
32+
33+
# Evaluation
34+
```
35+
bash eval.sh ./outputs_sft_bloom-560m trec19 bloom-560m
36+
```
37+
38+
# Results
39+
*Ranking results (NDCG@10) of the top-1000 candidate documents recalled by BM25.*
40+
| | DL19 | DL20 | BEIR | url |
41+
|---------|------|------|------|-----------------|
42+
| MonoBERT-340M | 72.3 | 70.3 | 50.5 | [huggingface](https://huggingface.co/veneres/monobert-msmarco) |
43+
| MonoT5-220M | 71.5 | 69.7 | 49.3 | [huggingface](https://huggingface.co/castorini/monot5-base-msmarco) |
44+
| MonoT5-770M | 73.2 | 71.2 | 53.1 | [huggingface](https://huggingface.co/castorini/monot5-large-msmarco) |
45+
| MonoT5-3B | 72.8 | 74.5 | 54.6 | [huggingface](https://huggingface.co/castorini/monot5-3b-msmarco) |
46+
| RankT5-770M | - | - | 53.7 | [huggingface](https://huggingface.co/bergum/rank-T5-flan) |
47+
| RankLLaMA| 74.6 | 76.6 | 52.5 | [huggingface](https://huggingface.co/castorini/rankllama-v1-7b-lora-passage) |
48+
| RankingGPT-bloom-560m| 75.3 | 73.2 | 53.7 | [huggingface](https://huggingface.co/zyznull/RankingGPT-bloom-560m) [modelscope](https://modelscope.cn/models/damo/RankingGPT-bloom-560m) |
49+
| RankingGPT-bloom-1b1| 75.6 | 73.2 | 54.5 | [huggingface](https://huggingface.co/zyznull/RankingGPT-bloom-1b1) [modelscope](https://modelscope.cn/models/damo/RankingGPT-bloom-1b1) |
50+
| RankingGPT-bloom-3b| 76.8 | 73.6 | 56.2 | [huggingface](https://huggingface.co/zyznull/RankingGPT-bloom-3b) [modelscope](https://modelscope.cn/models/damo/RankingGPT-bloom-3b) |
51+
| RankingGPT-bloom-7b| 77.3 | 74.6 | 56.6 | [huggingface](https://huggingface.co/zyznull/RankingGPT-bloom-7b) [modelscope](https://modelscope.cn/models/damo/RankingGPT-bloom-7b) |
52+
| RankingGPT-llama2-7b| 76.2 | 76.3 | 57.8 | [huggingface](https://huggingface.co/zyznull/RankingGPT-llama2-7b) [modelscope](https://modelscope.cn/models/damo/RankingGPT-llama2-7b) |
53+
| RankingGPT-baichuan2-7b| 75.9 | 74.3 | 57.5 | [huggingface](https://huggingface.co/zyznull/RankingGPT-baichuan2-7b) [modelscope](https://modelscope.cn/models/damo/RankingGPT-baichuan2-7b) |
54+
| RankingGPT-qwen-7b| 75.8 | 74.3 | 58.3 | [huggingface](https://huggingface.co/zyznull/RankingGPT-qwen-7b) [modelscope](https://modelscope.cn/models/damo/RankingGPT-qwen-7b) |

‎datasets/msmarco.json

+100
Large diffs are not rendered by default.

‎datasets/text_pairs.json

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{"query": "Document: The name \"aardvark\" is Afrikaans (Afrikaans pronunciation: [\u02c8\u0251\u02d0rtfark]), comes from earlier Afrikaans (erdvark) and means \"earth pig\" or \"ground pig\" (aarde: earth/ground, vark: pig), because of its burrowing habits. The name Orycteropus means burrowing foot, and the name afer refers to Africa. Query:", "response": "is aardvark derived from afrikaans?"}
2+
{"query": "Document: There is indeed a transform called discrete Laplace transform and it is of course closely related to the $\\mathcal{Z}$-transform. The (unilateral) discrete Laplace transform of a sequence $f_n$ is defined by (cf. link)\n\n$$\\mathcal{L}_T\\{f_n\\}=\\sum_{n=0}^{\\infty}f_ne^{-snT}$$\n\nwith some $T>0$. The discrete Laplace transform can be interpreted as the Laplace transform of a sampled function $f(t)\\cdot\\sum_n\\delta(t-nT)$ with $f_n=f(nT)$.\n\nIn practice it is not convenient to have the factor $e^{sT}$ appear everywhere. If one substitutes $z=e^{sT}$, the discrete Laplace transform is called (unilateral) $\\mathcal{Z}$-transform:\n\n$$\\mathcal{Z}\\{f_n\\}=\\sum_{n=0}^{\\infty}f_nz^{-n}$$\n\nThe same can obviously be done for the bilateral versions of the transforms, where the integrals and the sums start at $-\\infty$. Query:", "response": "Why Z-transform is considered as separate transform?"}
3+
{"query": "Document: I added a new condition to org-return, very similar to how it handles a link containing the org-link face property ..., with the new condition being triggered if the org-footnote face property is present. We use org-footnote-action to jump to/from. The code is from Org version 8.2.10 that ships with Emacs 25.2.1.\n\n (require 'org)\n\n(setq org-return-follows-link t)\n\n(defun org-return (&optional indent)\n \"Goto next table row or insert a newline.\nCalls `org-table-next-row' or `newline', depending on context.\nSee the individual commands for more information.\"\n (interactive)\n (let (org-ts-what)\n (cond\n ((or (bobp) (org-in-src-block-p))\n (if indent (newline-and-indent) (newline)))\n ((org-at-table-p)\n (org-table-justify-field-maybe)\n (call-interactively 'org-table-next-row))\n ;; when `newline-and-indent' is called within a list, make sure\n ;; text moved stays inside the item.\n ((and (org-in-item-p) indent)\n (if (and (org-at-item-p) (>= (point) (match-end 0)))\n (progn\n (save-match-data (newline))\n (org-indent-line-to (length (match-string 0))))\n (let ((ind (org-get-indentation)))\n (newline)\n (if (org-looking-back org-list-end-re)\n (org-indent-line)\n (org-indent-line-to ind)))))\n ((and org-return-follows-link\n (org-at-timestamp-p t)\n (not (eq org-ts-what 'after)))\n (org-follow-timestamp-link))\n ((and org-return-follows-link\n (let ((tprop (get-text-property (point) 'face)))\n (or (eq tprop 'org-link)\n (and (listp tprop) (memq 'org-link tprop)))))\n (call-interactively 'org-open-at-point))\n ;;; NEW CONDITION: `org-footnote' face property => `org-footnote-action'\n ((and org-return-follows-link\n (let ((tprop (get-text-property (point) 'face)))\n (or (eq tprop 'org-footnote)\n (and (listp tprop) (memq 'org-footnote tprop)))))\n (org-footnote-action))\n ((and (org-at-heading-p)\n (looking-at\n (org-re \"\\\\([ \\t]+\\\\(:[[:alnum:]_@#%:]+:\\\\)\\\\)[ \\t]*$\")))\n (org-show-entry)\n (end-of-line 1)\n (newline))\n (t\n (if indent\n (newline-and-indent)\n (newline)))))) Query:", "response": "org-return-follows-link with footnotes?"}
4+
{"query": "Document: The Ford Fusion Titanium has optional 19-inch wheels available in machined or dark stainless aluminum that add $695 to the price of the car. But the cost of replacing one tire (Continental ContiProContact) would be $244 to $292. Query:", "response": "how much do aluminum wheels cost"}
5+
{"query": "Document: In the following episode \"My Brother's Keeper\", Elena tells Damon he is the reason she and Stefan broke up and at the end of the episode Damon and Elena finally have sex. Query:", "response": "vampire diaries episode where elena and damon sleep together?"}
6+
{"query": "Document: Heavy manga spoilers ahead read at your own risk:\n\n\n In manga ch 45, Kohaku asks if they are related to which Senku answers\nthat he and Byakuya are NOT blood related BUT they (he and the villagers)\nare hundreds of generations apart so it doesn't matter.\nYes they had children, no they are not related to Senku.\n(Note that they don't specifically show that Byakuya and Lillian were together but\nthere was the already married couple and they showed Connie and Shamil getting married so most\nprobably Byakuya ended up with Lillian) Query:", "response": "Did Senku's father have children with Lillian in Dr. Stone?"}
7+
{"query": "Document: Asker's rating. 1 The average life span of the female mosquito is 3 to 100 days; the male's is 10 to 20 days. 2 females lie eggs, little worms come out of the eggs and the worms become those little flying pests... Query:", "response": "what's the life expectancy of a mosquito"}
8+
{"query": "Document: So this thread is dedicated to posting of poor critiques of Marxian economics. First is diamonds in the desert. So it starts off with a man, a man of means, but not unlimited means, stranded in the desert. He is dehydrated and close to death. At this point another man comes along with apparently only carrying two items with him, and looking to make a deal. He sees the dehydrated man has a sack of diamonds all of good quality. Being a good capitalist he wants to make a deal; \"All your diamonds and you can have the water\". The dehydrated man gives him all he has for water. This example is one libertarian/liberal/capitalist apologists will use to prove value is subjective. How could it not? When you are dehydrated isn't water more valuable than a sack of diamonds? Such a great and simple thought experiment to show how Marx was wrong and capitalism is the best system ever. Well there are problems with this example:\n\n1. The first problem is probably the least obvious: Value isn't determined by extreme cases, value is an aggregate. One instance of a high price doesn't mean water is worth more than diamonds. \n\n2. In the aggregate trade is a zero sum gain. Rarely does one trade items of greatly mismatched value, especially with a currency as a medium of exchange. Pointing to a case of someone getting ripped off doesn't change this.\n\n3. It changes value into only the result of exchange without any context. Is it possible that diamonds literally litter the ground? That any person can just pick one up and trade it for anything? And is water a rare resource? If this is the cause then it is a fine example of LTV. One can't assume that this abstraction is like the real world in all cases except where it comes to the exchange itself. \n\n\nAnyone else have examples of poor liberal/libertarian/capitalist thought experiments that they believe disproves Marxian economics? Query:", "response": "Diamonds in the desert, and other badly thought out critiques of Marxian economic theory."}
9+
{"query": "Document: The oxford dictionary defines pilot as:\n\n\n A person who operates the flying controls of an aircraft\n\n\nSo, technically, the drone operator should be called a pilot.\n\nFAA National Policy Order 8130.34C Airworthiness Certification of Unmanned Aircraft Systems and Optionally Piloted Aircraft Section 6 specifically calls the person operating the UAS (only if it has been issued with an airworthiness certificate ) Pilot.:\n\n\n \n UA Pilots and Observers.\n \n \n a. PIC Roles and Responsibilities.\n \n (1) The PIC must perform crew duties for only one UA at a time.\n \n (2) All UA flight operations must have a designated PIC. The PIC has responsibility over each flight conducted and is accountable for the UA flight operation.\n \n (3) The PIC is responsible for the safety of the UA as well as persons and property along the UA flight path. This includes, but is not limited to, collision avoidance and the safety of persons and property in the air and on the ground.\n \n (4) The PIC must avoid densely populated areas and congested airways in accordance with \u00a7 91.319.\n\n\nThe order requires the PIC to have a minimum of FAA PPL:\n\n\n b. UA PIC Certification and Ratings Requirements.\n \n (1) The PIC must hold and be in possession of, at a minimum, an FAA private pilot certificate, with either an airplane, rotorcraft, or powered-lift category; with single- or multiengine class ratings, appropriate to the type of UA being operated.\n \n (2) The PIC must have and be in possession of a valid second-class (or higher) airman medical certificate issued under 14 CFR part 67, Medical Standards and Certification.\n\n\nUK CAA also talks about 'Pilot Qualifications required to operate Unmanned Aircraft'. ICAO also uses the term 'pilot' for people controlling an UAV.\n\n\n\nBoth USAF and RAF call the UAV operators pilots- RAF calls them Remotely Piloted Aircraft System Pilots and USAF, Remotely Piloted Aircraft Pilots and they do get the 'wings'\n\n\n\n\"United States Air Force Unmanned Aircraft Operator Badge\" by SSgt Austin May of the USAF - http://www.af.mil/news/story.asp?id=123170151http://www.mildenhall.af.mil/news/story.asp?id=123170577. Licensed under Public Domain via Commons.\n\nHowever, there has been no UAV pilot license issued as far as I know. Query:", "response": "Should people flying UAVs be called \"Operators\" or \"Pilots\"?"}
10+
{"query": "Document: According to Google Earth, the nearest beach to Nashville TN is Myrtle Beach, South Carolina. According to the directions and route provided by Google Earth Myrtle Beach is 585 Miles from Nashville by car, aproximatly 10 hours and 42 minutes driving at the speed limit, not including time for travel plaza stops. Query:", "response": "how far is nashville tn from closest beach"}

‎eval.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import torch
2+
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM,T5Tokenizer, T5ForConditionalGeneration
3+
import torch
4+
import argparse
5+
import json
6+
from tqdm import tqdm
7+
import os
8+
import copy
9+
10+
def get_model_tokenizer(model_path):
11+
if 'llama' in model_path.lower():
12+
tokenizer = LlamaTokenizer.from_pretrained(model_path)
13+
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.float16, device_map="auto")
14+
else:
15+
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
16+
model = AutoModelForCausalLM.from_pretrained(model_path,torch_dtype=torch.float16, device_map="auto",trust_remote_code=True)
17+
model.eval()
18+
return model,tokenizer
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument('--model_path',
22+
default="",
23+
required=False)
24+
parser.add_argument('--res_path',
25+
default="",
26+
required=False)
27+
parser.add_argument('--rank_path',
28+
default="",
29+
required=False)
30+
parser.add_argument('--data_name',
31+
default='msmarco')
32+
33+
args = parser.parse_args()
34+
35+
36+
model_path=args.model_path
37+
data_name=args.data_name
38+
39+
IGNORE_INDEX = -100
40+
DEFAULT_PAD_TOKEN = "[PAD]"
41+
DEFAULT_EOS_TOKEN = "</s>"
42+
DEFAULT_BOS_TOKEN = "</s>"
43+
DEFAULT_UNK_TOKEN = "</s>"
44+
45+
bsz=8
46+
47+
prompt='Document: {doc} Query:'
48+
49+
model,tokenizer=get_model_tokenizer(model_path)
50+
if 'qwen' in model_path.lower():
51+
tokenizer.pad_token_id = tokenizer.eod_id
52+
53+
def get_num_token(text):
54+
return len(tokenizer.encode(text))
55+
56+
prompt_len=get_num_token(prompt)
57+
print(f"prompt_len: {prompt_len}")
58+
59+
60+
def truncation(text,length):
61+
text=tokenizer.decode(tokenizer.encode(text,max_length=length, add_special_tokens=False))
62+
return text
63+
64+
def _tokenize_fn(strings):
65+
"""Tokenize a list of strings."""
66+
tokenized_list = [
67+
tokenizer(
68+
text,
69+
return_tensors="pt",
70+
padding="longest",
71+
)['input_ids']
72+
for text in strings
73+
]
74+
input_ids = labels = [tokenized[0] for tokenized in tokenized_list]
75+
input_ids_lens = labels_lens = [
76+
tokenized.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
77+
]
78+
return dict(
79+
input_ids=input_ids,
80+
labels=labels,
81+
input_ids_lens=input_ids_lens,
82+
labels_lens=labels_lens,
83+
)
84+
85+
all_examples=[]
86+
all_sources=[]
87+
all_qpids=[]
88+
all_queries=[]
89+
90+
for line in tqdm(open(args.rank_path),desc='load data'):
91+
ex = json.loads(line)
92+
all_qpids.append((ex['qid'],ex['pid']))
93+
if data_name!='arguana':
94+
query = ex["query"].replace(DEFAULT_PAD_TOKEN,'PAD')
95+
query_len = get_num_token(query)
96+
passage_max_len = 512-prompt_len-query_len-10
97+
source = prompt.format(doc = truncation(ex['passage'], passage_max_len)).replace(DEFAULT_PAD_TOKEN,'PAD')
98+
else:
99+
source = prompt.format(doc = truncation(ex['passage'], 256)).replace(DEFAULT_PAD_TOKEN,'PAD')
100+
query = truncation(ex['query'], 256)
101+
all_examples.append(source+query)
102+
all_sources.append(source)
103+
all_queries.append(query)
104+
105+
106+
with open(args.res_path,"w") as fw:
107+
for index in tqdm(range(0,len(all_examples),bsz)):
108+
examples=all_examples[index:index+bsz]
109+
sources=all_sources[index:index+bsz]
110+
qpids=all_qpids[index:index+bsz]
111+
queries=all_queries[index:index+bsz]
112+
qid, pid = qpids[0]
113+
114+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings) for strings in (examples, sources)]
115+
input_ids = examples_tokenized["input_ids"]
116+
117+
labels = copy.deepcopy(input_ids)
118+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
119+
label[:source_len] = IGNORE_INDEX
120+
121+
122+
for index in range(len(input_ids)):
123+
input_ids[index]=input_ids[index][:-1]
124+
125+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).cuda()
126+
127+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).cuda()
128+
labels = labels[..., 1:].contiguous() #BL
129+
130+
with torch.no_grad():
131+
lm_logits = model(input_ids=input_ids,attention_mask=input_ids.ne(tokenizer.pad_token_id))[0]
132+
preds = torch.nn.functional.log_softmax(lm_logits,dim=-1)
133+
label_no_ingore = torch.where(labels==-100,torch.ones(labels.shape).long().cuda(),labels)
134+
logprobs = torch.gather(preds, -1, label_no_ingore.unsqueeze(dim=-1)).squeeze(dim=-1) # B L
135+
indexs=(labels!=-100).long()
136+
scores=(logprobs*indexs).sum(dim=-1)/indexs.sum(dim=-1)
137+
scores=scores.cpu().tolist()
138+
139+
for index,score in enumerate(scores):
140+
qid, pid=qpids[index]
141+
print(" ".join([qid,"Q0",pid,"-1",str(score),model_path]),file=fw)
142+
del lm_logits
143+
144+
145+
146+
results={}
147+
for line in open(args.res_path):
148+
line = line.strip().split()
149+
qid = line[0]
150+
pid = line[2]
151+
152+
score = float(line[4])
153+
if qid not in results:
154+
results[qid] = []
155+
results[qid].append((pid,score))
156+
157+
with open(args.res_path[:-4]+"_post.res","w") as fw:
158+
for qid in results:
159+
res = results[qid]
160+
sorted_res = sorted(res,key = lambda x:-x[1])
161+
for i,item in enumerate(sorted_res):
162+
print(" ".join([qid, "Q0", item[0], str(i), str(item[1]), 'llm']),file=fw)

‎eval.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
model=$1
2+
data_name=$2
3+
rank_name=$3
4+
5+
output_dir="./rankdata/${data_name}/${rank_name}/"
6+
mkdir -p ${output_dir}
7+
echo $output_dir
8+
9+
recall_path="./rankdata/${data_name}/top1000.json"
10+
qrel_path="./rankdata/${data_name}/qrels.txt"
11+
12+
13+
echo ${rank_name}
14+
echo ${recall_path}
15+
echo ${qrel_path}
16+
17+
18+
python eval.py \
19+
--model_path $model \
20+
--res_path "${output_dir}/$rank_name.res" \
21+
--rank_path $recall_path \
22+
--data_name $data_name
23+
24+
# ndcg
25+
python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 ${qrel_path} ${output_dir}/${rank_name}_post.res > ${output_dir}/${rank_name}_score.txt
26+
cat ${output_dir}/${rank_name}_score.txt

‎img/indomain.png

1.64 MB
Loading

‎img/outdomain1.png

1.67 MB
Loading

‎img/outdomain2.png

1.24 MB
Loading

‎pretrain.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import logging
17+
from dataclasses import dataclass, field
18+
from typing import Optional, Dict, Sequence
19+
20+
import numpy as np
21+
import torch
22+
import transformers
23+
from torch.utils.data import Dataset
24+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
25+
import datasets
26+
27+
28+
IGNORE_INDEX = -100
29+
DEFAULT_PAD_TOKEN = "[PAD]"
30+
DEFAULT_EOS_TOKEN = "</s>"
31+
DEFAULT_BOS_TOKEN = "<s>"
32+
DEFAULT_UNK_TOKEN = "<unk>"
33+
34+
35+
@dataclass
36+
class ModelArguments:
37+
model_name_or_path: Optional[str] = field(default="bigscience/bloom-560m")
38+
tokenizer_name_or_path: Optional[str] = field(default="bigscience/bloom-560m")
39+
40+
41+
@dataclass
42+
class DataArguments:
43+
train_data_path: str = field(default=None, metadata={"help": "Path to the training data."})
44+
dev_data_path: str = field(default=None, metadata={"help": "Path to the training data."})
45+
mask_input: bool = field(default=False)
46+
47+
48+
@dataclass
49+
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
50+
cache_dir: Optional[str] = field(default=None)
51+
optim: str = field(default="adamw_torch")
52+
model_max_length: int = field(
53+
default=2048,
54+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
55+
)
56+
train_with_peft: bool = field(default=False, metadata={"help": "Is training with peft"})
57+
58+
59+
60+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, save_peft=False):
61+
"""Collects the state dict and dump to disk."""
62+
if save_peft:
63+
trainer.model = trainer.model.cpu()
64+
trainer.model.save_pretrained(output_dir)
65+
else:
66+
state_dict = trainer.model.state_dict()
67+
if trainer.args.should_save:
68+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
69+
del state_dict
70+
trainer._save(output_dir, state_dict=cpu_state_dict)
71+
72+
73+
def smart_tokenizer_and_embedding_resize(
74+
special_tokens_dict: Dict,
75+
tokenizer: transformers.PreTrainedTokenizer,
76+
model: transformers.PreTrainedModel,
77+
):
78+
"""Resize tokenizer and embedding.
79+
80+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
81+
"""
82+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
83+
model.resize_token_embeddings(len(tokenizer))
84+
85+
if num_new_tokens > 0:
86+
input_embeddings = model.get_input_embeddings().weight.data
87+
output_embeddings = model.get_output_embeddings().weight.data
88+
89+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
90+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
91+
92+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
93+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
94+
95+
96+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
97+
"""Tokenize a list of strings."""
98+
tokenized_list = [
99+
tokenizer(
100+
text,
101+
return_tensors="pt",
102+
padding="longest",
103+
max_length=tokenizer.model_max_length,
104+
truncation=True,
105+
)
106+
for text in strings
107+
]
108+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
109+
input_ids_lens = labels_lens = [
110+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
111+
]
112+
return dict(
113+
input_ids=input_ids,
114+
labels=labels,
115+
input_ids_lens=input_ids_lens,
116+
labels_lens=labels_lens,
117+
)
118+
119+
120+
def preprocess(
121+
sources: Sequence[str],
122+
targets: Sequence[str],
123+
tokenizer: transformers.PreTrainedTokenizer,
124+
mask_input: bool
125+
) -> Dict:
126+
"""Preprocess the data by tokenizing."""
127+
examples = [s + t for s, t in zip(sources, targets)]
128+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
129+
input_ids = examples_tokenized["input_ids"]
130+
labels = copy.deepcopy(input_ids)
131+
if mask_input:
132+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
133+
label[:source_len] = IGNORE_INDEX
134+
return dict(input_ids=input_ids, labels=labels)
135+
136+
137+
class SupervisedDataset(Dataset):
138+
"""Dataset for supervised fine-tuning."""
139+
140+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, mask_input: bool):
141+
super(SupervisedDataset, self).__init__()
142+
logging.warning("Loading data...")
143+
self.list_data_dict = datasets.load_dataset('json',data_files=data_path)['train']
144+
self.tokenizer = tokenizer
145+
self.mask_input = mask_input
146+
147+
def __len__(self):
148+
return len(self.list_data_dict)
149+
150+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
151+
example=self.list_data_dict[i]
152+
source = example['query']
153+
if 'qwen' in self.tokenizer.name_or_path.lower():
154+
target = f"{example['response']}"
155+
else:
156+
target = f"{example['response']}{self.tokenizer.eos_token}"
157+
158+
data_dict = preprocess([source], [target], self.tokenizer, self.mask_input)
159+
160+
input_ids = data_dict["input_ids"][0]
161+
labels = data_dict["labels"][0]
162+
return dict(input_ids=input_ids, labels=labels)
163+
164+
165+
@dataclass
166+
class DataCollatorForSupervisedDataset(object):
167+
"""Collate examples for supervised fine-tuning."""
168+
169+
tokenizer: transformers.PreTrainedTokenizer
170+
171+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
172+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
173+
input_ids = torch.nn.utils.rnn.pad_sequence(
174+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
175+
)
176+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
177+
return dict(
178+
input_ids=input_ids,
179+
labels=labels,
180+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
181+
)
182+
183+
184+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
185+
"""Make dataset and collator for supervised fine-tuning."""
186+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.train_data_path, mask_input=data_args.mask_input)
187+
dev_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.dev_data_path, mask_input=data_args.mask_input)
188+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
189+
return dict(train_dataset=train_dataset, eval_dataset=dev_dataset, data_collator=data_collator)
190+
191+
def postprocess_text(preds, labels):
192+
preds = [pred.strip() for pred in preds]
193+
labels = [label.strip() for label in labels]
194+
195+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
196+
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
197+
198+
return preds, labels
199+
200+
201+
def train():
202+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
203+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
204+
if 'qwen' not in model_args.model_name_or_path.lower():
205+
training_args.predict_with_generate=True
206+
207+
model = transformers.AutoModelForCausalLM.from_pretrained(
208+
model_args.model_name_or_path,
209+
cache_dir=training_args.cache_dir,
210+
trust_remote_code=True
211+
)
212+
213+
print("load model")
214+
if training_args.train_with_peft:
215+
print("lora")
216+
from peft import get_peft_model, LoraConfig, TaskType
217+
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
218+
model = get_peft_model(model, peft_config)
219+
220+
try:
221+
tokenizer = transformers.AutoTokenizer.from_pretrained(
222+
model_args.tokenizer_name_or_path,
223+
cache_dir=training_args.cache_dir,
224+
model_max_length=training_args.model_max_length,
225+
padding_side="right",
226+
use_fast=False,
227+
trust_remote_code=True
228+
)
229+
except:
230+
tokenizer = transformers.LlamaTokenizer.from_pretrained(
231+
model_args.tokenizer_name_or_path,
232+
cache_dir=training_args.cache_dir,
233+
model_max_length=training_args.model_max_length,
234+
padding_side="right",
235+
use_fast=False
236+
)
237+
if 'qwen' not in model_args.model_name_or_path.lower():
238+
special_tokens_dict = dict()
239+
if tokenizer.pad_token is None:
240+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
241+
if tokenizer.eos_token is None:
242+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
243+
if tokenizer.bos_token is None:
244+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
245+
if tokenizer.unk_token is None:
246+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
247+
248+
smart_tokenizer_and_embedding_resize(
249+
special_tokens_dict=special_tokens_dict,
250+
tokenizer=tokenizer,
251+
model=model,
252+
)
253+
else:
254+
tokenizer.pad_token_id = tokenizer.eod_id
255+
256+
257+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
258+
trainer = Seq2SeqTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
259+
trainer.train()
260+
trainer.save_state()
261+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, save_peft=training_args.train_with_peft)
262+
263+
264+
if __name__ == "__main__":
265+
train()

‎pretrain.sh

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
model_name_or_path=$1
2+
modelname=$2
3+
layername=$3
4+
5+
data_name="text_pairs.json"
6+
mask_input=true
7+
ep=1
8+
save_steps=1000
9+
lr=5e-5
10+
bsz=16
11+
gas=4
12+
card=4
13+
worker=64
14+
15+
out_dir=outputs_pretrain_${modelname}
16+
mkdir -p ${out_dir}
17+
echo ${out_dir}
18+
19+
torchrun --nproc_per_node=${card} --master_port=28039 pretrain.py \
20+
--model_name_or_path ${model_name_or_path} \
21+
--tokenizer_name_or_path ${model_name_or_path} \
22+
--train_data_path ./datasets/${data_name} \
23+
--model_max_length 512 \
24+
--output_dir ${out_dir} \
25+
--num_train_epochs ${ep} \
26+
--per_device_train_batch_size ${bsz} \
27+
--per_device_eval_batch_size 1 \
28+
--gradient_accumulation_steps ${gas} \
29+
--evaluation_strategy "no" \
30+
--save_strategy "steps" \
31+
--save_steps ${save_steps} \
32+
--learning_rate ${lr} \
33+
--weight_decay 0. \
34+
--warmup_ratio 0.1 \
35+
--lr_scheduler_type "cosine" \
36+
--logging_steps 1 \
37+
--mask_input ${mask_input} \
38+
--dataloader_num_workers ${worker} \
39+
--bf16 True \
40+
--tf32 True \
41+
--fsdp "full_shard auto_wrap" \
42+
--fsdp_transformer_layer_cls_to_wrap ${layername}
43+
44+
echo ${out_dir}

‎rankdata/trec19/qrels.txt

+9,260
Large diffs are not rendered by default.

‎rankdata/trec19/top1000.json

+43,000
Large diffs are not rendered by default.

‎sft.py

+377
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional, Tuple, Union
15+
from transformers.modeling_outputs import (
16+
BaseModelOutputWithPastAndCrossAttentions,
17+
CausalLMOutputWithCrossAttentions,
18+
QuestionAnsweringModelOutput,
19+
SequenceClassifierOutputWithPast,
20+
TokenClassifierOutput,
21+
)
22+
import copy
23+
import logging
24+
from dataclasses import dataclass, field
25+
import json
26+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union,Sequence
27+
28+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
29+
import datasets
30+
import numpy as np
31+
import torch
32+
import transformers
33+
from torch.utils.data import Dataset
34+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, BloomForCausalLM, LlamaTokenizer
35+
import random
36+
# import evaluate
37+
import utils
38+
import os
39+
# metric = evaluate.load("rouge")
40+
41+
IGNORE_INDEX = -100
42+
DEFAULT_PAD_TOKEN = "[PAD]"
43+
DEFAULT_EOS_TOKEN = "</s>"
44+
DEFAULT_BOS_TOKEN = "</s>"
45+
DEFAULT_UNK_TOKEN = "</s>"
46+
47+
48+
@dataclass
49+
class ModelArguments:
50+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
51+
ref_path: Optional[str] = field(default="facebook/opt-125m")
52+
tokenizer_name_or_path: Optional[str] = field(default="facebook/opt-125m")
53+
temperature: Optional[float] = field(default=1.0)
54+
55+
top: int = field(default=24)
56+
57+
w_frozen: Optional[bool] = field(default=True)
58+
59+
@dataclass
60+
class DataArguments:
61+
train_data_path: str = field(default=None, metadata={"help": "Path to the training data."})
62+
train_group_size: int = field(default=-1)
63+
len_query: int = field(default=64)
64+
len_doc: int = field(default=438)
65+
66+
@dataclass
67+
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
68+
cache_dir: Optional[str] = field(default=None)
69+
optim: str = field(default="adamw_torch")
70+
model_max_length: int = field(
71+
default=2048,
72+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
73+
)
74+
75+
76+
def forward(
77+
self,
78+
input_ids: Optional[torch.LongTensor] = None,
79+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
80+
attention_mask: Optional[torch.Tensor] = None,
81+
head_mask: Optional[torch.Tensor] = None,
82+
inputs_embeds: Optional[torch.Tensor] = None,
83+
labels: Optional[torch.Tensor] = None,
84+
labels_gen: Optional[torch.Tensor] = None,
85+
use_cache: Optional[bool] = None,
86+
output_attentions: Optional[bool] = None,
87+
output_hidden_states: Optional[bool] = None,
88+
return_dict: Optional[bool] = None,
89+
**deprecated_arguments,
90+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
91+
r"""
92+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
93+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
94+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
95+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
96+
"""
97+
if deprecated_arguments.pop("position_ids", False) is not False:
98+
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
99+
warnings.warn(
100+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
101+
" passing `position_ids`.",
102+
FutureWarning,
103+
)
104+
if len(deprecated_arguments) > 0:
105+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
106+
107+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
108+
109+
transformer_outputs = self.transformer(
110+
input_ids,
111+
past_key_values=past_key_values,
112+
attention_mask=attention_mask,
113+
head_mask=head_mask,
114+
inputs_embeds=inputs_embeds,
115+
use_cache=use_cache,
116+
output_attentions=output_attentions,
117+
output_hidden_states=output_hidden_states,
118+
return_dict=return_dict,
119+
)
120+
hidden_states = transformer_outputs[0]
121+
122+
lm_logits = self.lm_head(hidden_states)
123+
124+
with torch.no_grad():
125+
init_lm_logits = self.init_model(input_ids=input_ids,attention_mask=attention_mask)[0]
126+
127+
loss = None
128+
if labels is not None:
129+
# move labels to correct device to enable model parallelism
130+
device = lm_logits.device
131+
labels = labels.to(device)
132+
labels_gen = labels_gen.to(device)
133+
indexs=(labels!=-100).long()
134+
label_no_ingore = torch.where(labels==-100,torch.ones(labels.shape).long().to(device),labels)
135+
136+
preds = torch.nn.functional.log_softmax(lm_logits,dim=-1) #BLV
137+
logprobs = torch.gather(preds, -1, label_no_ingore.unsqueeze(dim=-1)).squeeze(dim=-1) # B L
138+
scores = (logprobs*indexs).sum(dim=-1)/indexs.sum(dim=-1) #B -> bsz*group
139+
140+
141+
scores = torch.exp(scores).view(-1,self.train_group_size)/self.temperature # bsz,group
142+
143+
144+
target_label=torch.zeros(scores.shape[0], dtype=torch.long).to(device)
145+
loss1 = self.cross_entropy(scores, target_label)
146+
147+
# generation loss
148+
_,seq_length,vocab_size = lm_logits.shape
149+
pos_labels = labels_gen.view(-1,self.train_group_size,seq_length)[:,0] #BL
150+
pos_lm_logits = lm_logits.view(-1,self.train_group_size, seq_length, vocab_size)[:,0]
151+
152+
loss2 = self.cross_entropy(
153+
pos_lm_logits.reshape(-1, vocab_size), pos_labels.reshape(-1)
154+
)
155+
156+
# kl
157+
loss3 = self.kl_loss(input=preds.reshape([-1,vocab_size]), target=init_lm_logits.softmax(dim=-1).reshape([-1,vocab_size]))
158+
159+
loss = loss1 + loss2 + loss3
160+
161+
if not return_dict:
162+
output = (lm_logits,) + transformer_outputs[1:]
163+
return ((loss,) + output) if loss is not None else output
164+
165+
return CausalLMOutputWithCrossAttentions(
166+
loss=loss,
167+
logits=lm_logits,
168+
past_key_values=transformer_outputs.past_key_values,
169+
hidden_states=transformer_outputs.hidden_states,
170+
attentions=transformer_outputs.attentions,
171+
)
172+
173+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
174+
"""Collects the state dict and dump to disk."""
175+
state_dict = trainer.model.state_dict()
176+
if trainer.args.should_save:
177+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
178+
del state_dict
179+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
180+
181+
def smart_tokenizer_and_embedding_resize(
182+
special_tokens_dict: Dict,
183+
tokenizer: transformers.PreTrainedTokenizer,
184+
model: transformers.PreTrainedModel,
185+
):
186+
"""Resize tokenizer and embedding.
187+
188+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
189+
"""
190+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
191+
model.resize_token_embeddings(len(tokenizer))
192+
193+
if num_new_tokens > 0:
194+
input_embeddings = model.get_input_embeddings().weight.data
195+
output_embeddings = model.get_output_embeddings().weight.data
196+
197+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
198+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
199+
200+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
201+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
202+
203+
class SupervisedDataset(Dataset):
204+
def __init__(self, data, train_group_size, tokenizer, len_query, len_doc):
205+
self.data = data
206+
self.train_group_size=train_group_size
207+
self.tokenizer = tokenizer
208+
self.len_query=len_query
209+
self.len_doc=len_doc
210+
211+
def __len__(self):
212+
return len(self.data)
213+
214+
def __getitem__(self, idx):
215+
ex = self.data[idx]
216+
all_qd = []
217+
218+
if len(ex['negative_passages'])<self.train_group_size-1:
219+
all_qd = random.choices(ex['negative_passages'], k=self.train_group_size-1)
220+
else:
221+
all_qd = random.sample(ex['negative_passages'], self.train_group_size-1)
222+
223+
all_qd = [random.choice(ex['positive_passages'])] + all_qd
224+
225+
def truncation(text,length):
226+
text=self.tokenizer.decode(self.tokenizer.encode(text,max_length=length, add_special_tokens=False))
227+
return text
228+
229+
230+
query = truncation(ex['query'], self.len_query).replace(self.tokenizer.pad_token,'PAD')
231+
all_doc = [truncation(qd['text'], self.len_doc).replace(self.tokenizer.pad_token,'PAD') for qd in all_qd]
232+
233+
input_prompt = 'Document: {passage} Query:'
234+
235+
sources = [input_prompt.format(passage = doc) for doc in all_doc]
236+
targets=[query for _ in sources]
237+
238+
"""Preprocess the data by tokenizing."""
239+
examples = [s + t for s, t in zip(sources, targets)]
240+
examples_tokenized, sources_tokenized = [self._tokenize_fn(strings) for strings in (examples, sources)]
241+
input_ids = examples_tokenized["input_ids"]
242+
labels = copy.deepcopy(input_ids)
243+
labels_gen = copy.deepcopy(input_ids)
244+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
245+
label[:source_len] = IGNORE_INDEX
246+
assert len(input_ids)==len(labels)
247+
248+
return dict(input_ids=input_ids, labels=labels, labels_gen=labels_gen)
249+
250+
def _tokenize_fn(self, strings: Sequence[str]) -> Dict:
251+
"""Tokenize a list of strings."""
252+
tokenized_list = [
253+
self.tokenizer(
254+
text,
255+
return_tensors="pt",
256+
padding="longest",
257+
max_length=self.tokenizer.model_max_length,
258+
truncation=True,
259+
)
260+
for text in strings
261+
]
262+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
263+
input_ids_lens = labels_lens = [
264+
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
265+
]
266+
return dict(
267+
input_ids=input_ids,
268+
labels=labels,
269+
input_ids_lens=input_ids_lens,
270+
labels_lens=labels_lens,
271+
)
272+
273+
@dataclass
274+
class DataCollatorForSupervisedDataset(object):
275+
"""Collate examples for supervised fine-tuning."""
276+
tokenizer: transformers.PreTrainedTokenizer
277+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
278+
input_ids, labels, labels_gen = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "labels_gen"))
279+
input_ids=[item for sublist in input_ids for item in sublist]
280+
labels=[item for sublist in labels for item in sublist]
281+
labels_gen=[item for sublist in labels_gen for item in sublist]
282+
283+
for index in range(len(input_ids)):
284+
input_ids[index]=input_ids[index][:-1]
285+
286+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
287+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
288+
labels_gen = torch.nn.utils.rnn.pad_sequence(labels_gen, batch_first=True, padding_value=IGNORE_INDEX)
289+
290+
labels = labels[..., 1:].contiguous() #BL
291+
labels_gen = labels_gen[..., 1:].contiguous() #BL
292+
return dict(
293+
input_ids=input_ids,
294+
labels=labels,
295+
labels_gen=labels_gen,
296+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
297+
)
298+
299+
300+
def train():
301+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
302+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
303+
training_args.predict_with_generate=True
304+
305+
model = transformers.AutoModelForCausalLM.from_pretrained(
306+
model_args.model_name_or_path,
307+
cache_dir=training_args.cache_dir,
308+
)
309+
model.bsz = training_args.per_device_train_batch_size
310+
model.train_group_size = data_args.train_group_size
311+
model.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
312+
model.kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
313+
314+
model.temperature = model_args.temperature
315+
316+
model.init_model = transformers.AutoModelForCausalLM.from_pretrained(
317+
model_args.ref_path,
318+
cache_dir=training_args.cache_dir
319+
).eval()
320+
321+
if model_args.w_frozen:
322+
# peft
323+
for name, param in model.named_parameters():
324+
param.requires_grad = False
325+
326+
for name, param in model.transformer.h[-1*model_args.top:].named_parameters():
327+
param.requires_grad = True
328+
329+
from functools import partial
330+
model.forward = partial(forward, model)
331+
332+
if 'llama' in model_args.tokenizer_name_or_path.lower():
333+
tokenizer = LlamaTokenizer.from_pretrained(
334+
model_args.tokenizer_name_or_path,
335+
cache_dir=training_args.cache_dir,
336+
model_max_length=training_args.model_max_length,
337+
padding_side="right",
338+
use_fast=False,
339+
)
340+
else:
341+
tokenizer = transformers.AutoTokenizer.from_pretrained(
342+
model_args.tokenizer_name_or_path,
343+
cache_dir=training_args.cache_dir,
344+
model_max_length=training_args.model_max_length,
345+
padding_side="right",
346+
use_fast=False,
347+
)
348+
349+
if tokenizer.pad_token is None:
350+
smart_tokenizer_and_embedding_resize(
351+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
352+
tokenizer=tokenizer,
353+
model=model,
354+
)
355+
if "llama" in model_args.model_name_or_path:
356+
tokenizer.add_special_tokens(
357+
{
358+
"eos_token": DEFAULT_EOS_TOKEN,
359+
"bos_token": DEFAULT_BOS_TOKEN,
360+
"unk_token": DEFAULT_UNK_TOKEN,
361+
}
362+
)
363+
364+
data = datasets.load_dataset('json',data_files=data_args.train_data_path)['train']
365+
366+
367+
train_dataset = SupervisedDataset(data=data, train_group_size=data_args.train_group_size,tokenizer=tokenizer,len_query=data_args.len_query,len_doc=data_args.len_doc)
368+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
369+
370+
trainer = Seq2SeqTrainer(model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
371+
trainer.train()
372+
trainer.save_state()
373+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
374+
375+
376+
if __name__ == "__main__":
377+
train()

‎sft.sh

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
ep=1
2+
lr=3e-5
3+
bsz=1
4+
group=48
5+
gas=4
6+
card=8
7+
workers=64
8+
save_steps=1000
9+
data_name=msmarco.json
10+
temperature=0.001
11+
ref_path=$1
12+
ref_name=$2
13+
top=$3
14+
layername=$4
15+
16+
out_dir="outputs_sft_${ref_name}"
17+
18+
echo ${out_dir}
19+
mkdir -p ${out_dir}
20+
21+
torchrun --nproc_per_node=${card} --master_port=29405 sft.py \
22+
--model_name_or_path ${ref_path} \
23+
--tokenizer_name_or_path ${ref_path} \
24+
--train_data_path ./datasets/${data_name} \
25+
--model_max_length 512 \
26+
--output_dir ${out_dir} \
27+
--num_train_epochs ${ep} \
28+
--per_device_train_batch_size ${bsz} \
29+
--per_device_eval_batch_size 1 \
30+
--gradient_accumulation_steps ${gas} \
31+
--evaluation_strategy "no" \
32+
--save_strategy "steps" \
33+
--save_steps ${save_steps} \
34+
--learning_rate ${lr} \
35+
--weight_decay 0. \
36+
--warmup_ratio 0.03 \
37+
--lr_scheduler_type "cosine" \
38+
--logging_steps 1 \
39+
--train_group_size ${group} \
40+
--dataloader_num_workers ${workers} \
41+
--temperature ${temperature} \
42+
--len_query 32 \
43+
--len_doc 128 \
44+
--ref_path ${ref_path} \
45+
--only_query ${only_query} \
46+
--top ${top} \
47+
--bf16 True \
48+
--tf32 True \
49+
--fsdp "full_shard auto_wrap" \
50+
--fsdp_transformer_layer_cls_to_wrap ${layername}
51+
52+
echo ${out_dir}

0 commit comments

Comments
 (0)
Please sign in to comment.