@@ -1210,7 +1210,7 @@ struct llama_server_context
1210
1210
queue_results.send (res);
1211
1211
}
1212
1212
1213
- void send_embedding (server_slot &slot)
1213
+ void send_embedding (server_slot & slot, const llama_batch & batch )
1214
1214
{
1215
1215
task_result res;
1216
1216
res.id = slot.task_id ;
@@ -1219,6 +1219,7 @@ struct llama_server_context
1219
1219
res.stop = true ;
1220
1220
1221
1221
const int n_embd = llama_n_embd (model);
1222
+
1222
1223
if (!params.embedding )
1223
1224
{
1224
1225
LOG_WARNING (" embedding disabled" , {{" params.embedding" , params.embedding }});
@@ -1229,12 +1230,29 @@ struct llama_server_context
1229
1230
}
1230
1231
else
1231
1232
{
1232
- const float *data = llama_get_embeddings (ctx);
1233
- std::vector<float > embedding (data, data + n_embd);
1234
- res.result_json = json
1235
- {
1236
- {" embedding" , embedding},
1237
- };
1233
+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1234
+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
1235
+ continue ;
1236
+ }
1237
+
1238
+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1239
+ if (embd == NULL ) {
1240
+ embd = llama_get_embeddings_ith (ctx, i);
1241
+ if (embd == NULL ) {
1242
+ LOG_ERROR (" failed to get embeddings for token" , {{" token" , batch.token [i]}, {" seq_id" , batch.seq_id [i][0 ]}});
1243
+ res.result_json = json
1244
+ {
1245
+ {" embedding" , std::vector<float >(n_embd, 0 .0f )},
1246
+ };
1247
+ continue ;
1248
+ }
1249
+ }
1250
+
1251
+ res.result_json = json
1252
+ {
1253
+ {" embedding" , std::vector<float >(embd, embd + n_embd)},
1254
+ };
1255
+ }
1238
1256
}
1239
1257
queue_results.send (res);
1240
1258
}
@@ -1845,7 +1863,7 @@ struct llama_server_context
1845
1863
ga_i += ga_w/ga_n;
1846
1864
}
1847
1865
}
1848
- llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot_npast, {slot.id }, false );
1866
+ llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot_npast, { slot.id }, false );
1849
1867
slot_npast++;
1850
1868
}
1851
1869
@@ -1881,7 +1899,7 @@ struct llama_server_context
1881
1899
1882
1900
for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch)
1883
1901
{
1884
- const int32_t n_tokens = std::min (n_batch, ( int32_t ) ( batch.n_tokens - i) );
1902
+ const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
1885
1903
1886
1904
for (auto & slot : slots)
1887
1905
{
@@ -1954,7 +1972,7 @@ struct llama_server_context
1954
1972
// prompt evaluated for embedding
1955
1973
if (slot.embedding )
1956
1974
{
1957
- send_embedding (slot);
1975
+ send_embedding (slot, batch_view );
1958
1976
slot.release ();
1959
1977
slot.i_batch = -1 ;
1960
1978
continue ;
@@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
2036
2054
printf (" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n " );
2037
2055
printf (" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n " , params.yarn_beta_slow );
2038
2056
printf (" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n " , params.yarn_beta_fast );
2057
+ printf (" --pooling {none,mean,cls}\n " );
2058
+ printf (" pooling type for embeddings, use model default if unspecified\n " );
2039
2059
printf (" -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
2040
2060
printf (" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
2041
2061
printf (" not recommended: doubles context memory required and no measurable increase in quality\n " );
@@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
2276
2296
}
2277
2297
params.yarn_beta_slow = std::stof (argv[i]);
2278
2298
}
2299
+ else if (arg == " --pooling" )
2300
+ {
2301
+ if (++i >= argc) {
2302
+ invalid_param = true ;
2303
+ break ;
2304
+ }
2305
+ std::string value (argv[i]);
2306
+ /* */ if (value == " none" ) { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
2307
+ else if (value == " mean" ) { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
2308
+ else if (value == " cls" ) { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
2309
+ else { invalid_param = true ; break ; }
2310
+ }
2279
2311
else if (arg == " --threads" || arg == " -t" )
2280
2312
{
2281
2313
if (++i >= argc)
@@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
2330
2362
break ;
2331
2363
}
2332
2364
params.n_batch = std::stoi (argv[i]);
2333
- params.n_batch = std::min (512 , params.n_batch );
2334
2365
}
2335
2366
else if (arg == " --gpu-layers" || arg == " -ngl" || arg == " --n-gpu-layers" )
2336
2367
{
0 commit comments