Skip to content

Commit 475df1d

Browse files
iamlemecggerganov
andauthored
llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 87c2e8b commit 475df1d

File tree

5 files changed

+60
-29
lines changed

5 files changed

+60
-29
lines changed

common/common.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
335335
break;
336336
}
337337
params.yarn_beta_slow = std::stof(argv[i]);
338+
} else if (arg == "--pooling") {
339+
if (++i >= argc) {
340+
invalid_param = true;
341+
break;
342+
}
343+
std::string value(argv[i]);
344+
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
345+
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
346+
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
347+
else { invalid_param = true; break; }
338348
} else if (arg == "--defrag-thold" || arg == "-dt") {
339349
if (++i >= argc) {
340350
invalid_param = true;
@@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10141024
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
10151025
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
10161026
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
1027+
printf(" --pooling {none,mean,cls}\n");
1028+
printf(" pooling type for embeddings, use model default if unspecified\n");
10171029
printf(" -dt N, --defrag-thold N\n");
10181030
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
10191031
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12961308
cparams.yarn_beta_fast = params.yarn_beta_fast;
12971309
cparams.yarn_beta_slow = params.yarn_beta_slow;
12981310
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1311+
cparams.pooling_type = params.pooling_type;
12991312
cparams.defrag_thold = params.defrag_thold;
13001313
cparams.offload_kqv = !params.no_kv_offload;
13011314

common/common.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ struct gpt_params {
7676
float yarn_beta_slow = 1.0f; // YaRN high correction dim
7777
int32_t yarn_orig_ctx = 0; // YaRN original context length
7878
float defrag_thold = -1.0f; // KV cache defragmentation threshold
79-
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
80-
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
79+
80+
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
81+
82+
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
83+
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
8184

8285
// // sampling parameters
8386
struct llama_sampling_params sparams;

convert-hf-to-gguf.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1644,16 +1644,17 @@ def set_gguf_parameters(self):
16441644
self.gguf_writer.add_causal_attention(False)
16451645

16461646
# get pooling path
1647-
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
1648-
modules = json.load(f)
16491647
pooling_path = None
1650-
for mod in modules:
1651-
if mod["type"] == "sentence_transformers.models.Pooling":
1652-
pooling_path = mod["path"]
1653-
break
1648+
module_path = self.dir_model / "modules.json"
1649+
if module_path.is_file():
1650+
with open(module_path, encoding="utf-8") as f:
1651+
modules = json.load(f)
1652+
for mod in modules:
1653+
if mod["type"] == "sentence_transformers.models.Pooling":
1654+
pooling_path = mod["path"]
1655+
break
16541656

16551657
# get pooling type
1656-
pooling_type = gguf.PoolingType.NONE
16571658
if pooling_path is not None:
16581659
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
16591660
pooling = json.load(f)
@@ -1663,8 +1664,7 @@ def set_gguf_parameters(self):
16631664
pooling_type = gguf.PoolingType.CLS
16641665
else:
16651666
raise NotImplementedError("Only MEAN and CLS pooling types supported")
1666-
1667-
self.gguf_writer.add_pooling_type(pooling_type)
1667+
self.gguf_writer.add_pooling_type(pooling_type)
16681668

16691669
def set_vocab(self):
16701670
path = self.dir_model

llama.cpp

+28-16
Original file line numberDiff line numberDiff line change
@@ -873,16 +873,16 @@ struct LLM_TN {
873873
// gguf helpers
874874
//
875875

876-
static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
876+
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
877877
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
878878
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
879879
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
880880
};
881881

882-
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
882+
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
883883
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
884884
if (kv.second == name) {
885-
return kv.first;
885+
return (llama_rope_scaling_type) kv.first;
886886
}
887887
}
888888

@@ -1612,16 +1612,16 @@ struct llama_hparams {
16121612
float rope_freq_base_train;
16131613
float rope_freq_scale_train;
16141614
uint32_t n_yarn_orig_ctx;
1615-
int32_t rope_scaling_type_train;
16161615

16171616
float f_clamp_kqv = 0.0f;
16181617
float f_max_alibi_bias = 0.0f;
16191618

16201619
bool causal_attn = true;
16211620
bool need_kq_pos = false;
16221621

1623-
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1624-
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1622+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1623+
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1624+
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
16251625

16261626
bool operator!=(const llama_hparams & other) const {
16271627
if (this->vocab_only != other.vocab_only) return true;
@@ -1670,8 +1670,8 @@ struct llama_cparams {
16701670
uint32_t n_threads; // number of threads to use for generation
16711671
uint32_t n_threads_batch; // number of threads to use for batch processing
16721672

1673-
float rope_freq_base;
1674-
float rope_freq_scale;
1673+
float rope_freq_base;
1674+
float rope_freq_scale;
16751675

16761676
uint32_t n_yarn_orig_ctx;
16771677
// These hyperparameters are not exposed in GGUF, because all
@@ -1683,7 +1683,7 @@ struct llama_cparams {
16831683
float defrag_thold;
16841684

16851685
bool offload_kqv;
1686-
bool do_pooling;
1686+
enum llama_pooling_type pooling_type;
16871687

16881688
ggml_backend_sched_eval_callback cb_eval;
16891689
void * cb_eval_user_data;
@@ -2933,7 +2933,11 @@ template<>
29332933
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
29342934
uint32_t tmp;
29352935
const bool found = get_key(kid, tmp, required);
2936-
result = (enum llama_pooling_type) tmp;
2936+
if (found) {
2937+
result = (enum llama_pooling_type) tmp;
2938+
} else {
2939+
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
2940+
}
29372941
return found;
29382942
}
29392943

@@ -3210,7 +3214,7 @@ static void llm_load_hparams(
32103214
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
32113215
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
32123216
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
3213-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
3217+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
32143218

32153219
switch (hparams.n_layer) {
32163220
case 3:
@@ -5175,7 +5179,7 @@ struct llm_build_context {
51755179
n_kv (worst_case ? n_ctx : kv_self.n),
51765180
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
51775181
n_orig_ctx (cparams.n_yarn_orig_ctx),
5178-
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
5182+
pooling_type (cparams.pooling_type),
51795183
rope_type (hparams.rope_type),
51805184
cb (cb),
51815185
buf_compute_meta (lctx.buf_compute_meta) {
@@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
80158019
}
80168020
}
80178021

8018-
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8022+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
80198023
const int64_t n_tokens = batch.n_tokens;
80208024

80218025
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
80438047
}
80448048
}
80458049

8046-
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
8050+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
80478051
const int64_t n_tokens = batch.n_tokens;
80488052

80498053
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
1184611850
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
1184711851
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
1184811852
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
11853+
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
1184911854
/*.rope_freq_base =*/ 0.0f,
1185011855
/*.rope_freq_scale =*/ 0.0f,
1185111856
/*.yarn_ext_factor =*/ -1.0f,
@@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
1186111866
/*.logits_all =*/ false,
1186211867
/*.embedding =*/ false,
1186311868
/*.offload_kqv =*/ true,
11864-
/*.do_pooling =*/ true,
1186511869
/*.abort_callback =*/ nullptr,
1186611870
/*.abort_callback_data =*/ nullptr,
1186711871
};
@@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
1201212016
cparams.yarn_beta_slow = params.yarn_beta_slow;
1201312017
cparams.defrag_thold = params.defrag_thold;
1201412018
cparams.offload_kqv = params.offload_kqv;
12015-
cparams.do_pooling = params.do_pooling;
12019+
cparams.pooling_type = params.pooling_type;
1201612020

1201712021
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
1201812022
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
1203812042
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
1203912043
}
1204012044

12045+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12046+
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12047+
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
12048+
} else {
12049+
cparams.pooling_type = hparams.pooling_type;
12050+
}
12051+
}
12052+
1204112053
if (params.seed == LLAMA_DEFAULT_SEED) {
1204212054
params.seed = time(NULL);
1204312055
}

llama.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ extern "C" {
129129
};
130130

131131
enum llama_pooling_type {
132+
LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
132133
LLAMA_POOLING_TYPE_NONE = 0,
133134
LLAMA_POOLING_TYPE_MEAN = 1,
134135
LLAMA_POOLING_TYPE_CLS = 2,
@@ -236,7 +237,10 @@ extern "C" {
236237
uint32_t n_batch; // prompt processing maximum batch size
237238
uint32_t n_threads; // number of threads to use for generation
238239
uint32_t n_threads_batch; // number of threads to use for batch processing
239-
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
240+
241+
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
242+
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
243+
// (ignored if no pooling layer)
240244

241245
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
242246
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -258,7 +262,6 @@ extern "C" {
258262
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
259263
bool embedding; // embedding mode only
260264
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
261-
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
262265

263266
// Abort callback
264267
// if it returns true, execution of llama_decode() will be aborted

0 commit comments

Comments
 (0)