Skip to content

Commit 1f17ea6

Browse files
committed
speculative : fix KV cache management
1 parent 7c1bdd0 commit 1f17ea6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

examples/speculative/speculative.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
172172
LOG("out of drafted tokens\n");
173173
}
174174

175+
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
175176
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
176177
++n_past_dft;
177178

@@ -217,6 +218,7 @@ int main(int argc, char ** argv) {
217218

218219
// sample n_draft tokens from the draft model using greedy decoding
219220
int n_past_cur = n_past_dft;
221+
220222
for (int i = 0; i < n_draft; ++i) {
221223
float * logits = llama_get_logits(ctx_dft);
222224

@@ -256,6 +258,7 @@ int main(int argc, char ** argv) {
256258
}
257259

258260
// evaluate the drafted token on the draft model
261+
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
259262
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
260263
++n_past_cur;
261264

@@ -265,6 +268,7 @@ int main(int argc, char ** argv) {
265268
}
266269

267270
// evaluate the target model on the drafted tokens
271+
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
268272
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
269273
++n_past_tgt;
270274

0 commit comments

Comments
 (0)