Skip to content

Commit e640cbe

Browse files
committed
llama : add n_expert and n_expert_used to hparams + change quants
1 parent d1259b7 commit e640cbe

File tree

6 files changed

+110
-53
lines changed

6 files changed

+110
-53
lines changed

convert.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,16 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
151151

152152
@dataclass
153153
class Params:
154-
n_vocab: int
155-
n_embd: int
156-
n_layer: int
157-
n_ctx: int
158-
n_ff: int
159-
n_head: int
160-
n_head_kv: int
161-
f_norm_eps: float
154+
n_vocab: int
155+
n_embd: int
156+
n_layer: int
157+
n_ctx: int
158+
n_ff: int
159+
n_head: int
160+
n_head_kv: int
161+
n_experts: int | None = None
162+
n_experts_used: int | None = None
163+
f_norm_eps: float | None = None
162164

163165
rope_scaling_type: gguf.RopeScalingType | None = None
164166
f_rope_freq_base: float | None = None
@@ -255,27 +257,30 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
255257
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
256258
config = json.load(open(config_path))
257259

260+
n_experts = None
261+
n_experts_used = None
262+
258263
# hack to determine LLaMA v1 vs v2 vs CodeLlama
259264
if config.get("rope_theta") == 1000000:
260265
# CodeLlama
261266
n_ctx = 16384
262267
elif config["norm_eps"] == 1e-05:
263268
# LLaMA v2
264269
n_ctx = 4096
270+
elif config["moe"]:
271+
# Mixtral
272+
n_ctx = 32768
265273
else:
266274
# LLaMA v1
267275
n_ctx = 2048
268276

269-
# print model keys
270-
for k in model.keys():
271-
print(k)
277+
if "layers.0.feed_forward.w1.weight" in model:
278+
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
272279

273-
# check if MoE
274-
if "layers.0.feed_forward.experts.0.w1.weight" in model:
280+
if config.get("moe"):
275281
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
276-
n_ctx = 32768
277-
else:
278-
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
282+
n_experts = config["moe"]["num_experts"]
283+
n_experts_used = config["moe"]["num_experts_per_tok"]
279284

280285
return Params(
281286
n_vocab = model["tok_embeddings.weight"].shape[0],
@@ -285,6 +290,8 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
285290
n_ff = n_ff,
286291
n_head = (n_head := config["n_heads"]),
287292
n_head_kv = config.get("n_kv_heads", n_head),
293+
n_experts = n_experts,
294+
n_experts_used = n_experts_used,
288295
f_norm_eps = config["norm_eps"],
289296
f_rope_freq_base = config.get("rope_theta"),
290297
)
@@ -843,7 +850,17 @@ def add_meta_arch(self, params: Params) -> None:
843850
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
844851
self.gguf.add_head_count (params.n_head)
845852
self.gguf.add_head_count_kv (params.n_head_kv)
846-
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
853+
854+
if params.n_experts:
855+
self.gguf.add_expert_count(params.n_experts)
856+
857+
if params.n_experts_used:
858+
self.gguf.add_expert_used_count(params.n_experts_used)
859+
860+
if params.f_norm_eps:
861+
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
862+
else:
863+
raise ValueError('f_norm_eps is None')
847864

848865
if params.f_rope_freq_base is not None:
849866
self.gguf.add_rope_freq_base(params.f_rope_freq_base)

