@@ -2209,22 +2209,19 @@ static bool llama_kv_cache_find_slot(
2209
2209
// For recurrent state architectures (like Mamba),
2210
2210
// each KV cache cell can store the state for a whole sequence.
2211
2211
2212
- // starting point to find the minimum seq_id used in the batch
2213
- cache.head = cache.size - 1;
2214
- // likewise, to find the max seq_id in the batch
2215
- cache.used = 0;
2212
+ llama_seq_id min = cache.size - 1;
2213
+ llama_seq_id max = 0;
2214
+
2216
2215
for (uint32_t i = 0; i < n_tokens; ++i) {
2217
2216
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
2218
2217
llama_seq_id seq_id = batch.seq_id[i][j];
2219
2218
// make sure it's a valid seq_id
2220
- if ((uint32_t)seq_id < cache.size) {
2221
- // the number of "used" cells is simply the biggest seq_id
2222
- if (cache.used < (uint32_t)seq_id) {
2223
- cache.used = seq_id;
2219
+ if ((uint32_t) seq_id < cache.size) {
2220
+ if (seq_id > max) {
2221
+ max = seq_id;
2224
2222
}
2225
- // the "head" is the smallest seq_id
2226
- if (cache.head > (uint32_t)seq_id) {
2227
- cache.head = seq_id;
2223
+ if (seq_id < min) {
2224
+ min = seq_id;
2228
2225
}
2229
2226
// Assuming the tokens are in-order
2230
2227
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
@@ -2233,6 +2230,9 @@ static bool llama_kv_cache_find_slot(
2233
2230
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
2234
2231
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2235
2232
}
2233
+ if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
2234
+ cache.used += 1;
2235
+ }
2236
2236
cache.cells[seq_id].pos = batch.pos[i];
2237
2237
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
2238
2238
} else {
@@ -2244,9 +2244,12 @@ static bool llama_kv_cache_find_slot(
2244
2244
}
2245
2245
}
2246
2246
2247
- cache.n = cache.used - cache.head + 1;
2248
- // sanity check (max >= min)
2249
- return cache.used >= cache.head;
2247
+ // allow getting the range of used cells, from head to head + n
2248
+ cache.head = min;
2249
+ cache.n = max - min + 1;
2250
+
2251
+ // sanity check
2252
+ return max >= min;
2250
2253
}
2251
2254
// otherwise, one cell per token.
2252
2255
0 commit comments