Skip to content

Commit de06b53

Browse files
committed
Added grok-1 support
1 parent 2c33914 commit de06b53

File tree

4 files changed

+80
-31
lines changed

4 files changed

+80
-31
lines changed

mixtral-moe/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# Grok-1 Support
2+
```
3+
export MODEL_REPO=hpcai-tech/grok-1
4+
python scripts/download.py --repo_id $MODEL_REPO
5+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
6+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
7+
8+
TOKENIZERS_PARALLELISM=false ENABLE_INTRA_NODE_COMM=1 time torchrun --standalone --nproc_per_node=8 generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --compile --compile_prefill
9+
```
110
# Mixtral 8x7B
211
[Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B.
312

mixtral-moe/generate.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def generate(
131131
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
132132
tokens = tokenizer.encode(string)
133133
if bos:
134-
tokens = [tokenizer.bos_id()] + tokens
134+
tokens = [tokenizer.bos_token_id] + tokens
135135
return torch.tensor(tokens, dtype=torch.int, device=device)
136136

137137
def _load_model(checkpoint_path, device, precision, use_tp):
@@ -174,7 +174,7 @@ def main(
174174
"""
175175
assert checkpoint_path.is_file(), checkpoint_path
176176

177-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
177+
tokenizer_path = checkpoint_path.parent / "tokenizer.json"
178178
assert tokenizer_path.is_file(), str(tokenizer_path)
179179

180180
global print
@@ -196,7 +196,9 @@ def main(
196196
device_sync(device=device) # MKG
197197
print(f"Time to load model: {time.time() - t0:.02f} seconds")
198198

199-
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
199+
# tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
200+
from transformers import AutoTokenizer
201+
tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True)
200202
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
201203
prompt_length = encoded.size(0)
202204

@@ -235,7 +237,7 @@ def callback(x):
235237
if done_generating:
236238
return
237239
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
238-
if x.item() == tokenizer.eos_id():
240+
if x.item() == tokenizer.eos_token_id:
239241
done_generating = True
240242
if len(buffer) == 4 or done_generating:
241243
print(''.join(buffer), end='', flush=True)

mixtral-moe/model.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ def from_name(cls, name: str):
5050
assert len(config) == 1, name
5151
return cls(**transformer_configs[config[0]])
5252

53+
attn_output_multiplier = 0.08838834764831845
54+
embedding_multiplier_scale = 78.38367176906169
55+
output_multiplier_scale = 0.5773502691896257
56+
max_attn_val = 30.0
5357

5458
transformer_configs = {
5559
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
60+
"grok-1": dict(vocab_size=131072, block_size=8192, n_layer=64, n_head=48, n_local_heads=8, dim=6144, intermediate_size=32768, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
5661
}
5762

5863
class KVCache(nn.Module):
@@ -106,11 +111,13 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
106111
mask = self.causal_mask[None, None, input_pos]
107112
freqs_cis = self.freqs_cis[input_pos]
108113
x = self.tok_embeddings(idx)
114+
x *= embedding_multiplier_scale
109115

110116
for i, layer in enumerate(self.layers):
111117
x = layer(x, input_pos, freqs_cis, mask)
112118
x = self.norm(x)
113119
logits = self.output(x)
120+
logits *= output_multiplier_scale
114121
return logits
115122

116123
@classmethod
@@ -123,12 +130,14 @@ def __init__(self, config: ModelArgs) -> None:
123130
super().__init__()
124131
self.attention = Attention(config)
125132
self.block_sparse_moe = MOEFeedForward(config)
126-
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
127-
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
133+
self.pre_moe_norm = RMSNorm(config.dim, config.norm_eps)
134+
self.post_moe_norm = RMSNorm(config.dim, config.norm_eps)
135+
self.post_attn_norm = RMSNorm(config.dim, config.norm_eps)
136+
self.pre_attn_norm = RMSNorm(config.dim, config.norm_eps)
128137

129138
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
130-
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
131-
out = h + self.block_sparse_moe(self.ffn_norm(h))
139+
h = x + self.post_attn_norm(self.attention(self.pre_attn_norm(x), freqs_cis, mask, input_pos))
140+
out = h + self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(h)))
132141
return out
133142

134143

@@ -160,7 +169,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
160169
bsz, seqlen, _ = x.shape
161170

162171
kv_size = self.n_local_heads * self.head_dim
163-
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
172+
qkv = self.wqkv(x)
173+
q, k, v = qkv.split([self.dim, kv_size, kv_size], dim=-1)
164174

165175
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
166176
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
@@ -176,7 +186,13 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
176186

177187
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
178188
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
179-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
189+
attn_weights = torch.matmul(q, k.transpose(2, 3)).to(torch.float32)
190+
attn_weights = attn_weights * attn_output_multiplier
191+
attn_weights = max_attn_val * F.tanh(attn_weights / max_attn_val)
192+
attn_weights += torch.where(mask, 0, -float("inf"))
193+
attn_weights = F.softmax(attn_weights, dim=-1).to(q.dtype)
194+
y = torch.matmul(attn_weights, v)
195+
# y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
180196

181197
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
182198

mixtral-moe/scripts/convert_hf_checkpoint.py

+43-21
Original file line numberDiff line numberDiff line change
@@ -32,42 +32,52 @@ def convert_hf_checkpoint(
3232
print(f"Model config {config.__dict__}")
3333

3434
weight_map = {
35-
"tok_embeddings.weight": "tok_embeddings.weight",
36-
"layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight",
37-
"layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
38-
"layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
39-
"layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
40-
"layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
41-
"layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
42-
"layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
43-
"layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
44-
"layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight",
45-
"layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
46-
"norm.weight": "norm.weight",
47-
"output.weight": "output.weight",
35+
"model.embed_tokens.weight": "tok_embeddings.weight",
36+
"model.layers.{}.attn.q_proj.weight": "layers.{}.attention.wq.weight",
37+
"model.layers.{}.attn.k_proj.weight": "layers.{}.attention.wk.weight",
38+
"model.layers.{}.attn.v_proj.weight": "layers.{}.attention.wv.weight",
39+
"model.layers.{}.attn.o_proj.weight": "layers.{}.attention.wo.weight",
40+
# "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
41+
# "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
42+
# "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
43+
"model.layers.{}.moe_block.experts.{}.linear.weight": "layers.{}.block_sparse_moe.cond_ffn.w1.{}",
44+
"model.layers.{}.moe_block.experts.{}.linear_1.weight": "layers.{}.block_sparse_moe.cond_ffn.w2.{}",
45+
"model.layers.{}.moe_block.experts.{}.linear_v.weight": "layers.{}.block_sparse_moe.cond_ffn.w3.{}",
46+
"model.layers.{}.moe_block.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
47+
"model.layers.{}.pre_attn_norm.scale": "layers.{}.pre_attn_norm.weight",
48+
"model.layers.{}.post_attn_norm.scale": "layers.{}.post_attn_norm.weight",
49+
"model.layers.{}.pre_moe_norm.scale": "layers.{}.pre_moe_norm.weight",
50+
"model.layers.{}.post_moe_norm.scale": "layers.{}.post_moe_norm.weight",
51+
"model.norm.scale": "norm.weight",
52+
"lm_head.weight": "output.weight",
4853
}
4954

50-
pt_files = glob.glob(str(checkpoint_dir / "*.pt"))
55+
pt_files = glob.glob(str(checkpoint_dir / "*.bin"))
5156

5257
merged_result = {}
5358
for file in sorted(pt_files):
5459
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
5560
merged_result.update(state_dict)
5661
final_result = {}
57-
for key, value in merged_result.items():
62+
for key, value in list(merged_result.items()):
5863
if "layers" in key:
59-
abstract_key = re.sub(r'.(\d+).', '.{}.', key)
60-
layer_num = re.search(r'\d+', key).group(0)
64+
abstract_key = re.sub(r'\.(\d+)\.', '.{}.', key)
65+
nums = re.findall(r'\d+', key)
66+
if abstract_key not in weight_map:
67+
continue
6168
new_key = weight_map[abstract_key]
6269
if new_key is None:
6370
continue
64-
new_key = new_key.format(layer_num)
71+
new_key = new_key.format(*nums)
6572
else:
73+
if key not in weight_map:
74+
continue
6675
new_key = weight_map[key]
67-
6876
final_result[new_key] = value
77+
del merged_result[key]
6978

7079
for key in tuple(final_result.keys()):
80+
print(key)
7181
if "wq" in key:
7282
q = final_result[key]
7383
k = final_result[key.replace("wq", "wk")]
@@ -77,9 +87,21 @@ def convert_hf_checkpoint(
7787
del final_result[key.replace("wq", "wk")]
7888
del final_result[key.replace("wq", "wv")]
7989
elif "w1" in key or "w3" in key:
80-
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
90+
if not key.endswith('0'):
91+
continue
92+
full_keys = [key[:-1] + str(i) for i in range(8)]
93+
results = [final_result[k] for k in full_keys]
94+
final_result[key[:-2]] = torch.stack(results, dim=0)
95+
for k in full_keys:
96+
del final_result[k]
8197
elif "w2" in key:
82-
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
98+
if not key.endswith('0'):
99+
continue
100+
full_keys = [key[:-1] + str(i) for i in range(8)]
101+
results = [final_result[k] for k in full_keys]
102+
final_result[key[:-2]] = torch.stack(results, dim=0)
103+
for k in full_keys:
104+
del final_result[k]
83105
elif "gate" in key:
84106
final_result[key] = final_result[key].contiguous()
85107

0 commit comments

Comments
 (0)