Skip to content

Commit bcc0eb4

Browse files
ggerganovslaren
andauthored
llama : per-layer KV cache + quantum K cache (#4309)
* per-layer KV * remove unnecessary copies * less code duplication, offload k and v separately * llama : offload KV cache per-layer * llama : offload K shift tensors * llama : offload for rest of the model arches * llama : enable offload debug temporarily * llama : keep the KV related layers on the device * llama : remove mirrors, perform Device -> Host when partial offload * common : add command-line arg to disable KV cache offloading * llama : update session save/load * llama : support quantum K cache (#4312) * llama : support quantum K cache (wip) * metal : add F32 -> Q8_0 copy kernel * cuda : add F32 -> Q8_0 copy kernel ggml-ci * cuda : use mmv kernel for quantum cache ops * llama : pass KV cache type through API * llama : fix build ggml-ci * metal : add F32 -> Q4_0 copy kernel * metal : add F32 -> Q4_1 copy kernel * cuda : wip * cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels * llama-bench : support type_k/type_v * metal : use mm kernel only for quantum KV cache * cuda : add comment * llama : remove memory_f16 and kv_f16 flags --------- Co-authored-by: slaren <[email protected]> * readme : add API change notice --------- Co-authored-by: slaren <[email protected]>
1 parent 81bc921 commit bcc0eb4

File tree

11 files changed

+747
-287
lines changed

11 files changed

+747
-287
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
1010

1111
### Hot topics
1212

13+
- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
1314
- Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
1415
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
1516
- Collecting Apple Silicon performance stats: https://github.com/ggerganov/llama.cpp/discussions/4167

common/common.cpp

+39-6
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
278278
break;
279279
}
280280
params.yarn_beta_slow = std::stof(argv[i]);
281-
} else if (arg == "--memory-f32") {
282-
params.memory_f16 = false;
283281
} else if (arg == "--samplers") {
284282
if (++i >= argc) {
285283
invalid_param = true;
@@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
510508
params.infill = true;
511509
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
512510
params.dump_kv_cache = true;
511+
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
512+
params.no_kv_offload = true;
513+
} else if (arg == "-ctk" || arg == "--cache-type-k") {
514+
params.cache_type_k = argv[++i];
515+
} else if (arg == "-ctv" || arg == "--cache-type-v") {
516+
params.cache_type_v = argv[++i];
513517
} else if (arg == "--multiline-input") {
514518
params.multiline_input = true;
515519
} else if (arg == "--simple-io") {
@@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
858862
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
859863
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
860864
printf(" --no-penalize-nl do not penalize newline token\n");
861-
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
862-
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
863865
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
864866
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
865867
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
@@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
900902
printf(" --verbose-prompt print prompt before generation\n");
901903
printf(" -dkvc, --dump-kv-cache\n");
902904
printf(" verbose print of the KV cache\n");
905+
printf(" -nkvo, --no-kv-offload\n");
906+
printf(" disable KV offload\n");
907+
printf(" -ctk TYPE, --cache-type-k TYPE\n");
908+
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
909+
printf(" -ctv TYPE, --cache-type-v TYPE\n");
910+
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
903911
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
904912
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
905913
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
10151023
return mparams;
10161024
}
10171025

1026+
static ggml_type kv_cache_type_from_str(const std::string & s) {
1027+
if (s == "f16") {
1028+
return GGML_TYPE_F16;
1029+
}
1030+
if (s == "q8_0") {
1031+
return GGML_TYPE_Q8_0;
1032+
}
1033+
if (s == "q4_0") {
1034+
return GGML_TYPE_Q4_0;
1035+
}
1036+
if (s == "q4_1") {
1037+
return GGML_TYPE_Q4_1;
1038+
}
1039+
if (s == "q5_0") {
1040+
return GGML_TYPE_Q5_0;
1041+
}
1042+
if (s == "q5_1") {
1043+
return GGML_TYPE_Q5_1;
1044+
}
1045+
1046+
throw std::runtime_error("Invalid cache type: " + s);
1047+
}
1048+
10181049
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
10191050
auto cparams = llama_context_default_params();
10201051

@@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
10241055
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
10251056
cparams.mul_mat_q = params.mul_mat_q;
10261057
cparams.seed = params.seed;
1027-
cparams.f16_kv = params.memory_f16;
10281058
cparams.logits_all = params.logits_all;
10291059
cparams.embedding = params.embedding;
10301060
cparams.rope_scaling_type = params.rope_scaling_type;
@@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
10351065
cparams.yarn_beta_fast = params.yarn_beta_fast;
10361066
cparams.yarn_beta_slow = params.yarn_beta_slow;
10371067
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1068+
cparams.offload_kqv = !params.no_kv_offload;
1069+
1070+
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
1071+
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
10381072

10391073
return cparams;
10401074
}
@@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
14471481
}
14481482
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
14491483
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
1450-
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
14511484
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
14521485
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
14531486
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

