Skip to content

Commit 3c46719

Browse files
committed
sampler : API to iterate constraints
ggml-ci
1 parent 23f0802 commit 3c46719

File tree

10 files changed

+69
-48
lines changed

10 files changed

+69
-48
lines changed

common/sampling.cpp

+17-22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct gpt_sampler {
1212
struct llama_sampler * smpl;
1313
};
1414

15-
std::string gpt_sampler_params::print_all() const {
15+
std::string gpt_sampler_params::print() const {
1616
char result[1024];
1717

1818
snprintf(result, sizeof(result),
@@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const {
2626
return std::string(result);
2727
}
2828

29-
std::string gpt_sampler_params::print_constraints() const {
30-
std::string result = "CFG -> Penalties ";
31-
if (mirostat == 0) {
32-
for (const auto & cnstr : constraints) {
33-
const auto name = gpt_constraint_type_to_str(cnstr);
34-
if (!name.empty()) {
35-
result += "-> " + name + " ";
36-
}
37-
}
38-
} else {
39-
result += "-> mirostat ";
29+
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
30+
std::string result = "\tlogits";
31+
32+
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
33+
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
34+
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
4035
}
4136

4237
return result;
@@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
7065
for (const auto & cnstr : params.constraints) {
7166
switch (cnstr) {
7267
case GPT_CONSTRAINT_TYPE_TOP_K:
73-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
68+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
7469
break;
7570
case GPT_CONSTRAINT_TYPE_TOP_P:
76-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
71+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
7772
break;
7873
case GPT_CONSTRAINT_TYPE_MIN_P:
79-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
74+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
8075
break;
8176
case GPT_CONSTRAINT_TYPE_TFS_Z:
82-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
77+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
8378
break;
8479
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
85-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
80+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
8681
break;
8782
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
88-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
83+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
8984
break;
9085
default:
9186
GGML_ASSERT(false && "unknown constraint type");
9287
}
9388
}
9489
} else if (params.mirostat == 1) {
95-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
96-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
90+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
91+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
9792
} else if (params.mirostat == 2) {
98-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
99-
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
93+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
94+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
10095
} else {
10196
GGML_ASSERT(false && "unknown mirostat version");
10297
}

common/sampling.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ struct gpt_sampler_params {
5454
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
5555

5656
// print the parameters into a string
57-
std::string print_all() const;
58-
59-
// print the constraints into a string
60-
std::string print_constraints() const;
57+
std::string print() const;
6158
};
6259

6360
// gpt_sampler extends llama_sampler with additional functionality:
@@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da
10097

10198
// helpers
10299

100+
// print the constraints into a string
101+
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
102+
103103
// get a string representation of the last accepted tokens
104104
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
105105

examples/batched.swift/Sources/main.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ defer {
6161
llama_sampler_free(smpl)
6262
}
6363

64-
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1));
65-
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1));
66-
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4));
64+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
65+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
66+
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));
6767

6868
let n_ctx = llama_n_ctx(context)
6969

examples/batched/batched.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
7070

7171
llama_sampler * smpl = llama_sampler_init(model, sparams);
7272

73-
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
74-
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
75-
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
73+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
74+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
75+
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));
7676

7777
if (ctx == NULL) {
7878
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

examples/infill/infill.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ int main(int argc, char ** argv) {
301301
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
302302
}
303303
}
304-
LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str());
304+
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
305305
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
306306
LOG_TEE("\n\n");
307307

examples/main/main.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,15 @@ int main(int argc, char ** argv) {
457457
}
458458
}
459459
}
460-
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
461-
LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str());
460+
461+
smpl = gpt_sampler_init(model, sparams);
462+
if (!smpl) {
463+
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
464+
exit(1);
465+
}
466+
467+
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
468+
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
462469
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
463470

464471
// group-attention state
@@ -525,12 +532,6 @@ int main(int argc, char ** argv) {
525532
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
526533
}
527534

528-
smpl = gpt_sampler_init(model, sparams);
529-
if (!smpl) {
530-
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
531-
exit(1);
532-
}
533-
534535
if (llama_model_has_encoder(model)) {
535536
int enc_input_size = embd_inp.size();
536537
llama_token * enc_input_buf = embd_inp.data();

include/llama.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ extern "C" {
10101010
// The llama_sampler object contains the entire sampling information:
10111011
//
10121012
// - RNG state (seed and generator)
1013-
// - Custom set of constraints (see llama_sampler_add_constraint)
1013+
// - Custom set of constraints (see llama_sampler_constraint_add)
10141014
// - Sampling method (greedy, dist)
10151015
// - Previous tokens
10161016
//
@@ -1081,7 +1081,7 @@ extern "C" {
10811081

10821082
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
10831083

1084-
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
1084+
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
10851085
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
10861086

10871087
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
@@ -1100,7 +1100,10 @@ extern "C" {
11001100
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
11011101

11021102
// important: takes ownership of the constraint object and will free it in llama_sampler_free
1103-
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
1103+
LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr);
1104+
LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl);
1105+
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);
1106+
11041107

11051108
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
11061109
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);

src/llama-sampling.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
12151215
// TODO: should we reset the timings?
12161216
}
12171217

1218-
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
1218+
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
12191219
smpl.constraints.push_back(cnstr);
12201220
}
12211221

1222+
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) {
1223+
return smpl.constraints.size();
1224+
}
1225+
1226+
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) {
1227+
if (ith < 0 || ith >= (int) smpl.constraints.size()) {
1228+
return nullptr;
1229+
}
1230+
1231+
return smpl.constraints[ith];
1232+
}
1233+
12221234
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
12231235
smpl.prev.push_back(token);
12241236

src/llama-sampling.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp
109109
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
110110
void llama_sampler_reset_impl( struct llama_sampler & smpl);
111111

112-
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
112+
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
113+
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
114+
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);
113115

114116
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
115117
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);

src/llama.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -20699,8 +20699,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp
2069920699
return &smpl->cur_p;
2070020700
}
2070120701

20702-
void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
20703-
llama_sampler_add_constraint_impl(*smpl, cnstr);
20702+
void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
20703+
llama_sampler_constraint_add_impl(*smpl, cnstr);
20704+
}
20705+
20706+
int llama_sampler_n_constraints (const struct llama_sampler * smpl) {
20707+
return llama_sampler_n_constraints_impl(*smpl);
20708+
}
20709+
20710+
struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) {
20711+
return llama_sampler_constraint_get_impl(*smpl, i);
2070420712
}
2070520713

2070620714
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {

0 commit comments

Comments
 (0)