@@ -149,6 +149,7 @@ struct task_server {
149
149
task_type type;
150
150
json data;
151
151
bool infill_mode = false ;
152
+ bool embedding_mode = false ;
152
153
};
153
154
154
155
struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371
372
std::vector<completion_token_output> generated_token_probs;
372
373
373
374
bool infill = false ;
375
+ bool embedding = false ;
374
376
bool has_next_token = true ;
375
377
bool truncated = false ;
376
378
bool stopped_eos = false ;
@@ -1244,13 +1246,14 @@ struct llama_server_context
1244
1246
queue_results.push_back (res);
1245
1247
}
1246
1248
1247
- int request_completion (json data, bool infill)
1249
+ int request_completion (json data, bool infill, bool embedding )
1248
1250
{
1249
1251
std::lock_guard<std::mutex> lock (mutex_tasks);
1250
1252
task_server task;
1251
1253
task.id = id_gen++;
1252
1254
task.data = data;
1253
1255
task.infill_mode = infill;
1256
+ task.embedding_mode = embedding;
1254
1257
task.type = COMPLETION_TASK;
1255
1258
queue_tasks.push_back (task);
1256
1259
return task.id ;
@@ -1376,7 +1379,7 @@ struct llama_server_context
1376
1379
{
1377
1380
LOG_TEE (" slot unavailable\n " );
1378
1381
// send error result
1379
- send_error (task.id , " slot unavaliable " );
1382
+ send_error (task.id , " slot unavailable " );
1380
1383
return ;
1381
1384
}
1382
1385
@@ -1388,6 +1391,7 @@ struct llama_server_context
1388
1391
slot->reset ();
1389
1392
1390
1393
slot->infill = task.infill_mode ;
1394
+ slot->embedding = task.embedding_mode ;
1391
1395
slot->task_id = task.id ;
1392
1396
1393
1397
if (!launch_slot_with_data (slot, task.data ))
@@ -1695,7 +1699,7 @@ struct llama_server_context
1695
1699
}
1696
1700
1697
1701
// prompt evaluated for embedding
1698
- if (params .embedding )
1702
+ if (slot .embedding )
1699
1703
{
1700
1704
send_embedding (slot);
1701
1705
slot.release ();
@@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
2274
2278
svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
2275
2279
{
2276
2280
json data = json::parse (req.body );
2277
- const int task_id = llama.request_completion (data, false );
2281
+ const int task_id = llama.request_completion (data, false , false );
2278
2282
if (!json_value (data, " stream" , false )) {
2279
2283
std::string completion_text;
2280
2284
task_result result = llama.next_result (task_id);
@@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
2329
2333
svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
2330
2334
{
2331
2335
json data = json::parse (req.body );
2332
- const int task_id = llama.request_completion (data, true );
2336
+ const int task_id = llama.request_completion (data, true , false );
2333
2337
if (!json_value (data, " stream" , false )) {
2334
2338
std::string completion_text;
2335
2339
task_result result = llama.next_result (task_id);
@@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
2433
2437
{
2434
2438
prompt = " " ;
2435
2439
}
2436
- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false );
2440
+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
2437
2441
task_result result = llama.next_result (task_id);
2438
2442
return res.set_content (result.result_json .dump (), " application/json" );
2439
2443
});
0 commit comments