Skip to content

Commit ea5497d

Browse files
gpt2 : Add gpt2 architecture integration (#4555)
1 parent f679349 commit ea5497d

7 files changed

+281
-14
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ as the main playground for developing new features for the [ggml](https://github
103103
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
104104
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
105105
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
106+
- [x] [GPT-2](https://huggingface.co/gpt2)
106107

107108
**Multimodal models:**
108109

convert-hf-to-gguf.py

+66
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
182182
return QwenModel
183183
if model_architecture == "MixtralForCausalLM":
184184
return MixtralModel
185+
if model_architecture == "GPT2LMHeadModel":
186+
return GPT2Model
185187
if model_architecture == "PhiForCausalLM":
186188
return Phi2Model
187189
if model_architecture == "PlamoForCausalLM":
@@ -225,6 +227,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
225227
return gguf.MODEL_ARCH.QWEN
226228
if arch == "MixtralForCausalLM":
227229
return gguf.MODEL_ARCH.LLAMA
230+
if arch == "GPT2LMHeadModel":
231+
return gguf.MODEL_ARCH.GPT2
228232
if arch == "PhiForCausalLM":
229233
return gguf.MODEL_ARCH.PHI2
230234
if arch == "PlamoForCausalLM":
@@ -993,6 +997,68 @@ def write_tensors(self):
993997
self.gguf_writer.add_tensor(new_name, data)
994998

995999

1000+
class GPT2Model(Model):
1001+
def set_gguf_parameters(self):
1002+
self.gguf_writer.add_name(self.dir_model.name)
1003+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1004+
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
1005+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
1006+
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
1007+
self.gguf_writer.add_head_count(self.hparams["n_head"])
1008+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
1009+
self.gguf_writer.add_file_type(self.ftype)
1010+
1011+
def write_tensors(self):
1012+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
1013+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1014+
1015+
for name, data_torch in self.get_tensors():
1016+
# we don't need these
1017+
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias")):
1018+
continue
1019+
1020+
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
1021+
data_torch = data_torch.transpose(1, 0)
1022+
1023+
old_dtype = data_torch.dtype
1024+
1025+
# convert any unsupported data types to float32
1026+
if data_torch.dtype not in (torch.float16, torch.float32):
1027+
data_torch = data_torch.to(torch.float32)
1028+
1029+
data = data_torch.squeeze().numpy()
1030+
1031+
# map tensor names
1032+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1033+
if new_name is None:
1034+
print(f"Can not map tensor {name!r}")
1035+
sys.exit()
1036+
1037+
n_dims = len(data.shape)
1038+
data_dtype = data.dtype
1039+
1040+
# if f32 desired, convert any float16 to float32
1041+
if self.ftype == 0 and data_dtype == np.float16:
1042+
data = data.astype(np.float32)
1043+
1044+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1045+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1046+
data = data.astype(np.float32)
1047+
1048+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1049+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1050+
data = data.astype(np.float16)
1051+
1052+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1053+
1054+
self.gguf_writer.add_tensor(new_name, data)
1055+
1056+
# note: GPT2 output is tied to (same as) wte in original model
1057+
if new_name == "token_embd.weight":
1058+
print(f"output.weight, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1059+
self.gguf_writer.add_tensor("output.weight", data)
1060+
1061+
9961062
class Phi2Model(Model):
9971063
def set_gguf_parameters(self):
9981064
block_count = self.hparams["n_layer"]

gguf-py/gguf/constants.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,16 @@ class MODEL_TENSOR(IntEnum):
370370
MODEL_TENSOR.FFN_UP,
371371
],
372372
MODEL_ARCH.GPT2: [
373-
# TODO
373+
MODEL_TENSOR.TOKEN_EMBD,
374+
MODEL_TENSOR.POS_EMBD,
375+
MODEL_TENSOR.OUTPUT_NORM,
376+
MODEL_TENSOR.OUTPUT,
377+
MODEL_TENSOR.ATTN_NORM,
378+
MODEL_TENSOR.ATTN_QKV,
379+
MODEL_TENSOR.ATTN_OUT,
380+
MODEL_TENSOR.FFN_NORM,
381+
MODEL_TENSOR.FFN_DOWN,
382+
MODEL_TENSOR.FFN_UP,
374383
],
375384
MODEL_ARCH.PHI2: [
376385
MODEL_TENSOR.TOKEN_EMBD,

gguf-py/gguf/tensor_mapping.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class TensorNameMap:
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert
1919
"language_model.embedding.word_embeddings", # persimmon
20+
"wte", # gpt2
2021
"transformer.embd.wte", # phi2
2122
),
2223

@@ -34,6 +35,7 @@ class TensorNameMap:
3435
MODEL_TENSOR.POS_EMBD: (
3536
"transformer.wpe", # gpt2
3637
"embeddings.position_embeddings", # bert
38+
"wpe", # gpt2
3739
),
3840

3941
# Output
@@ -53,7 +55,7 @@ class TensorNameMap:
5355
"norm", # llama-pth
5456
"embeddings.LayerNorm", # bert
5557
"transformer.norm_f", # mpt
56-
"ln_f", # refact bloom qwen
58+
"ln_f", # refact bloom qwen gpt2
5759
"language_model.encoder.final_layernorm", # persimmon
5860
"lm_head.ln", # phi2
5961
),
@@ -78,6 +80,7 @@ class TensorNameMap:
7880
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
7981
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
8082
"model.layers.{bid}.ln1", # yi
83+
"h.{bid}.ln_1", # gpt2
8184
"transformer.h.{bid}.ln", # phi2
8285
"model.layers.layers.{bid}.norm", # plamo
8386
),
@@ -95,6 +98,7 @@ class TensorNameMap:
9598
"transformer.h.{bid}.self_attention.query_key_value", # falcon
9699
"h.{bid}.self_attention.query_key_value", # bloom
97100
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
101+
"h.{bid}.attn.c_attn", # gpt2
98102
"transformer.h.{bid}.mixer.Wqkv", # phi2
99103
),
100104

@@ -137,6 +141,7 @@ class TensorNameMap:
137141
"encoder.layer.{bid}.attention.output.dense", # bert
138142
"transformer.h.{bid}.attn.out_proj", # gpt-j
139143
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
144+
"h.{bid}.attn.c_proj", # gpt2
140145
"transformer.h.{bid}.mixer.out_proj", # phi2
141146
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
142147
),
@@ -159,6 +164,7 @@ class TensorNameMap:
159164
"encoder.layer.{bid}.output.LayerNorm", # bert
160165
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
161166
"model.layers.{bid}.ln2", # yi
167+
"h.{bid}.ln_2", # gpt2
162168
),
163169

164170
MODEL_TENSOR.FFN_GATE_INP: (
@@ -179,6 +185,7 @@ class TensorNameMap:
179185
"transformer.h.{bid}.mlp.fc_in", # gpt-j
180186
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
181187
"transformer.h.{bid}.mlp.w1", # qwen
188+
"h.{bid}.mlp.c_fc", # gpt2
182189
"transformer.h.{bid}.mlp.fc1", # phi2
183190
"model.layers.layers.{bid}.mlp.up_proj", # plamo
184191
),
@@ -218,6 +225,7 @@ class TensorNameMap:
218225
"encoder.layer.{bid}.output.dense", # bert
219226
"transformer.h.{bid}.mlp.fc_out", # gpt-j
220227
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
228+
"h.{bid}.mlp.c_proj", # gpt2
221229
"transformer.h.{bid}.mlp.fc2", # phi2
222230
"model.layers.layers.{bid}.mlp.down_proj", # plamo
223231
),

0 commit comments

Comments
 (0)