From 9b3b07cc5ce2cae67b5da1bce8657a15e0daf39a Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Sat, 22 Apr 2023 14:31:08 +0300 Subject: [PATCH 1/8] Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. --- CMakeLists.txt | 2 +- Makefile | 2 +- examples/common.cpp | 28 +++ examples/main/main.cpp | 62 +++++- llama.cpp | 450 +++++++++++++++++++++++++++++++---------- llama.h | 34 ++-- tests/CMakeLists.txt | 3 + 7 files changed, 450 insertions(+), 131 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fdbeddfca443..9d7c9d1ed35ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) # Compile flags # -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) diff --git a/Makefile b/Makefile index 0715e857bc346..b4af18c0e9b82 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ endif # keep standard at C11 and C++11 CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC LDFLAGS = # warnings diff --git a/examples/common.cpp b/examples/common.cpp index 9f10dc268558b..a8f57360a18ac 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -114,6 +114,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.temp = std::stof(argv[i]); + } else if (arg == "--tfs") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.tfs_z = std::stof(argv[i]); + } else if (arg == "--typical") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.typical_p = std::stof(argv[i]); } else if (arg == "--repeat_last_n") { if (++i >= argc) { invalid_param = true; @@ -126,6 +138,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.repeat_penalty = std::stof(argv[i]); + } else if (arg == "--alpha_frequency") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.alpha_frequency = std::stof(argv[i]); + } else if (arg == "--alpha_presence") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.alpha_presence = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch_size") { if (++i >= argc) { invalid_param = true; @@ -242,6 +266,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p); + fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z); + fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p); + fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence); + fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fda65574fad7a..9b795bd3a6915 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n", + params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -387,10 +387,15 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token - const int32_t top_k = params.top_k; - const float top_p = params.top_p; const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.alpha_presence; + const float alpha_frequency = params.alpha_frequency; // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session) { @@ -402,14 +407,55 @@ int main(int argc, char ** argv) { { auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); if (params.ignore_eos) { - logits[llama_token_eos()] = 0; + logits[llama_token_eos()] = -INFINITY; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (size_t i = 0; i < n_vocab; i++) { + candidates.emplace_back(i, logits[i], 0.0f); } - id = llama_sample_top_p_top_k(ctx, - last_n_tokens.data() + n_ctx - params.repeat_last_n, - params.repeat_last_n, top_k, top_p, temp, repeat_penalty); + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + + // Apply penalties + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + llama_sample_repetition_penalty(&candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(&candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + + +#if 1 + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } else { + // Temperature sampling + llama_sample_top_k(&candidates_p, top_k); + llama_sample_tail_free(&candidates_p, tfs_z); + llama_sample_typical(&candidates_p, typical_p); + llama_sample_top_p(&candidates_p, top_p); + + llama_sample_temperature(&candidates_p, temp); + // printf("`%d`", candidates_p.size); + id = llama_sample_token(ctx, &candidates_p); + } +#else + const float tau = 5.0f; + static float mu = 2.0f * tau; + static int k = 40; + const float eta = 0.1f; + const int m = 100; + const float N = n_vocab; + id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); + // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu); +#endif last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/llama.cpp b/llama.cpp index dca017db62503..64debd715ca87 100644 --- a/llama.cpp +++ b/llama.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1478,109 +1479,369 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // sampling // -static void sample_top_k(std::vector> & logits_id, int top_k) { - // find the top k tokens - std::partial_sort( - logits_id.begin(), - logits_id.begin() + top_k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); +void llama_sample_softmax(llama_token_data_array * candidates) { + assert(candidates->size > 0); + std::span tokens(candidates->data, candidates->size); + + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(tokens.begin(), tokens.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } - logits_id.resize(top_k); + float max_l = tokens[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < tokens.size(); ++i) { + // printf("llama_sample_softmax: i: %d, logit: %f\n", i, tokens[i].logit); + float p = expf(tokens[i].logit - max_l); + tokens[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < tokens.size(); ++i) { + tokens[i].p /= cum_sum; + } } -static llama_vocab::id llama_sample_top_p_top_k( - llama_context & lctx, - const std::vector & last_n_tokens, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - auto & rng = lctx.rng; - - const int n_logits = lctx.model.hparams.n_vocab; - - const auto & logits = lctx.logits; - const auto * plogits = logits.data() + logits.size() - n_logits; - - if (temp <= 0) { - // select the token with the highest logit directly - float max_logit = plogits[0]; - llama_vocab::id max_id = 0; - - for (int i = 1; i < n_logits; ++i) { - if (plogits[i] > max_logit) { - max_logit = plogits[i]; - max_id = i; - } +void llama_sample_top_k(llama_token_data_array * candidates_p, int k) { + assert(k > 0); + std::span candidates(candidates_p->data, candidates_p->size); + + // Sort scores in descending order + if (!candidates_p->sorted) { + if (k >= candidates_p->size) { + std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + } else { + std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); } - return max_id; + candidates_p->sorted = true; } + candidates_p->size = std::min(k, (int) candidates.size()); +} - std::vector> logits_id; - logits_id.reserve(n_logits); +void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } - { - const float scale = 1.0f/temp; - for (int i = 0; i < n_logits; ++i) { - // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (plogits[i] < 0.0f) { - logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); - } - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale, i)); - } + llama_sample_softmax(candidates_p); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates_p->size; + + for (size_t i = 0; i < candidates_p->size; ++i) { + cum_sum += candidates_p->data[i].p; + + // Check if the running sum is greater than p or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep) { + last_idx = i; + break; } } - sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits); + // Resize the output vector to keep only the top-p tokens + candidates_p->size = last_idx; +} - // compute probs for the top k tokens - std::vector probs; - probs.reserve(logits_id.size()); +// https://www.trentonbricken.com/Tail-Free-Sampling/ +void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) { + if (z >= 1.0f || candidates_p->size <= 2) { + return; + } + + llama_sample_softmax(candidates_p); - float maxl = logits_id[0].first; - double sum = 0.0; - for (const auto & kv : logits_id) { - const float p = expf(kv.first - maxl); - probs.push_back(p); - sum += p; + // Compute the first and second derivatives + std::vector first_derivatives(candidates_p->size - 1); + std::vector second_derivatives(candidates_p->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates_p->data[i].p - candidates_p->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; } - // normalize the probs - for (auto & p : probs) { - p /= sum; + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = abs(second_derivatives[i]); } - if (top_p < 1.0) { - double cumsum = 0.0; - for (int i = 0; i < (int) probs.size(); i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - probs.resize(i + 1); - logits_id.resize(i + 1); - break; - } + // Normalize the second derivatives + float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + + float cum_sum = 0.0f; + size_t last_idx = candidates_p->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; } } - //printf("\n"); - //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); - //} - //printf("\n\n"); - //exit(0); + // Resize the output vector to keep only the tokens above the tail location + candidates_p->size = last_idx; +} + + +// https://arxiv.org/pdf/2202.00666.pdf +// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr +void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { + if (typical_p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(candidates_p); + + std::span candidates(candidates_p->data, candidates_p->size); + + float entropy = 0.0f; + for (const auto & candidate : candidates) { + entropy += -candidate.p * logf(candidate.p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (const auto & candidate : candidates) { + float shifted_score = fabsf(-logf(candidate.p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort candidates based on the shifted_scores and their corresponding indices + std::vector indices(candidates.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > typical_p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates[idx]); + } + + // Replace the data in candidates_p with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data); + candidates_p->size = new_candidates.size(); +} + + +void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) { + std::span candidates(candidates_p->data, candidates_p->size); + for (auto & candidate : candidates) { + candidate.logit /= temp; + } +} + +void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { + if (last_tokens_size == 0 || penalty == 1.0f) { + return; + } + + // CTRL paper: https://arxiv.org/pdf/1909.05858.pdf + std::span candidates(candidates_p->data, candidates_p->size); + std::span last_tokens(last_tokens_p, last_tokens_size); + + for (size_t i = 0; i < candidates.size(); ++i) { + auto token_iter = std::find(last_tokens.begin(), last_tokens.end(), candidates[i].id); + if (token_iter == last_tokens.end()) { + continue; + } + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates[i].logit <= 0) { + candidates[i].logit *= penalty; + } else { + candidates[i].logit /= penalty; + } + + // But it does not penalize tokens that logits are near zero, which is a problem. + // Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits. + // float probability = std::exp(candidates[i].logit); + // probability /= penalty; + // candidates[i].logit = std::log(probability); + } + + candidates_p->sorted = false; +} + +void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { + return; + } + + std::span candidates(candidates_p->data, candidates_p->size); + std::span last_tokens(last_tokens_p, last_tokens_size); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (const auto & token : last_tokens) { + token_count[token]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates.size(); ++i) { + auto token_iter = token_count.find(candidates[i].id); + if (token_iter == token_count.end()) { + continue; + } + + int count = token_iter->second; + candidates[i].logit -= count * alpha_frequency + float(count > 0) * alpha_presence; + } + + candidates_p->sorted = false; +} + +/// @brief Mirostat 1.0 implementation. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param N The size of the vocabulary. This is used in the calculation of the `k` value. +/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. +/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { + // https://arxiv.org/abs/2007.14966 + std::span candidates(candidates_p->data, candidates_p->size); + + // printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu); + + llama_sample_softmax(candidates_p); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) { + float t_i = logf((i + 2) / float(i + 1)); + float b_i = logf(candidates[i].p / candidates[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + // printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N); + float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); + *k = std::min(new_k, float(candidates.size())); + + // Sample the next word X using top-k sampling + // printf("llama_sample_mirostat *k = %d\n", *k); + llama_sample_top_k(candidates_p, *k); + llama_token X = llama_sample_token(ctx, candidates_p); + + // Compute error as the difference between observed surprise and target surprise value + int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + return X; +} + +llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { + std::span candidates(candidates_p->data, candidates_p->size); + + llama_sample_softmax(candidates_p); + + // Truncate the words with surprise values greater than mu + candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + // Normalize the probabilities of the remaining words + llama_sample_softmax(candidates_p); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token(ctx, candidates_p); + + // Compute error as the difference between observed surprise and target surprise value + int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + return X; +} + + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) { + // Find max element + std::span candidates(candidates_p->data, candidates_p->size); + auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + ctx->n_sample++; + return result; +} + + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) { + // const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(candidates_p); + + std::span candidates(candidates_p->data, candidates_p->size); + + std::vector probs; + probs.reserve(candidates.size()); + for (auto & candidate : candidates) { + probs.push_back(candidate.p); + } std::discrete_distribution<> dist(probs.begin(), probs.end()); + auto & rng = ctx->rng; int idx = dist(rng); - return logits_id[idx].second; + llama_token result = candidates[idx].id; + + // ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + + return result; } // @@ -2352,35 +2613,6 @@ llama_token llama_token_eos() { return 2; } -llama_token llama_sample_top_p_top_k( - llama_context * ctx, - const llama_token * last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - const int64_t t_start_sample_us = ggml_time_us(); - - llama_token result = 0; - - // TODO: avoid this ... - const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); - - result = llama_sample_top_p_top_k( - *ctx, - last_n_tokens, - top_k, - top_p, - temp, - repeat_penalty); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - - return result; -} - void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 86a7d279a9ef4..129574eed0921 100644 --- a/llama.h +++ b/llama.h @@ -39,12 +39,16 @@ extern "C" { typedef struct llama_token_data { llama_token id; // token id - + float logit; // log-odds of the token float p; // probability of the token - float plog; // log probability of the token - } llama_token_data; + typedef struct llama_token_data_array { + llama_token_data * data; + size_t size; + bool sorted; + } llama_token_data_array; + typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { @@ -182,15 +186,21 @@ extern "C" { LLAMA_API llama_token llama_token_bos(); LLAMA_API llama_token llama_token_eos(); - // TODO: improve the last_n_tokens interface ? - LLAMA_API llama_token llama_sample_top_p_top_k( - struct llama_context * ctx, - const llama_token * last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty); + // Sampling functions + LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + + LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates); + LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k); + LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp); + + LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 81eadbc4db0a4..9bc5ea036c812 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,9 +3,12 @@ function(llama_add_test source) add_executable(${TEST_TARGET} ${source}) target_link_libraries(${TEST_TARGET} PRIVATE llama) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) + target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address) + target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address) endfunction() # llama_add_test(test-double-float.c) # SLOW llama_add_test(test-quantize-fns.cpp) llama_add_test(test-quantize-perf.cpp) +llama_add_test(test-sampling.cpp) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) From f01c67fe55d4c48b7903394416303aafc20e3f3b Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Sat, 22 Apr 2023 21:23:10 +0300 Subject: [PATCH 2/8] mirostat --- examples/common.cpp | 37 +++++-- examples/common.h | 17 ++- examples/main/main.cpp | 53 +++++----- llama.cpp | 158 ++++++++++++++++++++------- llama.h | 24 ++--- tests/CMakeLists.txt | 2 - tests/test-sampling.cpp | 229 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 427 insertions(+), 93 deletions(-) create mode 100644 tests/test-sampling.cpp diff --git a/examples/common.cpp b/examples/common.cpp index a8f57360a18ac..7e62be356b1cf 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -150,6 +150,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.alpha_presence = std::stof(argv[i]); + } else if (arg == "--mirostat") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat = std::stoi(argv[i]); + } else if (arg == "--mirostat_eta") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_eta = std::stof(argv[i]); + } else if (arg == "--mirostat_tau") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_tau = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch_size") { if (++i >= argc) { invalid_param = true; @@ -264,14 +282,17 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); - fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); - fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p); - fprintf(stderr, " --tfs N tail free sampling (default: %.1f)\n", (double)params.tfs_z); - fprintf(stderr, " --typical N locally typical sampling (default: %.1f)\n", (double)params.typical_p); - fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %d)\n", params.alpha_presence); - fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f)\n", (double)params.alpha_frequency); - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); - fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); + fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p); + fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z); + fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty); + fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence); + fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency); + fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat); + fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta); + fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); diff --git a/examples/common.h b/examples/common.h index 9d3697d793eff..de25e6435397e 100644 --- a/examples/common.h +++ b/examples/common.h @@ -17,17 +17,24 @@ struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict - int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt // sampling parameters - int32_t top_k = 40; - float top_p = 0.95f; - float temp = 0.80f; - float repeat_penalty = 1.10f; + int32_t top_k = 0; // <= 0 to use vocab size + float top_p = 1.0f; // 1.0 = disabled + float tfs_z = 1.0f; // 1.0 = disabled + float typical_p = 1.0f; // 1.0 = disabled + float temp = 1.0f; // 1.0 = disabled + float repeat_penalty = 1.0f; // 1.0 = disabled + int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float alpha_frequency = 0.0f; // 0.0 = disabled + float alpha_presence = 0.0f; // 0.0 = disabled + int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.0f; // target entropy + float mirostat_eta = 0.1f; // learning rate std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9b795bd3a6915..a6de98fedfc61 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n", - params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp); + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", + params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -396,6 +396,9 @@ int main(int argc, char ** argv) { const float repeat_penalty = params.repeat_penalty; const float alpha_presence = params.alpha_presence; const float alpha_frequency = params.alpha_frequency; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session) { @@ -415,47 +418,45 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(n_vocab); - for (size_t i = 0; i < n_vocab; i++) { + for (size_t i = 0; i < (size_t) n_vocab; i++) { candidates.emplace_back(i, logits[i], 0.0f); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // Apply penalties auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(&candidates_p, + llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(&candidates_p, + llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); -#if 1 if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); } else { - // Temperature sampling - llama_sample_top_k(&candidates_p, top_k); - llama_sample_tail_free(&candidates_p, tfs_z); - llama_sample_typical(&candidates_p, typical_p); - llama_sample_top_p(&candidates_p, top_p); - - llama_sample_temperature(&candidates_p, temp); - // printf("`%d`", candidates_p.size); - id = llama_sample_token(ctx, &candidates_p); + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + static int mirostat_k = 40; + const int mirostat_m = 100; + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k); + llama_sample_tail_free(ctx, &candidates_p, tfs_z); + llama_sample_typical(ctx, &candidates_p, typical_p); + llama_sample_top_p(ctx, &candidates_p, top_p); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } } -#else - const float tau = 5.0f; - static float mu = 2.0f * tau; - static int k = 40; - const float eta = 0.1f; - const int m = 100; - const float N = n_vocab; - id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); - // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu); -#endif + // printf("`%d`", candidates_p.size); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/llama.cpp b/llama.cpp index 64debd715ca87..4da4df1f2c09e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1479,8 +1479,11 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // sampling // -void llama_sample_softmax(llama_token_data_array * candidates) { +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { assert(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + std::span tokens(candidates->data, candidates->size); // Sort the logits in descending order @@ -1502,35 +1505,47 @@ void llama_sample_softmax(llama_token_data_array * candidates) { for (size_t i = 0; i < tokens.size(); ++i) { tokens[i].p /= cum_sum; } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_top_k(llama_token_data_array * candidates_p, int k) { - assert(k > 0); +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates_p, int k, size_t min_keep) { + const int64_t t_start_sample_us = ggml_time_us(); + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates_p->size); + std::span candidates(candidates_p->data, candidates_p->size); // Sort scores in descending order if (!candidates_p->sorted) { - if (k >= candidates_p->size) { - std::sort(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates_p->size) { + std::sort(candidates.begin(), candidates.end(), comp); } else { - std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), - [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), comp); } candidates_p->sorted = true; } - candidates_p->size = std::min(k, (int) candidates.size()); + candidates_p->size = k; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t min_keep) { +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates_p, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax(candidates_p); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates_p); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -1548,15 +1563,21 @@ void llama_sample_top_p(llama_token_data_array * candidates_p, float p, size_t m // Resize the output vector to keep only the top-p tokens candidates_p->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } // https://www.trentonbricken.com/Tail-Free-Sampling/ -void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size_t min_keep) { +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates_p, float z, size_t min_keep) { if (z >= 1.0f || candidates_p->size <= 2) { return; } - llama_sample_softmax(candidates_p); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates_p); // Compute the first and second derivatives std::vector first_derivatives(candidates_p->size - 1); @@ -1594,18 +1615,23 @@ void llama_sample_tail_free(llama_token_data_array * candidates_p, float z, size // Resize the output vector to keep only the tokens above the tail location candidates_p->size = last_idx; -} + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} // https://arxiv.org/pdf/2202.00666.pdf // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr -void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { if (typical_p >= 1.0f) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + // Compute the softmax of logits and calculate entropy - llama_sample_softmax(candidates_p); + llama_sample_softmax(nullptr, candidates_p); std::span candidates(candidates_p->data, candidates_p->size); @@ -1654,21 +1680,32 @@ void llama_sample_typical(llama_token_data_array * candidates_p, float typical_p // Replace the data in candidates_p with the new_candidates data std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data); candidates_p->size = new_candidates.size(); + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); -void llama_sample_temperature(llama_token_data_array * candidates_p, float temp) { std::span candidates(candidates_p->data, candidates_p->size); for (auto & candidate : candidates) { candidate.logit /= temp; } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + // CTRL paper: https://arxiv.org/pdf/1909.05858.pdf std::span candidates(candidates_p->data, candidates_p->size); std::span last_tokens(last_tokens_p, last_tokens_size); @@ -1695,13 +1732,19 @@ void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llam } candidates_p->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } + const int64_t t_start_sample_us = ggml_time_us(); + std::span candidates(candidates_p->data, candidates_p->size); std::span last_tokens(last_tokens_p, last_tokens_size); @@ -1723,6 +1766,10 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand } candidates_p->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } /// @brief Mirostat 1.0 implementation. @@ -1733,20 +1780,26 @@ void llama_sample_frequency_and_presence_penalties(llama_token_data_array * cand /// @param N The size of the vocabulary. This is used in the calculation of the `k` value. /// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. /// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { + assert(ctx); + + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + // https://arxiv.org/abs/2007.14966 + // Algorithm 1 std::span candidates(candidates_p->data, candidates_p->size); // printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu); - llama_sample_softmax(candidates_p); + llama_sample_softmax(nullptr, candidates_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (int i = 0; i < m - 1 && i < candidates.size() - 1; ++i) { - float t_i = logf((i + 2) / float(i + 1)); + for (size_t i = 0; i < size_t(m - 1) && i < candidates.size() - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); float b_i = logf(candidates[i].p / candidates[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; @@ -1757,15 +1810,20 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a float epsilon_hat = s_hat - 1; // printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N); float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - *k = std::min(new_k, float(candidates.size())); + // printf("llama_sample_mirostat: new_k = %f\n", new_k); + *k = int(std::min(new_k, float(candidates.size()))); // Sample the next word X using top-k sampling // printf("llama_sample_mirostat *k = %d\n", *k); - llama_sample_top_k(candidates_p, *k); + llama_sample_top_k(nullptr, candidates_p, *k); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } llama_token X = llama_sample_token(ctx, candidates_p); + t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates[X_idx].p); @@ -1774,13 +1832,23 @@ llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_a // Update mu using the learning rate and error *mu = *mu - eta * e; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } return X; } -llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { + assert(ctx); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + // https://arxiv.org/abs/2007.14966 + // Algorithm 2 std::span candidates(candidates_p->data, candidates_p->size); - llama_sample_softmax(candidates_p); + llama_sample_softmax(ctx, candidates_p); // Truncate the words with surprise values greater than mu candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { @@ -1788,13 +1856,17 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat })); // Normalize the probabilities of the remaining words - llama_sample_softmax(candidates_p); + llama_sample_softmax(ctx, candidates_p); // Sample the next word X from the remaining words + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } llama_token X = llama_sample_token(ctx, candidates_p); + t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - int X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { return candidate.id == X; })); float observed_surprise = -log2f(candidates[X_idx].p); @@ -1803,11 +1875,15 @@ llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_dat // Update mu using the learning rate and error *mu = *mu - eta * e; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } return X; } - llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) { + const int64_t t_start_sample_us = ggml_time_us(); + // Find max element std::span candidates(candidates_p->data, candidates_p->size); auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { @@ -1815,14 +1891,17 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da }); llama_token result = max_iter->id; - ctx->n_sample++; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } return result; } - llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) { - // const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(candidates_p); + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates_p); std::span candidates(candidates_p->data, candidates_p->size); @@ -1838,9 +1917,8 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra llama_token result = candidates[idx].id; - // ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; - return result; } diff --git a/llama.h b/llama.h index 129574eed0921..4f72c273c48d2 100644 --- a/llama.h +++ b/llama.h @@ -187,20 +187,20 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // Sampling functions - LLAMA_API void llama_sample_repetition_penalty(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); - LLAMA_API void llama_sample_frequency_and_presence_penalties(llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); - - LLAMA_API void llama_sample_softmax(llama_token_data_array * candidates); - LLAMA_API void llama_sample_top_k(llama_token_data_array * candidates, int k); - LLAMA_API void llama_sample_top_p(llama_token_data_array * candidates, float p, size_t min_keep = 1); - LLAMA_API void llama_sample_tail_free(llama_token_data_array * candidates, float z, size_t min_keep = 1); - LLAMA_API void llama_sample_typical(llama_token_data_array * candidates, float p, size_t min_keep = 1); - LLAMA_API void llama_sample_temperature(llama_token_data_array * candidates, float temp); - + LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + + LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); + LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + + LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); - LLAMA_API llama_token llama_sample_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); - LLAMA_API llama_token llama_sample_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9bc5ea036c812..645648585ab3d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,8 +3,6 @@ function(llama_add_test source) add_executable(${TEST_TARGET} ${source}) target_link_libraries(${TEST_TARGET} PRIVATE llama) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) - target_compile_options(${TEST_TARGET} PRIVATE -fsanitize=address) - target_link_options(${TEST_TARGET} PRIVATE -fsanitize=address) endfunction() # llama_add_test(test-double-float.c) # SLOW diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp new file mode 100644 index 0000000000000..0a23c80c5c342 --- /dev/null +++ b/tests/test-sampling.cpp @@ -0,0 +1,229 @@ +#include "ggml.h" +#include "llama.h" +#include +#include +#include +#include +#include +#include +#include + +void dump(const llama_token_data_array * candidates) { + for (size_t i = 0; i < candidates->size; i++) { + printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); + } +} + +#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) + + +void test_top_k(const std::vector & probs, + const std::vector & expected_probs, + int k) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + llama_sample_top_k(nullptr, &candidates_p, k); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + } +} + + +void test_top_p(const std::vector & probs, + const std::vector & expected_probs, + float p) { + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_top_p(nullptr, &candidates_p, p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + } +} + + +void test_tfs(const std::vector & probs, + const std::vector & expected_probs, + float z) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_tail_free(nullptr, &candidates_p, z); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_typical(const std::vector & probs, + const std::vector & expected_probs, + float p) { + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + // DUMP(&candidates_p); + llama_sample_typical(nullptr, &candidates_p, p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_repetition_penalty( + const std::vector & probs, + const std::vector & last_tokens, + const std::vector & expected_probs, + float penalty) { + assert(probs.size() == expected_probs.size()); + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty); + llama_sample_softmax(nullptr, &candidates_p); + DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_frequency_presence_penalty( + const std::vector & probs, + const std::vector & last_tokens, + const std::vector & expected_probs, + float alpha_frequency, float alpha_presence) { + assert(probs.size() == expected_probs.size()); + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence); + llama_sample_softmax(nullptr, &candidates_p); + // DUMP(&candidates_p); + + assert(candidates_p.size == expected_probs.size()); + for (size_t i = 0; i < candidates_p.size; i++) { + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + } +} + + +void test_mirostat() { + std::vector probs = {0.1, 0.2, 0.3, 0.4}; + std::vector expected_probs = {0.1, 0.2, 0.3, 0.4}; + + size_t n_vocab = probs.size(); + std::vector candidates; + candidates.reserve(n_vocab); + for (int i = 0; i < n_vocab; i++) { + float logit = log(probs[i]); + candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + DUMP(&candidates_p); + + float tau = 5.0f; + float mu = 2.0f * tau; + int k = 0; + float eta = 0.1f; + int m = 100; + // float N = 32000; + float N = 4; + // llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); + DUMP(&candidates_p); + + // assert(candidates_p.size == expected_probs.size()); + // for (size_t i = 0; i < candidates_p.size; i++) { + // assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + // } +} + +int main(void) { + test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); + test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); + + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0); + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7); + test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1); + + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25); + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75); + test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99); + + test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); + test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); + + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0); + + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); + test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); + + // test_mirostat(); + + printf("OK\n"); +} From 61f822f63b7a6add79340aaef0ab9f073cddc0f6 Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 03:12:49 +0300 Subject: [PATCH 3/8] Added --logit-bias and --no-penalize-nl, removed std::span --- Makefile | 2 +- examples/common.cpp | 57 ++++++++--- examples/common.h | 26 ++--- examples/main/main.cpp | 21 ++-- llama.cpp | 220 +++++++++++++++++------------------------ llama.h | 24 ++++- 6 files changed, 185 insertions(+), 165 deletions(-) diff --git a/Makefile b/Makefile index b4af18c0e9b82..0715e857bc346 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ endif # keep standard at C11 and C++11 CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++20 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = # warnings diff --git a/examples/common.cpp b/examples/common.cpp index 7e62be356b1cf..a4938b4846136 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #if defined (_WIN32) #include @@ -138,18 +140,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.repeat_penalty = std::stof(argv[i]); - } else if (arg == "--alpha_frequency") { + } else if (arg == "--frequency_penalty") { if (++i >= argc) { invalid_param = true; break; } - params.alpha_frequency = std::stof(argv[i]); - } else if (arg == "--alpha_presence") { + params.frequency_penalty = std::stof(argv[i]); + } else if (arg == "--presence_penalty") { if (++i >= argc) { invalid_param = true; break; } - params.alpha_presence = std::stof(argv[i]); + params.presence_penalty = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; @@ -227,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { - params.ignore_eos = true; + params.logit_bias[llama_token_eos()] = -INFINITY; + } else if (arg == "--no-penalize-nl") { + params.penalize_nl = false; + } else if (arg == "-l" || arg == "--logit-bias") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::stringstream ss(argv[i]); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) { + params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } else { + throw std::exception(); + } + } catch (const std::exception &e) { + invalid_param = true; + break; + } } else if (arg == "--n_parts") { if (++i >= argc) { invalid_param = true; @@ -282,19 +305,23 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); - fprintf(stderr, " --top_k N top-k sampling (default: %d, disabled: 0)\n", params.top_k); - fprintf(stderr, " --top_p N top-p sampling (default: %.1f, disabled: 1.0)\n", (double)params.top_p); - fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, disabled: 1.0)\n", (double)params.tfs_z); - fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, disabled: 1.0)\n", (double)params.typical_p); - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, disabled: 0)\n", params.repeat_last_n); - fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, disabled: 1.0)\n", (double)params.repeat_penalty); - fprintf(stderr, " --alpha_presence N repeat alpha presence (default: %.1f, disabled: 0.0)\n", (double)params.alpha_presence); - fprintf(stderr, " --alpha_frequency N repeat alpha frequency (default: %.1f, disabled: 0.0)\n", (double)params.alpha_frequency); - fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, disabled: 0, mirostat: 1, mirostat 2.0: 2)\n", params.mirostat); + fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); + fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); + fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); + fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); + fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); + fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat); fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta); fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau); + fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS"); + fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); + fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n"); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); - fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n"); + fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); diff --git a/examples/common.h b/examples/common.h index de25e6435397e..14e6b1ba7c113 100644 --- a/examples/common.h +++ b/examples/common.h @@ -8,6 +8,7 @@ #include #include #include +#include // // CLI argument parsing @@ -23,18 +24,19 @@ struct gpt_params { int32_t n_keep = 0; // number of tokens to keep from initial prompt // sampling parameters - int32_t top_k = 0; // <= 0 to use vocab size - float top_p = 1.0f; // 1.0 = disabled - float tfs_z = 1.0f; // 1.0 = disabled - float typical_p = 1.0f; // 1.0 = disabled - float temp = 1.0f; // 1.0 = disabled + std::unordered_map logit_bias; // logit bias for specific tokens + int32_t top_k = 0; // <= 0 to use vocab size + float top_p = 1.0f; // 1.0 = disabled + float tfs_z = 1.0f; // 1.0 = disabled + float typical_p = 1.0f; // 1.0 = disabled + float temp = 1.0f; // 1.0 = disabled float repeat_penalty = 1.0f; // 1.0 = disabled - int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float alpha_frequency = 0.0f; // 0.0 = disabled - float alpha_presence = 0.0f; // 0.0 = disabled - int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.0f; // target entropy - float mirostat_eta = 0.1f; // learning rate + int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.0f; // 0.0 = disabled + float presence_penalty = 0.0f; // 0.0 = disabled + int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.0f; // target entropy + float mirostat_eta = 0.1f; // learning rate std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; @@ -54,7 +56,7 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool instruct = false; // instruction mode (used for Alpaca models) - bool ignore_eos = false; // do not stop generating after eos + bool penalize_nl = true; // consider newlines as a repeatable token bool perplexity = false; // compute perplexity over the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a6de98fedfc61..da974005705ee 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,8 +276,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", - params.repeat_last_n, params.repeat_penalty, params.alpha_presence, params.alpha_frequency, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", + params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -394,11 +394,12 @@ int main(int argc, char ** argv) { const float typical_p = params.typical_p; const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.alpha_presence; - const float alpha_frequency = params.alpha_frequency; - const int mirostat = params.mirostat; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session) { @@ -412,8 +413,9 @@ int main(int argc, char ** argv) { auto logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); - if (params.ignore_eos) { - logits[llama_token_eos()] = -INFINITY; + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; } std::vector candidates; @@ -425,6 +427,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // Apply penalties + float nl_logit = logits[llama_token_nl()]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -432,7 +435,9 @@ int main(int argc, char ** argv) { llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); - + if (!penalize_nl) { + logits[llama_token_nl()] = nl_logit; + } if (temp <= 0) { // Greedy sampling diff --git a/llama.cpp b/llama.cpp index 4da4df1f2c09e..2ec6d894a810d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1484,26 +1483,23 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c const int64_t t_start_sample_us = ggml_time_us(); - std::span tokens(candidates->data, candidates->size); - // Sort the logits in descending order if (!candidates->sorted) { - std::sort(tokens.begin(), tokens.end(), [](const llama_token_data & a, const llama_token_data & b) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }); candidates->sorted = true; } - float max_l = tokens[0].logit; + float max_l = candidates->data[0].logit; float cum_sum = 0.0f; - for (size_t i = 0; i < tokens.size(); ++i) { - // printf("llama_sample_softmax: i: %d, logit: %f\n", i, tokens[i].logit); - float p = expf(tokens[i].logit - max_l); - tokens[i].p = p; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; cum_sum += p; } - for (size_t i = 0; i < tokens.size(); ++i) { - tokens[i].p /= cum_sum; + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; } if (ctx) { @@ -1511,48 +1507,46 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c } } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates_p, int k, size_t min_keep) { +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { const int64_t t_start_sample_us = ggml_time_us(); k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates_p->size); - - std::span candidates(candidates_p->data, candidates_p->size); + k = std::min(k, (int) candidates->size); // Sort scores in descending order - if (!candidates_p->sorted) { + if (!candidates->sorted) { auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - if (k == (int) candidates_p->size) { - std::sort(candidates.begin(), candidates.end(), comp); + if (k == (int) candidates->size) { + std::sort(candidates->data, candidates->data + candidates->size, comp); } else { - std::partial_sort(candidates.begin(), candidates.begin() + k, candidates.end(), comp); + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } - candidates_p->sorted = true; + candidates->sorted = true; } - candidates_p->size = k; + candidates->size = k; if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates_p, float p, size_t min_keep) { +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(ctx, candidates_p); + llama_sample_softmax(ctx, candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; - size_t last_idx = candidates_p->size; + size_t last_idx = candidates->size; - for (size_t i = 0; i < candidates_p->size; ++i) { - cum_sum += candidates_p->data[i].p; + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; // Check if the running sum is greater than p or if we have kept at least min_keep tokens if (cum_sum > p && i >= min_keep) { @@ -1562,29 +1556,28 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can } // Resize the output vector to keep only the top-p tokens - candidates_p->size = last_idx; + candidates->size = last_idx; if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -// https://www.trentonbricken.com/Tail-Free-Sampling/ -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates_p, float z, size_t min_keep) { - if (z >= 1.0f || candidates_p->size <= 2) { +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { return; } const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates_p); + llama_sample_softmax(nullptr, candidates); // Compute the first and second derivatives - std::vector first_derivatives(candidates_p->size - 1); - std::vector second_derivatives(candidates_p->size - 2); + std::vector first_derivatives(candidates->size - 1); + std::vector second_derivatives(candidates->size - 2); for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates_p->data[i].p - candidates_p->data[i + 1].p; + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; } for (size_t i = 0; i < second_derivatives.size(); ++i) { second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; @@ -1602,7 +1595,7 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * } float cum_sum = 0.0f; - size_t last_idx = candidates_p->size; + size_t last_idx = candidates->size; for (size_t i = 0; i < second_derivatives.size(); ++i) { cum_sum += second_derivatives[i]; @@ -1614,41 +1607,40 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * } // Resize the output vector to keep only the tokens above the tail location - candidates_p->size = last_idx; + candidates->size = last_idx; if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -// https://arxiv.org/pdf/2202.00666.pdf -// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates_p, float typical_p, size_t min_keep) { - if (typical_p >= 1.0f) { + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { return; } const int64_t t_start_sample_us = ggml_time_us(); // Compute the softmax of logits and calculate entropy - llama_sample_softmax(nullptr, candidates_p); - - std::span candidates(candidates_p->data, candidates_p->size); + llama_sample_softmax(nullptr, candidates); float entropy = 0.0f; - for (const auto & candidate : candidates) { - entropy += -candidate.p * logf(candidate.p); + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); } // Compute the absolute difference between negative log probability and entropy for each candidate std::vector shifted_scores; - for (const auto & candidate : candidates) { - float shifted_score = fabsf(-logf(candidate.p) - entropy); + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); shifted_scores.push_back(shifted_score); } - // Sort candidates based on the shifted_scores and their corresponding indices - std::vector indices(candidates.size()); + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(candidates->size); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { @@ -1661,10 +1653,10 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < indices.size(); ++i) { size_t idx = indices[i]; - cum_sum += candidates[idx].p; + cum_sum += candidates->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > typical_p && i >= min_keep - 1) { + if (cum_sum > p && i >= min_keep - 1) { last_idx = i + 1; break; } @@ -1674,12 +1666,12 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c std::vector new_candidates; for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; - new_candidates.push_back(candidates[idx]); + new_candidates.push_back(candidates->data[idx]); } - // Replace the data in candidates_p with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates_p->data); - candidates_p->size = new_candidates.size(); + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -1689,9 +1681,8 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { const int64_t t_start_sample_us = ggml_time_us(); - std::span candidates(candidates_p->data, candidates_p->size); - for (auto & candidate : candidates) { - candidate.logit /= temp; + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= temp; } if (ctx) { @@ -1699,29 +1690,25 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array } } -void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty) { +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; } const int64_t t_start_sample_us = ggml_time_us(); - // CTRL paper: https://arxiv.org/pdf/1909.05858.pdf - std::span candidates(candidates_p->data, candidates_p->size); - std::span last_tokens(last_tokens_p, last_tokens_size); - - for (size_t i = 0; i < candidates.size(); ++i) { - auto token_iter = std::find(last_tokens.begin(), last_tokens.end(), candidates[i].id); - if (token_iter == last_tokens.end()) { + for (size_t i = 0; i < candidates->size; ++i) { + auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); + if (token_iter == last_tokens + last_tokens_size) { continue; } // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates[i].logit <= 0) { - candidates[i].logit *= penalty; + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty; } else { - candidates[i].logit /= penalty; + candidates->data[i].logit /= penalty; } // But it does not penalize tokens that logits are near zero, which is a problem. @@ -1731,76 +1718,60 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat // candidates[i].logit = std::log(probability); } - candidates_p->sorted = false; + candidates->sorted = false; if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } const int64_t t_start_sample_us = ggml_time_us(); - std::span candidates(candidates_p->data, candidates_p->size); - std::span last_tokens(last_tokens_p, last_tokens_size); - // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; - for (const auto & token : last_tokens) { - token_count[token]++; + for (size_t i = 0; i < last_tokens_size; ++i) { + token_count[last_tokens_p[i]]++; } // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates.size(); ++i) { - auto token_iter = token_count.find(candidates[i].id); + for (size_t i = 0; i < candidates->size; ++i) { + auto token_iter = token_count.find(candidates->data[i].id); if (token_iter == token_count.end()) { continue; } int count = token_iter->second; - candidates[i].logit -= count * alpha_frequency + float(count > 0) * alpha_presence; + candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; } - candidates_p->sorted = false; + candidates->sorted = false; if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } -/// @brief Mirostat 1.0 implementation. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param N The size of the vocabulary. This is used in the calculation of the `k` value. -/// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. -/// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, int m, float N, int * k, float * mu) { + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu) { assert(ctx); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); - // https://arxiv.org/abs/2007.14966 - // Algorithm 1 - std::span candidates(candidates_p->data, candidates_p->size); - - // printf("llama_sample_mirostat: candidates.size() = %d, m = %d, N = %f, tau = %f, eta = %f, *k = %d, *mu = %f\n", candidates.size(), m, N, tau, eta, *k, *mu); - - llama_sample_softmax(nullptr, candidates_p); + llama_sample_softmax(nullptr, candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates.size() - 1; ++i) { + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates[i].p / candidates[i + 1].p); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; } @@ -1808,25 +1779,23 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - // printf("llama_sample_mirostat: s_hat = %f, epsilon_hat = %f, *mu = %f, N = %f\n", s_hat, epsilon_hat, *mu, N); float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - // printf("llama_sample_mirostat: new_k = %f\n", new_k); - *k = int(std::min(new_k, float(candidates.size()))); + *k = int(std::min(new_k, float(candidates->size))); // Sample the next word X using top-k sampling // printf("llama_sample_mirostat *k = %d\n", *k); - llama_sample_top_k(nullptr, candidates_p, *k); + llama_sample_top_k(nullptr, candidates, *k); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } - llama_token X = llama_sample_token(ctx, candidates_p); + llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); - float observed_surprise = -log2f(candidates[X_idx].p); + float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error @@ -1839,37 +1808,33 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ return X; } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates_p, float tau, float eta, float * mu) { +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { assert(ctx); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); - // https://arxiv.org/abs/2007.14966 - // Algorithm 2 - std::span candidates(candidates_p->data, candidates_p->size); - - llama_sample_softmax(ctx, candidates_p); + llama_sample_softmax(ctx, candidates); // Truncate the words with surprise values greater than mu - candidates_p->size = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return -log2f(candidate.p) > *mu; })); // Normalize the probabilities of the remaining words - llama_sample_softmax(ctx, candidates_p); + llama_sample_softmax(ctx, candidates); // Sample the next word X from the remaining words if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } - llama_token X = llama_sample_token(ctx, candidates_p); + llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates.begin(), std::find_if(candidates.begin(), candidates.end(), [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); - float observed_surprise = -log2f(candidates[X_idx].p); + float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error @@ -1881,12 +1846,11 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok return X; } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates_p) { +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { const int64_t t_start_sample_us = ggml_time_us(); // Find max element - std::span candidates(candidates_p->data, candidates_p->size); - auto max_iter = std::max_element(candidates.begin(), candidates.end(), [](const llama_token_data & a, const llama_token_data & b) { + auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); @@ -1898,24 +1862,22 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates_p) { +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates_p); - - std::span candidates(candidates_p->data, candidates_p->size); + llama_sample_softmax(nullptr, candidates); std::vector probs; - probs.reserve(candidates.size()); - for (auto & candidate : candidates) { - probs.push_back(candidate.p); + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); auto & rng = ctx->rng; int idx = dist(rng); - llama_token result = candidates[idx].id; + llama_token result = candidates->data[idx].id; ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->n_sample++; @@ -2691,6 +2653,10 @@ llama_token llama_token_eos() { return 2; } +llama_token llama_token_nl() { + return 13; +} + void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 4f72c273c48d2..5f61971ceb949 100644 --- a/llama.h +++ b/llama.h @@ -185,18 +185,38 @@ extern "C" { // Special tokens LLAMA_API llama_token llama_token_bos(); LLAMA_API llama_token llama_token_eos(); + LLAMA_API llama_token llama_token_nl(); // Sampling functions - LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float penalty); - LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates_p, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + + /// @brief Repetition penalty + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/pdf/1909.05858.pdf with negative logit fix + LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty); + /// @brief Frequency and presence repetition penalties + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details + LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); + + /// @brief Tail Free Sampling https://www.trentonbricken.com/Tail-Free-Sampling/ LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); + + /// @brief Locally Typical Sampling https://arxiv.org/pdf/2202.00666.pdf LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @brief Mirostat implementation. + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param ctx The llama context. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param N The size of the vocabulary. This is used in the calculation of the `k` value. + /// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. + /// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); From 6c4c88d54fd725c2719b18a1e7879b1d17c1e415 Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 19:53:24 +0300 Subject: [PATCH 4/8] Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) --- CMakeLists.txt | 2 +- examples/common.cpp | 23 +++++++----- examples/main/main.cpp | 11 +++--- llama.cpp | 16 ++------ llama.h | 34 +++++++++++------ tests/test-sampling.cpp | 83 ++++++++++++----------------------------- 6 files changed, 70 insertions(+), 99 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d7c9d1ed35ef..5fdbeddfca443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) # Compile flags # -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) diff --git a/examples/common.cpp b/examples/common.cpp index a4938b4846136..6c712c713db9b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -158,13 +158,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.mirostat = std::stoi(argv[i]); - } else if (arg == "--mirostat_eta") { + } else if (arg == "--mirostat_lr") { if (++i >= argc) { invalid_param = true; break; } params.mirostat_eta = std::stof(argv[i]); - } else if (arg == "--mirostat_tau") { + } else if (arg == "--mirostat_ent") { if (++i >= argc) { invalid_param = true; break; @@ -242,7 +242,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { char sign; std::string value_str; try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-' || sign == '=' || sign == ':')) { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); } else { throw std::exception(); @@ -309,18 +309,21 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); - fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); - fprintf(stderr, " --mirostat N use mirostat sampling (default: %d, 0 = disabled, 1 = mirostat, 2 = mirostat 2.0)\n", params.mirostat); - fprintf(stderr, " --mirostat_eta N mirostat learning rate (default: %.1f)\n", (double)params.mirostat_eta); - fprintf(stderr, " --mirostat_tau N mirostat target entropy (default: %.1f)\n", (double)params.mirostat_tau); - fprintf(stderr, " -l TOKEN+BIAS, --logit-bias TOKEN+BIAS"); + fprintf(stderr, " --mirostat N use Mirostat sampling.\n"); + fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); + fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); + fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); + fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); + fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); - fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello'\n"); + fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); + fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); - fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2+-inf)\n"); + fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index da974005705ee..674920b8a04c5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -276,7 +276,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_eta = %f, mirostat_tau = %f\n", + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -420,8 +420,8 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(n_vocab); - for (size_t i = 0; i < (size_t) n_vocab; i++) { - candidates.emplace_back(i, logits[i], 0.0f); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -445,11 +445,12 @@ int main(int argc, char ** argv) { } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; - static int mirostat_k = 40; const int mirostat_m = 100; - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, float(n_vocab), &mirostat_k, &mirostat_mu); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling diff --git a/llama.cpp b/llama.cpp index 2ec6d894a810d..5645a22e2e940 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1710,12 +1710,6 @@ void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_dat } else { candidates->data[i].logit /= penalty; } - - // But it does not penalize tokens that logits are near zero, which is a problem. - // Another solution is to convert the logits to probabilities, apply the penalty, and then convert back to logits. - // float probability = std::exp(candidates[i].logit); - // probability /= penalty; - // candidates[i].logit = std::log(probability); } candidates->sorted = false; @@ -1757,9 +1751,9 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu) { +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); - + auto N = float(llama_n_vocab(ctx)); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -1779,12 +1773,10 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float new_k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - *k = int(std::min(new_k, float(candidates->size))); + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - // printf("llama_sample_mirostat *k = %d\n", *k); - llama_sample_top_k(nullptr, candidates, *k); + llama_sample_top_k(nullptr, candidates, int(k)); if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/llama.h b/llama.h index 5f61971ceb949..fccce707e7df1 100644 --- a/llama.h +++ b/llama.h @@ -189,37 +189,47 @@ extern "C" { // Sampling functions - /// @brief Repetition penalty - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/pdf/1909.05858.pdf with negative logit fix + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty); - /// @brief Frequency and presence repetition penalties - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details + + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); - /// @brief Tail Free Sampling https://www.trentonbricken.com/Tail-Free-Sampling/ + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1); - /// @brief Locally Typical Sampling https://arxiv.org/pdf/2202.00666.pdf + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); - /// @brief Mirostat implementation. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param ctx The llama context. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param N The size of the vocabulary. This is used in the calculation of the `k` value. - /// @param k A reference to the integer variable used to store the calculated top-k value. The top-k value determines how many of the most probable tokens are considered for sampling. - /// @param mu A reference to the floating-point variable that represents the maximum cross-entropy value. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float N, int * k, float * mu); + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + + /// @details Selects the token with the highest probability. LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); // Performance information diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 0a23c80c5c342..3f3d5d174e83f 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #include "llama.h" #include #include @@ -23,12 +22,12 @@ void test_top_k(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_top_k(nullptr, &candidates_p, k); @@ -48,12 +47,12 @@ void test_top_p(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_top_p(nullptr, &candidates_p, p); // DUMP(&candidates_p); @@ -71,12 +70,12 @@ void test_tfs(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_tail_free(nullptr, &candidates_p, z); // DUMP(&candidates_p); @@ -94,12 +93,12 @@ void test_typical(const std::vector & probs, size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // DUMP(&candidates_p); llama_sample_typical(nullptr, &candidates_p, p); // DUMP(&candidates_p); @@ -121,12 +120,12 @@ void test_repetition_penalty( size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); llama_sample_repetition_penalty(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), penalty); @@ -150,12 +149,12 @@ void test_frequency_presence_penalty( size_t n_vocab = probs.size(); std::vector candidates; candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + float logit = log(probs[token_id]); + candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (llama_token *)last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence); @@ -168,38 +167,6 @@ void test_frequency_presence_penalty( } } - -void test_mirostat() { - std::vector probs = {0.1, 0.2, 0.3, 0.4}; - std::vector expected_probs = {0.1, 0.2, 0.3, 0.4}; - - size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); - for (int i = 0; i < n_vocab; i++) { - float logit = log(probs[i]); - candidates.emplace_back(llama_token_data{i, logit, 0.0f}); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size() }; - DUMP(&candidates_p); - - float tau = 5.0f; - float mu = 2.0f * tau; - int k = 0; - float eta = 0.1f; - int m = 100; - // float N = 32000; - float N = 4; - // llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu); - DUMP(&candidates_p); - - // assert(candidates_p.size == expected_probs.size()); - // for (size_t i = 0; i < candidates_p.size; i++) { - // assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); - // } -} - int main(void) { test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); @@ -223,7 +190,5 @@ int main(void) { test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); - // test_mirostat(); - printf("OK\n"); } From 416f49182ab9883c26ee26e3d9ec6eaf1a1fb0fd Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 20:19:17 +0300 Subject: [PATCH 5/8] Save and load example adjust --- examples/save-load-state/save-load-state.cpp | 34 +++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39aa7f82cae5c..07dfa2c74ed07 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -64,14 +64,15 @@ int main(int argc, char ** argv) { // first run printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sample_top_p_top_k( - ctx, - &last_n_tokens_data.back() - params.repeat_last_n, - params.repeat_last_n, - 40, - 1.0, - 1.0, - 1.1); + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx, &candidates_p); auto next_token_str = llama_token_to_str(ctx, next_token); last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); @@ -106,14 +107,15 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sample_top_p_top_k( - ctx2, - &last_n_tokens_data.back() - params.repeat_last_n, - params.repeat_last_n, - 40, - 1.0, - 1.0, - 1.1); + auto logits = llama_get_logits(ctx2); + auto n_vocab = llama_n_vocab(ctx2); + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx2, &candidates_p); auto next_token_str = llama_token_to_str(ctx2, next_token); last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); From 3bf3a968b6f70013ea94163243080202ddd9c66e Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 20:36:53 +0300 Subject: [PATCH 6/8] Tests --- tests/test-sampling.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 3f3d5d174e83f..c89b569fe3f7f 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -7,6 +7,9 @@ #include #include +#undef assert +#define assert(__expr) do { if (!(__expr)) { printf("%s:%d (%s) %s\n", __FILE__, __LINE__, __func__, #__expr); exit(1); } } while(0) + void dump(const llama_token_data_array * candidates) { for (size_t i = 0; i < candidates->size; i++) { printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); @@ -53,13 +56,14 @@ void test_top_p(const std::vector & probs, } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sample_softmax(nullptr, &candidates_p); // DUMP(&candidates_p); llama_sample_top_p(nullptr, &candidates_p, p); // DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -82,7 +86,7 @@ void test_tfs(const std::vector & probs, assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -105,7 +109,7 @@ void test_typical(const std::vector & probs, assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -163,7 +167,7 @@ void test_frequency_presence_penalty( assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); + assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -182,9 +186,9 @@ int main(void) { test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0, 0.25, 0.25, 0.25, 0.25}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0, 0, 0, 0.5, 0.5}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5, 0.5}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0); + test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); From 4ab7bb77c045ac868aa34cff5708d6f86740094e Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 20:42:44 +0300 Subject: [PATCH 7/8] Windows build fix --- llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llama.cpp b/llama.cpp index 5645a22e2e940..4335772b53480 100644 --- a/llama.cpp +++ b/llama.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 From f571806da77ee034a07a0792a53eb91eb9b8dda8 Mon Sep 17 00:00:00 2001 From: Ivan Stepanov Date: Fri, 28 Apr 2023 22:12:25 +0300 Subject: [PATCH 8/8] Windows test fix --- tests/test-sampling.cpp | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index c89b569fe3f7f..7eee4f6d3a645 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,14 +1,13 @@ #include "llama.h" -#include -#include +#include "ggml.h" +#include +#include #include #include #include #include #include -#undef assert -#define assert(__expr) do { if (!(__expr)) { printf("%s:%d (%s) %s\n", __FILE__, __LINE__, __func__, #__expr); exit(1); } } while(0) void dump(const llama_token_data_array * candidates) { for (size_t i = 0; i < candidates->size; i++) { @@ -32,9 +31,9 @@ void test_top_k(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); - // DUMP(&candidates_p); + DUMP(&candidates_p); llama_sample_top_k(nullptr, &candidates_p, k); - // DUMP(&candidates_p); + DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { @@ -57,9 +56,9 @@ void test_top_p(const std::vector & probs, llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); - // DUMP(&candidates_p); + DUMP(&candidates_p); llama_sample_top_p(nullptr, &candidates_p, p); - // DUMP(&candidates_p); + DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { @@ -80,9 +79,9 @@ void test_tfs(const std::vector & probs, } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // DUMP(&candidates_p); + DUMP(&candidates_p); llama_sample_tail_free(nullptr, &candidates_p, z); - // DUMP(&candidates_p); + DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { @@ -103,9 +102,9 @@ void test_typical(const std::vector & probs, } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // DUMP(&candidates_p); + DUMP(&candidates_p); llama_sample_typical(nullptr, &candidates_p, p); - // DUMP(&candidates_p); + DUMP(&candidates_p); assert(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { @@ -172,6 +171,8 @@ void test_frequency_presence_penalty( } int main(void) { + ggml_time_init(); + test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3);