@@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
172
172
LOG (" out of drafted tokens\n " );
173
173
}
174
174
175
+ llama_kv_cache_rm_seq (ctx_dft, 0 , n_past_dft, n_ctx);
175
176
llama_decode (ctx_dft, llama_batch_get_one (&id, 1 , n_past_dft, 0 ), params.n_threads );
176
177
++n_past_dft;
177
178
@@ -217,6 +218,7 @@ int main(int argc, char ** argv) {
217
218
218
219
// sample n_draft tokens from the draft model using greedy decoding
219
220
int n_past_cur = n_past_dft;
221
+
220
222
for (int i = 0 ; i < n_draft; ++i) {
221
223
float * logits = llama_get_logits (ctx_dft);
222
224
@@ -256,6 +258,7 @@ int main(int argc, char ** argv) {
256
258
}
257
259
258
260
// evaluate the drafted token on the draft model
261
+ llama_kv_cache_rm_seq (ctx_dft, 0 , n_past_cur, n_ctx);
259
262
llama_decode (ctx_dft, llama_batch_get_one (&drafted.back (), 1 , n_past_cur, 0 ), params.n_threads );
260
263
++n_past_cur;
261
264
@@ -265,6 +268,7 @@ int main(int argc, char ** argv) {
265
268
}
266
269
267
270
// evaluate the target model on the drafted tokens
271
+ llama_kv_cache_rm_seq (ctx_tgt, 0 , n_past_tgt, n_ctx);
268
272
llama_decode (ctx_tgt, llama_batch_get_one (drafted.data (), drafted.size (), n_past_tgt, 0 ), params.n_threads );
269
273
++n_past_tgt;
270
274
0 commit comments