|
| 1 | +// A basic application simulating a server with multiple clients. |
| 2 | +// The clients submite requests to the server and they are processed in parallel. |
| 3 | + |
| 4 | +#include "build-info.h" |
| 5 | + |
| 6 | +#include "common.h" |
| 7 | +#include "llama.h" |
| 8 | + |
| 9 | +#include <cmath> |
| 10 | +#include <cstdio> |
| 11 | +#include <string> |
| 12 | +#include <vector> |
| 13 | + |
| 14 | +// trim whitespace from the beginning and end of a string |
| 15 | +static std::string trim(const std::string & str) { |
| 16 | + size_t start = 0; |
| 17 | + size_t end = str.size(); |
| 18 | + |
| 19 | + while (start < end && isspace(str[start])) { |
| 20 | + start += 1; |
| 21 | + } |
| 22 | + |
| 23 | + while (end > start && isspace(str[end - 1])) { |
| 24 | + end -= 1; |
| 25 | + } |
| 26 | + |
| 27 | + return str.substr(start, end - start); |
| 28 | +} |
| 29 | + |
| 30 | +static std::string k_system = R"( |
| 31 | +Transcript of a dialog, where the User interacts with an Assistant. |
| 32 | +The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. |
| 33 | +
|
| 34 | +User: Hello, what is the temperature outside? |
| 35 | +Assistant: It is 72 degrees Fahrenheit. |
| 36 | +User: What is the definition of a prime number? |
| 37 | +Assistant: A prime number is a number that is divisible only by itself and 1. |
| 38 | +User: )"; |
| 39 | + |
| 40 | +static std::vector<std::string> k_prompts = { |
| 41 | + "What is the meaning of life?", |
| 42 | + "What is the population of Europe?", |
| 43 | + "List all planets in the Solar System.", |
| 44 | + "What is the capital of France?", |
| 45 | + "Tell me an interesting fact about llamas.", |
| 46 | + "What is the best way to cook a steak?", |
| 47 | + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", |
| 48 | + "Recommend some interesting books to read.", |
| 49 | + "What is the best way to learn a new language?", |
| 50 | + "How to get a job at Google?", |
| 51 | + "If you could have any superpower, what would it be?", |
| 52 | + "I want to learn how to play the piano.", |
| 53 | +}; |
| 54 | + |
| 55 | +struct client { |
| 56 | + int32_t id = 0; |
| 57 | + |
| 58 | + llama_seq_id seq_id = -1; |
| 59 | + |
| 60 | + llama_token sampled; |
| 61 | + |
| 62 | + int32_t n_prompt = 0; |
| 63 | + int32_t n_decoded = 0; |
| 64 | + int32_t i_batch = -1; |
| 65 | + |
| 66 | + std::string input; |
| 67 | + std::string prompt; |
| 68 | + std::string response; |
| 69 | + |
| 70 | + std::vector<llama_token> last_tokens; |
| 71 | +}; |
| 72 | + |
| 73 | +int main(int argc, char ** argv) { |
| 74 | + gpt_params params; |
| 75 | + |
| 76 | + if (gpt_params_parse(argc, argv, params) == false) { |
| 77 | + return 1; |
| 78 | + } |
| 79 | + |
| 80 | + const int n_clients = 16; |
| 81 | + |
| 82 | +#ifndef LOG_DISABLE_LOGS |
| 83 | + log_set_target(log_filename_generator("parallel", "log")); |
| 84 | + LOG_TEE("Log start\n"); |
| 85 | + log_dump_cmdline(argc, argv); |
| 86 | +#endif // LOG_DISABLE_LOGS |
| 87 | + |
| 88 | + // init llama.cpp |
| 89 | + llama_backend_init(params.numa); |
| 90 | + |
| 91 | + llama_model * model = NULL; |
| 92 | + |
| 93 | + llama_context * ctx = NULL; |
| 94 | + |
| 95 | + // load the target model |
| 96 | + params.logits_all = true; |
| 97 | + std::tie(model, ctx) = llama_init_from_gpt_params(params); |
| 98 | + |
| 99 | + fprintf(stderr, "\n\n"); |
| 100 | + fflush(stderr); |
| 101 | + |
| 102 | + const int n_ctx = llama_n_ctx(ctx); |
| 103 | + const int n_vocab = llama_n_vocab(ctx); |
| 104 | + |
| 105 | + std::vector<client> clients(n_clients); |
| 106 | + for (size_t i = 0; i < clients.size(); ++i) { |
| 107 | + auto & client = clients[i]; |
| 108 | + client.id = i; |
| 109 | + client.last_tokens.resize(n_ctx); |
| 110 | + std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); |
| 111 | + } |
| 112 | + |
| 113 | + std::vector<llama_token_data> candidates; |
| 114 | + candidates.reserve(n_vocab); |
| 115 | + |
| 116 | + auto t_main_start = ggml_time_us(); |
| 117 | + |
| 118 | + int64_t n_tokens_total = 0; |
| 119 | + |
| 120 | + llama_seq_id g_seq_id = 0; |
| 121 | + |
| 122 | + std::vector<llama_token> batch_token; |
| 123 | + std::vector<llama_pos> batch_pos; |
| 124 | + std::vector<llama_seq_id> batch_seq_id; |
| 125 | + std::vector<client *> batch_clients; |
| 126 | + |
| 127 | + while (true) { |
| 128 | + uint32_t n_tokens = 0; |
| 129 | + |
| 130 | + batch_token.clear(); |
| 131 | + batch_pos.clear(); |
| 132 | + batch_seq_id.clear(); |
| 133 | + |
| 134 | + for (auto & client : clients) { |
| 135 | + if (client.seq_id == -1) { |
| 136 | + client.seq_id = g_seq_id; |
| 137 | + client.input = k_prompts[rand() % k_prompts.size()]; |
| 138 | + client.prompt = k_system + client.input + "\nAssistant:"; |
| 139 | + client.response = ""; |
| 140 | + std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); |
| 141 | + |
| 142 | + std::vector<llama_token> prompt_tokens; |
| 143 | + prompt_tokens = ::llama_tokenize(ctx, client.prompt, true); |
| 144 | + |
| 145 | + for (size_t i = 0; i < prompt_tokens.size(); ++i) { |
| 146 | + batch_token.push_back(prompt_tokens[i]); |
| 147 | + batch_pos.push_back(i); |
| 148 | + batch_seq_id.push_back(client.seq_id); |
| 149 | + batch_clients.push_back(&client); |
| 150 | + } |
| 151 | + client.n_prompt = prompt_tokens.size(); |
| 152 | + client.n_decoded = prompt_tokens.size(); |
| 153 | + client.i_batch = batch_token.size() - 1; |
| 154 | + |
| 155 | + g_seq_id += 1; |
| 156 | + } else { |
| 157 | + batch_token.push_back(client.sampled); |
| 158 | + batch_pos.push_back(client.n_decoded); |
| 159 | + batch_seq_id.push_back(client.seq_id); |
| 160 | + batch_clients.push_back(&client); |
| 161 | + client.n_decoded += 1; |
| 162 | + client.i_batch = batch_token.size() - 1; |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + // process in chunks of params.n_batch |
| 167 | + for (size_t i = 0; i < batch_token.size(); i += params.n_batch) { |
| 168 | + n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i)); |
| 169 | + |
| 170 | + llama_batch batch = { |
| 171 | + n_tokens, |
| 172 | + batch_token.data() + i, |
| 173 | + nullptr, |
| 174 | + batch_pos.data() + i, |
| 175 | + batch_seq_id.data() + i, |
| 176 | + 0, 0, 0, // unused |
| 177 | + }; |
| 178 | + |
| 179 | + if (llama_decode(ctx, batch, params.n_threads)) { |
| 180 | + LOG_TEE("%s : failed to decode batch\n", __func__); |
| 181 | + return 1; |
| 182 | + } |
| 183 | + |
| 184 | + for (auto & client : clients) { |
| 185 | + if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { |
| 186 | + continue; |
| 187 | + } |
| 188 | + |
| 189 | + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i); |
| 190 | + |
| 191 | + // remember which tokens were sampled - used for repetition penalties during sampling |
| 192 | + client.last_tokens.erase(client.last_tokens.begin()); |
| 193 | + client.last_tokens.push_back(id); |
| 194 | + |
| 195 | + const std::string token_str = llama_token_to_piece(ctx, id); |
| 196 | + client.response += token_str; |
| 197 | + client.sampled = id; |
| 198 | + |
| 199 | + //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", |
| 200 | + // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); |
| 201 | + |
| 202 | + if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) { |
| 203 | + const size_t pos = client.response.find("User:"); |
| 204 | + if (pos != std::string::npos) { |
| 205 | + client.response = client.response.substr(0, pos); |
| 206 | + } |
| 207 | + |
| 208 | + llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx); |
| 209 | + |
| 210 | + const auto t_main_end = ggml_time_us(); |
| 211 | + |
| 212 | + n_tokens_total += client.n_decoded - client.n_prompt; |
| 213 | + |
| 214 | + printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n", |
| 215 | + client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, |
| 216 | + (double) n_tokens_total / (t_main_end - t_main_start) * 1e6, |
| 217 | + client.input.c_str(), ::trim(client.response).c_str()); |
| 218 | + |
| 219 | + client.seq_id = -1; |
| 220 | + } |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + static bool is_first = true; |
| 225 | + if (is_first) { |
| 226 | + t_main_start = ggml_time_us(); |
| 227 | + n_tokens_total = 0; |
| 228 | + is_first = false; |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + LOG_TEE("\n\n"); |
| 233 | + |
| 234 | + llama_print_timings(ctx); |
| 235 | + |
| 236 | + llama_free(ctx); |
| 237 | + llama_free_model(model); |
| 238 | + |
| 239 | + llama_backend_free(); |
| 240 | + |
| 241 | + fprintf(stderr, "\n\n"); |
| 242 | + |
| 243 | + return 0; |
| 244 | +} |
0 commit comments