Skip to content

Extend llama_kv_cache_seq_rm to allow matching any sequence #3843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par

std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_kv_cache_clear(lctx);
llama_reset_timings(lctx);
}

Expand Down
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ int main(int argc, char ** argv) {

const auto t_pp_start = ggml_time_us();

llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
Expand Down
4 changes: 2 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {

test t(inst, lmodel, ctx);

llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

// warmup run
if (t.n_prompt > 0) {
Expand All @@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
}

for (int i = 0; i < params.reps; i++) {
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

uint64_t t_start = get_time_ns();
if (t.n_prompt > 0) {
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
}

// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}

LOGLN(
Expand Down
6 changes: 3 additions & 3 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ struct llama_server_context

void kv_cache_clear() {
// clear the entire KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);
clean_kv_cache = false;
}

Expand Down
29 changes: 15 additions & 14 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,17 +1468,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
return 0;
}

static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
if (c0 < 0) c0 = 0;
if (c1 < 0) c1 = cache.size;

for (int32_t i = c0; i < c1; ++i) {
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (int32_t i = 0; i < cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}

// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
cache.head = 0;
}

static void llama_kv_cache_seq_rm(
Expand All @@ -1492,8 +1487,14 @@ static void llama_kv_cache_seq_rm(
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth the complexity, but we could optimize to:

static void llama_kv_cache_seq_rm(
        struct llama_kv_cache & cache,
                 llama_seq_id   seq_id,
                    llama_pos   p0,
                    llama_pos   p1) {
    uint32_t new_head = cache.size;

    if (p0 < 0) p0 = 0;
    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

    if (seq_id < 0) {
        if (p0 == 0 && p1 >= cache.size) {
            llama_kv_cache_clear(cache);
            return;
        }
        for (uint32_t i = 0; i < cache.size; ++i) {
            if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
                cache.cells[i].seq_id.clear();
                cache.cells[i].pos = -1;
                if (new_head == cache.size) new_head = i;
            }
        }
    } else {
        for (uint32_t i = 0; i < cache.size; ++i) {
            if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
                cache.cells[i].seq_id.erase(seq_id);
                if (cache.cells[i].seq_id.empty()) {
                    cache.cells[i].pos = -1;
                    if (new_head == cache.size) new_head = i;
                }
            }
        }
    }

    // If we freed up a slot, set head to it so searching can start there.
    if (new_head != cache.size) cache.head = new_head;
}

This avoids checking if seq_id < 0 each iteration, but a single int test probably wouldn't be noticeable even for huge KV caches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, either way would be fine. If you change it directly merge or you just merge as it is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you don't have a preference, I'll just leave it as is. I don't think it's worth the added complexity.

if (seq_id < 0) {
cache.cells[i].seq_id.clear();
} else if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
Expand Down Expand Up @@ -9204,8 +9205,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
}

void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}

void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
Expand Down Expand Up @@ -9651,7 +9652,7 @@ int llama_eval(
llama_token * tokens,
int32_t n_tokens,
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);

const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
if (ret < 0) {
Expand All @@ -9666,7 +9667,7 @@ int llama_eval_embd(
float * embd,
int32_t n_tokens,
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);

llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };

Expand Down
15 changes: 6 additions & 9 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,14 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");

// Remove all tokens data of cells in [c0, c1)
// c0 < 0 : [0, c1]
// c1 < 0 : [c0, inf)
LLAMA_API void llama_kv_cache_tokens_rm(
struct llama_context * ctx,
int32_t c0,
int32_t c1);
// Clear the KV cache
LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx);

// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
Expand Down