Skip to content
This repository was archived by the owner on Feb 6, 2024. It is now read-only.

Commit 638c0ff

Browse files
committed
Allow returning probs w/ greedy sampling (negative temp)
* ggml-org/llama.cpp#3813
1 parent b2f1d2c commit 638c0ff

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

Diff for: Sources/llmfarm_core_cpp/ggml/common.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
5252
params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
5353
} else if (arg == "--temp") {
5454
params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
55+
params.temp = std::max(params.temp, 0.0f);
5556
} else if (arg == "--repeat-last-n") {
5657
params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
5758
} else if (arg == "--repeat-penalty") {

Diff for: Sources/llmfarm_core_cpp/ggml/sampling.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,12 @@ llama_token llama_sampling_sample(
121121
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
122122
}
123123

124-
if (temp <= 0) {
125-
// Greedy sampling
124+
if (temp < 0.0) {
125+
// greedy sampling, with probs
126+
llama_sample_softmax(ctx_main, &cur_p);
127+
id = cur_p.data[0].id;
128+
} else if (temp == 0.0) {
129+
// greedy sampling, no probs
126130
id = llama_sample_token_greedy(ctx, &cur_p);
127131
} else {
128132
if (mirostat == 1) {

Diff for: Sources/llmfarm_core_cpp/spm-headers/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ extern "C" {
661661
float * mu);
662662

663663
/// @details Selects the token with the highest probability.
664+
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
664665
LLAMA_API llama_token llama_sample_token_greedy(
665666
struct llama_context * ctx,
666667
llama_token_data_array * candidates);

0 commit comments

Comments
 (0)