Skip to content

Commit 81c00c2

Browse files
committed
add refact model
1 parent c091cdf commit 81c00c2

File tree

3 files changed

+659
-6
lines changed

3 files changed

+659
-6
lines changed

convert-refact-hf-to-gguf.py

+269
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
#!/usr/bin/env python3
2+
# HF falcon--> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import struct
10+
import sys
11+
from pathlib import Path
12+
from typing import Any
13+
14+
import numpy as np
15+
import torch
16+
from transformers import AutoTokenizer # type: ignore[import]
17+
18+
if 'NO_LOCAL_GGUF' not in os.environ:
19+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
20+
import gguf
21+
22+
23+
def bytes_to_unicode():
24+
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
25+
"""
26+
Returns list of utf-8 byte and a corresponding list of unicode strings.
27+
The reversible bpe codes work on unicode strings.
28+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30+
This is a significant percentage of your normal, say, 32K bpe vocab.
31+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32+
And avoids mapping to whitespace/control characters the bpe code barfs on.
33+
"""
34+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
35+
cs = bs[:]
36+
n = 0
37+
for b in range(2**8):
38+
if b not in bs:
39+
bs.append(b)
40+
cs.append(2**8+n)
41+
n += 1
42+
return dict(zip(bs, (chr(n) for n in cs)))
43+
44+
45+
def count_model_parts(dir_model: Path) -> int:
46+
num_parts = 0
47+
for filename in os.listdir(dir_model):
48+
if filename.startswith("pytorch_model-"):
49+
num_parts += 1
50+
51+
if num_parts > 0:
52+
print("gguf: found " + str(num_parts) + " model parts")
53+
return num_parts
54+
55+
56+
def parse_args() -> argparse.Namespace:
57+
parser = argparse.ArgumentParser(description="Convert a Refact model to a GGML compatible file")
58+
parser.add_argument(
59+
"--vocab-only", action="store_true",
60+
help="extract only the vocab",
61+
)
62+
parser.add_argument(
63+
"--outfile", type=Path,
64+
help="path to write to; default: based on input",
65+
)
66+
parser.add_argument(
67+
"model", type=Path,
68+
help="directory containing model file, or model file itself (*.bin)",
69+
)
70+
parser.add_argument(
71+
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
72+
help="output format - use 0 for float32, 1 for float16",
73+
)
74+
return parser.parse_args()
75+
76+
args = parse_args()
77+
78+
dir_model = args.model
79+
ftype = args.ftype
80+
if not dir_model.is_dir():
81+
82+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
83+
sys.exit(1)
84+
85+
# possible tensor data types
86+
# ftype == 0 -> float32
87+
# ftype == 1 -> float16
88+
89+
# map from ftype to string
90+
ftype_str = ["f32", "f16"]
91+
92+
if args.outfile is not None:
93+
fname_out = args.outfile
94+
else:
95+
# output in the same directory as the model by default
96+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
97+
98+
print("gguf: loading model "+dir_model.name)
99+
100+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
101+
hparams = json.load(f)
102+
103+
if hparams["architectures"][0] != "GPTRefactForCausalLM":
104+
print("Model architecture not supported: " + hparams["architectures"][0])
105+
106+
sys.exit(1)
107+
108+
# get number of model parts
109+
num_parts = count_model_parts(dir_model)
110+
111+
ARCH=gguf.MODEL_ARCH.REFACT
112+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
113+
114+
print("gguf: get model metadata")
115+
116+
# Get refact feed forward dimension
117+
hidden_dim = hparams["n_embd"]
118+
inner_dim = 4 * hidden_dim
119+
hidden_dim = int(2 * inner_dim / 3)
120+
multiple_of = 256
121+
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
122+
123+
block_count = hparams["n_layer"]
124+
125+
gguf_writer.add_name("Refact")
126+
# refact uses Alibi. So this is from config.json which might be used by training.
127+
gguf_writer.add_context_length(hparams["n_positions"])
128+
gguf_writer.add_embedding_length(hparams["n_embd"])
129+
130+
gguf_writer.add_feed_forward_length(ff_dim)
131+
gguf_writer.add_block_count(block_count)
132+
gguf_writer.add_head_count(hparams["n_head"])
133+
gguf_writer.add_head_count_kv(1)
134+
gguf_writer.add_layer_norm_rms_eps(hparams["layer_norm_epsilon"])
135+
gguf_writer.add_file_type(ftype)
136+
137+
# TOKENIZATION
138+
139+
print("gguf: get tokenizer metadata")
140+
141+
tokens: list[bytearray] = []
142+
scores: list[float] = []
143+
toktypes: list[int] = []
144+
145+
tokenizer_json_file = dir_model / 'tokenizer.json'
146+
if not tokenizer_json_file.is_file():
147+
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
148+
sys.exit(1)
149+
150+
# gpt2 tokenizer
151+
gguf_writer.add_tokenizer_model("gpt2")
152+
153+
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
154+
tokenizer_json = json.load(f)
155+
156+
print("gguf: get gpt2 tokenizer vocab")
157+
158+
# The number of tokens in tokenizer.json can differ from the expected vocab size.
159+
# This causes downstream issues with mismatched tensor sizes when running the inference
160+
vocab_size = hparams["vocab_size"] if "vocab_size" in hparams else len(tokenizer_json["model"]["vocab"])
161+
162+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
163+
164+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
165+
byte_encoder = bytes_to_unicode()
166+
byte_decoder = {v: k for k, v in byte_encoder.items()}
167+
168+
for i in range(vocab_size):
169+
if i in reverse_vocab:
170+
text = reverse_vocab[i]
171+
try:
172+
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
173+
except KeyError:
174+
text = bytearray()
175+
for c in reverse_vocab[i]:
176+
if ord(c) < 256: # single byte character
177+
text.append(byte_decoder[ord(c)])
178+
else: # multibyte special token character
179+
text.extend(c.encode('utf-8'))
180+
else:
181+
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
182+
pad_token = f"[PAD{i}]".encode("utf8")
183+
text = bytearray(pad_token)
184+
185+
tokens.append(text)
186+
scores.append(0.0) # dymmy
187+
toktypes.append(gguf.TokenType.NORMAL) # dummy
188+
189+
gguf_writer.add_token_list(tokens)
190+
gguf_writer.add_token_scores(scores)
191+
gguf_writer.add_token_types(toktypes)
192+
193+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
194+
special_vocab.add_to_gguf(gguf_writer)
195+
196+
# TENSORS
197+
198+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
199+
200+
# params for qkv transform
201+
n_head = hparams["n_head"]
202+
n_head_kv = 1
203+
204+
head_dim = hparams["n_embd"] // n_head
205+
206+
# tensor info
207+
print("gguf: get tensor metadata")
208+
209+
if num_parts == 0:
210+
part_names = iter(("pytorch_model.bin",))
211+
else:
212+
part_names = (
213+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
214+
)
215+
for part_name in part_names:
216+
if args.vocab_only:
217+
break
218+
print("gguf: loading model part '" + part_name + "'")
219+
model_part = torch.load(dir_model / part_name, map_location="cpu")
220+
221+
for name in model_part.keys():
222+
data = model_part[name]
223+
224+
old_dtype = data.dtype
225+
226+
# convert any unsupported data types to float32
227+
if data.dtype != torch.float16 and data.dtype != torch.float32:
228+
data = data.to(torch.float32)
229+
230+
data = data.squeeze().numpy()
231+
232+
# map tensor names
233+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ))
234+
if new_name is None:
235+
print("Can not map tensor '" + name + "'")
236+
sys.exit()
237+
238+
n_dims = len(data.shape)
239+
data_dtype = data.dtype
240+
241+
# if f32 desired, convert any float16 to float32
242+
if ftype == 0 and data_dtype == np.float16:
243+
data = data.astype(np.float32)
244+
245+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
246+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
247+
data = data.astype(np.float32)
248+
249+
# if f16 desired, convert any float32 2-dim weight tensors to float16
250+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
251+
data = data.astype(np.float16)
252+
253+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
254+
255+
gguf_writer.add_tensor(new_name, data)
256+
257+
258+
print("gguf: write header")
259+
gguf_writer.write_header_to_file()
260+
print("gguf: write metadata")
261+
gguf_writer.write_kv_data_to_file()
262+
if not args.vocab_only:
263+
print("gguf: write tensors")
264+
gguf_writer.write_tensors_to_file()
265+
266+
gguf_writer.close()
267+
268+
print(f"gguf: model successfully exported to '{fname_out}'")
269+
print("")

