Skip to content

Commit c25105a

Browse files
mscheong01jordankanter
authored andcommitted
speculative : implement stochastic speculative sampling (ggml-org#5625)
* (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix ggml-org#5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README
1 parent eebb866 commit c25105a

File tree

6 files changed

+256
-57
lines changed

6 files changed

+256
-57
lines changed

common/common.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
513513
break;
514514
}
515515
params.n_sequences = std::stoi(argv[i]);
516-
} else if (arg == "--p-accept" || arg == "-pa") {
517-
if (++i >= argc) {
518-
invalid_param = true;
519-
break;
520-
}
521-
params.p_accept = std::stof(argv[i]);
522516
} else if (arg == "--p-split" || arg == "-ps") {
523517
if (++i >= argc) {
524518
invalid_param = true;
@@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10441038
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
10451039
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
10461040
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
1047-
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
10481041
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
10491042
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
10501043
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");

common/common.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,10 @@ struct gpt_params {
5353
int32_t n_ctx = 512; // context size
5454
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
5555
int32_t n_keep = 0; // number of tokens to keep from initial prompt
56-
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
56+
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
5757
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
5858
int32_t n_parallel = 1; // number of parallel sequences to decode
5959
int32_t n_sequences = 1; // number of sequences to decode
60-
float p_accept = 0.5f; // speculative decoding accept probability
6160
float p_split = 0.1f; // speculative decoding split probability
6261
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
6362
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

common/sampling.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
295295
return id;
296296
}
297297

298+
static llama_token_data_array llama_sample_probability_distribution_impl(
299+
struct llama_sampling_context * ctx_sampling,
300+
struct llama_context * ctx_main,
301+
struct llama_context * ctx_cfg,
302+
const int idx) {
303+
const llama_sampling_params & params = ctx_sampling->params;
304+
305+
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
306+
307+
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
308+
const float penalty_repeat = params.penalty_repeat;
309+
const float penalty_freq = params.penalty_freq;
310+
const float penalty_present = params.penalty_present;
311+
const bool penalize_nl = params.penalize_nl;
312+
313+
auto & prev = ctx_sampling->prev;
314+
auto & cur = ctx_sampling->cur;
315+
316+
// Get a pointer to the logits
317+
float * logits = llama_get_logits_ith(ctx_main, idx);
318+
319+
// Declare original_logits at the beginning of the function scope
320+
std::vector<float> original_logits;
321+
322+
// apply params.logit_bias map
323+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
324+
logits[it->first] += it->second;
325+
}
326+
327+
if (ctx_cfg) {
328+
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
329+
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
330+
}
331+
332+
cur.clear();
333+
334+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
335+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
336+
}
337+
338+
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
339+
340+
// apply penalties
341+
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
342+
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
343+
if (penalty_tokens_used_size) {
344+
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
345+
346+
llama_sample_repetition_penalties(ctx_main, &cur_p,
347+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
348+
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
349+
350+
if (!penalize_nl) {
351+
for (size_t idx = 0; idx < cur_p.size; idx++) {
352+
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
353+
cur_p.data[idx].logit = nl_logit;
354+
break;
355+
}
356+
}
357+
}
358+
}
359+
360+
// apply grammar checks
361+
if (ctx_sampling->grammar != NULL) {
362+
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
363+
}
364+
365+
llama_sample_softmax(ctx_main, &cur_p);
366+
return cur_p;
367+
}
368+
298369
llama_token llama_sampling_sample(
299370
struct llama_sampling_context * ctx_sampling,
300371
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
304375
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
305376
}
306377

378+
llama_token_data_array llama_sampling_probability_distribution(
379+
struct llama_sampling_context * ctx_sampling,
380+
struct llama_context * ctx_main,
381+
struct llama_context * ctx_cfg,
382+
const int idx) {
383+
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
384+
}
385+
307386
void llama_sampling_accept(
308387
struct llama_sampling_context * ctx_sampling,
309388
struct llama_context * ctx_main,

common/sampling.h

+7
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
131131
struct llama_context * ctx_cfg,
132132
int idx = 0);
133133

134+
// returns the probability that token of given id will be sampled
135+
llama_token_data_array llama_sampling_probability_distribution(
136+
struct llama_sampling_context * ctx_sampling,
137+
struct llama_context * ctx_main,
138+
struct llama_context * ctx_cfg,
139+
int idx = 0);
140+
134141
void llama_sampling_accept(
135142
struct llama_sampling_context * ctx_sampling,
136143
struct llama_context * ctx_main,

examples/speculative/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ More info:
66

77
- https://github.com/ggerganov/llama.cpp/pull/2926
88
- https://github.com/ggerganov/llama.cpp/pull/3624
9+
- https://github.com/ggerganov/llama.cpp/pull/5625

0 commit comments

Comments
 (0)