Skip to content

Commit 99b71c0

Browse files
authored
Server: Use multi-task for embeddings endpoint (#6001)
* use multitask for embd endpoint * specify types * remove redundant {"n_predict", 0}
1 parent 306d34b commit 99b71c0

File tree

2 files changed

+38
-50
lines changed

2 files changed

+38
-50
lines changed

examples/server/server.cpp

+27-49
Original file line numberDiff line numberDiff line change
@@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) {
27632763
res.set_header("Access-Control-Allow-Credentials", "true");
27642764
res.set_header("Access-Control-Allow-Methods", "POST");
27652765
res.set_header("Access-Control-Allow-Headers", "*");
2766+
return res.set_content("", "application/json; charset=utf-8");
27662767
});
27672768

27682769
svr->set_logger(log_server_request);
@@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) {
33713372
const json body = json::parse(req.body);
33723373
bool is_openai = false;
33733374

3374-
// an input prompt can string or a list of tokens (integer)
3375-
std::vector<json> prompts;
3375+
// an input prompt can be a string or a list of tokens (integer)
3376+
json prompt;
33763377
if (body.count("input") != 0) {
33773378
is_openai = true;
3378-
if (body["input"].is_array()) {
3379-
// support multiple prompts
3380-
for (const json & elem : body["input"]) {
3381-
prompts.push_back(elem);
3382-
}
3383-
} else {
3384-
// single input prompt
3385-
prompts.push_back(body["input"]);
3386-
}
3379+
prompt = body["input"];
33873380
} else if (body.count("content") != 0) {
3388-
// only support single prompt here
3389-
std::string content = body["content"];
3390-
prompts.push_back(content);
3381+
// with "content", we only support single prompt
3382+
prompt = std::vector<std::string>{body["content"]};
33913383
} else {
33923384
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
33933385
return;
33943386
}
33953387

3396-
// process all prompts
3397-
json responses = json::array();
3398-
for (auto & prompt : prompts) {
3399-
// TODO @ngxson : maybe support multitask for this endpoint?
3400-
// create and queue the task
3388+
// create and queue the task
3389+
json responses;
3390+
{
34013391
const int id_task = ctx_server.queue_tasks.get_new_id();
3402-
34033392
ctx_server.queue_results.add_waiting_task_id(id_task);
3404-
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
3393+
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
34053394

34063395
// get the result
34073396
server_task_result result = ctx_server.queue_results.recv(id_task);
34083397
ctx_server.queue_results.remove_waiting_task_id(id_task);
34093398
if (!result.error) {
3410-
// append to the responses
3411-
responses.push_back(result.data);
3399+
if (result.data.count("results")) {
3400+
// result for multi-task
3401+
responses = result.data["results"];
3402+
} else {
3403+
// result for single task
3404+
responses = std::vector<json>{result.data};
3405+
}
34123406
} else {
34133407
// error received, ignore everything else
34143408
res_error(res, result.data);
@@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) {
34173411
}
34183412

34193413
// write JSON response
3420-
json root;
3421-
if (is_openai) {
3422-
json res_oai = json::array();
3423-
int i = 0;
3424-
for (auto & elem : responses) {
3425-
res_oai.push_back(json{
3426-
{"embedding", json_value(elem, "embedding", json::array())},
3427-
{"index", i++},
3428-
{"object", "embedding"}
3429-
});
3430-
}
3431-
root = format_embeddings_response_oaicompat(body, res_oai);
3432-
} else {
3433-
root = responses[0];
3434-
}
3414+
json root = is_openai
3415+
? format_embeddings_response_oaicompat(body, responses)
3416+
: responses[0];
34353417
return res.set_content(root.dump(), "application/json; charset=utf-8");
34363418
};
34373419

3420+
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3421+
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3422+
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3423+
return false;
3424+
};
3425+
};
3426+
34383427
//
34393428
// Router
34403429
//
@@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) {
34463435
}
34473436

34483437
// using embedded static files
3449-
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3450-
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3451-
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3452-
return false;
3453-
};
3454-
};
3455-
3456-
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
3457-
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
3458-
return res.set_content("", "application/json; charset=utf-8");
3459-
});
34603438
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
34613439
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
34623440
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));

examples/server/utils.hpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,24 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
529529
}
530530

531531
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
532+
json data = json::array();
533+
int i = 0;
534+
for (auto & elem : embeddings) {
535+
data.push_back(json{
536+
{"embedding", json_value(elem, "embedding", json::array())},
537+
{"index", i++},
538+
{"object", "embedding"}
539+
});
540+
}
541+
532542
json res = json {
533543
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
534544
{"object", "list"},
535545
{"usage", json {
536546
{"prompt_tokens", 0},
537547
{"total_tokens", 0}
538548
}},
539-
{"data", embeddings}
549+
{"data", data}
540550
};
541551

542552
return res;

0 commit comments

Comments
 (0)