gguf-py/gguf/gguf.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum):
8585
GPTNEOX : int = auto()
8686
MPT : int = auto()
8787
STARCODER : int = auto()
88+
REFACT : int = auto()
8889

8990

9091
class MODEL_TENSOR(IntEnum):
@@ -116,6 +117,7 @@ class MODEL_TENSOR(IntEnum):
116117
MODEL_ARCH.GPTNEOX: "gptneox",
117118
MODEL_ARCH.MPT: "mpt",
118119
MODEL_ARCH.STARCODER: "starcoder",
120+
MODEL_ARCH.REFACT: "refact",
119121
}
120122

121123
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
@@ -185,6 +187,20 @@ class MODEL_TENSOR(IntEnum):
185187
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
186188
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
187189
},
190+
MODEL_ARCH.REFACT: {
191+
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
192+
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
193+
MODEL_TENSOR.OUTPUT: "output",
194+
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
195+
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
196+
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
197+
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
198+
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
199+
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
200+
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
201+
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
202+
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
203+
},
188204
MODEL_ARCH.GPT2: {
189205
# TODO
190206
},
@@ -209,7 +225,7 @@ class TensorNameMap:
209225
# Token embeddings
210226
MODEL_TENSOR.TOKEN_EMBD: (
211227
"gpt_neox.embed_in", # gptneox
212-
"transformer.wte", # gpt2 mpt
228+
"transformer.wte", # gpt2 mpt refact
213229
"transformer.word_embeddings", # falcon
214230
"model.embed_tokens", # llama-hf
215231
"tok_embeddings", # llama-pth
@@ -233,6 +249,7 @@ class TensorNameMap:
233249
"transformer.ln_f", # gpt2 falcon
234250
"model.norm", # llama-hf baichuan
235251
"norm", # llama-pth
252+
"ln_f", # refact
236253
),
237254

