@@ -92,6 +92,7 @@ enum server_task_type {
92
92
enum server_task_cmpl_type {
93
93
SERVER_TASK_CMPL_TYPE_NORMAL,
94
94
SERVER_TASK_CMPL_TYPE_EMBEDDING,
95
+ SERVER_TASK_CMPL_TYPE_RERANK,
95
96
SERVER_TASK_CMPL_TYPE_INFILL,
96
97
};
97
98
@@ -172,6 +173,7 @@ struct server_slot {
172
173
std::vector<completion_token_output> generated_token_probs;
173
174
174
175
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176
+
175
177
bool has_next_token = true ;
176
178
bool truncated = false ;
177
179
bool stopped_eos = false ;
@@ -942,8 +944,17 @@ struct server_context {
942
944
slot.prompt = *prompt;
943
945
} else if (prompt->is_array () && prompt->size () == 1 && prompt->at (0 ).is_array ()) {
944
946
slot.prompt = prompt->at (0 );
947
+ } else if (prompt->is_array () && prompt->size () > 1 ) {
948
+ // array of strings
949
+ for (const auto & el : *prompt) {
950
+ if (!el.is_string ()) {
951
+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
952
+ return false ;
953
+ }
954
+ }
955
+ slot.prompt = *prompt;
945
956
} else {
946
- send_error (task, " \" prompt\" must be a string or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
957
+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
947
958
return false ;
948
959
}
949
960
}
@@ -1368,6 +1379,7 @@ struct server_context {
1368
1379
1369
1380
res.data = json {
1370
1381
{" embedding" , std::vector<float >(n_embd, 0 .0f )},
1382
+ {" index" , slot.index },
1371
1383
};
1372
1384
1373
1385
continue ;
@@ -1386,6 +1398,44 @@ struct server_context {
1386
1398
queue_results.send (res);
1387
1399
}
1388
1400
1401
+ void send_rank (const server_slot & slot, const llama_batch & batch) {
1402
+ server_task_result res;
1403
+ res.id = slot.id_task ;
1404
+ res.error = false ;
1405
+ res.stop = true ;
1406
+
1407
+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1408
+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id + 1 ) {
1409
+ continue ;
1410
+ }
1411
+
1412
+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1413
+ if (embd == NULL ) {
1414
+ embd = llama_get_embeddings_ith (ctx, i);
1415
+ }
1416
+
1417
+ if (embd == NULL ) {
1418
+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , batch.token [i], batch.seq_id [i][0 ]);
1419
+
1420
+ res.data = json {
1421
+ {" index" , slot.index },
1422
+ {" rank" , -1e6 },
1423
+ };
1424
+
1425
+ continue ;
1426
+ }
1427
+
1428
+ res.data = json {
1429
+ {" index" , slot.index },
1430
+ {" rank" , embd[0 ]},
1431
+ };
1432
+ }
1433
+
1434
+ SLT_DBG (slot, " sending rank, res = '%s'\n " , res.data .dump ().c_str ());
1435
+
1436
+ queue_results.send (res);
1437
+ }
1438
+
1389
1439
//
1390
1440
// Functions to create new task(s) and receive result(s)
1391
1441
//
@@ -1421,13 +1471,23 @@ struct server_context {
1421
1471
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1422
1472
else if (prompt.is_array ()) {
1423
1473
std::vector<json> prompts = prompt;
1424
- for (size_t i = 0 ; i < prompts.size (); i++) {
1425
- const auto & e = prompts[i];
1426
- if (e.is_string () || json_is_array_of_numbers (e)) {
1427
- data[" index" ] = i;
1428
- create_task (data, true , e);
1429
- } else {
1430
- throw std::runtime_error (error_msg);
1474
+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1475
+ for (size_t i = 1 ; i < prompts.size (); i++) {
1476
+ json qd;
1477
+ qd.push_back (prompts[0 ]);
1478
+ qd.push_back (prompts[i]);
1479
+ data[" index" ] = i - 1 ;
1480
+ create_task (data, true , qd);
1481
+ }
1482
+ } else {
1483
+ for (size_t i = 0 ; i < prompts.size (); i++) {
1484
+ const auto & e = prompts[i];
1485
+ if (e.is_string () || json_is_array_of_numbers (e)) {
1486
+ data[" index" ] = i;
1487
+ create_task (data, true , e);
1488
+ } else {
1489
+ throw std::runtime_error (error_msg);
1490
+ }
1431
1491
}
1432
1492
}
1433
1493
}
@@ -1471,7 +1531,9 @@ struct server_context {
1471
1531
break ;
1472
1532
}
1473
1533
1474
- size_t idx = result.data [" index" ];
1534
+ const size_t idx = result.data [" index" ];
1535
+ GGML_ASSERT (idx < results.size () && " index out of range" );
1536
+
1475
1537
results[idx] = result;
1476
1538
}
1477
1539
result_handler (results);
@@ -1922,6 +1984,29 @@ struct server_context {
1922
1984
}
1923
1985
1924
1986
prompt_tokens = embd_inp;
1987
+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1988
+ // require slot.prompt to be array of 2 strings
1989
+ if (!slot.prompt .is_array () || slot.prompt .size () != 2 ) {
1990
+ SLT_ERR (slot, " %s" , " invalid prompt for rerank task\n " );
1991
+ slot.release ();
1992
+ send_error (slot, " invalid prompt for rerank task" , ERROR_TYPE_INVALID_REQUEST);
1993
+ continue ;
1994
+ }
1995
+
1996
+ // prompt: <s>query</s><s>doc</s>
1997
+ prompt_tokens.clear ();
1998
+ prompt_tokens.push_back (llama_token_bos (model));
1999
+ {
2000
+ const auto part = tokenize (slot.prompt [0 ], false );
2001
+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2002
+ }
2003
+ prompt_tokens.push_back (llama_token_eos (model));
2004
+ prompt_tokens.push_back (llama_token_bos (model));
2005
+ {
2006
+ const auto part = tokenize (slot.prompt [1 ], false );
2007
+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2008
+ }
2009
+ prompt_tokens.push_back (llama_token_eos (model));
1925
2010
} else {
1926
2011
prompt_tokens = tokenize (slot.prompt , system_prompt.empty ()); // add BOS if there isn't system prompt
1927
2012
}
@@ -1941,7 +2026,7 @@ struct server_context {
1941
2026
continue ;
1942
2027
}
1943
2028
1944
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2029
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
1945
2030
// this prompt is too large to process - discard it
1946
2031
if (slot.n_prompt_tokens > n_ubatch) {
1947
2032
slot.release ();
@@ -2011,15 +2096,18 @@ struct server_context {
2011
2096
slot.n_prompt_tokens_processed = 0 ;
2012
2097
}
2013
2098
2014
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2099
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
2015
2100
// cannot fit the prompt in the current batch - will try next iter
2016
2101
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2017
2102
continue ;
2018
2103
}
2019
2104
}
2020
2105
2021
2106
// check that we are in the right batch_type, if not defer the slot
2022
- bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0 ;
2107
+ const bool slot_type =
2108
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2109
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0 ;
2110
+
2023
2111
if (batch_type == -1 ) {
2024
2112
batch_type = slot_type;
2025
2113
} else if (batch_type != slot_type) {
@@ -2192,6 +2280,13 @@ struct server_context {
2192
2280
continue ; // continue loop of slots
2193
2281
}
2194
2282
2283
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2284
+ send_rank (slot, batch_view);
2285
+ slot.release ();
2286
+ slot.i_batch = -1 ;
2287
+ continue ; // continue loop of slots
2288
+ }
2289
+
2195
2290
// prompt evaluated for next-token prediction
2196
2291
slot.state = SLOT_STATE_GENERATING;
2197
2292
} else if (slot.state != SLOT_STATE_GENERATING) {
@@ -2974,6 +3069,82 @@ int main(int argc, char ** argv) {
2974
3069
res_ok (res, root);
2975
3070
};
2976
3071
3072
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3073
+ const json body = json::parse (req.body );
3074
+
3075
+ // TODO: implement
3076
+ // int top_n = 1;
3077
+ // if (body.count("top_n") != 1) {
3078
+ // top_n = body.at("top_n");
3079
+ // } else {
3080
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3081
+ // return;
3082
+ // }
3083
+
3084
+ json query;
3085
+ if (body.count (" query" ) == 1 ) {
3086
+ query = body.at (" query" );
3087
+ if (!query.is_string ()) {
3088
+ res_error (res, format_error_response (" \" query\" must be a string" , ERROR_TYPE_INVALID_REQUEST));
3089
+ return ;
3090
+ }
3091
+ } else {
3092
+ exit (0 );
3093
+ res_error (res, format_error_response (" \" query\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3094
+ return ;
3095
+ }
3096
+
3097
+ json documents;
3098
+ if (body.count (" documents" ) != 0 ) {
3099
+ documents = body.at (" documents" );
3100
+ if (!documents.is_array () || documents.size () == 0 ) {
3101
+ res_error (res, format_error_response (" \" documents\" must be a non-empty string array" , ERROR_TYPE_INVALID_REQUEST));
3102
+ return ;
3103
+ }
3104
+ } else {
3105
+ res_error (res, format_error_response (" \" documents\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3106
+ return ;
3107
+ }
3108
+
3109
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
3110
+ json prompt;
3111
+ prompt.push_back (query);
3112
+ for (const auto & doc : documents) {
3113
+ prompt.push_back (doc);
3114
+ }
3115
+
3116
+ LOG_DBG (" rerank prompt: %s\n " , prompt.dump ().c_str ());
3117
+
3118
+ // create and queue the task
3119
+ json responses = json::array ();
3120
+ bool error = false ;
3121
+ {
3122
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl ({{" prompt" , prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3123
+ ctx_server.queue_results .add_waiting_tasks (tasks);
3124
+ ctx_server.queue_tasks .post (tasks);
3125
+
3126
+ // get the result
3127
+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
3128
+
3129
+ ctx_server.receive_cmpl_results (task_ids, [&](std::vector<server_task_result> & results) {
3130
+ for (const auto & res : results) {
3131
+ responses.push_back (res.data );
3132
+ }
3133
+ }, [&](const json & error_data) {
3134
+ res_error (res, error_data);
3135
+ error = true ;
3136
+ });
3137
+ }
3138
+
3139
+ if (error) {
3140
+ return ;
3141
+ }
3142
+
3143
+ // write JSON response
3144
+ json root = format_response_rerank (body, responses);
3145
+ res_ok (res, root);
3146
+ };
3147
+
2977
3148
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
2978
3149
json result = json::array ();
2979
3150
for (size_t i = 0 ; i < ctx_server.loras .size (); ++i) {
@@ -3070,6 +3241,7 @@ int main(int argc, char ** argv) {
3070
3241
svr->Post (" /embedding" , handle_embeddings); // legacy
3071
3242
svr->Post (" /embeddings" , handle_embeddings);
3072
3243
svr->Post (" /v1/embeddings" , handle_embeddings);
3244
+ svr->Post (" /v1/rerank" , handle_rerank);
3073
3245
svr->Post (" /tokenize" , handle_tokenize);
3074
3246
svr->Post (" /detokenize" , handle_detokenize);
3075
3247
// LoRA adapters hotswap
0 commit comments