Skip to content

Commit d2002bc

Browse files
KerfuffleV2Nexesenex
authored andcommitted
Extend llama_kv_cache_seq_rm to allow matching any sequence (ggml-org#3843)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
1 parent aa7f70f commit d2002bc

File tree

8 files changed

+30
-32
lines changed

8 files changed

+30
-32
lines changed

Diff for: common/common.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
889889

890890
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
891891
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
892-
llama_kv_cache_tokens_rm(lctx, -1, -1);
892+
llama_kv_cache_clear(lctx);
893893
llama_reset_timings(lctx);
894894
}
895895

Diff for: examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
185185

186186
const auto t_pp_start = ggml_time_us();
187187

188-
llama_kv_cache_tokens_rm(ctx, -1, -1);
188+
llama_kv_cache_clear(ctx);
189189

190190
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
191191
LOG_TEE("%s: llama_decode() failed\n", __func__);

Diff for: examples/llama-bench/llama-bench.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {
10371037

10381038
test t(inst, lmodel, ctx);
10391039

1040-
llama_kv_cache_tokens_rm(ctx, -1, -1);
1040+
llama_kv_cache_clear(ctx);
10411041

10421042
// warmup run
10431043
if (t.n_prompt > 0) {
@@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
10481048
}
10491049

10501050
for (int i = 0; i < params.reps; i++) {
1051-
llama_kv_cache_tokens_rm(ctx, -1, -1);
1051+
llama_kv_cache_clear(ctx);
10521052

10531053
uint64_t t_start = get_time_ns();
10541054
if (t.n_prompt > 0) {

Diff for: examples/main/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
298298
}
299299

300300
// remove any "future" tokens that we might have inherited from the previous session
301-
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
301+
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
302302
}
303303

304304
LOGLN(

Diff for: examples/perplexity/perplexity.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
210210
const auto t_start = std::chrono::high_resolution_clock::now();
211211

212212
// clear the KV cache
213-
llama_kv_cache_tokens_rm(ctx, -1, -1);
213+
llama_kv_cache_clear(ctx);
214214

215215
for (int j = 0; j < num_batches; ++j) {
216216
const int batch_start = start + j * n_batch;
@@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
339339
const auto t_start = std::chrono::high_resolution_clock::now();
340340

341341
// clear the KV cache
342-
llama_kv_cache_tokens_rm(ctx, -1, -1);
342+
llama_kv_cache_clear(ctx);
343343

344344
for (int j = 0; j < num_batches; ++j) {
345345
const int batch_start = start + j * n_batch;
@@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
573573
}
574574

575575
// clear the KV cache
576-
llama_kv_cache_tokens_rm(ctx, -1, -1);
576+
llama_kv_cache_clear(ctx);
577577

578578
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
579579
if (logits.empty()) {

Diff for: examples/server/server.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ struct llama_server_context
857857

858858
void kv_cache_clear() {
859859
// clear the entire KV cache
860-
llama_kv_cache_tokens_rm(ctx, -1, -1);
860+
llama_kv_cache_clear(ctx);
861861
clean_kv_cache = false;
862862
}
863863

Diff for: llama.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
14661466
return 0;
14671467
}
14681468

1469-
static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
1470-
if (c0 < 0) c0 = 0;
1471-
if (c1 < 0) c1 = cache.size;
1472-
1473-
for (int32_t i = c0; i < c1; ++i) {
1469+
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
1470+
for (int32_t i = 0; i < cache.size; ++i) {
14741471
cache.cells[i].pos = -1;
14751472
cache.cells[i].seq_id.clear();
14761473
}
1477-
1478-
// Searching for a free slot can start here since we know it will be empty.
1479-
cache.head = uint32_t(c0);
1474+
cache.head = 0;
14801475
}
14811476

14821477
static void llama_kv_cache_seq_rm(
@@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
14901485
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
14911486

14921487
for (uint32_t i = 0; i < cache.size; ++i) {
1493-
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1494-
cache.cells[i].seq_id.erase(seq_id);
1488+
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1489+
if (seq_id < 0) {
1490+
cache.cells[i].seq_id.clear();
1491+
} else if (cache.cells[i].has_seq_id(seq_id)) {
1492+
cache.cells[i].seq_id.erase(seq_id);
1493+
} else {
1494+
continue;
1495+
}
14951496
if (cache.cells[i].seq_id.empty()) {
14961497
cache.cells[i].pos = -1;
14971498
if (new_head == cache.size) new_head = i;
@@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
92079208
return ctx->kv_self.head;
92089209
}
92099210

9210-
void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
9211-
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
9211+
void llama_kv_cache_clear(struct llama_context * ctx) {
9212+
llama_kv_cache_clear(ctx->kv_self);
92129213
}
92139214

92149215
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -9654,7 +9655,7 @@ int llama_eval(
96549655
llama_token * tokens,
96559656
int32_t n_tokens,
96569657
int n_past) {
9657-
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
9658+
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
96589659

96599660
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
96609661
if (ret < 0) {
@@ -9669,7 +9670,7 @@ int llama_eval_embd(
96699670
float * embd,
96709671
int32_t n_tokens,
96719672
int n_past) {
9672-
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
9673+
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
96739674

96749675
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
96759676

Diff for: llama.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -334,17 +334,14 @@ extern "C" {
334334
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
335335
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
336336

337-
// Remove all tokens data of cells in [c0, c1)
338-
// c0 < 0 : [0, c1]
339-
// c1 < 0 : [c0, inf)
340-
LLAMA_API void llama_kv_cache_tokens_rm(
341-
struct llama_context * ctx,
342-
int32_t c0,
343-
int32_t c1);
337+
// Clear the KV cache
338+
LLAMA_API void llama_kv_cache_clear(
339+
struct llama_context * ctx);
344340

345341
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
346-
// p0 < 0 : [0, p1]
347-
// p1 < 0 : [p0, inf)
342+
// seq_id < 0 : match any sequence
343+
// p0 < 0 : [0, p1]
344+
// p1 < 0 : [p0, inf)
348345
LLAMA_API void llama_kv_cache_seq_rm(
349346
struct llama_context * ctx,
350347
llama_seq_id seq_id,

0 commit comments

Comments
 (0)