238255
# Rope frequencies
@@ -245,7 +262,7 @@ class TensorNameMap:
245262
# Attention norm
246263
MODEL_TENSOR.ATTN_NORM: (
247264
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
248-
"transformer.h.{bid}.ln_1", # gpt2
265+
"transformer.h.{bid}.ln_1", # gpt2 refact
249266
"transformer.blocks.{bid}.norm_1", # mpt
250267
"transformer.h.{bid}.input_layernorm", # falcon7b
251268
"transformer.h.{bid}.ln_mlp", # falcon40b
@@ -269,25 +286,28 @@ class TensorNameMap:
269286
# Attention query
270287
MODEL_TENSOR.ATTN_Q: (
271288
"model.layers.{bid}.self_attn.q_proj", # llama-hf
289+
"transformer.h.{bid}.attn.q", # refact
272290
"layers.{bid}.attention.wq", # llama-pth
273291
),
274292

275293
# Attention key
276294
MODEL_TENSOR.ATTN_K: (
277295
"model.layers.{bid}.self_attn.k_proj", # llama-hf
296+
"transformer.h.{bid}.attn.k", # refact
278297
"layers.{bid}.attention.wk", # llama-pth
279298
),
280299

281300
# Attention value
282301
MODEL_TENSOR.ATTN_V: (
283302
"model.layers.{bid}.self_attn.v_proj", # llama-hf
303+
"transformer.h.{bid}.attn.v", # refact
284304
"layers.{bid}.attention.wv", # llama-pth
285305
),
286306

287307
# Attention output
288308
MODEL_TENSOR.ATTN_OUT: (
289309
"gpt_neox.layers.{bid}.attention.dense", # gptneox
290-
"transformer.h.{bid}.attn.c_proj", # gpt2
310+
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
291311
"transformer.blocks.{bid}.attn.out_proj", # mpt
292312
"transformer.h.{bid}.self_attention.dense", # falcon
293313
"model.layers.{bid}.self_attn.o_proj", # llama-hf
@@ -303,7 +323,7 @@ class TensorNameMap:
303323
# Feed-forward norm
304324
MODEL_TENSOR.FFN_NORM: (
305325
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
306-
"transformer.h.{bid}.ln_2", # gpt2
326+
"transformer.h.{bid}.ln_2", # gpt2 refact
307327
"transformer.blocks.{bid}.norm_2", # mpt
308328
"model.layers.{bid}.post_attention_layernorm", # llama-hf
309329
"layers.{bid}.ffn_norm", # llama-pth
@@ -317,18 +337,20 @@ class TensorNameMap:
317337
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
318338
"model.layers.{bid}.mlp.up_proj", # llama-hf
319339
"layers.{bid}.feed_forward.w3", # llama-pth
340+
"transformer.h.{bid}.mlp.linear_3", # refact
320341
),
321342

322343
# Feed-forward gate
323344
MODEL_TENSOR.FFN_GATE: (
324345
"model.layers.{bid}.mlp.gate_proj", # llama-hf
346+
"transformer.h.{bid}.mlp.linear_1", # refact
325347
"layers.{bid}.feed_forward.w1", # llama-pth
326348
),
327349

328350
# Feed-forward down
329351
MODEL_TENSOR.FFN_DOWN: (
330352
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
331-
"transformer.h.{bid}.mlp.c_proj", # gpt2
353+
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
332354
"transformer.blocks.{bid}.ffn.down_proj", # mpt
333355
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
334356
"model.layers.{bid}.mlp.down_proj", # llama-hf

0 commit comments

Comments
 (0)