@@ -873,16 +873,16 @@ struct LLM_TN {
873
873
// gguf helpers
874
874
//
875
875
876
- static const std::map<int32_t , const char *> LLAMA_ROPE_SCALING_TYPES = {
876
+ static const std::map<llama_rope_scaling_type , const char *> LLAMA_ROPE_SCALING_TYPES = {
877
877
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
878
878
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
879
879
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
880
880
};
881
881
882
- static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
882
+ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
883
883
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
884
884
if (kv.second == name) {
885
- return kv.first;
885
+ return (llama_rope_scaling_type) kv.first;
886
886
}
887
887
}
888
888
@@ -1612,16 +1612,16 @@ struct llama_hparams {
1612
1612
float rope_freq_base_train;
1613
1613
float rope_freq_scale_train;
1614
1614
uint32_t n_yarn_orig_ctx;
1615
- int32_t rope_scaling_type_train;
1616
1615
1617
1616
float f_clamp_kqv = 0.0f;
1618
1617
float f_max_alibi_bias = 0.0f;
1619
1618
1620
1619
bool causal_attn = true;
1621
1620
bool need_kq_pos = false;
1622
1621
1623
- enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1624
- enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1622
+ enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1623
+ enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1624
+ enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
1625
1625
1626
1626
bool operator!=(const llama_hparams & other) const {
1627
1627
if (this->vocab_only != other.vocab_only) return true;
@@ -1670,8 +1670,8 @@ struct llama_cparams {
1670
1670
uint32_t n_threads; // number of threads to use for generation
1671
1671
uint32_t n_threads_batch; // number of threads to use for batch processing
1672
1672
1673
- float rope_freq_base;
1674
- float rope_freq_scale;
1673
+ float rope_freq_base;
1674
+ float rope_freq_scale;
1675
1675
1676
1676
uint32_t n_yarn_orig_ctx;
1677
1677
// These hyperparameters are not exposed in GGUF, because all
@@ -1683,7 +1683,7 @@ struct llama_cparams {
1683
1683
float defrag_thold;
1684
1684
1685
1685
bool offload_kqv;
1686
- bool do_pooling ;
1686
+ enum llama_pooling_type pooling_type ;
1687
1687
1688
1688
ggml_backend_sched_eval_callback cb_eval;
1689
1689
void * cb_eval_user_data;
@@ -2933,7 +2933,11 @@ template<>
2933
2933
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
2934
2934
uint32_t tmp;
2935
2935
const bool found = get_key(kid, tmp, required);
2936
- result = (enum llama_pooling_type) tmp;
2936
+ if (found) {
2937
+ result = (enum llama_pooling_type) tmp;
2938
+ } else {
2939
+ result = LLAMA_POOLING_TYPE_UNSPECIFIED;
2940
+ }
2937
2941
return found;
2938
2942
}
2939
2943
@@ -3210,7 +3214,7 @@ static void llm_load_hparams(
3210
3214
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
3211
3215
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
3212
3216
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
3213
- ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
3217
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false );
3214
3218
3215
3219
switch (hparams.n_layer) {
3216
3220
case 3:
@@ -5175,7 +5179,7 @@ struct llm_build_context {
5175
5179
n_kv (worst_case ? n_ctx : kv_self.n),
5176
5180
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
5177
5181
n_orig_ctx (cparams.n_yarn_orig_ctx),
5178
- pooling_type (cparams.do_pooling ? hparams. pooling_type : LLAMA_POOLING_TYPE_NONE ),
5182
+ pooling_type (cparams.pooling_type),
5179
5183
rope_type (hparams.rope_type),
5180
5184
cb (cb),
5181
5185
buf_compute_meta (lctx.buf_compute_meta) {
@@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8015
8019
}
8016
8020
}
8017
8021
8018
- if (cparams.do_pooling && hparams. pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8022
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8019
8023
const int64_t n_tokens = batch.n_tokens;
8020
8024
8021
8025
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8043
8047
}
8044
8048
}
8045
8049
8046
- if (cparams.do_pooling && hparams. pooling_type == LLAMA_POOLING_TYPE_CLS) {
8050
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
8047
8051
const int64_t n_tokens = batch.n_tokens;
8048
8052
8049
8053
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
11846
11850
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
11847
11851
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
11848
11852
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
11853
+ /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
11849
11854
/*.rope_freq_base =*/ 0.0f,
11850
11855
/*.rope_freq_scale =*/ 0.0f,
11851
11856
/*.yarn_ext_factor =*/ -1.0f,
@@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
11861
11866
/*.logits_all =*/ false,
11862
11867
/*.embedding =*/ false,
11863
11868
/*.offload_kqv =*/ true,
11864
- /*.do_pooling =*/ true,
11865
11869
/*.abort_callback =*/ nullptr,
11866
11870
/*.abort_callback_data =*/ nullptr,
11867
11871
};
@@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
12012
12016
cparams.yarn_beta_slow = params.yarn_beta_slow;
12013
12017
cparams.defrag_thold = params.defrag_thold;
12014
12018
cparams.offload_kqv = params.offload_kqv;
12015
- cparams.do_pooling = params.do_pooling ;
12019
+ cparams.pooling_type = params.pooling_type ;
12016
12020
12017
12021
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
12018
12022
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
12038
12042
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
12039
12043
}
12040
12044
12045
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12046
+ if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12047
+ cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
12048
+ } else {
12049
+ cparams.pooling_type = hparams.pooling_type;
12050
+ }
12051
+ }
12052
+
12041
12053
if (params.seed == LLAMA_DEFAULT_SEED) {
12042
12054
params.seed = time(NULL);
12043
12055
}
0 commit comments