Skip to content

Commit e34dcc4

Browse files
kalomazeggerganov
authored andcommitted
grammar : check the full vocab only if necessary (opt) (ggml-org#4306)
* Check the full vocab for grammar only if necessary * Fix missing logit restoration step (?) Does this matter, actually? * Fix whitespace / formatting * Adjust comment * Didn't mean to push test gbnf * Split sampling into the helper function (?) And also revert the changes made to the header * common : fix final newline --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 2ea986c commit e34dcc4

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

common/sampling.cpp

+45-3
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ static void sampler_queue(
149149
}
150150
}
151151

152-
llama_token llama_sampling_sample(
152+
static llama_token llama_sampling_sample_impl(
153153
struct llama_sampling_context * ctx_sampling,
154154
struct llama_context * ctx_main,
155155
struct llama_context * ctx_cfg,
156-
const int idx) {
156+
const int idx,
157+
bool is_resampling) { // Add a parameter to indicate if we are resampling
157158
const llama_sampling_params & params = ctx_sampling->params;
158159

159160
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
173174

174175
llama_token id = 0;
175176

177+
// Get a pointer to the logits
176178
float * logits = llama_get_logits_ith(ctx_main, idx);
177179

180+
// Declare original_logits at the beginning of the function scope
181+
std::vector<float> original_logits;
182+
183+
if (!is_resampling) {
184+
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
185+
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
186+
}
187+
178188
// apply params.logit_bias map
179189
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
180190
logits[it->first] += it->second;
@@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
210220
}
211221
}
212222

213-
if (ctx_sampling->grammar != NULL) {
223+
// If we are in the resampling phase, apply grammar checks before sampling logic
224+
if (is_resampling && ctx_sampling->grammar != NULL) {
214225
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
215226
}
216227

@@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
252263
}
253264
}
254265

266+
if (ctx_sampling->grammar != NULL && !is_resampling) {
267+
// Create an array with a single token data element for the sampled id
268+
llama_token_data single_token_data = {id, logits[id], 0.0f};
269+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
270+
271+
// Apply grammar constraints to the single token
272+
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
273+
274+
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
275+
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
276+
277+
// If the token is not valid according to the grammar, perform resampling
278+
if (!is_valid) {
279+
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
280+
281+
// Restore logits from the copy
282+
std::copy(original_logits.begin(), original_logits.end(), logits);
283+
284+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
285+
}
286+
}
287+
255288
return id;
256289
}
257290

291+
llama_token llama_sampling_sample(
292+
struct llama_sampling_context * ctx_sampling,
293+
struct llama_context * ctx_main,
294+
struct llama_context * ctx_cfg,
295+
const int idx) {
296+
// Call the implementation function with is_resampling set to false by default
297+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
298+
}
299+
258300
void llama_sampling_accept(
259301
struct llama_sampling_context * ctx_sampling,
260302
struct llama_context * ctx_main,

0 commit comments

Comments
 (0)