Skip to content

Commit 8ffcd4e

Browse files
a-holexiyb
authored andcommitted
server : re-enable completion and embedded at the same time (ggml-org#3876)
1 parent c80994c commit 8ffcd4e

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
.DS_Store
1616
.build/
1717
.cache/
18+
.ccls-cache/
1819
.direnv/
1920
.envrc
2021
.swiftpm

Diff for: examples/server/server.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ struct task_server {
149149
task_type type;
150150
json data;
151151
bool infill_mode = false;
152+
bool embedding_mode = false;
152153
};
153154

154155
struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371372
std::vector<completion_token_output> generated_token_probs;
372373

373374
bool infill = false;
375+
bool embedding = false;
374376
bool has_next_token = true;
375377
bool truncated = false;
376378
bool stopped_eos = false;
@@ -1244,13 +1246,14 @@ struct llama_server_context
12441246
queue_results.push_back(res);
12451247
}
12461248

1247-
int request_completion(json data, bool infill)
1249+
int request_completion(json data, bool infill, bool embedding)
12481250
{
12491251
std::lock_guard<std::mutex> lock(mutex_tasks);
12501252
task_server task;
12511253
task.id = id_gen++;
12521254
task.data = data;
12531255
task.infill_mode = infill;
1256+
task.embedding_mode = embedding;
12541257
task.type = COMPLETION_TASK;
12551258
queue_tasks.push_back(task);
12561259
return task.id;
@@ -1376,7 +1379,7 @@ struct llama_server_context
13761379
{
13771380
LOG_TEE("slot unavailable\n");
13781381
// send error result
1379-
send_error(task.id, "slot unavaliable");
1382+
send_error(task.id, "slot unavailable");
13801383
return;
13811384
}
13821385

@@ -1388,6 +1391,7 @@ struct llama_server_context
13881391
slot->reset();
13891392

13901393
slot->infill = task.infill_mode;
1394+
slot->embedding = task.embedding_mode;
13911395
slot->task_id = task.id;
13921396

13931397
if (!launch_slot_with_data(slot, task.data))
@@ -1695,7 +1699,7 @@ struct llama_server_context
16951699
}
16961700

16971701
// prompt evaluated for embedding
1698-
if (params.embedding)
1702+
if (slot.embedding)
16991703
{
17001704
send_embedding(slot);
17011705
slot.release();
@@ -2277,7 +2281,7 @@ int main(int argc, char **argv)
22772281
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
22782282
{
22792283
json data = json::parse(req.body);
2280-
const int task_id = llama.request_completion(data, false);
2284+
const int task_id = llama.request_completion(data, false, false);
22812285
if (!json_value(data, "stream", false)) {
22822286
std::string completion_text;
22832287
task_result result = llama.next_result(task_id);
@@ -2332,7 +2336,7 @@ int main(int argc, char **argv)
23322336
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
23332337
{
23342338
json data = json::parse(req.body);
2335-
const int task_id = llama.request_completion(data, true);
2339+
const int task_id = llama.request_completion(data, true, false);
23362340
if (!json_value(data, "stream", false)) {
23372341
std::string completion_text;
23382342
task_result result = llama.next_result(task_id);
@@ -2436,7 +2440,7 @@ int main(int argc, char **argv)
24362440
{
24372441
prompt = "";
24382442
}
2439-
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false);
2443+
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
24402444
task_result result = llama.next_result(task_id);
24412445
return res.set_content(result.result_json.dump(), "application/json");
24422446
});

0 commit comments

Comments
 (0)