Skip to content

Commit ca190bc

Browse files
authored
server : re-enable completion and embedded at the same time (#3876)
1 parent 71e3718 commit ca190bc

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
.DS_Store
1616
.build/
1717
.cache/
18+
.ccls-cache/
1819
.direnv/
1920
.envrc
2021
.swiftpm

Diff for: examples/server/server.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ struct task_server {
149149
task_type type;
150150
json data;
151151
bool infill_mode = false;
152+
bool embedding_mode = false;
152153
};
153154

154155
struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371372
std::vector<completion_token_output> generated_token_probs;
372373

373374
bool infill = false;
375+
bool embedding = false;
374376
bool has_next_token = true;
375377
bool truncated = false;
376378
bool stopped_eos = false;
@@ -1244,13 +1246,14 @@ struct llama_server_context
12441246
queue_results.push_back(res);
12451247
}
12461248

1247-
int request_completion(json data, bool infill)
1249+
int request_completion(json data, bool infill, bool embedding)
12481250
{
12491251
std::lock_guard<std::mutex> lock(mutex_tasks);
12501252
task_server task;
12511253
task.id = id_gen++;
12521254
task.data = data;
12531255
task.infill_mode = infill;
1256+
task.embedding_mode = embedding;
12541257
task.type = COMPLETION_TASK;
12551258
queue_tasks.push_back(task);
12561259
return task.id;
@@ -1376,7 +1379,7 @@ struct llama_server_context
13761379
{
13771380
LOG_TEE("slot unavailable\n");
13781381
// send error result
1379-
send_error(task.id, "slot unavaliable");
1382+
send_error(task.id, "slot unavailable");
13801383
return;
13811384
}
13821385

@@ -1388,6 +1391,7 @@ struct llama_server_context
13881391
slot->reset();
13891392

13901393
slot->infill = task.infill_mode;
1394+
slot->embedding = task.embedding_mode;
13911395
slot->task_id = task.id;
13921396

13931397
if (!launch_slot_with_data(slot, task.data))
@@ -1695,7 +1699,7 @@ struct llama_server_context
16951699
}
16961700

16971701
// prompt evaluated for embedding
1698-
if (params.embedding)
1702+
if (slot.embedding)
16991703
{
17001704
send_embedding(slot);
17011705
slot.release();
@@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
22742278
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
22752279
{
22762280
json data = json::parse(req.body);
2277-
const int task_id = llama.request_completion(data, false);
2281+
const int task_id = llama.request_completion(data, false, false);
22782282
if (!json_value(data, "stream", false)) {
22792283
std::string completion_text;
22802284
task_result result = llama.next_result(task_id);
@@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
23292333
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
23302334
{
23312335
json data = json::parse(req.body);
2332-
const int task_id = llama.request_completion(data, true);
2336+
const int task_id = llama.request_completion(data, true, false);
23332337
if (!json_value(data, "stream", false)) {
23342338
std::string completion_text;
23352339
task_result result = llama.next_result(task_id);
@@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
24332437
{
24342438
prompt = "";
24352439
}
2436-
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false);
2440+
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
24372441
task_result result = llama.next_result(task_id);
24382442
return res.set_content(result.result_json.dump(), "application/json");
24392443
});

0 commit comments

Comments
 (0)