Skip to content

Commit 5d22a42

Browse files
arki05ggerganov
authored andcommitted
llama : add grok-1 support (ggml-org#6204)
* Add support for Grok model architecture * Revert convert-hf-to-gguf to default options * Fixed f_norm_rms_eps bug * Fix whitespaces * llama : fix grok rope type * llama : minor --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent ec2bd69 commit 5d22a42

File tree

4 files changed

+384
-20
lines changed

4 files changed

+384
-20
lines changed

convert-hf-to-gguf.py

+26
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,42 @@ def set_gguf_parameters(self):
9393

9494
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
9595
self.gguf_writer.add_context_length(n_ctx)
96+
print(f"gguf: context length = {n_ctx}")
9697

9798
n_embd = self.find_hparam(["hidden_size", "n_embd"])
9899
self.gguf_writer.add_embedding_length(n_embd)
100+
print(f"gguf: embedding length = {n_embd}")
99101

100102
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
101103
self.gguf_writer.add_feed_forward_length(n_ff)
104+
print(f"gguf: feed forward length = {n_ff}")
102105

103106
n_head = self.find_hparam(["num_attention_heads", "n_head"])
104107
self.gguf_writer.add_head_count(n_head)
108+
print(f"gguf: head count = {n_head}")
105109

106110
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
107111
self.gguf_writer.add_head_count_kv(n_head_kv)
112+
print(f"gguf: key-value head count = {n_head_kv}")
108113

109114
if (rope_theta := self.hparams.get("rope_theta")) is not None:
110115
self.gguf_writer.add_rope_freq_base(rope_theta)
116+
print(f"gguf: rope theta = {rope_theta}")
111117
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
112118
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
119+
print(f"gguf: rms norm epsilon = {f_rms_eps}")
113120
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
114121
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
122+
print(f"gguf: layer norm epsilon = {f_norm_eps}")
115123
if (n_experts := self.hparams.get("num_local_experts")) is not None:
116124
self.gguf_writer.add_expert_count(n_experts)
125+
print(f"gguf: expert count = {n_experts}")
117126
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
118127
self.gguf_writer.add_expert_used_count(n_experts_used)
128+
print(f"gguf: experts used count = {n_experts_used}")
119129

120130
self.gguf_writer.add_file_type(self.ftype)
131+
print(f"gguf: file type = {self.ftype}")
121132

122133
def write_tensors(self):
123134
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
@@ -1051,6 +1062,21 @@ def set_vocab(self):
10511062
self._set_vocab_sentencepiece()
10521063

10531064

1065+
@Model.register("GrokForCausalLM")
1066+
class GrokModel(Model):
1067+
model_arch = gguf.MODEL_ARCH.GROK
1068+
1069+
def set_vocab(self):
1070+
self._set_vocab_sentencepiece()
1071+
1072+
def __init__(self, *args, **kwargs):
1073+
super().__init__(*args, **kwargs)
1074+
1075+
def set_gguf_parameters(self):
1076+
super().set_gguf_parameters()
1077+
self.gguf_writer.add_name("Grok")
1078+
1079+
10541080
@Model.register("MiniCPMForCausalLM")
10551081
class MiniCPMModel(Model):
10561082
model_arch = gguf.MODEL_ARCH.MINICPM

gguf-py/gguf/constants.py

+24
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class MODEL_ARCH(IntEnum):
100100
LLAMA = auto()
101101
FALCON = auto()
102102
BAICHUAN = auto()
103+
GROK = auto()
103104
GPT2 = auto()
104105
GPTJ = auto()
105106
GPTNEOX = auto()
@@ -167,6 +168,7 @@ class MODEL_TENSOR(IntEnum):
167168
MODEL_ARCH.LLAMA: "llama",
168169
MODEL_ARCH.FALCON: "falcon",
169170
MODEL_ARCH.BAICHUAN: "baichuan",
171+
MODEL_ARCH.GROK: "grok",
170172
MODEL_ARCH.GPT2: "gpt2",
171173
MODEL_ARCH.GPTJ: "gptj",
172174
MODEL_ARCH.GPTNEOX: "gptneox",
@@ -251,6 +253,28 @@ class MODEL_TENSOR(IntEnum):
251253
MODEL_TENSOR.FFN_DOWN_EXP,
252254
MODEL_TENSOR.FFN_UP_EXP,
253255
],
256+
MODEL_ARCH.GROK: [
257+
MODEL_TENSOR.TOKEN_EMBD,
258+
MODEL_TENSOR.OUTPUT_NORM,
259+
MODEL_TENSOR.OUTPUT,
260+
MODEL_TENSOR.ROPE_FREQS,
261+
MODEL_TENSOR.ATTN_NORM,
262+
MODEL_TENSOR.ATTN_Q,
263+
MODEL_TENSOR.ATTN_K,
264+
MODEL_TENSOR.ATTN_V,
265+
MODEL_TENSOR.ATTN_OUT,
266+
MODEL_TENSOR.ATTN_ROT_EMBD,
267+
MODEL_TENSOR.ATTN_OUT_NORM,
268+
MODEL_TENSOR.FFN_GATE_INP,
269+
MODEL_TENSOR.FFN_NORM,
270+
MODEL_TENSOR.FFN_GATE,
271+
MODEL_TENSOR.FFN_DOWN,
272+
MODEL_TENSOR.FFN_UP,
273+
MODEL_TENSOR.FFN_GATE_EXP,
274+
MODEL_TENSOR.FFN_DOWN_EXP,
275+
MODEL_TENSOR.FFN_UP_EXP,
276+
MODEL_TENSOR.LAYER_OUT_NORM,
277+
],
254278
MODEL_ARCH.GPTNEOX: [
255279
MODEL_TENSOR.TOKEN_EMBD,
256280
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class TensorNameMap:
2323
"model.embedding", # mamba-qbert
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
26+
"transformer.in_out_embed", # Grok
2627
),
2728