ggml.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -4075,7 +4075,7 @@ struct ggml_tensor * ggml_mul_mat(
40754075

40764076
struct ggml_tensor * ggml_mul_mat_id(
40774077
struct ggml_context * ctx,
4078-
struct ggml_tensor * as[],
4078+
struct ggml_tensor * const as[],
40794079
int n_as,
40804080
struct ggml_tensor * ids,
40814081
int id,

ggml.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ extern "C" {
10511051
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
10521052
GGML_API struct ggml_tensor * ggml_mul_mat_id(
10531053
struct ggml_context * ctx,
1054-
struct ggml_tensor * as[],
1054+
struct ggml_tensor * const as[],
10551055
int n_as,
10561056
struct ggml_tensor * ids,
10571057
int id,

gguf-py/gguf/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class LLM:
3838
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
3939
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
4040
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
41+
EXPERT_COUNT = "{arch}.expert_count"
42+
EXPERT_USED_COUNT = "{arch}.expert_used_count"
4143

4244
class Attention:
4345
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

+6
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,12 @@ def add_max_alibi_bias(self, bias: float) -> None:
339339
def add_clamp_kqv(self, value: float) -> None:
340340
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
341341

342+
def add_expert_count(self, count: int) -> None:
343+
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
344+
345+
def add_expert_used_count(self, count: int) -> None:
346+
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
347+
342348
def add_layer_norm_eps(self, value: float) -> None:
343349
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
344350

llama.cpp

+66-34
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@
9191
#define LLAMA_ATTRIBUTE_FORMAT(...)
9292
#endif
9393

94-
#define LLAMA_MAX_NODES 8192
94+
#define LLAMA_MAX_NODES 8192
95+
#define LLAMA_MAX_EXPERTS 8
9596

9697
//
9798
// logging
@@ -231,6 +232,8 @@ enum llm_kv {
231232
LLM_KV_FEED_FORWARD_LENGTH,
232233
LLM_KV_USE_PARALLEL_RESIDUAL,
233234
LLM_KV_TENSOR_DATA_LAYOUT,
235+
LLM_KV_EXPERT_COUNT,
236+
LLM_KV_EXPERT_USED_COUNT,
234237

235238
LLM_KV_ATTENTION_HEAD_COUNT,
236239
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -281,6 +284,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
281284
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
282285
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
283286
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
287+
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
288+
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
284289

285290
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
286291
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1176,6 +1181,8 @@ struct llama_hparams {
11761181
uint32_t n_layer;
11771182
uint32_t n_rot;
11781183
uint32_t n_ff;
1184+
uint32_t n_expert = 0;
1185+
uint32_t n_expert_used = 0;
11791186

11801187
float f_norm_eps;
11811188
float f_norm_rms_eps;
@@ -1190,15 +1197,18 @@ struct llama_hparams {
11901197
float f_max_alibi_bias;
11911198

11921199
bool operator!=(const llama_hparams & other) const {
1193-
if (this->vocab_only != other.vocab_only) return true;
1194-
if (this->n_vocab != other.n_vocab) return true;
1195-
if (this->n_ctx_train != other.n_ctx_train) return true;
1196-
if (this->n_embd != other.n_embd) return true;
1197-
if (this->n_head != other.n_head) return true;
1198-
if (this->n_head_kv != other.n_head_kv) return true;
1199-
if (this->n_layer != other.n_layer) return true;
1200-
if (this->n_rot != other.n_rot) return true;
1201-
if (this->n_ff != other.n_ff) return true;
1200+
if (this->vocab_only != other.vocab_only) return true;
1201+
if (this->n_vocab != other.n_vocab) return true;
1202+
if (this->n_ctx_train != other.n_ctx_train) return true;
1203+
if (this->n_embd != other.n_embd) return true;
1204+
if (this->n_head != other.n_head) return true;
1205+
if (this->n_head_kv != other.n_head_kv) return true;
1206+
if (this->n_layer != other.n_layer) return true;
1207+
if (this->n_rot != other.n_rot) return true;
1208+
if (this->n_ff != other.n_ff) return true;
1209+
if (this->n_expert != other.n_expert) return true;
1210+
if (this->n_expert_used != other.n_expert_used) return true;
1211+
12021212
if (this->rope_finetuned != other.rope_finetuned) return true;
12031213
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
12041214

@@ -1282,9 +1292,9 @@ struct llama_layer {
12821292

12831293
// ff MoE
12841294
struct ggml_tensor * ffn_gate_inp;
1285-
struct ggml_tensor * ffn_gate_exp[8];
1286-
struct ggml_tensor * ffn_down_exp[8];
1287-
struct ggml_tensor * ffn_up_exp[8];
1295+
struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
1296+
struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
1297+
struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS];
12881298

12891299
// ff bias
12901300
struct ggml_tensor * ffn_down_b; // b2
@@ -2458,6 +2468,16 @@ static void llm_load_hparams(
24582468
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
24592469
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
24602470
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
2471+
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
2472+
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
2473+
2474+
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
2475+
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
2476+
if (hparams.n_expert > 0) {
2477+
GGML_ASSERT(hparams.n_expert_used > 0);
2478+
} else {
2479+
GGML_ASSERT(hparams.n_expert_used == 0);
2480+
}
24612481

24622482
// n_head_kv is optional, default to n_head
24632483
hparams.n_head_kv = hparams.n_head;
@@ -2889,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
28892909
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
28902910
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
28912911
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
2912+
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
2913+
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
28922914
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
28932915
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
28942916
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
@@ -3046,10 +3068,16 @@ static void llm_load_tensors(
30463068
layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
30473069

30483070
if (layer.ffn_gate_inp == nullptr) {
3071+
GGML_ASSERT(hparams.n_expert == 0);
3072+
GGML_ASSERT(hparams.n_expert_used == 0);
3073+
30493074
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
30503075
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
30513076
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
30523077
} else {
3078+
GGML_ASSERT(hparams.n_expert > 0);
3079+
GGML_ASSERT(hparams.n_expert_used > 0);
3080+
30533081
// MoE branch
30543082
for (int x = 0; x < 8; ++x) {
30553083
layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
@@ -3073,7 +3101,7 @@ static void llm_load_tensors(
30733101
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
30743102
} else {
30753103
vram_weights += ggml_nbytes(layer.ffn_gate_inp);
3076-
for (int x = 0; x < 8; ++x) {
3104+
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
30773105
vram_weights +=
30783106
ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
30793107
}
@@ -4058,6 +4086,8 @@ struct llm_build_context {
40584086
const int64_t n_head_kv;
40594087
const int64_t n_embd_head;
40604088
const int64_t n_embd_gqa;
4089+
const int64_t n_expert;
4090+
const int64_t n_expert_used;
40614091

40624092
const float freq_base;
40634093
const float freq_scale;
@@ -4099,6 +4129,8 @@ struct llm_build_context {
40994129
n_head_kv (hparams.n_head_kv),
41004130
n_embd_head (hparams.n_embd_head()),
41014131
n_embd_gqa (hparams.n_embd_gqa()),
4132+
n_expert (hparams.n_expert),
4133+
n_expert_used (hparams.n_expert_used),
41024134
freq_base (cparams.rope_freq_base),
41034135
freq_scale (cparams.rope_freq_scale),
41044136
ext_factor (cparams.yarn_ext_factor),
@@ -4242,25 +4274,21 @@ struct llm_build_context {
42424274
LLM_NORM_RMS, cb, il);
42434275
cb(cur, "ffn_norm", il);
42444276

4245-
// TODO: param
4246-
const int n_experts = 8;
4247-
const int n_experts_per_tok = 2;
4248-
42494277
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
42504278
cb(logits, "ffn_moe_logits", il);
42514279

42524280
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
42534281
cb(probs, "ffn_moe_probs", il);
42544282

42554283
// select experts
4256-
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4284+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
42574285
cb(selected_experts->src[0], "ffn_moe_argsort", il);
42584286

42594287
ggml_tensor * weights = ggml_get_rows(ctx0,
4260-
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
4288+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
42614289
cb(weights, "ffn_moe_weights", il);
42624290

4263-
weights = ggml_reshape_2d(ctx0, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
4291+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
42644292

42654293
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
42664294
cb(weights_sum, "ffn_moe_weights_sum", il);
@@ -4271,18 +4299,13 @@ struct llm_build_context {
42714299
// compute expert outputs
42724300
ggml_tensor * moe_out = nullptr;
42734301

4274-
for (int i = 0; i < n_experts_per_tok; ++i) {
4302+
for (int i = 0; i < n_expert_used; ++i) {
42754303
ggml_tensor * cur_expert;
42764304

4277-
// TODO: fix
4278-
ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp;
4279-
ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
4280-
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
4281-
4282-
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur);
4305+
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
42834306
cb(cur_up, "ffn_moe_up", il);
42844307

4285-
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur);
4308+
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
42864309
cb(cur_gate, "ffn_moe_gate", il);
42874310

42884311
cur_gate = ggml_silu(ctx0, cur_gate);
@@ -4291,7 +4314,7 @@ struct llm_build_context {
42914314
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
42924315
cb(cur_expert, "ffn_moe_gate_par", il);
42934316

4294-
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
4317+
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
42954318
cb(cur_expert, "ffn_moe_down", il);
42964319

42974320
cur_expert = ggml_mul(ctx0, cur_expert,
@@ -8192,11 +8215,9 @@ static void llama_convert_tensor_internal(
81928215
workers.clear();
81938216
}
81948217

8195-
static ggml_type get_k_quant_type(
8196-
quantize_state_internal & qs,
8197-
ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
8198-
) {
8218+
static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
81998219
const std::string name = ggml_get_name(tensor);
8220+
82008221
// TODO: avoid hardcoded tensor names - use the TN_* constants
82018222
const llm_arch arch = qs.model.arch;
82028223
const auto tn = LLM_TN(arch);
@@ -8230,7 +8251,18 @@ static ggml_type get_k_quant_type(
82308251
// nearly negligible increase in model size by quantizing this tensor with more bits:
82318252
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
82328253
}
8254+
if (qs.model.hparams.n_expert == 8) {
8255+
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
8256+
// TODO: explore better strategies
8257+
new_type = GGML_TYPE_Q8_0;
8258+
}
82338259
++qs.i_attention_wv;
8260+
} else if (name.find("attn_k.weight") != std::string::npos) {
8261+
if (qs.model.hparams.n_expert == 8) {
8262+
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
8263+
// TODO: explore better strategies
8264+
new_type = GGML_TYPE_Q8_0;
8265+
}
82348266
} else if (name.find("ffn_down.weight") != std::string::npos) {
82358267
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
82368268
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {

0 commit comments

Comments
 (0)