Skip to content

Commit 0161372

Browse files
committed
parallel : example for serving multiple users in parallel
1 parent 1f17ea6 commit 0161372

File tree

9 files changed

+262
-13
lines changed

9 files changed

+262
-13
lines changed

common/common.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
454454
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
455455
params.logdir += DIRECTORY_SEPARATOR;
456456
}
457-
} else if (arg == "--perplexity") {
458-
params.perplexity = true;
457+
} else if (arg == "--perplexity" || arg == "--all-logits") {
458+
params.logits_all = true;
459459
} else if (arg == "--ppl-stride") {
460460
if (++i >= argc) {
461461
invalid_param = true;
@@ -653,7 +653,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
653653
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
654654
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
655655
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
656-
printf(" --perplexity compute perplexity over each ctx window of the prompt\n");
656+
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
657657
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
658658
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
659659
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@@ -735,7 +735,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
735735
lparams.f16_kv = params.memory_f16;
736736
lparams.use_mmap = params.use_mmap;
737737
lparams.use_mlock = params.use_mlock;
738-
lparams.logits_all = params.perplexity;
738+
lparams.logits_all = params.logits_all;
739739
lparams.embedding = params.embedding;
740740
lparams.rope_freq_base = params.rope_freq_base;
741741
lparams.rope_freq_scale = params.rope_freq_scale;

common/common.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ struct gpt_params {
113113
bool ignore_eos = false; // ignore generated EOS tokens
114114
bool instruct = false; // instruction mode (used for Alpaca models)
115115
bool penalize_nl = true; // consider newlines as a repeatable token
116-
bool perplexity = false; // compute perplexity over the prompt
116+
bool logits_all = false; // return logits for all tokens in the batch
117117
bool use_mmap = true; // use mmap for faster loads
118118
bool use_mlock = false; // use mlock to keep model in memory
119119
bool numa = false; // attempt optimizations that help on some NUMA systems

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ else()
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
2626
add_subdirectory(speculative)
27+
add_subdirectory(parallel)
2728
add_subdirectory(embd-input)
2829
add_subdirectory(llama-bench)
2930
add_subdirectory(beam-search)

examples/main/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
124124
console::init(params.simple_io, params.use_color);
125125
atexit([]() { console::cleanup(); });
126126

127-
if (params.perplexity) {
127+
if (params.logits_all) {
128128
printf("\n************\n");
129129
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
130130
printf("************\n\n");

examples/parallel/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
set(TARGET parallel)
2+
add_executable(${TARGET} parallel.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
if(TARGET BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
8+
endif()

examples/parallel/parallel.cpp

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
// A basic application simulating a server with multiple clients.
2+
// The clients submite requests to the server and they are processed in parallel.
3+
4+
#include "build-info.h"
5+
6+
#include "common.h"
7+
#include "llama.h"
8+
9+
#include <cmath>
10+
#include <cstdio>
11+
#include <string>
12+
#include <vector>
13+
14+
// trim whitespace from the beginning and end of a string
15+
static std::string trim(const std::string & str) {
16+
size_t start = 0;
17+
size_t end = str.size();
18+
19+
while (start < end && isspace(str[start])) {
20+
start += 1;
21+
}
22+
23+
while (end > start && isspace(str[end - 1])) {
24+
end -= 1;
25+
}
26+
27+
return str.substr(start, end - start);
28+
}
29+
30+
static std::string k_system = R"(
31+
Transcript of a dialog, where the User interacts with an Assistant.
32+
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
33+
34+
User: Hello, what is the temperature outside?
35+
Assistant: It is 72 degrees Fahrenheit.
36+
User: What is the definition of a prime number?
37+
Assistant: A prime number is a number that is divisible only by itself and 1.
38+
User: )";
39+
40+
static std::vector<std::string> k_prompts = {
41+
"What is the meaning of life?",
42+
"What is the population of Europe?",
43+
"List all planets in the Solar System.",
44+
"What is the capital of France?",
45+
"Tell me an interesting fact about llamas.",
46+
"What is the best way to cook a steak?",
47+
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
48+
"Recommend some interesting books to read.",
49+
"What is the best way to learn a new language?",
50+
"How to get a job at Google?",
51+
"If you could have any superpower, what would it be?",
52+
"I want to learn how to play the piano.",
53+
};
54+
55+
struct client {
56+
int32_t id = 0;
57+
58+
llama_seq_id seq_id = -1;
59+
60+
llama_token sampled;
61+
62+
int32_t n_prompt = 0;
63+
int32_t n_decoded = 0;
64+
int32_t i_batch = -1;
65+
66+
std::string input;
67+
std::string prompt;
68+
std::string response;
69+
70+
std::vector<llama_token> last_tokens;
71+
};
72+
73+
int main(int argc, char ** argv) {
74+
gpt_params params;
75+
76+
if (gpt_params_parse(argc, argv, params) == false) {
77+
return 1;
78+
}
79+
80+
const int n_clients = 16;
81+
82+
#ifndef LOG_DISABLE_LOGS
83+
log_set_target(log_filename_generator("parallel", "log"));
84+
LOG_TEE("Log start\n");
85+
log_dump_cmdline(argc, argv);
86+
#endif // LOG_DISABLE_LOGS
87+
88+
// init llama.cpp
89+
llama_backend_init(params.numa);
90+
91+
llama_model * model = NULL;
92+
93+
llama_context * ctx = NULL;
94+
95+
// load the target model
96+
params.logits_all = true;
97+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
98+
99+
fprintf(stderr, "\n\n");
100+
fflush(stderr);
101+
102+
const int n_ctx = llama_n_ctx(ctx);
103+
const int n_vocab = llama_n_vocab(ctx);
104+
105+
std::vector<client> clients(n_clients);
106+
for (size_t i = 0; i < clients.size(); ++i) {
107+
auto & client = clients[i];
108+
client.id = i;
109+
client.last_tokens.resize(n_ctx);
110+
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
111+
}
112+
113+
std::vector<llama_token_data> candidates;
114+
candidates.reserve(n_vocab);
115+
116+
auto t_main_start = ggml_time_us();
117+
118+
int64_t n_tokens_total = 0;
119+
120+
llama_seq_id g_seq_id = 0;
121+
122+
std::vector<llama_token> batch_token;
123+
std::vector<llama_pos> batch_pos;
124+
std::vector<llama_seq_id> batch_seq_id;
125+
std::vector<client *> batch_clients;
126+
127+
while (true) {
128+
uint32_t n_tokens = 0;
129+
130+
batch_token.clear();
131+
batch_pos.clear();
132+
batch_seq_id.clear();
133+
134+
for (auto & client : clients) {
135+
if (client.seq_id == -1) {
136+
client.seq_id = g_seq_id;
137+
client.input = k_prompts[rand() % k_prompts.size()];
138+
client.prompt = k_system + client.input + "\nAssistant:";
139+
client.response = "";
140+
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
141+
142+
std::vector<llama_token> prompt_tokens;
143+
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
144+
145+
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
146+
batch_token.push_back(prompt_tokens[i]);
147+
batch_pos.push_back(i);
148+
batch_seq_id.push_back(client.seq_id);
149+
batch_clients.push_back(&client);
150+
}
151+
client.n_prompt = prompt_tokens.size();
152+
client.n_decoded = prompt_tokens.size();
153+
client.i_batch = batch_token.size() - 1;
154+
155+
g_seq_id += 1;
156+
} else {
157+
batch_token.push_back(client.sampled);
158+
batch_pos.push_back(client.n_decoded);
159+
batch_seq_id.push_back(client.seq_id);
160+
batch_clients.push_back(&client);
161+
client.n_decoded += 1;
162+
client.i_batch = batch_token.size() - 1;
163+
}
164+
}
165+
166+
// process in chunks of params.n_batch
167+
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
168+
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
169+
170+
llama_batch batch = {
171+
n_tokens,
172+
batch_token.data() + i,
173+
nullptr,
174+
batch_pos.data() + i,
175+
batch_seq_id.data() + i,
176+
0, 0, 0, // unused
177+
};
178+
179+
if (llama_decode(ctx, batch, params.n_threads)) {
180+
LOG_TEE("%s : failed to decode batch\n", __func__);
181+
return 1;
182+
}
183+
184+
for (auto & client : clients) {
185+
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
186+
continue;
187+
}
188+
189+
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
190+
191+
// remember which tokens were sampled - used for repetition penalties during sampling
192+
client.last_tokens.erase(client.last_tokens.begin());
193+
client.last_tokens.push_back(id);
194+
195+
const std::string token_str = llama_token_to_piece(ctx, id);
196+
client.response += token_str;
197+
client.sampled = id;
198+
199+
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
200+
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
201+
202+
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
203+
const size_t pos = client.response.find("User:");
204+
if (pos != std::string::npos) {
205+
client.response = client.response.substr(0, pos);
206+
}
207+
208+
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
209+
210+
const auto t_main_end = ggml_time_us();
211+
212+
n_tokens_total += client.n_decoded - client.n_prompt;
213+
214+
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
215+
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
216+
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
217+
client.input.c_str(), ::trim(client.response).c_str());
218+
219+
client.seq_id = -1;
220+
}
221+
}
222+
}
223+
224+
static bool is_first = true;
225+
if (is_first) {
226+
t_main_start = ggml_time_us();
227+
n_tokens_total = 0;
228+
is_first = false;
229+
}
230+
}
231+
232+
LOG_TEE("\n\n");
233+
234+
llama_print_timings(ctx);
235+
236+
llama_free(ctx);
237+
llama_free_model(model);
238+
239+
llama_backend_free();
240+
241+
fprintf(stderr, "\n\n");
242+
243+
return 0;
244+
}

examples/perplexity/perplexity.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ int main(int argc, char ** argv) {
681681
return 1;
682682
}
683683

684-
params.perplexity = true;
684+
params.logits_all = true;
685685
params.n_batch = std::min(params.n_batch, params.n_ctx);
686686

687687
if (params.ppl_stride > 0) {

examples/speculative/speculative.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
3737
llama_context * ctx_dft = NULL;
3838

3939
// load the target model
40-
params.perplexity = true; // HACK: enable logits_all = true
40+
params.logits_all = true;
4141
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
4242

4343
// load the draft model
@@ -172,7 +172,6 @@ int main(int argc, char ** argv) {
172172
LOG("out of drafted tokens\n");
173173
}
174174

175-
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
176175
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
177176
++n_past_dft;
178177

@@ -218,7 +217,6 @@ int main(int argc, char ** argv) {
218217

219218
// sample n_draft tokens from the draft model using greedy decoding
220219
int n_past_cur = n_past_dft;
221-
222220
for (int i = 0; i < n_draft; ++i) {
223221
float * logits = llama_get_logits(ctx_dft);
224222

@@ -258,7 +256,6 @@ int main(int argc, char ** argv) {
258256
}
259257

260258
// evaluate the drafted token on the draft model
261-
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
262259
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
263260
++n_past_cur;
264261

@@ -268,7 +265,6 @@ int main(int argc, char ** argv) {
268265
}
269266

270267
// evaluate the target model on the drafted tokens
271-
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
272268
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
273269
++n_past_tgt;
274270

llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6673,7 +6673,7 @@ struct llama_context * llama_new_context_with_model(
66736673
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
66746674

66756675
// build worst-case graph
6676-
uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
6676+
uint32_t n_tokens = std::max((int)hparams.n_ctx, params.n_batch);
66776677
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
66786678
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0));
66796679

0 commit comments

Comments
 (0)