Skip to content

Commit 5f95dcc

Browse files
committed
server : add rerank endpoint
ggml-ci
1 parent f03bcd8 commit 5f95dcc

File tree

3 files changed

+209
-14
lines changed

3 files changed

+209
-14
lines changed

common/arg.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
11031103
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
11041104
else { throw std::invalid_argument("invalid value"); }
11051105
}
1106-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
1106+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
11071107
add_opt(llama_arg(
11081108
{"--attention"}, "{causal,non,causal}",
11091109
"attention type for embeddings, use model default if unspecified",

examples/server/server.cpp

+184-12
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ enum server_task_type {
9292
enum server_task_cmpl_type {
9393
SERVER_TASK_CMPL_TYPE_NORMAL,
9494
SERVER_TASK_CMPL_TYPE_EMBEDDING,
95+
SERVER_TASK_CMPL_TYPE_RERANK,
9596
SERVER_TASK_CMPL_TYPE_INFILL,
9697
};
9798

@@ -172,6 +173,7 @@ struct server_slot {
172173
std::vector<completion_token_output> generated_token_probs;
173174

174175
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176+
175177
bool has_next_token = true;
176178
bool truncated = false;
177179
bool stopped_eos = false;
@@ -942,8 +944,17 @@ struct server_context {
942944
slot.prompt = *prompt;
943945
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
944946
slot.prompt = prompt->at(0);
947+
} else if (prompt->is_array() && prompt->size() > 1) {
948+
// array of strings
949+
for (const auto & el : *prompt) {
950+
if (!el.is_string()) {
951+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
952+
return false;
953+
}
954+
}
955+
slot.prompt = *prompt;
945956
} else {
946-
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
957+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
947958
return false;
948959
}
949960
}
@@ -1368,6 +1379,7 @@ struct server_context {
13681379

13691380
res.data = json {
13701381
{"embedding", std::vector<float>(n_embd, 0.0f)},
1382+
{"index", slot.index},
13711383
};
13721384

13731385
continue;
@@ -1386,6 +1398,44 @@ struct server_context {
13861398
queue_results.send(res);
13871399
}
13881400

1401+
void send_rank(const server_slot & slot, const llama_batch & batch) {
1402+
server_task_result res;
1403+
res.id = slot.id_task;
1404+
res.error = false;
1405+
res.stop = true;
1406+
1407+
for (int i = 0; i < batch.n_tokens; ++i) {
1408+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1409+
continue;
1410+
}
1411+
1412+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1413+
if (embd == NULL) {
1414+
embd = llama_get_embeddings_ith(ctx, i);
1415+
}
1416+
1417+
if (embd == NULL) {
1418+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1419+
1420+
res.data = json {
1421+
{"index", slot.index},
1422+
{"rank", -1e6},
1423+
};
1424+
1425+
continue;
1426+
}
1427+
1428+
res.data = json {
1429+
{"index", slot.index},
1430+
{"rank", embd[0]},
1431+
};
1432+
}
1433+
1434+
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
1435+
1436+
queue_results.send(res);
1437+
}
1438+
13891439
//
13901440
// Functions to create new task(s) and receive result(s)
13911441
//
@@ -1421,13 +1471,23 @@ struct server_context {
14211471
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14221472
else if (prompt.is_array()) {
14231473
std::vector<json> prompts = prompt;
1424-
for (size_t i = 0; i < prompts.size(); i++) {
1425-
const auto & e = prompts[i];
1426-
if (e.is_string() || json_is_array_of_numbers(e)) {
1427-
data["index"] = i;
1428-
create_task(data, true, e);
1429-
} else {
1430-
throw std::runtime_error(error_msg);
1474+
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1475+
for (size_t i = 1; i < prompts.size(); i++) {
1476+
json qd;
1477+
qd.push_back(prompts[0]);
1478+
qd.push_back(prompts[i]);
1479+
data["index"] = i - 1;
1480+
create_task(data, true, qd);
1481+
}
1482+
} else {
1483+
for (size_t i = 0; i < prompts.size(); i++) {
1484+
const auto & e = prompts[i];
1485+
if (e.is_string() || json_is_array_of_numbers(e)) {
1486+
data["index"] = i;
1487+
create_task(data, true, e);
1488+
} else {
1489+
throw std::runtime_error(error_msg);
1490+
}
14311491
}
14321492
}
14331493
}
@@ -1471,7 +1531,9 @@ struct server_context {
14711531
break;
14721532
}
14731533

1474-
size_t idx = result.data["index"];
1534+
const size_t idx = result.data["index"];
1535+
GGML_ASSERT(idx < results.size() && "index out of range");
1536+
14751537
results[idx] = result;
14761538
}
14771539
result_handler(results);
@@ -1922,6 +1984,29 @@ struct server_context {
19221984
}
19231985

19241986
prompt_tokens = embd_inp;
1987+
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1988+
// require slot.prompt to be array of 2 strings
1989+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
1990+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
1991+
slot.release();
1992+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
1993+
continue;
1994+
}
1995+
1996+
// prompt: <s>query</s><s>doc</s>
1997+
prompt_tokens.clear();
1998+
prompt_tokens.push_back(llama_token_bos(model));
1999+
{
2000+
const auto part = tokenize(slot.prompt[0], false);
2001+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2002+
}
2003+
prompt_tokens.push_back(llama_token_eos(model));
2004+
prompt_tokens.push_back(llama_token_bos(model));
2005+
{
2006+
const auto part = tokenize(slot.prompt[1], false);
2007+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2008+
}
2009+
prompt_tokens.push_back(llama_token_eos(model));
19252010
} else {
19262011
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
19272012
}
@@ -1941,7 +2026,7 @@ struct server_context {
19412026
continue;
19422027
}
19432028

1944-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2029+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
19452030
// this prompt is too large to process - discard it
19462031
if (slot.n_prompt_tokens > n_ubatch) {
19472032
slot.release();
@@ -2011,15 +2096,18 @@ struct server_context {
20112096
slot.n_prompt_tokens_processed = 0;
20122097
}
20132098

2014-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2099+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
20152100
// cannot fit the prompt in the current batch - will try next iter
20162101
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
20172102
continue;
20182103
}
20192104
}
20202105

20212106
// check that we are in the right batch_type, if not defer the slot
2022-
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
2107+
const bool slot_type =
2108+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2109+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
2110+
20232111
if (batch_type == -1) {
20242112
batch_type = slot_type;
20252113
} else if (batch_type != slot_type) {
@@ -2192,6 +2280,13 @@ struct server_context {
21922280
continue; // continue loop of slots
21932281
}
21942282

2283+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2284+
send_rank(slot, batch_view);
2285+
slot.release();
2286+
slot.i_batch = -1;
2287+
continue; // continue loop of slots
2288+
}
2289+
21952290
// prompt evaluated for next-token prediction
21962291
slot.state = SLOT_STATE_GENERATING;
21972292
} else if (slot.state != SLOT_STATE_GENERATING) {
@@ -2974,6 +3069,82 @@ int main(int argc, char ** argv) {
29743069
res_ok(res, root);
29753070
};
29763071

3072+
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3073+
const json body = json::parse(req.body);
3074+
3075+
// TODO: implement
3076+
//int top_n = 1;
3077+
//if (body.count("top_n") != 1) {
3078+
// top_n = body.at("top_n");
3079+
//} else {
3080+
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3081+
// return;
3082+
//}
3083+
3084+
json query;
3085+
if (body.count("query") == 1) {
3086+
query = body.at("query");
3087+
if (!query.is_string()) {
3088+
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3089+
return;
3090+
}
3091+
} else {
3092+
exit(0);
3093+
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3094+
return;
3095+
}
3096+
3097+
json documents;
3098+
if (body.count("documents") != 0) {
3099+
documents = body.at("documents");
3100+
if (!documents.is_array() || documents.size() == 0) {
3101+
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3102+
return;
3103+
}
3104+
} else {
3105+
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3106+
return;
3107+
}
3108+
3109+
// construct prompt object: array of ["query", "doc0", "doc1", ...]
3110+
json prompt;
3111+
prompt.push_back(query);
3112+
for (const auto & doc : documents) {
3113+
prompt.push_back(doc);
3114+
}
3115+
3116+
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3117+
3118+
// create and queue the task
3119+
json responses = json::array();
3120+
bool error = false;
3121+
{
3122+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3123+
ctx_server.queue_results.add_waiting_tasks(tasks);
3124+
ctx_server.queue_tasks.post(tasks);
3125+
3126+
// get the result
3127+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3128+
3129+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3130+
for (const auto & res : results) {
3131+
responses.push_back(res.data);
3132+
}
3133+
}, [&](const json & error_data) {
3134+
res_error(res, error_data);
3135+
error = true;
3136+
});
3137+
}
3138+
3139+
if (error) {
3140+
return;
3141+
}
3142+
3143+
// write JSON response
3144+
json root = format_response_rerank(body, responses);
3145+
res_ok(res, root);
3146+
};
3147+
29773148
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
29783149
json result = json::array();
29793150
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
@@ -3070,6 +3241,7 @@ int main(int argc, char ** argv) {
30703241
svr->Post("/embedding", handle_embeddings); // legacy
30713242
svr->Post("/embeddings", handle_embeddings);
30723243
svr->Post("/v1/embeddings", handle_embeddings);
3244+
svr->Post("/v1/rerank", handle_rerank);
30733245
svr->Post("/tokenize", handle_tokenize);
30743246
svr->Post("/detokenize", handle_detokenize);
30753247
// LoRA adapters hotswap

examples/server/utils.hpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
534534
json res = json {
535535
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
536536
{"object", "list"},
537-
{"usage", json {
537+
{"usage", json { // TODO: fill
538538
{"prompt_tokens", 0},
539539
{"total_tokens", 0}
540540
}},
@@ -544,6 +544,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
544544
return res;
545545
}
546546

547+
static json format_response_rerank(const json & request, const json & ranks) {
548+
json data = json::array();
549+
int i = 0;
550+
for (const auto & rank : ranks) {
551+
data.push_back(json{
552+
{"index", i++},
553+
{"relevance_score", json_value(rank, "rank", 0.0)},
554+
});
555+
}
556+
557+
json res = json {
558+
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
559+
{"object", "list"},
560+
{"usage", json { // TODO: fill
561+
{"prompt_tokens", 0},
562+
{"total_tokens", 0}
563+
}},
564+
{"results", data}
565+
};
566+
567+
return res;
568+
}
569+
547570
static bool is_valid_utf8(const std::string & str) {
548571
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
549572
const unsigned char* end = bytes + str.length();

0 commit comments

Comments
 (0)