@@ -149,6 +149,7 @@ struct task_server {
149
149
task_type type;
150
150
json data;
151
151
bool infill_mode = false ;
152
+ bool embedding_mode = false ;
152
153
};
153
154
154
155
struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371
372
std::vector<completion_token_output> generated_token_probs;
372
373
373
374
bool infill = false ;
375
+ bool embedding = false ;
374
376
bool has_next_token = true ;
375
377
bool truncated = false ;
376
378
bool stopped_eos = false ;
@@ -1244,13 +1246,14 @@ struct llama_server_context
1244
1246
queue_results.push_back (res);
1245
1247
}
1246
1248
1247
- int request_completion (json data, bool infill)
1249
+ int request_completion (json data, bool infill, bool embedding )
1248
1250
{
1249
1251
std::lock_guard<std::mutex> lock (mutex_tasks);
1250
1252
task_server task;
1251
1253
task.id = id_gen++;
1252
1254
task.data = data;
1253
1255
task.infill_mode = infill;
1256
+ task.embedding_mode = embedding;
1254
1257
task.type = COMPLETION_TASK;
1255
1258
queue_tasks.push_back (task);
1256
1259
return task.id ;
@@ -1376,7 +1379,7 @@ struct llama_server_context
1376
1379
{
1377
1380
LOG_TEE (" slot unavailable\n " );
1378
1381
// send error result
1379
- send_error (task.id , " slot unavaliable " );
1382
+ send_error (task.id , " slot unavailable " );
1380
1383
return ;
1381
1384
}
1382
1385
@@ -1388,6 +1391,7 @@ struct llama_server_context
1388
1391
slot->reset ();
1389
1392
1390
1393
slot->infill = task.infill_mode ;
1394
+ slot->embedding = task.embedding_mode ;
1391
1395
slot->task_id = task.id ;
1392
1396
1393
1397
if (!launch_slot_with_data (slot, task.data ))
@@ -1695,7 +1699,7 @@ struct llama_server_context
1695
1699
}
1696
1700
1697
1701
// prompt evaluated for embedding
1698
- if (params .embedding )
1702
+ if (slot .embedding )
1699
1703
{
1700
1704
send_embedding (slot);
1701
1705
slot.release ();
@@ -2277,7 +2281,7 @@ int main(int argc, char **argv)
2277
2281
svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
2278
2282
{
2279
2283
json data = json::parse (req.body );
2280
- const int task_id = llama.request_completion (data, false );
2284
+ const int task_id = llama.request_completion (data, false , false );
2281
2285
if (!json_value (data, " stream" , false )) {
2282
2286
std::string completion_text;
2283
2287
task_result result = llama.next_result (task_id);
@@ -2332,7 +2336,7 @@ int main(int argc, char **argv)
2332
2336
svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
2333
2337
{
2334
2338
json data = json::parse (req.body );
2335
- const int task_id = llama.request_completion (data, true );
2339
+ const int task_id = llama.request_completion (data, true , false );
2336
2340
if (!json_value (data, " stream" , false )) {
2337
2341
std::string completion_text;
2338
2342
task_result result = llama.next_result (task_id);
@@ -2436,7 +2440,7 @@ int main(int argc, char **argv)
2436
2440
{
2437
2441
prompt = " " ;
2438
2442
}
2439
- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false );
2443
+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
2440
2444
task_result result = llama.next_result (task_id);
2441
2445
return res.set_content (result.result_json .dump (), " application/json" );
2442
2446
});
0 commit comments