Skip to content

Commit dbceec8

Browse files
llama : add StableLM2 12B (ggml-org#6635)
* StableLM2 12B support for huggingface -> GGUF * StableLM12 tensormapping and constants * StableLM-2-12b model support * fix * Added 12B support * Removed autoformatting; resolved bug where model_arch was not selecting StableLM2 * Formatting * Do QK norm stacking in model conversion step * Converge StableLM and StableLM2 code to simplify graph construction * Fix accidental removal * Removed warnings * Revert formatter * Move QK norm stack to private function so it's easier to read * refactor stablelm graph builder to support 1.6, 3b and 12b more efficiently * Proper check for None type for new_name to avoid crash; formatting; revert change to base class `write_tensors()` * Format * Formatting * format Co-authored-by: compilade <[email protected]> * Fix incorrect check for K norm * space after commas; Keep indentation multiple of 4 spaces * Flake8 format * Removed unnecessary conditional branches * Removed unused comment * Fixed incorrect tensor passing * Format --------- Co-authored-by: compilade <[email protected]>
1 parent f4dea7d commit dbceec8

File tree

3 files changed

+134
-12
lines changed

3 files changed

+134
-12
lines changed

convert-hf-to-gguf.py

+82
Original file line numberDiff line numberDiff line change
@@ -1207,9 +1207,91 @@ def set_gguf_parameters(self):
12071207
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
12081208
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
12091209
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
1210+
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
12101211
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
12111212
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
12121213

1214+
def write_tensors(self):
1215+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
1216+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1217+
n_head = self.hparams.get("num_attention_heads")
1218+
n_kv_head = self.hparams.get("num_key_value_heads")
1219+
q_norms = dict()
1220+
k_norms = dict()
1221+
for name, data_torch in self.get_tensors():
1222+
# we don't need these
1223+
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
1224+
continue
1225+
1226+
old_dtype = data_torch.dtype
1227+
1228+
# convert any unsupported data types to float32
1229+
if data_torch.dtype not in (torch.float16, torch.float32):
1230+
data_torch = data_torch.to(torch.float32)
1231+
1232+
data = data_torch.squeeze().numpy()
1233+
n_dims = len(data.shape)
1234+
if name.find("q_layernorm.norms") != -1:
1235+
q_norms[name] = data
1236+
if len(q_norms) >= (block_count * n_head):
1237+
self._stack_qk_norm(block_count, name, tensor_map, n_head, q_norms, n_dims, layer_name="q_layernorm")
1238+
continue
1239+
if name.find("k_layernorm.norms") != -1:
1240+
k_norms[name] = data
1241+
if len(k_norms) >= (block_count * n_kv_head):
1242+
self._stack_qk_norm(block_count, name, tensor_map, n_kv_head, k_norms, n_dims, layer_name="k_layernorm")
1243+
continue
1244+
1245+
# map tensor names
1246+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1247+
if new_name is None:
1248+
print(f"Can not map tensor {name!r}")
1249+
sys.exit()
1250+
1251+
n_dims = len(data.shape)
1252+
data_dtype = data.dtype
1253+
1254+
# if f32 desired, convert any float16 to float32
1255+
if self.ftype == 0 and data_dtype == np.float16:
1256+
data = data.astype(np.float32)
1257+
1258+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1259+
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
1260+
data = data.astype(np.float32)
1261+
1262+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1263+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
1264+
data = data.astype(np.float16)
1265+
1266+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1267+
1268+
self.gguf_writer.add_tensor(new_name, data)
1269+
1270+
def _stack_qk_norm(self, block_count, name, tensor_map, n_head, norms, n_dims, layer_name="q_layernorm"):
1271+
for bid in range(block_count):
1272+
datas = []
1273+
for xid in range(n_head):
1274+
ename = f"model.layers.{bid}.self_attn.{layer_name}.norms.{xid}.weight"
1275+
datas.append(norms[ename])
1276+
del norms[ename]
1277+
data = np.stack(datas, axis=0)
1278+
data_dtype = data.dtype
1279+
merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
1280+
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
1281+
if new_name is None:
1282+
print(f"Can not map tensor {name!r}")
1283+
sys.exit()
1284+
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
1285+
data = data.astype(np.float32)
1286+
1287+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1288+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and not new_name.endswith("_norm.weight") and n_dims == 2:
1289+
data = data.astype(np.float16)
1290+
1291+
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
1292+
1293+
self.gguf_writer.add_tensor(new_name, data)
1294+
12131295

12141296
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
12151297
class LlamaModel(Model):

gguf-py/gguf/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,8 @@ class MODEL_TENSOR(IntEnum):
455455
MODEL_TENSOR.FFN_GATE,
456456
MODEL_TENSOR.FFN_DOWN,
457457
MODEL_TENSOR.FFN_UP,
458+
MODEL_TENSOR.ATTN_Q_NORM,
459+
MODEL_TENSOR.ATTN_K_NORM,
458460
],
459461
MODEL_ARCH.QWEN: [
460462
MODEL_TENSOR.TOKEN_EMBD,

llama.cpp

+50-12
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
716716
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
717717
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
718718
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
719+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
720+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
719721
},
720722
},
721723
{
@@ -1744,6 +1746,7 @@ enum e_model {
17441746
MODEL_4B,
17451747
MODEL_7B,
17461748
MODEL_8B,
1749+
MODEL_12B,
17471750
MODEL_13B,
17481751
MODEL_14B,
17491752
MODEL_15B,
@@ -3607,6 +3610,7 @@ static const char * llama_model_type_name(e_model type) {
36073610
case MODEL_3B: return "3B";
36083611
case MODEL_7B: return "7B";
36093612
case MODEL_8B: return "8B";
3613+
case MODEL_12B: return "12B";
36103614
case MODEL_13B: return "13B";
36113615
case MODEL_14B: return "14B";
36123616
case MODEL_15B: return "15B";
@@ -3898,6 +3902,7 @@ static void llm_load_hparams(
38983902
switch (hparams.n_layer) {
38993903
case 24: model.type = e_model::MODEL_1B; break;
39003904
case 32: model.type = e_model::MODEL_3B; break;
3905+
case 40: model.type = e_model::MODEL_12B; break;
39013906
default: model.type = e_model::MODEL_UNKNOWN;
39023907
}
39033908
} break;
@@ -5128,8 +5133,13 @@ static bool llm_load_tensors(
51285133
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
51295134
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
51305135

5131-
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
5132-
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
5136+
// optional q and k layernorms, present in StableLM 2 12B
5137+
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false);
5138+
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false);
5139+
5140+
// optional FFN norm, not present in StableLM 2 12B which uses parallel residual
5141+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false);
5142+
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false);
51335143

