Skip to content

llama : fix kv shift bug #3835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,14 +1554,14 @@ static void llama_kv_cache_seq_shift(

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].pos += delta;
cache.has_shift = true;
cache.cells[i].pos += delta;
cache.cells[i].delta += delta;

if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
}
}
}
Expand Down Expand Up @@ -6075,11 +6075,20 @@ static int llama_decode_internal(
#endif

// update the kv ring buffer
lctx.kv_self.has_shift = false;
lctx.kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
if (lctx.kv_self.head >= lctx.kv_self.size) {
lctx.kv_self.head = 0;
{
if (kv_self.has_shift) {
kv_self.has_shift = false;
for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].delta = 0;
}
}

kv_self.head += n_tokens;

// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}

#ifdef GGML_PERF
Expand Down