@@ -32,42 +32,52 @@ def convert_hf_checkpoint(
32
32
print (f"Model config { config .__dict__ } " )
33
33
34
34
weight_map = {
35
- "tok_embeddings.weight" : "tok_embeddings.weight" ,
36
- "layers.{}.attention.wq.weight" : "layers.{}.attention.wq.weight" ,
37
- "layers.{}.attention.wk.weight" : "layers.{}.attention.wk.weight" ,
38
- "layers.{}.attention.wv.weight" : "layers.{}.attention.wv.weight" ,
39
- "layers.{}.attention.wo.weight" : "layers.{}.attention.wo.weight" ,
40
- "layers.{}.block_sparse_moe.w1" : "layers.{}.block_sparse_moe.cond_ffn.w1" ,
41
- "layers.{}.block_sparse_moe.w2" : "layers.{}.block_sparse_moe.cond_ffn.w2" ,
42
- "layers.{}.block_sparse_moe.w3" : "layers.{}.block_sparse_moe.cond_ffn.w3" ,
43
- "layers.{}.block_sparse_moe.gate.weight" : "layers.{}.block_sparse_moe.gate.weight" ,
44
- "layers.{}.attention_norm.weight" : "layers.{}.attention_norm.weight" ,
45
- "layers.{}.ffn_norm.weight" : "layers.{}.ffn_norm.weight" ,
46
- "norm.weight" : "norm.weight" ,
47
- "output.weight" : "output.weight" ,
35
+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
36
+ "model.layers.{}.attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
37
+ "model.layers.{}.attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
38
+ "model.layers.{}.attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
39
+ "model.layers.{}.attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
40
+ # "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
41
+ # "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
42
+ # "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
43
+ "model.layers.{}.moe_block.experts.{}.linear.weight" : "layers.{}.block_sparse_moe.cond_ffn.w1.{}" ,
44
+ "model.layers.{}.moe_block.experts.{}.linear_1.weight" : "layers.{}.block_sparse_moe.cond_ffn.w2.{}" ,
45
+ "model.layers.{}.moe_block.experts.{}.linear_v.weight" : "layers.{}.block_sparse_moe.cond_ffn.w3.{}" ,
46
+ "model.layers.{}.moe_block.gate.weight" : "layers.{}.block_sparse_moe.gate.weight" ,
47
+ "model.layers.{}.pre_attn_norm.scale" : "layers.{}.pre_attn_norm.weight" ,
48
+ "model.layers.{}.post_attn_norm.scale" : "layers.{}.post_attn_norm.weight" ,
49
+ "model.layers.{}.pre_moe_norm.scale" : "layers.{}.pre_moe_norm.weight" ,
50
+ "model.layers.{}.post_moe_norm.scale" : "layers.{}.post_moe_norm.weight" ,
51
+ "model.norm.scale" : "norm.weight" ,
52
+ "lm_head.weight" : "output.weight" ,
48
53
}
49
54
50
- pt_files = glob .glob (str (checkpoint_dir / "*.pt " ))
55
+ pt_files = glob .glob (str (checkpoint_dir / "*.bin " ))
51
56
52
57
merged_result = {}
53
58
for file in sorted (pt_files ):
54
59
state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
55
60
merged_result .update (state_dict )
56
61
final_result = {}
57
- for key , value in merged_result .items ():
62
+ for key , value in list ( merged_result .items () ):
58
63
if "layers" in key :
59
- abstract_key = re .sub (r'.(\d+).' , '.{}.' , key )
60
- layer_num = re .search (r'\d+' , key ).group (0 )
64
+ abstract_key = re .sub (r'\.(\d+)\.' , '.{}.' , key )
65
+ nums = re .findall (r'\d+' , key )
66
+ if abstract_key not in weight_map :
67
+ continue
61
68
new_key = weight_map [abstract_key ]
62
69
if new_key is None :
63
70
continue
64
- new_key = new_key .format (layer_num )
71
+ new_key = new_key .format (* nums )
65
72
else :
73
+ if key not in weight_map :
74
+ continue
66
75
new_key = weight_map [key ]
67
-
68
76
final_result [new_key ] = value
77
+ del merged_result [key ]
69
78
70
79
for key in tuple (final_result .keys ()):
80
+ print (key )
71
81
if "wq" in key :
72
82
q = final_result [key ]
73
83
k = final_result [key .replace ("wq" , "wk" )]
@@ -77,9 +87,21 @@ def convert_hf_checkpoint(
77
87
del final_result [key .replace ("wq" , "wk" )]
78
88
del final_result [key .replace ("wq" , "wv" )]
79
89
elif "w1" in key or "w3" in key :
80
- final_result [key ] = final_result [key ].reshape (config .num_experts , config .intermediate_size , config .dim ).contiguous ()
90
+ if not key .endswith ('0' ):
91
+ continue
92
+ full_keys = [key [:- 1 ] + str (i ) for i in range (8 )]
93
+ results = [final_result [k ] for k in full_keys ]
94
+ final_result [key [:- 2 ]] = torch .stack (results , dim = 0 )
95
+ for k in full_keys :
96
+ del final_result [k ]
81
97
elif "w2" in key :
82
- final_result [key ] = final_result [key ].reshape (config .num_experts , config .intermediate_size , config .dim ).permute (0 , 2 , 1 ).contiguous ()
98
+ if not key .endswith ('0' ):
99
+ continue
100
+ full_keys = [key [:- 1 ] + str (i ) for i in range (8 )]
101
+ results = [final_result [k ] for k in full_keys ]
102
+ final_result [key [:- 2 ]] = torch .stack (results , dim = 0 )
103
+ for k in full_keys :
104
+ del final_result [k ]
83
105
elif "gate" in key :
84
106
final_result [key ] = final_result [key ].contiguous ()
85
107
0 commit comments