@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
1552
1552
1553
1553
for (uint32_t i = 0; i < cache.size; ++i) {
1554
1554
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1555
- cache.cells[i].pos += delta;
1555
+ cache.has_shift = true;
1556
+ cache.cells[i].pos += delta;
1557
+ cache.cells[i].delta += delta;
1558
+
1556
1559
if (cache.cells[i].pos < 0) {
1557
1560
cache.cells[i].pos = -1;
1558
1561
cache.cells[i].seq_id.clear();
1559
1562
if (new_head == cache.size) new_head = i;
1560
- } else {
1561
- cache.has_shift = true;
1562
- cache.cells[i].delta = delta;
1563
1563
}
1564
1564
}
1565
1565
}
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
6073
6073
#endif
6074
6074
6075
6075
// update the kv ring buffer
6076
- lctx.kv_self.has_shift = false;
6077
- lctx.kv_self.head += n_tokens;
6078
- // Ensure kv cache head points to a valid index.
6079
- if (lctx.kv_self.head >= lctx.kv_self.size) {
6080
- lctx.kv_self.head = 0;
6076
+ {
6077
+ if (kv_self.has_shift) {
6078
+ kv_self.has_shift = false;
6079
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
6080
+ kv_self.cells[i].delta = 0;
6081
+ }
6082
+ }
6083
+
6084
+ kv_self.head += n_tokens;
6085
+
6086
+ // Ensure kv cache head points to a valid index.
6087
+ if (kv_self.head >= kv_self.size) {
6088
+ kv_self.head = 0;
6089
+ }
6081
6090
}
6082
6091
6083
6092
#ifdef GGML_PERF
0 commit comments