Skip to content

Commit 6f9939d

Browse files
ikawrakowKawrakow
andauthored
KL-divergence (#5076)
* kl-divergence: be able to save all logits to a file * Add ability to compute KL-divergence --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 780e24a commit 6f9939d

File tree

3 files changed

+329
-2
lines changed

3 files changed

+329
-2
lines changed

common/common.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
672672
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
673673
params.logdir += DIRECTORY_SEPARATOR;
674674
}
675+
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
676+
if (++i >= argc) {
677+
invalid_param = true;
678+
break;
679+
}
680+
params.logits_file = argv[i];
675681
} else if (arg == "--perplexity" || arg == "--all-logits") {
676682
params.logits_all = true;
677683
} else if (arg == "--ppl-stride") {
@@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
716722
break;
717723
}
718724
params.multiple_choice_tasks = std::stoi(argv[i]);
725+
} else if (arg == "--kl-divergence") {
726+
params.kl_divergence = true;
719727
} else if (arg == "--ignore-eos") {
720728
params.ignore_eos = true;
721729
} else if (arg == "--no-penalize-nl") {
@@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
967975
printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
968976
printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n");
969977
printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
978+
printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base");
970979
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
971980
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
972981
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);

common/common.h

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ struct gpt_params {
9191
std::string input_suffix = ""; // string to suffix user inputs with
9292
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
9393
std::string logdir = ""; // directory in which to save YAML log files
94+
std::string logits_file = ""; // file for saving *all* logits
9495

9596
std::vector<llama_model_kv_override> kv_overrides;
9697

@@ -111,6 +112,8 @@ struct gpt_params {
111112
bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
112113
size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
113114

115+
bool kl_divergence = false; // compute KL-divergence
116+
114117
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
115118
bool random_prompt = false; // do not randomize prompt if none provided
116119
bool use_color = false; // use color to distinguish generations and inputs

0 commit comments

Comments
 (0)