@@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
672
672
if (params.logdir .back () != DIRECTORY_SEPARATOR) {
673
673
params.logdir += DIRECTORY_SEPARATOR;
674
674
}
675
+ } else if (arg == " --save-all-logits" || arg == " --kl-divergence-base" ) {
676
+ if (++i >= argc) {
677
+ invalid_param = true ;
678
+ break ;
679
+ }
680
+ params.logits_file = argv[i];
675
681
} else if (arg == " --perplexity" || arg == " --all-logits" ) {
676
682
params.logits_all = true ;
677
683
} else if (arg == " --ppl-stride" ) {
@@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
716
722
break ;
717
723
}
718
724
params.multiple_choice_tasks = std::stoi (argv[i]);
725
+ } else if (arg == " --kl-divergence" ) {
726
+ params.kl_divergence = true ;
719
727
} else if (arg == " --ignore-eos" ) {
720
728
params.ignore_eos = true ;
721
729
} else if (arg == " --no-penalize-nl" ) {
@@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
967
975
printf (" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n " , params.winogrande_tasks );
968
976
printf (" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n " );
969
977
printf (" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n " , params.winogrande_tasks );
978
+ printf (" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base" );
970
979
printf (" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
971
980
printf (" --draft N number of tokens to draft for speculative decoding (default: %d)\n " , params.n_draft );
972
981
printf (" --chunks N max number of chunks to process (default: %d, -1 = all)\n " , params.n_chunks );
0 commit comments