@@ -1029,23 +1029,33 @@ extern "C" {
1029
1029
1030
1030
LLAMA_API void llama_sampling_free (struct llama_sampling * smpl);
1031
1031
1032
+ // Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
1032
1033
LLAMA_API struct llama_sampling * llama_sampling_cp (const struct llama_sampling * smpl);
1033
1034
1034
1035
// - clear prev token
1035
1036
// - reset grammar state
1036
1037
LLAMA_API void llama_sampling_reset (struct llama_sampling * smpl);
1037
1038
1038
- LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
1039
+ // Sampling parameter mutation
1040
+ // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
1039
1041
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
1040
1042
LLAMA_API void llama_sampling_set_logit_bias (struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
1041
1043
1044
+ // Set the logits from which to sample.
1045
+ // This call initializes the internal token candidates array.
1046
+ // The internal candidates are implicitly used by the sampling API below when no candidates are provided.
1042
1047
LLAMA_API void llama_sampling_set_logits (
1043
1048
struct llama_sampling * smpl,
1044
1049
const float * logits);
1045
1050
1051
+ // / @details Returns the current candidate tokens.
1046
1052
LLAMA_API llama_token_data_array * llama_sampling_get_candidates (
1047
1053
struct llama_sampling * smpl);
1048
1054
1055
+ // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
1056
+ // Each function can accept an array of token candidates. If the candidates are not provided, the internal
1057
+ // candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
1058
+
1049
1059
// / @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1050
1060
LLAMA_API void llama_sampling_softmax (
1051
1061
struct llama_sampling * smpl,
@@ -1108,17 +1118,22 @@ extern "C" {
1108
1118
struct llama_sampling * smpl,
1109
1119
llama_token_data_array * candidates);
1110
1120
1111
- // / @details Sample a token using the configured samplers.
1121
+ // / @details Sample a token using the configured samplers (see "llama_sampling_params.samplers") .
1112
1122
LLAMA_API llama_token llama_sampling_sample (
1113
1123
struct llama_sampling * smpl,
1114
1124
llama_token_data_array * candidates);
1115
1125
1116
- // / @details Accepts the sampled token into the sampling context
1126
+ // / @details Accepts the sampled token into the sampling context.
1127
+ // / - adds it to "prev" tokens
1128
+ // / - updates the grammar state (if apply_grammar is true)
1117
1129
LLAMA_API void llama_sampling_accept (
1118
1130
struct llama_sampling * smpl,
1119
1131
llama_token token,
1120
1132
bool apply_grammar);
1121
1133
1134
+ // / @details Get the number of accepted tokens so far (max of n_prev)
1135
+ LLAMA_API int llama_sampling_n_prev (const struct llama_sampling * smpl);
1136
+
1122
1137
// / @details Get the ith accepted token
1123
1138
// / @param ith [0, n_prev), ith == 0 is the last accepted token.
1124
1139
// / returns LLAMA_TOKEN_NULL if ith is out of bounds
@@ -1131,9 +1146,6 @@ extern "C" {
1131
1146
// / returns LLAMA_TOKEN_NULL if there are no accepted tokens
1132
1147
LLAMA_API llama_token llama_sampling_last (const struct llama_sampling * smpl);
1133
1148
1134
- // / @details Get the number of accepted tokens (max of n_prev)
1135
- LLAMA_API int llama_sampling_n_prev (const struct llama_sampling * smpl);
1136
-
1137
1149
//
1138
1150
// Model split
1139
1151
//
0 commit comments