2829
# Token type embeddings
@@ -66,6 +67,7 @@ class TensorNameMap:
6667
"lm_head.ln", # phi2
6768
"model.norm_f", # mamba-qbert
6869
"backbone.norm_f", # mamba
70+
"transformer.rms_norm", # Grok
6971
),
7072

7173
# Rope frequencies
@@ -93,6 +95,7 @@ class TensorNameMap:
9395
"model.layers.{bid}.attention_norm", # internlm2
9496
"model.layers.{bid}.norm", # mamba-qbert
9597
"backbone.layers.{bid}.norm", # mamba
98+
"transformer.decoder_layer.{bid}.rms_norm", # Grok
9699
),
97100

98101
# Attention norm 2
@@ -116,32 +119,35 @@ class TensorNameMap:
116119

117120
# Attention query
118121
MODEL_TENSOR.ATTN_Q: (
119-
"model.layers.{bid}.self_attn.q_proj", # llama-hf
120-
"layers.{bid}.attention.wq", # llama-pth
121-
"encoder.layer.{bid}.attention.self.query", # bert
122-
"transformer.h.{bid}.attn.q_proj", # gpt-j
123-
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
124-
"model.layers.{bid}.attention.wq" # internlm2
122+
"model.layers.{bid}.self_attn.q_proj", # llama-hf
123+
"layers.{bid}.attention.wq", # llama-pth
124+
"encoder.layer.{bid}.attention.self.query", # bert
125+
"transformer.h.{bid}.attn.q_proj", # gpt-j
126+
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
127+
"model.layers.{bid}.attention.wq", # internlm2
128+
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
125129
),
126130

127131
# Attention key
128132
MODEL_TENSOR.ATTN_K: (
129-
"model.layers.{bid}.self_attn.k_proj", # llama-hf
130-
"layers.{bid}.attention.wk", # llama-pth
131-
"encoder.layer.{bid}.attention.self.key", # bert
132-
"transformer.h.{bid}.attn.k_proj", # gpt-j
133-
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
134-
"model.layers.{bid}.attention.wk" # internlm2
133+
"model.layers.{bid}.self_attn.k_proj", # llama-hf
134+
"layers.{bid}.attention.wk", # llama-pth
135+
"encoder.layer.{bid}.attention.self.key", # bert
136+
"transformer.h.{bid}.attn.k_proj", # gpt-j
137+
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
138+
"model.layers.{bid}.attention.wk", # internlm2
139+
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
135140
),
136141

137142
# Attention value
138143
MODEL_TENSOR.ATTN_V: (
139-
"model.layers.{bid}.self_attn.v_proj", # llama-hf
140-
"layers.{bid}.attention.wv", # llama-pth
141-
"encoder.layer.{bid}.attention.self.value", # bert
142-
"transformer.h.{bid}.attn.v_proj", # gpt-j
143-
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
144-
"model.layers.{bid}.attention.wv" # internlm2
144+
"model.layers.{bid}.self_attn.v_proj", # llama-hf
145+
"layers.{bid}.attention.wv", # llama-pth
146+
"encoder.layer.{bid}.attention.self.value", # bert
147+
"transformer.h.{bid}.attn.v_proj", # gpt-j
148+
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
149+
"model.layers.{bid}.attention.wv", # internlm2
150+
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
145151
),
146152

147153
# Attention output
@@ -162,12 +168,14 @@ class TensorNameMap:
162168
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
163169
"model.layers.{bid}.attention.wo", # internlm2
164170
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
171+
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
165172
),
166173

167174
# Attention output norm
168175
MODEL_TENSOR.ATTN_OUT_NORM: (
169176
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
170177
"encoder.layers.{bid}.norm1", # nomic-bert
178+
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
171179
),
172180

173181
# Rotary embeddings
@@ -190,11 +198,13 @@ class TensorNameMap:
190198
"model.layers.{bid}.ln2", # yi
191199
"h.{bid}.ln_2", # gpt2
192200
"model.layers.{bid}.ffn_norm", # internlm2
201+
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
193202
),
194203

195204
MODEL_TENSOR.FFN_GATE_INP: (
196205
"layers.{bid}.feed_forward.gate", # mixtral
197206
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
207+
"transformer.decoder_layer.{bid}.router" # Grok
198208
),
199209

200210
# Feed-forward up
@@ -223,6 +233,7 @@ class TensorNameMap:
223233
MODEL_TENSOR.FFN_UP_EXP: (
224234
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
225235
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
236+
"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok
226237
),
227238

228239
# AWQ-activation gate
@@ -243,6 +254,7 @@ class TensorNameMap:
243254
MODEL_TENSOR.FFN_GATE_EXP: (
244255
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
245256
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
257+
"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
246258
),
247259

248260
# Feed-forward down
@@ -270,6 +282,8 @@ class TensorNameMap:
270282
MODEL_TENSOR.FFN_DOWN_EXP: (
271283
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
272284
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
285+
"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok
286+
273287
),
274288

275289
MODEL_TENSOR.ATTN_Q_NORM: (
@@ -287,8 +301,9 @@ class TensorNameMap:
287301
),
288302

289303
MODEL_TENSOR.LAYER_OUT_NORM: (
290-
"encoder.layer.{bid}.output.LayerNorm", # bert
291-
"encoder.layers.{bid}.norm2", # nomic-bert
304+
"encoder.layer.{bid}.output.LayerNorm", # bert
305+
"encoder.layers.{bid}.norm2", # nomic-bert
306+
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
292307
),
293308

294309
MODEL_TENSOR.SSM_IN: (

0 commit comments

Comments
 (0)