51345144
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
51355145
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@@ -8197,7 +8207,7 @@ struct llm_build_context {
81978207
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
81988208

81998209
for (int il = 0; il < n_layer; ++il) {
8200-
struct ggml_tensor * inpSA = inpL;
8210+
82018211

82028212
// norm
82038213
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8206,6 +8216,8 @@ struct llm_build_context {
82068216
LLM_NORM, cb, il);
82078217
cb(cur, "attn_norm", il);
82088218

8219+
struct ggml_tensor * inpSA = cur;
8220+
82098221
// self-attention
82108222
{
82118223
// compute Q and K and RoPE them
@@ -8230,15 +8242,36 @@ struct llm_build_context {
82308242
cb(Vcur, "Vcur", il);
82318243
}
82328244

8245+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8246+
cb(Qcur, "Qcur", il);
8247+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8248+
cb(Kcur, "Kcur", il);
8249+
8250+
if (model.layers[il].attn_q_norm) {
8251+
Qcur = llm_build_norm(ctx0, Qcur, hparams,
8252+
model.layers[il].attn_q_norm,
8253+
NULL,
8254+
LLM_NORM, cb, il);
8255+
cb(Qcur, "Qcur", il);
8256+
}
8257+
if (model.layers[il].attn_k_norm) {
8258+
Kcur = llm_build_norm(ctx0, Kcur, hparams,
8259+
model.layers[il].attn_k_norm,
8260+
NULL,
8261+
LLM_NORM, cb, il);
8262+
cb(Kcur, "Kcur", il);
8263+
}
8264+
8265+
82338266
Qcur = ggml_rope_custom(
8234-
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8267+
ctx0, Qcur, inp_pos,
82358268
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
82368269
ext_factor, attn_factor, beta_fast, beta_slow
82378270
);
82388271
cb(Qcur, "Qcur", il);
82398272

82408273
Kcur = ggml_rope_custom(
8241-
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8274+
ctx0, Kcur, inp_pos,
82428275
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
82438276
ext_factor, attn_factor, beta_fast, beta_slow
82448277
);
@@ -8253,20 +8286,25 @@ struct llm_build_context {
82538286
// skip computing output for unused tokens
82548287
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
82558288
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8289+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
82568290
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
82578291
}
82588292

8259-
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8293+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
82608294
cb(ffn_inp, "ffn_inp", il);
82618295

82628296
// feed-forward network
82638297
{
8264-
cur = llm_build_norm(ctx0, ffn_inp, hparams,
8265-
model.layers[il].ffn_norm,
8266-
model.layers[il].ffn_norm_b,
8267-
LLM_NORM, cb, il);
8268-
cb(cur, "ffn_norm", il);
8269-
8298+
if (model.layers[il].ffn_norm) {
8299+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
8300+
model.layers[il].ffn_norm,
8301+
model.layers[il].ffn_norm_b,
8302+
LLM_NORM, cb, il);
8303+
cb(cur, "ffn_norm", il);
8304+
} else {
8305+
// parallel residual
8306+
cur = inpSA;
8307+
}
82708308
cur = llm_build_ffn(ctx0, cur,
82718309
model.layers[il].ffn_up, NULL,
82728310
model.layers[il].ffn_gate, NULL,

0 commit comments

Comments
 (0)