Skip to content

Commit f30ea47

Browse files
slarencompiladeggerganov
authored
llama : add pipeline parallelism support (ggml-org#6017)
* llama : add pipeline parallelism support for batch processing with multiple CUDA GPUs ggml-ci * server : add -ub, --ubatch-size parameter * fix server embedding test * llama : fix Mamba inference for pipeline parallelism Tested to work correctly with both `main` and `parallel` examples. * llama : limit max batch size to n_batch * add LLAMA_SCHED_MAX_COPIES to configure the number of input copies for pipeline parallelism default increase to 4 (from 2) changing this value may improve performance for some systems, but increases memory usage * fix hip build * fix sycl build (disable cpy_tensor_async) * fix hip build * llama : limit n_batch and n_ubatch to n_ctx during context creation * llama : fix norm backend * batched-bench : sync after decode * swiftui : sync after decode * ggml : allow ggml_get_rows to use multiple threads if they are available * check n_ubatch >= n_tokens with non-casual attention * llama : do not limit n_batch to n_ctx with non-casual attn * server : construct batch with size of llama_n_batch * ggml_backend_cpu_graph_compute : fix return value when alloc fails * llama : better n_batch and n_ubatch comment * fix merge * small fix * reduce default n_batch to 2048 --------- Co-authored-by: Francis Couture-Harpin <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent d8fd0cc commit f30ea47

25 files changed

+1426
-846
lines changed

CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ option(LLAMA_SYCL "llama: use SYCL"
118118
option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl calculations" OFF)
119119
set(LLAMA_SYCL_TARGET "INTEL" CACHE STRING "llama: sycl target device")
120120
option(LLAMA_CPU_HBM "llama: use memkind for CPU HBM" OFF)
121+
set(LLAMA_SCHED_MAX_COPIES "4" CACHE STRING "llama: max input copies for pipeline parallelism")
121122

122123
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
123124
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -147,6 +148,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
147148
find_package(Threads REQUIRED)
148149
include(CheckCXXCompilerFlag)
149150

151+
add_compile_definitions(GGML_SCHED_MAX_COPIES=${LLAMA_SCHED_MAX_COPIES})
152+
150153
# enable libstdc++ assertions for debug builds
151154
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
152155
add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ ifeq ($(UNAME_S),OpenBSD)
167167
MK_CPPFLAGS += -D_BSD_SOURCE
168168
endif
169169

170+
ifdef LLAMA_SCHED_MAX_COPIES
171+
MK_CPPFLAGS += -DGGML_SCHED_MAX_COPIES=$(LLAMA_SCHED_MAX_COPIES)
172+
endif
173+
170174
ifdef LLAMA_DEBUG
171175
MK_CFLAGS += -O0 -g
172176
MK_CXXFLAGS += -O0 -g

common/common.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
483483
break;
484484
}
485485
params.n_batch = std::stoi(argv[i]);
486+
} else if (arg == "-ub" || arg == "--ubatch-size") {
487+
if (++i >= argc) {
488+
invalid_param = true;
489+
break;
490+
}
491+
params.n_ubatch = std::stoi(argv[i]);
486492
} else if (arg == "--keep") {
487493
if (++i >= argc) {
488494
invalid_param = true;
@@ -977,7 +983,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
977983
printf(" binary file containing multiple choice tasks.\n");
978984
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
979985
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
980-
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
986+
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
987+
printf(" -ub N, --ubatch-size N\n");
988+
printf(" physical maximum batch size (default: %d)\n", params.n_ubatch);
981989
printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n");
982990
printf(" (default: %s)\n", sampler_type_names.c_str());
983991
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
@@ -1287,8 +1295,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12871295
auto cparams = llama_context_default_params();
12881296

12891297
cparams.n_ctx = params.n_ctx;
1290-
cparams.n_batch = params.n_batch;
12911298
cparams.n_seq_max = params.n_parallel;
1299+
cparams.n_batch = params.n_batch;
1300+
cparams.n_ubatch = params.n_ubatch;
12921301
cparams.n_threads = params.n_threads;
12931302
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
12941303
cparams.seed = params.seed;
@@ -1379,6 +1388,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
13791388
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
13801389
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
13811390
llama_kv_cache_clear(lctx);
1391+
llama_synchronize(lctx);
13821392
llama_reset_timings(lctx);
13831393
}
13841394

common/common.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct gpt_params {
5151
int32_t n_threads_batch_draft = -1;
5252
int32_t n_predict = -1; // new tokens to predict
5353
int32_t n_ctx = 512; // context size
54-
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
54+
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
55+
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
5556
int32_t n_keep = 0; // number of tokens to keep from initial prompt
5657
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
5758
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)

examples/batched-bench/batched-bench.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ int main(int argc, char ** argv) {
138138
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
139139
return false;
140140
}
141+
142+
llama_synchronize(ctx);
141143
}
142144

