Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a5d664c

Browse files
committedAug 29, 2024··
llama : add comments [no ci]
1 parent b4fbb2a commit a5d664c

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed
 

‎include/llama.h

+18-6
Original file line numberDiff line numberDiff line change
@@ -1029,23 +1029,33 @@ extern "C" {
10291029

10301030
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
10311031

1032+
// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
10321033
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
10331034

10341035
// - clear prev token
10351036
// - reset grammar state
10361037
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
10371038

1038-
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
1039+
// Sampling parameter mutation
1040+
// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
10391041
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
10401042
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
10411043

1044+
// Set the logits from which to sample.
1045+
// This call initializes the internal token candidates array.
1046+
// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
10421047
LLAMA_API void llama_sampling_set_logits(
10431048
struct llama_sampling * smpl,
10441049
const float * logits);
10451050

1051+
/// @details Returns the current candidate tokens.
10461052
LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
10471053
struct llama_sampling * smpl);
10481054

1055+
// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
1056+
// Each function can accept an array of token candidates. If the candidates are not provided, the internal
1057+
// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
1058+
10491059
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
10501060
LLAMA_API void llama_sampling_softmax(
10511061
struct llama_sampling * smpl,
@@ -1108,17 +1118,22 @@ extern "C" {
11081118
struct llama_sampling * smpl,
11091119
llama_token_data_array * candidates);
11101120

1111-
/// @details Sample a token using the configured samplers.
1121+
/// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
11121122
LLAMA_API llama_token llama_sampling_sample(
11131123
struct llama_sampling * smpl,
11141124
llama_token_data_array * candidates);
11151125

1116-
/// @details Accepts the sampled token into the sampling context
1126+
/// @details Accepts the sampled token into the sampling context.
1127+
/// - adds it to "prev" tokens
1128+
/// - updates the grammar state (if apply_grammar is true)
11171129
LLAMA_API void llama_sampling_accept(
11181130
struct llama_sampling * smpl,
11191131
llama_token token,
11201132
bool apply_grammar);
11211133

1134+
/// @details Get the number of accepted tokens so far (max of n_prev)
1135+
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
1136+
11221137
/// @details Get the ith accepted token
11231138
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
11241139
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
@@ -1131,9 +1146,6 @@ extern "C" {
11311146
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
11321147
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
11331148

1134-
/// @details Get the number of accepted tokens (max of n_prev)
1135-
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
1136-
11371149
//
11381150
// Model split
11391151
//

‎src/llama-sampling.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,12 @@ void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, s
186186
int ib = nbuckets - 1;
187187
for ( ; ib >= 0; --ib) {
188188
nhave += histo[ib];
189-
if (nhave >= k) break;
189+
if (nhave >= k) {
190+
break;
191+
}
190192
}
191193
std::vector<llama_token_data> tmp_tokens(nhave);
192-
auto ptr = tmp_tokens.data();
194+
auto * ptr = tmp_tokens.data();
193195
std::vector<llama_token_data*> bucket_ptrs;
194196
bucket_ptrs.reserve(nbuckets - ib);
195197
for (int j = nbuckets - 1; j >= ib; --j) {
@@ -573,6 +575,7 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array
573575
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
574576
return candidate.id == X;
575577
}));
578+
576579
float observed_surprise = -log2f(candidates->data[X_idx].p);
577580
float e = observed_surprise - tau;
578581

‎src/llama-sampling.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array
9898
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
9999
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
100100

101-
llama_token llama_sampling_sample_greedy_impl (struct llama_token_data_array * candidates);
102-
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
101+
llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates);
102+
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
103103

104104
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar);
105105

106-
llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith);
106+
llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith);
107107
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl);

‎src/llama.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -20084,10 +20084,6 @@ void llama_sampling_reset(struct llama_sampling * smpl) {
2008420084
llama_sampling_reset_impl(*smpl);
2008520085
}
2008620086

20087-
void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
20088-
llama_sampling_set_rng_seed_impl(*smpl, seed);
20089-
}
20090-
2009120087
void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
2009220088
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
2009320089
}
@@ -20392,6 +20388,10 @@ void llama_sampling_accept(
2039220388
smpl->n_accept++;
2039320389
}
2039420390

20391+
int llama_sampling_n_prev(const struct llama_sampling * smpl) {
20392+
return llama_sampling_n_prev_impl(*smpl);
20393+
}
20394+
2039520395
llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
2039620396
return llama_sampling_prev_impl(*smpl, ith);
2039720397
}
@@ -20400,10 +20400,6 @@ llama_token llama_sampling_last(const struct llama_sampling * smpl) {
2040020400
return llama_sampling_prev_impl(*smpl, 0);
2040120401
}
2040220402

20403-
int llama_sampling_n_prev(const struct llama_sampling * smpl) {
20404-
return llama_sampling_n_prev_impl(*smpl);
20405-
}
20406-
2040720403
//
2040820404
// model split
2040920405
//

0 commit comments

Comments
 (0)
Please sign in to comment.