Skip to content

Commit fad8848

Browse files
committed
mamba : more correctly update the "used" field of the KV cache
1 parent 27d5dcf commit fad8848

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

llama.cpp

+17-14
Original file line numberDiff line numberDiff line change
@@ -2209,22 +2209,19 @@ static bool llama_kv_cache_find_slot(
22092209
// For recurrent state architectures (like Mamba),
22102210
// each KV cache cell can store the state for a whole sequence.
22112211

2212-
// starting point to find the minimum seq_id used in the batch
2213-
cache.head = cache.size - 1;
2214-
// likewise, to find the max seq_id in the batch
2215-
cache.used = 0;
2212+
llama_seq_id min = cache.size - 1;
2213+
llama_seq_id max = 0;
2214+
22162215
for (uint32_t i = 0; i < n_tokens; ++i) {
22172216
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
22182217
llama_seq_id seq_id = batch.seq_id[i][j];
22192218
// make sure it's a valid seq_id
2220-
if ((uint32_t)seq_id < cache.size) {
2221-
// the number of "used" cells is simply the biggest seq_id
2222-
if (cache.used < (uint32_t)seq_id) {
2223-
cache.used = seq_id;
2219+
if ((uint32_t) seq_id < cache.size) {
2220+
if (seq_id > max) {
2221+
max = seq_id;
22242222
}
2225-
// the "head" is the smallest seq_id
2226-
if (cache.head > (uint32_t)seq_id) {
2227-
cache.head = seq_id;
2223+
if (seq_id < min) {
2224+
min = seq_id;
22282225
}
22292226
// Assuming the tokens are in-order
22302227
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
@@ -2233,6 +2230,9 @@ static bool llama_kv_cache_find_slot(
22332230
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
22342231
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
22352232
}
2233+
if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
2234+
cache.used += 1;
2235+
}
22362236
cache.cells[seq_id].pos = batch.pos[i];
22372237
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
22382238
} else {
@@ -2244,9 +2244,12 @@ static bool llama_kv_cache_find_slot(
22442244
}
22452245
}
22462246

2247-
cache.n = cache.used - cache.head + 1;
2248-
// sanity check (max >= min)
2249-
return cache.used >= cache.head;
2247+
// allow getting the range of used cells, from head to head + n
2248+
cache.head = min;
2249+
cache.n = max - min + 1;
2250+
2251+
// sanity check
2252+
return max >= min;
22502253
}
22512254
// otherwise, one cell per token.
22522255

0 commit comments

Comments
 (0)