143145
return true;

examples/embedding/embedding.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
107107

108108
// max batch size
109109
const uint64_t n_batch = params.n_batch;
110-
GGML_ASSERT(params.n_batch == params.n_ctx);
110+
GGML_ASSERT(params.n_batch >= params.n_ctx);
111111

112112
// tokenize the prompts and trim
113113
std::vector<std::vector<int32_t>> inputs;

examples/llama-bench/llama-bench.cpp

+43-10
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ struct cmd_params {
164164
std::vector<int> n_prompt;
165165
std::vector<int> n_gen;
166166
std::vector<int> n_batch;
167+
std::vector<int> n_ubatch;
167168
std::vector<ggml_type> type_k;
168169
std::vector<ggml_type> type_v;
169170
std::vector<int> n_threads;
@@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
183184
/* model */ {"models/7B/ggml-model-q4_0.gguf"},
184185
/* n_prompt */ {512},
185186
/* n_gen */ {128},
186-
/* n_batch */ {512},
187+
/* n_batch */ {2048},
188+
/* n_ubatch */ {512},
187189
/* type_k */ {GGML_TYPE_F16},
188190
/* type_v */ {GGML_TYPE_F16},
189191
/* n_threads */ {get_num_physical_cores()},
@@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
208210
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
209211
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
210212
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
213+
printf(" -ub N, --ubatch-size <n> (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
211214
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
212215
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
213216
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
@@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
217220
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
218221
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
219222
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
220-
printf(" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n");
223+
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
221224
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
222225
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
223226
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
@@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
297300
}
298301
auto p = split<int>(argv[i], split_delim);
299302
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
303+
} else if (arg == "-ub" || arg == "--ubatch-size") {
304+
if (++i >= argc) {
305+
invalid_param = true;
306+
break;
307+
}
308+
auto p = split<int>(argv[i], split_delim);
309+
params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
300310
} else if (arg == "-ctk" || arg == "--cache-type-k") {
301311
if (++i >= argc) {
302312
invalid_param = true;
@@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
455465
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
456466
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
457467
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
468+
if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; }
458469
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
459470
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
460471
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
@@ -474,6 +485,7 @@ struct cmd_params_instance {
474485
int n_prompt;
475486
int n_gen;
476487
int n_batch;
488+
int n_ubatch;
477489
ggml_type type_k;
478490
ggml_type type_v;
479491
int n_threads;
@@ -511,6 +523,7 @@ struct cmd_params_instance {
511523

512524
cparams.n_ctx = n_prompt + n_gen;
513525
cparams.n_batch = n_batch;
526+
cparams.n_ubatch = n_ubatch;
514527
cparams.type_k = type_k;
515528
cparams.type_v = type_v;
516529
cparams.offload_kqv = !no_kv_offload;
@@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
532545
for (const auto & mmp : params.use_mmap)
533546
for (const auto & embd : params.embeddings)
534547
for (const auto & nb : params.n_batch)
548+
for (const auto & nub : params.n_ubatch)
535549
for (const auto & tk : params.type_k)
536550
for (const auto & tv : params.type_v)
537551
for (const auto & nkvo : params.no_kv_offload)
@@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
545559
/* .n_prompt = */ n_prompt,
546560
/* .n_gen = */ 0,
547561
/* .n_batch = */ nb,
562+
/* .n_ubatch = */ nub,
548563
/* .type_k = */ tk,
549564
/* .type_v = */ tv,
550565
/* .n_threads = */ nt,
@@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
568583
/* .n_prompt = */ 0,
569584
/* .n_gen = */ n_gen,
570585
/* .n_batch = */ nb,
586+
/* .n_ubatch = */ nub,
571587
/* .type_k = */ tk,
572588
/* .type_v = */ tv,
573589
/* .n_threads = */ nt,
@@ -604,6 +620,7 @@ struct test {
604620
uint64_t model_size;
605621
uint64_t model_n_params;
606622
int n_batch;
623+
int n_ubatch;
607624
int n_threads;
608625
ggml_type type_k;
609626
ggml_type type_v;
@@ -627,6 +644,7 @@ struct test {
627644
model_size = llama_model_size(lmodel);
628645
model_n_params = llama_model_n_params(lmodel);
629646
n_batch = inst.n_batch;
647+
n_ubatch = inst.n_ubatch;
630648
n_threads = inst.n_threads;
631649
type_k = inst.type_k;
632650
type_v = inst.type_v;
@@ -705,7 +723,8 @@ struct test {
705723
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas",
706724
"cpu_info", "gpu_info",
707725
"model_filename", "model_type", "model_size", "model_n_params",
708-
"n_batch", "n_threads", "type_k", "type_v",
726+
"n_batch", "n_ubatch",
727+
"n_threads", "type_k", "type_v",
709728
"n_gpu_layers", "split_mode",
710729
"main_gpu", "no_kv_offload",
711730
"tensor_split", "use_mmap", "embeddings",
@@ -719,7 +738,8 @@ struct test {
719738
enum field_type {STRING, BOOL, INT, FLOAT};
720739

721740
static field_type get_field_type(const std::string & field) {
722-
if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
741+
if (field == "build_number" || field == "n_batch" || field == "n_ubatch" ||
742+
field == "n_threads" ||
723743
field == "model_size" || field == "model_n_params" ||
724744
field == "n_gpu_layers" || field == "main_gpu" ||
725745
field == "n_prompt" || field == "n_gen" ||
@@ -759,7 +779,8 @@ struct test {
759779
std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas),
760780
cpu_info, gpu_info,
761781
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
762-
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
782+
std::to_string(n_batch), std::to_string(n_ubatch),
783+
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
763784
std::to_string(n_gpu_layers), split_mode_str(split_mode),
764785
std::to_string(main_gpu), std::to_string(no_kv_offload),
765786
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
@@ -957,6 +978,9 @@ struct markdown_printer : public printer {
957978
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
958979
fields.emplace_back("n_batch");
959980
}
981+
if (params.n_ubatch.size() > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch) {
982+
fields.emplace_back("n_ubatch");
983+
}
960984
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
961985
fields.emplace_back("type_k");
962986
}
@@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
10961120
};
10971121

10981122
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1123+
llama_set_n_threads(ctx, n_threads, n_threads);
1124+
1125+
//std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
1126+
//llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
1127+
//GGML_UNUSED(n_batch);
1128+
10991129
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
11001130
int n_processed = 0;
11011131

1102-
llama_set_n_threads(ctx, n_threads, n_threads);
1103-
11041132
while (n_processed < n_prompt) {
11051133
int n_tokens = std::min(n_prompt - n_processed, n_batch);
11061134
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
11071135
n_processed += n_tokens;
11081136
}
1137+
1138+
llama_synchronize(ctx);
11091139
}
11101140

11111141
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
1112-
llama_token token = llama_token_bos(llama_get_model(ctx));
1113-
11141142
llama_set_n_threads(ctx, n_threads, n_threads);
11151143

1144+
llama_token token = llama_token_bos(llama_get_model(ctx));
1145+
11161146
for (int i = 0; i < n_gen; i++) {
11171147
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
1148+
llama_synchronize(ctx);
11181149
}
11191150
}
11201151

@@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {
12031234

12041235
// warmup run
12051236
if (t.n_prompt > 0) {
1206-
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
1237+
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1238+
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
12071239
}
12081240
if (t.n_gen > 0) {
12091241
test_gen(ctx, 1, 0, t.n_threads);
@@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
12191251
if (t.n_gen > 0) {
12201252
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
12211253
}
1254+
12221255
uint64_t t_ns = get_time_ns() - t_start;
12231256
t.samples_ns.push_back(t_ns);
12241257
}

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+2
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ actor LlamaContext {
221221
if llama_decode(context, batch) != 0 {
222222
print("llama_decode() failed during prompt")
223223
}
224+
llama_synchronize(context)
224225

225226
let t_pp_end = ggml_time_us()
226227

@@ -240,6 +241,7 @@ actor LlamaContext {
240241
if llama_decode(context, batch) != 0 {
241242
print("llama_decode() failed during text generation")
242243
}
244+
llama_synchronize(context)
243245
}
244246

245247
let t_tg_end = ggml_time_us()

examples/perplexity/perplexity.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
589589
}
590590
}
591591

592-
const auto t_end = std::chrono::high_resolution_clock::now();
593592

594593
if (i == 0) {
594+
llama_synchronize(ctx);
595+
const auto t_end = std::chrono::high_resolution_clock::now();
595596
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
596597
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
597598
int total_seconds = (int)(t_total*n_chunk/n_seq);

0 commit comments

Comments
 (0)