@@ -338,10 +338,14 @@ enum llm_tensor {
338
338
LLM_TENSOR_ATTN_NORM,
339
339
LLM_TENSOR_ATTN_NORM_2,
340
340
LLM_TENSOR_ATTN_ROT_EMBD,
341
+ LLM_TENSOR_FFN_GATE_INP,
342
+ LLM_TENSOR_FFN_NORM,
341
343
LLM_TENSOR_FFN_GATE,
342
344
LLM_TENSOR_FFN_DOWN,
343
345
LLM_TENSOR_FFN_UP,
344
- LLM_TENSOR_FFN_NORM,
346
+ LLM_TENSOR_FFN_DOWN_EXP,
347
+ LLM_TENSOR_FFN_GATE_EXP,
348
+ LLM_TENSOR_FFN_UP_EXP,
345
349
LLM_TENSOR_ATTN_Q_NORM,
346
350
LLM_TENSOR_ATTN_K_NORM,
347
351
};
@@ -360,10 +364,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
360
364
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
361
365
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
362
366
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
367
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
363
368
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
364
369
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
365
370
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
366
371
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
372
+ { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
373
+ { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
374
+ { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
367
375
},
368
376
},
369
377
{
@@ -585,6 +593,10 @@ struct LLM_TN {
585
593
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
586
594
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
587
595
}
596
+
597
+ std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
598
+ return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
599
+ }
588
600
};
589
601
590
602
//
@@ -1268,6 +1280,12 @@ struct llama_layer {
1268
1280
struct ggml_tensor * ffn_down; // w2
1269
1281
struct ggml_tensor * ffn_up; // w3
1270
1282
1283
+ // ff MoE
1284
+ struct ggml_tensor * ffn_gate_inp;
1285
+ struct ggml_tensor * ffn_gate_exp[8];
1286
+ struct ggml_tensor * ffn_down_exp[8];
1287
+ struct ggml_tensor * ffn_up_exp[8];
1288
+
1271
1289
// ff bias
1272
1290
struct ggml_tensor * ffn_down_b; // b2
1273
1291
struct ggml_tensor * ffn_up_b; // b3
@@ -3025,9 +3043,20 @@ static void llm_load_tensors(
3025
3043
3026
3044
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
3027
3045
3028
- layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
3029
- layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
3030
- layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
3046
+ layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
3047
+
3048
+ if (layer.ffn_gate_inp == nullptr) {
3049
+ layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
3050
+ layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
3051
+ layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
3052
+ } else {
3053
+ // MoE branch
3054
+ for (int x = 0; x < 8; ++x) {
3055
+ layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
3056
+ layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split);
3057
+ layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
3058
+ }
3059
+ }
3031
3060
3032
3061
if (backend == GGML_BACKEND_GPU) {
3033
3062
vram_weights +=
@@ -3037,8 +3066,18 @@ static void llm_load_tensors(
3037
3066
(layer.bk ? ggml_nbytes(layer.bk) : 0) +
3038
3067
(layer.bv ? ggml_nbytes(layer.bv) : 0) +
3039
3068
(layer.bo ? ggml_nbytes(layer.bo) : 0) +
3040
- ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
3041
- ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
3069
+ ggml_nbytes(layer.ffn_norm);
3070
+
3071
+ if (layer.ffn_gate_inp == nullptr) {
3072
+ vram_weights +=
3073
+ ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
3074
+ } else {
3075
+ vram_weights += ggml_nbytes(layer.ffn_gate_inp);
3076
+ for (int x = 0; x < 8; ++x) {
3077
+ vram_weights +=
3078
+ ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
3079
+ }
3080
+ }
3042
3081
}
3043
3082
}
3044
3083
} break;
0 commit comments