Skip to content

Commit c6c0afd

Browse files
committed
refactor to avoid code duplication
1 parent 309534d commit c6c0afd

File tree

3 files changed

+88
-76
lines changed

3 files changed

+88
-76
lines changed

expose.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ const int stop_token_max = 10;
44
// match kobold's sampler list and order
55
enum samplers
66
{
7-
KCPP_SAMPLER_TOP_K,
8-
KCPP_SAMPLER_TOP_A,
9-
KCPP_SAMPLER_TOP_P,
10-
KCPP_SAMPLER_TFS,
11-
KCPP_SAMPLER_TYP,
12-
KCPP_SAMPLER_TEMP,
13-
KCPP_SAMPLER_REP_PEN,
7+
KCPP_SAMPLER_TOP_K=0,
8+
KCPP_SAMPLER_TOP_A=1,
9+
KCPP_SAMPLER_TOP_P=2,
10+
KCPP_SAMPLER_TFS=3,
11+
KCPP_SAMPLER_TYP=4,
12+
KCPP_SAMPLER_TEMP=5,
13+
KCPP_SAMPLER_REP_PEN=6,
1414
KCPP_SAMPLER_MAX
1515
};
1616
struct load_model_inputs

gpttype_adapter.cpp

+78-66
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,31 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
219219
candidates->size = last_idx;
220220
}
221221

222-
void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
222+
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
223223
{
224224
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
225-
llama_sample_repetition_penalty(nullptr, &candidates_p,
225+
llama_sample_repetition_penalty(nullptr, candidates_p,
226226
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
227227
last_n_repeat, rep_pen);
228228
}
229229

230+
void sample_temperature(llama_token_data_array * candidates_p, float temp)
231+
{
232+
if (temp <= 0)
233+
{
234+
// Imitate greedy sampling
235+
temp = 0.01f; //cannot be zero else div0
236+
llama_sample_temperature(nullptr, candidates_p, temp);
237+
llama_sample_top_k(nullptr, candidates_p, 1, 1); //only want first candidate
238+
}
239+
else
240+
{
241+
llama_sample_temperature(nullptr, candidates_p, temp);
242+
}
243+
}
244+
230245
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
231-
int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX])
246+
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order)
232247
{
233248
int id = 0;
234249
std::vector<llama_token_data> candidates;
@@ -239,78 +254,54 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa
239254

240255
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
241256

242-
// Run this except for when we are going to do the sampler reordering case below
243-
if (temp <= 0 || mirostat > 0 || sampler_len == 0)
244-
{
245-
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
246-
}
247-
248-
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
249-
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
250-
// last_n_repeat, alpha_frequency, alpha_presence);
251-
252-
if (temp <= 0)
253-
{
254-
// Greedy sampling
255-
id = llama_sample_token_greedy(nullptr, &candidates_p);
256-
}
257-
else
257+
if (mirostat == 1 || mirostat == 2)
258258
{
259+
static float mirostat_mu = 2.0f * mirostat_tau;
260+
const int mirostat_m = 100;
261+
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
262+
sample_temperature(&candidates_p, temp);
259263
if (mirostat == 1)
260264
{
261-
static float mirostat_mu = 2.0f * mirostat_tau;
262-
const int mirostat_m = 100;
263-
llama_sample_temperature(nullptr, &candidates_p, temp);
264265
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
265266
}
266-
else if (mirostat == 2)
267+
else
267268
{
268-
static float mirostat_mu = 2.0f * mirostat_tau;
269-
llama_sample_temperature(nullptr, &candidates_p, temp);
270269
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
271270
}
272-
else if (sampler_len > 0)
271+
}
272+
else
273+
{
274+
for (int i = 0; i < sampler_order.size(); i++)
273275
{
274-
for (int i = 0; i < sampler_len; i++) {
275-
switch (sampler_order[i]) {
276-
case KCPP_SAMPLER_TOP_K:
277-
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
278-
break;
279-
case KCPP_SAMPLER_TOP_A:
280-
sample_top_a(&candidates_p,top_a,1);
281-
break;
282-
case KCPP_SAMPLER_TOP_P:
283-
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
284-
break;
285-
case KCPP_SAMPLER_TFS:
286-
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
287-
break;
288-
case KCPP_SAMPLER_TYP:
289-
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
290-
break;
291-
case KCPP_SAMPLER_TEMP:
292-
llama_sample_temperature(nullptr, &candidates_p, temp);
293-
break;
294-
case KCPP_SAMPLER_REP_PEN:
295-
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
296-
break;
297-
default:
298-
break;
299-
}
276+
switch (sampler_order[i])
277+
{
278+
case KCPP_SAMPLER_TOP_K:
279+
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
280+
break;
281+
case KCPP_SAMPLER_TOP_A:
282+
sample_top_a(&candidates_p,top_a,1);
283+
break;
284+
case KCPP_SAMPLER_TOP_P:
285+
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
286+
break;
287+
case KCPP_SAMPLER_TFS:
288+
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
289+
break;
290+
case KCPP_SAMPLER_TYP:
291+
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
292+
break;
293+
case KCPP_SAMPLER_TEMP:
294+
sample_temperature(&candidates_p, temp);
295+
break;
296+
case KCPP_SAMPLER_REP_PEN:
297+
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
298+
break;
299+
default:
300+
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
301+
break;
300302
}
301-
id = sample_token(&candidates_p, rng);
302-
}
303-
else
304-
{
305-
// Temperature sampling
306-
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
307-
sample_top_a(&candidates_p,top_a,1);
308-
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
309-
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
310-
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
311-
llama_sample_temperature(nullptr, &candidates_p, temp);
312-
id = sample_token(&candidates_p, rng);
313303
}
304+
id = sample_token(&candidates_p, rng);
314305
}
315306

316307
return id;
@@ -952,6 +943,28 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
952943
std::mt19937 rng(params.seed);
953944
concat_output = "";
954945

946+
//prepare sampler order
947+
std::vector<samplers> sampler_order;
948+
if(inputs.sampler_len<=0) //list by value
949+
{
950+
sampler_order = {
951+
KCPP_SAMPLER_REP_PEN,
952+
KCPP_SAMPLER_TOP_K,
953+
KCPP_SAMPLER_TOP_A,
954+
KCPP_SAMPLER_TFS,
955+
KCPP_SAMPLER_TYP,
956+
KCPP_SAMPLER_TOP_P,
957+
KCPP_SAMPLER_TEMP
958+
};
959+
}
960+
else
961+
{
962+
for(int i=0;i<inputs.sampler_len;++i)
963+
{
964+
sampler_order.push_back(inputs.sampler_order[i]);
965+
}
966+
}
967+
955968
bool startedsampling = false;
956969
bool use_scratch = true; //for normal inference always use scratch
957970

@@ -1274,8 +1287,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
12741287

12751288
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
12761289
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
1277-
params.mirostat, params.mirostat_tau, params.mirostat_eta,
1278-
inputs.sampler_len, inputs.sampler_order);
1290+
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order);
12791291

