Skip to content

Commit 522a1d3

Browse files
ggerganovNexesenex
authored andcommitted
llama : fix kv shift bug (ggml-org#3835)
ggml-ci
1 parent ebca7f2 commit 522a1d3

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

Diff for: llama.cpp

+18-9
Original file line numberDiff line numberDiff line change
@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
15521552

15531553
for (uint32_t i = 0; i < cache.size; ++i) {
15541554
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1555-
cache.cells[i].pos += delta;
1555+
cache.has_shift = true;
1556+
cache.cells[i].pos += delta;
1557+
cache.cells[i].delta += delta;
1558+
15561559
if (cache.cells[i].pos < 0) {
15571560
cache.cells[i].pos = -1;
15581561
cache.cells[i].seq_id.clear();
15591562
if (new_head == cache.size) new_head = i;
1560-
} else {
1561-
cache.has_shift = true;
1562-
cache.cells[i].delta = delta;
15631563
}
15641564
}
15651565
}
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
60736073
#endif
60746074

60756075
// update the kv ring buffer
6076-
lctx.kv_self.has_shift = false;
6077-
lctx.kv_self.head += n_tokens;
6078-
// Ensure kv cache head points to a valid index.
6079-
if (lctx.kv_self.head >= lctx.kv_self.size) {
6080-
lctx.kv_self.head = 0;
6076+
{
6077+
if (kv_self.has_shift) {
6078+
kv_self.has_shift = false;
6079+
for (uint32_t i = 0; i < kv_self.size; ++i) {
6080+
kv_self.cells[i].delta = 0;
6081+
}
6082+
}
6083+
6084+
kv_self.head += n_tokens;
6085+
6086+
// Ensure kv cache head points to a valid index.
6087+
if (kv_self.head >= kv_self.size) {
6088+
kv_self.head = 0;
6089+
}
60816090
}
60826091

60836092
#ifdef GGML_PERF

0 commit comments

Comments
 (0)