Skip to content
This repository was archived by the owner on Feb 6, 2024. It is now read-only.

Commit d49648f

Browse files
committed
Fix kv shift bug
* ggml-org/llama.cpp#3835
1 parent 3323149 commit d49648f

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

Diff for: Sources/llmfarm_core_cpp/llama/llama.cpp

+17-8
Original file line numberDiff line numberDiff line change
@@ -1570,14 +1570,14 @@ static void llama_kv_cache_seq_shift(
15701570

15711571
for (uint32_t i = 0; i < cache.size; ++i) {
15721572
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1573+
cache.has_shift = true;
15731574
cache.cells[i].pos += delta;
1575+
cache.cells[i].delta += delta;
1576+
15741577
if (cache.cells[i].pos < 0) {
15751578
cache.cells[i].pos = -1;
15761579
cache.cells[i].seq_id.clear();
15771580
if (new_head == cache.size) new_head = i;
1578-
} else {
1579-
cache.has_shift = true;
1580-
cache.cells[i].delta = delta;
15811581
}
15821582
}
15831583
}
@@ -6320,11 +6320,20 @@ static int llama_decode_internal(
63206320
#endif
63216321

63226322
// update the kv ring buffer
6323-
lctx.kv_self.has_shift = false;
6324-
lctx.kv_self.head += n_tokens;
6325-
// Ensure kv cache head points to a valid index.
6326-
if (lctx.kv_self.head >= lctx.kv_self.size) {
6327-
lctx.kv_self.head = 0;
6323+
{
6324+
if (kv_self.has_shift) {
6325+
kv_self.has_shift = false;
6326+
for (uint32_t i = 0; i < kv_self.size; ++i) {
6327+
kv_self.cells[i].delta = 0;
6328+
}
6329+
}
6330+
6331+
kv_self.head += n_tokens;
6332+
6333+
// Ensure kv cache head points to a valid index.
6334+
if (kv_self.head >= kv_self.size) {
6335+
kv_self.head = 0;
6336+
}
63286337
}
63296338

63306339
#ifdef GGML_PERF

0 commit comments

Comments
 (0)