12801292
last_n_tokens.erase(last_n_tokens.begin());
12811293
last_n_tokens.push_back(id);

koboldcpp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def load_model(model_filename):
189189
ret = handle.load_model(inputs)
190190
return ret
191191

192-
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False):
192+
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], stream_sse=False):
193193
inputs = generation_inputs()
194194
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
195195
inputs.prompt = prompt.encode("UTF-8")
@@ -289,7 +289,7 @@ def run_blocking():
289289
mirostat=genparams.get('mirostat', 0),
290290
mirostat_tau=genparams.get('mirostat_tau', 5.0),
291291
mirostat_eta=genparams.get('mirostat_eta', 0.1),
292-
sampler_order=genparams.get('sampler_order', None),
292+
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
293293
seed=genparams.get('sampler_seed', -1),
294294
stop_sequence=genparams.get('stop_sequence', []),
295295
stream_sse=stream_flag)
@@ -309,7 +309,7 @@ def run_blocking():
309309
mirostat=genparams.get('mirostat', 0),
310310
mirostat_tau=genparams.get('mirostat_tau', 5.0),
311311
mirostat_eta=genparams.get('mirostat_eta', 0.1),
312-
sampler_order=genparams.get('sampler_order', None),
312+
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
313313
seed=genparams.get('sampler_seed', -1),
314314
stop_sequence=genparams.get('stop_sequence', []),
315315
stream_sse=stream_flag)

0 commit comments

Comments
 (0)