Skip to content

Commit 871c241

Browse files
kalomazeggerganovcebtenzzre
authored andcommitted
samplers : Min-P sampler implementation [alternative to Top P/Top K] (ggml-org#3841)
* Introduce the new Min-P sampler by @kalomaze The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. * Min-P enabled and set to 0.05 default --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: cebtenzzre <[email protected]>
1 parent a94264a commit 871c241

File tree

6 files changed

+54
-2
lines changed

6 files changed

+54
-2
lines changed

Diff for: common/common.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
218218
break;
219219
}
220220
sparams.top_p = std::stof(argv[i]);
221+
} else if (arg == "--min-p") {
222+
if (++i >= argc) {
223+
invalid_param = true;
224+
break;
225+
}
226+
sparams.min_p = std::stof(argv[i]);
221227
} else if (arg == "--temp") {
222228
if (++i >= argc) {
223229
invalid_param = true;
@@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
679685
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
680686
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
681687
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
688+
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
682689
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
683690
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
684691
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
@@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
12751282
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
12761283
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
12771284
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
1285+
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
12781286
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
12791287
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
12801288
}

Diff for: common/sampling.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
8989

9090
snprintf(result, sizeof(result),
9191
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
92-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
92+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
9393
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
9494
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
95-
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
95+
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
9696
params.mirostat, params.mirostat_eta, params.mirostat_tau);
9797

9898
return std::string(result);
@@ -110,6 +110,7 @@ llama_token llama_sampling_sample(
110110
const float temp = params.temp;
111111
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
112112
const float top_p = params.top_p;
113+
const float min_p = params.min_p;
113114
const float tfs_z = params.tfs_z;
114115
const float typical_p = params.typical_p;
115116
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
@@ -190,6 +191,7 @@ llama_token llama_sampling_sample(
190191
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
191192
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
192193
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
194+
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
193195
llama_sample_temp (ctx_main, &cur_p, temp);
194196

195197
id = llama_sample_token(ctx_main, &cur_p);

Diff for: common/sampling.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ typedef struct llama_sampling_params {
1414
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
1515
int32_t top_k = 40; // <= 0 to use vocab size
1616
float top_p = 0.95f; // 1.0 = disabled
17+
float min_p = 0.05f; // 0.0 = disabled
1718
float tfs_z = 1.00f; // 1.0 = disabled
1819
float typical_p = 1.00f; // 1.0 = disabled
1920
float temp = 0.80f; // 1.0 = disabled

Diff for: examples/main/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ Top-p sampling, also known as nucleus sampling, is another text generation metho
208208

209209
Example usage: `--top-p 0.95`
210210

211+
### Min P Sampling
212+
213+
- `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.05).
214+
215+
The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.
216+
217+
Example usage: `--min-p 0.05`
218+
211219
### Tail Free Sampling (TFS)
212220

213221
- `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled).

Diff for: llama.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -7368,6 +7368,32 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
73687368
}
73697369
}
73707370

7371+
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
7372+
if (p <= 0.0f || !candidates->size) {
7373+
return;
7374+
}
7375+
7376+
llama_sample_softmax(ctx, candidates);
7377+
7378+
const int64_t t_start_sample_us = ggml_time_us();
7379+
7380+
float scale = candidates->data[0].p; // scale by max prob
7381+
size_t i = 1; // first token always matches
7382+
7383+
for (; i < candidates->size; ++i) {
7384+
if (candidates->data[i].p < p * scale && i >= min_keep) {
7385+
break; // prob too small
7386+
}
7387+
}
7388+
7389+
// Resize the output vector to keep only the matching tokens
7390+
candidates->size = i;
7391+
7392+
if (ctx) {
7393+
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
7394+
}
7395+
}
7396+
73717397
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
73727398
if (z >= 1.0f || candidates->size <= 2) {
73737399
return;

Diff for: llama.h

+7
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,13 @@ extern "C" {
598598
float p,
599599
size_t min_keep);
600600

601+
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
602+
LLAMA_API void llama_sample_min_p(
603+
struct llama_context * ctx,
604+
llama_token_data_array * candidates,
605+
float p,
606+
size_t min_keep);
607+
601608
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
602609
LLAMA_API void llama_sample_tail_free(
603610
struct llama_context * ctx,

0 commit comments

Comments
 (0)