common/common.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ struct gpt_params {
100100
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
101101

102102
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
103-
bool memory_f16 = true; // use f16 instead of f32 for memory kv
104103
bool random_prompt = false; // do not randomize prompt if none provided
105104
bool use_color = false; // use color to distinguish generations and inputs
106105
bool interactive = false; // interactive mode
@@ -125,10 +124,14 @@ struct gpt_params {
125124
bool verbose_prompt = false; // print prompt tokens before generation
126125
bool infill = false; // use infill mode
127126
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
127+
bool no_kv_offload = false; // disable KV offloading
128+
129+
std::string cache_type_k = "f16"; // KV cache data type for the K
130+
std::string cache_type_v = "f16"; // KV cache data type for the V
128131

129132
// multimodal models (see examples/llava)
130133
std::string mmproj = ""; // path to multimodal projector
131-
std::string image = ""; // path to an image file
134+
std::string image = ""; // path to an image file
132135
};
133136

134137
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

examples/llama-bench/llama-bench.cpp

+91-20
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ static std::vector<T> split(const std::string & str, char delim) {
5353
return values;
5454
}
5555

56+
template<typename T, typename F>
57+
static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
58+
std::vector<std::string> str_values;
59+
std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
60+
return str_values;
61+
}
62+
5663
template<typename T>
5764
static T avg(const std::vector<T> & v) {
5865
if (v.empty()) {
@@ -126,7 +133,8 @@ struct cmd_params {
126133
std::vector<int> n_prompt;
127134
std::vector<int> n_gen;
128135
std::vector<int> n_batch;
129-
std::vector<bool> f32_kv;
136+
std::vector<ggml_type> type_k;
137+
std::vector<ggml_type> type_v;
130138
std::vector<int> n_threads;
131139
std::vector<int> n_gpu_layers;
132140
std::vector<int> main_gpu;
@@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = {
142150
/* n_prompt */ {512},
143151
/* n_gen */ {128},
144152
/* n_batch */ {512},
145-
/* f32_kv */ {false},
153+
/* type_k */ {GGML_TYPE_F16},
154+
/* type_v */ {GGML_TYPE_F16},
146155
/* n_threads */ {get_num_physical_cores()},
147156
/* n_gpu_layers */ {99},
148157
/* main_gpu */ {0},
@@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) {
162171
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
163172
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
164173
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
165-
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
174+
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
175+
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
166176
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
167177
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
168178
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
@@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) {
173183
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
174184
printf("\n");
175185
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
186+
}
176187

188+
static ggml_type ggml_type_from_name(const std::string & s) {
189+
if (s == "f16") {
190+
return GGML_TYPE_F16;
191+
}
192+
if (s == "q8_0") {
193+
return GGML_TYPE_Q8_0;
194+
}
195+
if (s == "q4_0") {
196+
return GGML_TYPE_Q4_0;
197+
}
198+
if (s == "q4_1") {
199+
return GGML_TYPE_Q4_1;
200+
}
201+
if (s == "q5_0") {
202+
return GGML_TYPE_Q5_0;
203+
}
204+
if (s == "q5_1") {
205+
return GGML_TYPE_Q5_1;
206+
}
207+
208+
return GGML_TYPE_COUNT;
177209
}
178210

211+
179212
static cmd_params parse_cmd_params(int argc, char ** argv) {
180213
cmd_params params;
181214
std::string arg;
@@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
224257
}
225258
auto p = split<int>(argv[i], split_delim);
226259
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
227-
} else if (arg == "--memory-f32") {
260+
} else if (arg == "-ctk" || arg == "--cache-type-k") {
228261
if (++i >= argc) {
229262
invalid_param = true;
230263
break;
231264
}
232-
auto p = split<int>(argv[i], split_delim);
233-
params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end());
265+
auto p = split<std::string>(argv[i], split_delim);
266+
std::vector<ggml_type> types;
267+
for (const auto & t : p) {
268+
ggml_type gt = ggml_type_from_name(t);
269+
if (gt == GGML_TYPE_COUNT) {
270+
invalid_param = true;
271+
break;
272+
}
273+
types.push_back(gt);
274+
}
275+
params.type_k.insert(params.type_k.end(), types.begin(), types.end());
276+
} else if (arg == "-ctv" || arg == "--cache-type-v") {
277+
if (++i >= argc) {
278+
invalid_param = true;
279+
break;
280+
}
281+
auto p = split<std::string>(argv[i], split_delim);
282+
std::vector<ggml_type> types;
283+
for (const auto & t : p) {
284+
ggml_type gt = ggml_type_from_name(t);
285+
if (gt == GGML_TYPE_COUNT) {
286+
invalid_param = true;
287+
break;
288+
}
289+
types.push_back(gt);
290+
}
291+
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
234292
} else if (arg == "-t" || arg == "--threads") {
235293
if (++i >= argc) {
236294
invalid_param = true;
@@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
321379
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
322380
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
323381
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
324-
if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; }
382+
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
383+
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
325384
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
326385
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
327386
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
@@ -336,7 +395,8 @@ struct cmd_params_instance {
336395
int n_prompt;
337396
int n_gen;
338397
int n_batch;
339-
bool f32_kv;
398+
ggml_type type_k;
399+
ggml_type type_v;
340400
int n_threads;
341401
int n_gpu_layers;
342402
int main_gpu;
@@ -365,7 +425,8 @@ struct cmd_params_instance {
365425

366426
cparams.n_ctx = n_prompt + n_gen;
367427
cparams.n_batch = n_batch;
368-
cparams.f16_kv = !f32_kv;
428+
cparams.type_k = type_k;
429+
cparams.type_v = type_v;
369430
cparams.mul_mat_q = mul_mat_q;
370431

371432
return cparams;
@@ -380,15 +441,17 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
380441
for (const auto & mg : params.main_gpu)
381442
for (const auto & ts : params.tensor_split)
382443
for (const auto & nb : params.n_batch)
383-
for (const auto & fk : params.f32_kv)
444+
for (const auto & tk : params.type_k)
445+
for (const auto & tv : params.type_v)
384446
for (const auto & mmq : params.mul_mat_q)
385447
for (const auto & nt : params.n_threads) {
386448
cmd_params_instance instance = {
387449
/* .model = */ m,
388450
/* .n_prompt = */ n_prompt,
389451
/* .n_gen = */ n_gen,
390452
/* .n_batch = */ nb,
391-
/* .f32_kv = */ fk,
453+
/* .type_k = */ tk,
454+
/* .type_v = */ tv,
392455
/* .n_threads = */ nt,
393456
/* .n_gpu_layers = */ nl,
394457
/* .main_gpu = */ mg,
@@ -410,7 +473,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
410473
for (const auto & mg : params.main_gpu)
411474
for (const auto & ts : params.tensor_split)
412475
for (const auto & nb : params.n_batch)
413-
for (const auto & fk : params.f32_kv)
476+
for (const auto & tk : params.type_k)
477+
for (const auto & tv : params.type_v)
414478
for (const auto & mmq : params.mul_mat_q)
415479
for (const auto & nt : params.n_threads) {
416480
for (const auto & n_prompt : params.n_prompt) {
@@ -422,7 +486,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
422486
/* .n_prompt = */ n_prompt,
423487
/* .n_gen = */ 0,
424488
/* .n_batch = */ nb,
425-
/* .f32_kv = */ fk,
489+
/* .type_k = */ tk,
490+
/* .type_v = */ tv,
426491
/* .n_threads = */ nt,
427492
/* .n_gpu_layers = */ nl,
428493
/* .main_gpu = */ mg,
@@ -441,7 +506,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
441506
/* .n_prompt = */ 0,
442507
/* .n_gen = */ n_gen,
443508
/* .n_batch = */ nb,
444-
/* .f32_kv = */ fk,
509+
/* .type_k = */ tk,
510+
/* .type_v = */ tv,
445511
/* .n_threads = */ nt,
446512
/* .n_gpu_layers = */ nl,
447513
/* .main_gpu = */ mg,
@@ -489,7 +555,8 @@ struct test {
489555
uint64_t model_n_params;
490556
int n_batch;
491557
int n_threads;
492-
bool f32_kv;
558+
ggml_type type_k;
559+
ggml_type type_v;
493560
int n_gpu_layers;
494561
int main_gpu;
495562
bool mul_mat_q;
@@ -508,7 +575,8 @@ struct test {
508575
model_n_params = llama_model_n_params(lmodel);
509576
n_batch = inst.n_batch;
510577
n_threads = inst.n_threads;
511-
f32_kv = inst.f32_kv;
578+
type_k = inst.type_k;
579+
type_v = inst.type_v;
512580
n_gpu_layers = inst.n_gpu_layers;
513581
main_gpu = inst.main_gpu;
514582
mul_mat_q = inst.mul_mat_q;
@@ -571,7 +639,7 @@ struct test {
571639
"cuda", "opencl", "metal", "gpu_blas", "blas",
572640
"cpu_info", "gpu_info",
573641
"model_filename", "model_type", "model_size", "model_n_params",
574-
"n_batch", "n_threads", "f16_kv",
642+
"n_batch", "n_threads", "type_k", "type_v",
575643
"n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
576644
"n_prompt", "n_gen", "test_time",
577645
"avg_ns", "stddev_ns",
@@ -621,7 +689,7 @@ struct test {
621689
std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
622690
cpu_info, gpu_info,
623691
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
624-
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
692+
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
625693
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
626694
std::to_string(n_prompt), std::to_string(n_gen), test_time,
627695
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -805,8 +873,11 @@ struct markdown_printer : public printer {
805873
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
806874
fields.push_back("n_batch");
807875
}
808-
if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) {
809-
fields.push_back("f16_kv");
876+
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
877+
fields.push_back("type_k");
878+
}
879+
if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
880+
fields.push_back("type_v");
810881
}
811882
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
812883
fields.push_back("main_gpu");

examples/quantize-stats/quantize-stats.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ int main(int argc, char ** argv) {
321321
auto cparams = llama_context_default_params();
322322
cparams.n_ctx = 256;
323323
cparams.seed = 1;
324-
cparams.f16_kv = false;
325324

326325
ctx = llama_new_context_with_model(model, cparams);
327326

0 commit comments

Comments
 (0)