@@ -149,11 +149,12 @@ static void sampler_queue(
149
149
}
150
150
}
151
151
152
- llama_token llama_sampling_sample (
152
+ static llama_token llama_sampling_sample_impl (
153
153
struct llama_sampling_context * ctx_sampling,
154
154
struct llama_context * ctx_main,
155
155
struct llama_context * ctx_cfg,
156
- const int idx) {
156
+ const int idx,
157
+ bool is_resampling) { // Add a parameter to indicate if we are resampling
157
158
const llama_sampling_params & params = ctx_sampling->params ;
158
159
159
160
const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
173
174
174
175
llama_token id = 0 ;
175
176
177
+ // Get a pointer to the logits
176
178
float * logits = llama_get_logits_ith (ctx_main, idx);
177
179
180
+ // Declare original_logits at the beginning of the function scope
181
+ std::vector<float > original_logits;
182
+
183
+ if (!is_resampling) {
184
+ // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
185
+ original_logits = std::vector<float >(logits, logits + llama_n_vocab (llama_get_model (ctx_main)));
186
+ }
187
+
178
188
// apply params.logit_bias map
179
189
for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
180
190
logits[it->first ] += it->second ;
@@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
210
220
}
211
221
}
212
222
213
- if (ctx_sampling->grammar != NULL ) {
223
+ // If we are in the resampling phase, apply grammar checks before sampling logic
224
+ if (is_resampling && ctx_sampling->grammar != NULL ) {
214
225
llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
215
226
}
216
227
@@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
252
263
}
253
264
}
254
265
266
+ if (ctx_sampling->grammar != NULL && !is_resampling) {
267
+ // Create an array with a single token data element for the sampled id
268
+ llama_token_data single_token_data = {id, logits[id], 0 .0f };
269
+ llama_token_data_array single_token_data_array = { &single_token_data, 1 , false };
270
+
271
+ // Apply grammar constraints to the single token
272
+ llama_sample_grammar (ctx_main, &single_token_data_array, ctx_sampling->grammar );
273
+
274
+ // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
275
+ bool is_valid = single_token_data_array.data [0 ].logit != -INFINITY;
276
+
277
+ // If the token is not valid according to the grammar, perform resampling
278
+ if (!is_valid) {
279
+ LOG (" Resampling because token %d: '%s' does not meet grammar rules\n " , id, llama_token_to_piece (ctx_main, id).c_str ());
280
+
281
+ // Restore logits from the copy
282
+ std::copy (original_logits.begin (), original_logits.end (), logits);
283
+
284
+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, true ); // Pass true for is_resampling
285
+ }
286
+ }
287
+
255
288
return id;
256
289
}
257
290
291
+ llama_token llama_sampling_sample (
292
+ struct llama_sampling_context * ctx_sampling,
293
+ struct llama_context * ctx_main,
294
+ struct llama_context * ctx_cfg,
295
+ const int idx) {
296
+ // Call the implementation function with is_resampling set to false by default
297
+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
298
+ }
299
+
258
300
void llama_sampling_accept (
259
301
struct llama_sampling_context * ctx_sampling,
260
302
struct llama_context * ctx_main,
0 commit comments