@@ -164,6 +164,7 @@ struct cmd_params {
164
164
std::vector<int > n_prompt;
165
165
std::vector<int > n_gen;
166
166
std::vector<int > n_batch;
167
+ std::vector<int > n_ubatch;
167
168
std::vector<ggml_type> type_k;
168
169
std::vector<ggml_type> type_v;
169
170
std::vector<int > n_threads;
@@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
183
184
/* model */ {" models/7B/ggml-model-q4_0.gguf" },
184
185
/* n_prompt */ {512 },
185
186
/* n_gen */ {128 },
186
- /* n_batch */ {512 },
187
+ /* n_batch */ {2048 },
188
+ /* n_ubatch */ {512 },
187
189
/* type_k */ {GGML_TYPE_F16},
188
190
/* type_v */ {GGML_TYPE_F16},
189
191
/* n_threads */ {get_num_physical_cores ()},
@@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
208
210
printf (" -p, --n-prompt <n> (default: %s)\n " , join (cmd_params_defaults.n_prompt , " ," ).c_str ());
209
211
printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
210
212
printf (" -b, --batch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_batch , " ," ).c_str ());
213
+ printf (" -ub N, --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
211
214
printf (" -ctk <t>, --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
212
215
printf (" -ctv <t>, --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
213
216
printf (" -t, --threads <n> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
@@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
217
220
printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
218
221
printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
219
222
printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
220
- printf (" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n " );
223
+ printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
221
224
printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
222
225
printf (" -o, --output <csv|json|md|sql> (default: %s)\n " , output_format_str (cmd_params_defaults.output_format ));
223
226
printf (" -v, --verbose (default: %s)\n " , cmd_params_defaults.verbose ? " 1" : " 0" );
@@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
297
300
}
298
301
auto p = split<int >(argv[i], split_delim);
299
302
params.n_batch .insert (params.n_batch .end (), p.begin (), p.end ());
303
+ } else if (arg == " -ub" || arg == " --ubatch-size" ) {
304
+ if (++i >= argc) {
305
+ invalid_param = true ;
306
+ break ;
307
+ }
308
+ auto p = split<int >(argv[i], split_delim);
309
+ params.n_ubatch .insert (params.n_ubatch .end (), p.begin (), p.end ());
300
310
} else if (arg == " -ctk" || arg == " --cache-type-k" ) {
301
311
if (++i >= argc) {
302
312
invalid_param = true ;
@@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
455
465
if (params.n_prompt .empty ()) { params.n_prompt = cmd_params_defaults.n_prompt ; }
456
466
if (params.n_gen .empty ()) { params.n_gen = cmd_params_defaults.n_gen ; }
457
467
if (params.n_batch .empty ()) { params.n_batch = cmd_params_defaults.n_batch ; }
468
+ if (params.n_ubatch .empty ()) { params.n_ubatch = cmd_params_defaults.n_ubatch ; }
458
469
if (params.type_k .empty ()) { params.type_k = cmd_params_defaults.type_k ; }
459
470
if (params.type_v .empty ()) { params.type_v = cmd_params_defaults.type_v ; }
460
471
if (params.n_gpu_layers .empty ()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers ; }
@@ -474,6 +485,7 @@ struct cmd_params_instance {
474
485
int n_prompt;
475
486
int n_gen;
476
487
int n_batch;
488
+ int n_ubatch;
477
489
ggml_type type_k;
478
490
ggml_type type_v;
479
491
int n_threads;
@@ -511,6 +523,7 @@ struct cmd_params_instance {
511
523
512
524
cparams.n_ctx = n_prompt + n_gen;
513
525
cparams.n_batch = n_batch;
526
+ cparams.n_ubatch = n_ubatch;
514
527
cparams.type_k = type_k;
515
528
cparams.type_v = type_v;
516
529
cparams.offload_kqv = !no_kv_offload;
@@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
532
545
for (const auto & mmp : params.use_mmap )
533
546
for (const auto & embd : params.embeddings )
534
547
for (const auto & nb : params.n_batch )
548
+ for (const auto & nub : params.n_ubatch )
535
549
for (const auto & tk : params.type_k )
536
550
for (const auto & tv : params.type_v )
537
551
for (const auto & nkvo : params.no_kv_offload )
@@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
545
559
/* .n_prompt = */ n_prompt,
546
560
/* .n_gen = */ 0 ,
547
561
/* .n_batch = */ nb,
562
+ /* .n_ubatch = */ nub,
548
563
/* .type_k = */ tk,
549
564
/* .type_v = */ tv,
550
565
/* .n_threads = */ nt,
@@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
568
583
/* .n_prompt = */ 0 ,
569
584
/* .n_gen = */ n_gen,
570
585
/* .n_batch = */ nb,
586
+ /* .n_ubatch = */ nub,
571
587
/* .type_k = */ tk,
572
588
/* .type_v = */ tv,
573
589
/* .n_threads = */ nt,
@@ -604,6 +620,7 @@ struct test {
604
620
uint64_t model_size;
605
621
uint64_t model_n_params;
606
622
int n_batch;
623
+ int n_ubatch;
607
624
int n_threads;
608
625
ggml_type type_k;
609
626
ggml_type type_v;
@@ -627,6 +644,7 @@ struct test {
627
644
model_size = llama_model_size (lmodel);
628
645
model_n_params = llama_model_n_params (lmodel);
629
646
n_batch = inst.n_batch ;
647
+ n_ubatch = inst.n_ubatch ;
630
648
n_threads = inst.n_threads ;
631
649
type_k = inst.type_k ;
632
650
type_v = inst.type_v ;
@@ -705,7 +723,8 @@ struct test {
705
723
" cuda" , " opencl" , " vulkan" , " kompute" , " metal" , " sycl" , " gpu_blas" , " blas" ,
706
724
" cpu_info" , " gpu_info" ,
707
725
" model_filename" , " model_type" , " model_size" , " model_n_params" ,
708
- " n_batch" , " n_threads" , " type_k" , " type_v" ,
726
+ " n_batch" , " n_ubatch" ,
727
+ " n_threads" , " type_k" , " type_v" ,
709
728
" n_gpu_layers" , " split_mode" ,
710
729
" main_gpu" , " no_kv_offload" ,
711
730
" tensor_split" , " use_mmap" , " embeddings" ,
@@ -719,7 +738,8 @@ struct test {
719
738
enum field_type {STRING, BOOL, INT, FLOAT};
720
739
721
740
static field_type get_field_type (const std::string & field) {
722
- if (field == " build_number" || field == " n_batch" || field == " n_threads" ||
741
+ if (field == " build_number" || field == " n_batch" || field == " n_ubatch" ||
742
+ field == " n_threads" ||
723
743
field == " model_size" || field == " model_n_params" ||
724
744
field == " n_gpu_layers" || field == " main_gpu" ||
725
745
field == " n_prompt" || field == " n_gen" ||
@@ -759,7 +779,8 @@ struct test {
759
779
std::to_string (metal), std::to_string (sycl), std::to_string (gpu_blas), std::to_string (blas),
760
780
cpu_info, gpu_info,
761
781
model_filename, model_type, std::to_string (model_size), std::to_string (model_n_params),
762
- std::to_string (n_batch), std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
782
+ std::to_string (n_batch), std::to_string (n_ubatch),
783
+ std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
763
784
std::to_string (n_gpu_layers), split_mode_str (split_mode),
764
785
std::to_string (main_gpu), std::to_string (no_kv_offload),
765
786
tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
@@ -957,6 +978,9 @@ struct markdown_printer : public printer {
957
978
if (params.n_batch .size () > 1 || params.n_batch != cmd_params_defaults.n_batch ) {
958
979
fields.emplace_back (" n_batch" );
959
980
}
981
+ if (params.n_ubatch .size () > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch ) {
982
+ fields.emplace_back (" n_ubatch" );
983
+ }
960
984
if (params.type_k .size () > 1 || params.type_k != cmd_params_defaults.type_k ) {
961
985
fields.emplace_back (" type_k" );
962
986
}
@@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
1096
1120
};
1097
1121
1098
1122
static void test_prompt (llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1123
+ llama_set_n_threads (ctx, n_threads, n_threads);
1124
+
1125
+ // std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
1126
+ // llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
1127
+ // GGML_UNUSED(n_batch);
1128
+
1099
1129
std::vector<llama_token> tokens (n_batch, llama_token_bos (llama_get_model (ctx)));
1100
1130
int n_processed = 0 ;
1101
1131
1102
- llama_set_n_threads (ctx, n_threads, n_threads);
1103
-
1104
1132
while (n_processed < n_prompt) {
1105
1133
int n_tokens = std::min (n_prompt - n_processed, n_batch);
1106
1134
llama_decode (ctx, llama_batch_get_one (tokens.data (), n_tokens, n_past + n_processed, 0 ));
1107
1135
n_processed += n_tokens;
1108
1136
}
1137
+
1138
+ llama_synchronize (ctx);
1109
1139
}
1110
1140
1111
1141
static void test_gen (llama_context * ctx, int n_gen, int n_past, int n_threads) {
1112
- llama_token token = llama_token_bos (llama_get_model (ctx));
1113
-
1114
1142
llama_set_n_threads (ctx, n_threads, n_threads);
1115
1143
1144
+ llama_token token = llama_token_bos (llama_get_model (ctx));
1145
+
1116
1146
for (int i = 0 ; i < n_gen; i++) {
1117
1147
llama_decode (ctx, llama_batch_get_one (&token, 1 , n_past + i, 0 ));
1148
+ llama_synchronize (ctx);
1118
1149
}
1119
1150
}
1120
1151
@@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {
1203
1234
1204
1235
// warmup run
1205
1236
if (t.n_prompt > 0 ) {
1206
- test_prompt (ctx, std::min (2 , t.n_batch ), 0 , t.n_batch , t.n_threads );
1237
+ // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1238
+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
1207
1239
}
1208
1240
if (t.n_gen > 0 ) {
1209
1241
test_gen (ctx, 1 , 0 , t.n_threads );
@@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
1219
1251
if (t.n_gen > 0 ) {
1220
1252
test_gen (ctx, t.n_gen , t.n_prompt , t.n_threads );
1221
1253
}
1254
+
1222
1255
uint64_t t_ns = get_time_ns () - t_start;
1223
1256
t.samples_ns .push_back (t_ns);
1224
1257
}
0 commit comments