@@ -219,16 +219,31 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
219
219
candidates->size = last_idx;
220
220
}
221
221
222
- void apply_penalties (int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
222
+ void sample_rep_pen (int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
223
223
{
224
224
auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), rep_pen_range), n_ctx);
225
- llama_sample_repetition_penalty (nullptr , & candidates_p,
225
+ llama_sample_repetition_penalty (nullptr , candidates_p,
226
226
last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
227
227
last_n_repeat, rep_pen);
228
228
}
229
229
230
+ void sample_temperature (llama_token_data_array * candidates_p, float temp)
231
+ {
232
+ if (temp <= 0 )
233
+ {
234
+ // Imitate greedy sampling
235
+ temp = 0 .01f ; // cannot be zero else div0
236
+ llama_sample_temperature (nullptr , candidates_p, temp);
237
+ llama_sample_top_k (nullptr , candidates_p, 1 , 1 ); // only want first candidate
238
+ }
239
+ else
240
+ {
241
+ llama_sample_temperature (nullptr , candidates_p, temp);
242
+ }
243
+ }
244
+
230
245
int SampleLogits (const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
231
- int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX] )
246
+ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector< samplers> & sampler_order)
232
247
{
233
248
int id = 0 ;
234
249
std::vector<llama_token_data> candidates;
@@ -239,78 +254,54 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa
239
254
240
255
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
241
256
242
- // Run this except for when we are going to do the sampler reordering case below
243
- if (temp <= 0 || mirostat > 0 || sampler_len == 0 )
244
- {
245
- apply_penalties (n_ctx, rep_pen_range, rep_pen, candidates_p);
246
- }
247
-
248
- // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
249
- // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
250
- // last_n_repeat, alpha_frequency, alpha_presence);
251
-
252
- if (temp <= 0 )
253
- {
254
- // Greedy sampling
255
- id = llama_sample_token_greedy (nullptr , &candidates_p);
256
- }
257
- else
257
+ if (mirostat == 1 || mirostat == 2 )
258
258
{
259
+ static float mirostat_mu = 2 .0f * mirostat_tau;
260
+ const int mirostat_m = 100 ;
261
+ sample_rep_pen (n_ctx, rep_pen_range, rep_pen, &candidates_p);
262
+ sample_temperature (&candidates_p, temp);
259
263
if (mirostat == 1 )
260
264
{
261
- static float mirostat_mu = 2 .0f * mirostat_tau;
262
- const int mirostat_m = 100 ;
263
- llama_sample_temperature (nullptr , &candidates_p, temp);
264
265
id = sample_token_mirostat (n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
265
266
}
266
- else if (mirostat == 2 )
267
+ else
267
268
{
268
- static float mirostat_mu = 2 .0f * mirostat_tau;
269
- llama_sample_temperature (nullptr , &candidates_p, temp);
270
269
id = sample_token_mirostat_v2 (&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
271
270
}
272
- else if (sampler_len > 0 )
271
+ }
272
+ else
273
+ {
274
+ for (int i = 0 ; i < sampler_order.size (); i++)
273
275
{
274
- for ( int i = 0 ; i < sampler_len; i++) {
275
- switch (sampler_order[i]) {
276
- case KCPP_SAMPLER_TOP_K:
277
- llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
278
- break ;
279
- case KCPP_SAMPLER_TOP_A:
280
- sample_top_a (&candidates_p,top_a,1 );
281
- break ;
282
- case KCPP_SAMPLER_TOP_P:
283
- llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
284
- break ;
285
- case KCPP_SAMPLER_TFS:
286
- llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
287
- break ;
288
- case KCPP_SAMPLER_TYP:
289
- llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
290
- break ;
291
- case KCPP_SAMPLER_TEMP:
292
- llama_sample_temperature ( nullptr , &candidates_p, temp);
293
- break ;
294
- case KCPP_SAMPLER_REP_PEN:
295
- apply_penalties (n_ctx, rep_pen_range, rep_pen, candidates_p);
296
- break ;
297
- default :
298
- break ;
299
- }
276
+ switch (sampler_order[i])
277
+ {
278
+ case KCPP_SAMPLER_TOP_K:
279
+ llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
280
+ break ;
281
+ case KCPP_SAMPLER_TOP_A:
282
+ sample_top_a (&candidates_p,top_a,1 );
283
+ break ;
284
+ case KCPP_SAMPLER_TOP_P:
285
+ llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
286
+ break ;
287
+ case KCPP_SAMPLER_TFS:
288
+ llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
289
+ break ;
290
+ case KCPP_SAMPLER_TYP:
291
+ llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
292
+ break ;
293
+ case KCPP_SAMPLER_TEMP:
294
+ sample_temperature ( &candidates_p, temp);
295
+ break ;
296
+ case KCPP_SAMPLER_REP_PEN:
297
+ sample_rep_pen (n_ctx, rep_pen_range, rep_pen, & candidates_p);
298
+ break ;
299
+ default :
300
+ printf ( " \n SampleLogits: Unknown Sampler : %d " ,sampler_order[i]) ;
301
+ break ;
300
302
}
301
- id = sample_token (&candidates_p, rng);
302
- }
303
- else
304
- {
305
- // Temperature sampling
306
- llama_sample_top_k (nullptr , &candidates_p, top_k,1 );
307
- sample_top_a (&candidates_p,top_a,1 );
308
- llama_sample_tail_free (nullptr , &candidates_p, tfs,1 );
309
- llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
310
- llama_sample_top_p (nullptr , &candidates_p, top_p,1 );
311
- llama_sample_temperature (nullptr , &candidates_p, temp);
312
- id = sample_token (&candidates_p, rng);
313
303
}
304
+ id = sample_token (&candidates_p, rng);
314
305
}
315
306
316
307
return id;
@@ -952,6 +943,28 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
952
943
std::mt19937 rng (params.seed );
953
944
concat_output = " " ;
954
945
946
+ // prepare sampler order
947
+ std::vector<samplers> sampler_order;
948
+ if (inputs.sampler_len <=0 ) // list by value
949
+ {
950
+ sampler_order = {
951
+ KCPP_SAMPLER_REP_PEN,
952
+ KCPP_SAMPLER_TOP_K,
953
+ KCPP_SAMPLER_TOP_A,
954
+ KCPP_SAMPLER_TFS,
955
+ KCPP_SAMPLER_TYP,
956
+ KCPP_SAMPLER_TOP_P,
957
+ KCPP_SAMPLER_TEMP
958
+ };
959
+ }
960
+ else
961
+ {
962
+ for (int i=0 ;i<inputs.sampler_len ;++i)
963
+ {
964
+ sampler_order.push_back (inputs.sampler_order [i]);
965
+ }
966
+ }
967
+
955
968
bool startedsampling = false ;
956
969
bool use_scratch = true ; // for normal inference always use scratch
957
970
@@ -1274,8 +1287,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1274
1287
1275
1288
id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
1276
1289
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
1277
- params.mirostat , params.mirostat_tau , params.mirostat_eta ,
1278
- inputs.sampler_len , inputs.sampler_order );
1290
+ params.mirostat , params.mirostat_tau , params.mirostat_eta , sampler_order);
1279
1291
1280
1292
last_n_tokens.erase (last_n_tokens.begin ());
1281
1293
last_n_tokens.push_back (id);
0 commit comments