@@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init(
3502
3502
return true;
3503
3503
}
3504
3504
3505
+ // a structure holds information about the slot found in llama_kv_cache_find_slot
3506
+ struct llama_kv_cache_slot_info {
3507
+ std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
3508
+ bool found = false; // the slot was found
3509
+
3510
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3511
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3512
+
3513
+ operator bool() const { return found; }
3514
+ };
3515
+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3516
+
3505
3517
// find an empty slot of size "n_tokens" in the cache
3506
3518
// updates the cache head
3519
+ // returns a structure holding information about the slot found
3507
3520
// Note: On success, it's important that cache.head points
3508
3521
// to the first cell of the slot.
3509
- static bool llama_kv_cache_find_slot(
3522
+ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3510
3523
struct llama_kv_cache & cache,
3511
3524
const struct llama_ubatch & batch) {
3512
3525
const uint32_t n_tokens = batch.n_tokens;
@@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
3534
3547
// too big seq_id
3535
3548
// TODO: would it be possible to resize the cache instead?
3536
3549
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3537
- return false ;
3550
+ return llama_kv_cache_slot_info_failed ;
3538
3551
}
3539
3552
if (j > 0) {
3540
3553
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot(
3669
3682
// allow getting the range of used cells, from head to head + n
3670
3683
cache.head = min;
3671
3684
cache.n = max - min + 1;
3685
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3686
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
3672
3687
3673
3688
// sanity check
3674
- return cache.n >= n_seqs;
3689
+ return llama_kv_cache_slot_info( cache.n >= n_seqs) ;
3675
3690
}
3676
3691
// otherwise, one cell per token.
3677
3692
3678
3693
if (n_tokens > cache.size) {
3679
3694
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3680
- return false ;
3695
+ return llama_kv_cache_slot_info_failed ;
3681
3696
}
3682
3697
3683
3698
uint32_t n_tested = 0;
@@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot(
3705
3720
3706
3721
if (n_tested >= cache.size) {
3707
3722
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3708
- return false ;
3723
+ return llama_kv_cache_slot_info_failed ;
3709
3724
}
3710
3725
}
3711
3726
@@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot(
3722
3737
3723
3738
cache.used += n_tokens;
3724
3739
3725
- return true ;
3740
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens) ;
3726
3741
}
3727
3742
3728
3743
// find how many cells are currently in use
@@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
3998
4013
return cparams.flash_attn ? 256u : 32u;
3999
4014
}
4000
4015
4016
+ // saves the kv_cache state for future recovery.
4017
+ // used to rollback llama_kv_cache_find_slot changes.
4018
+ struct llama_kv_slot_restorer {
4019
+ struct llama_kv_cache_state {
4020
+ uint32_t head = 0;
4021
+ uint32_t n = 0;
4022
+ } old_state;
4023
+
4024
+ // for non-recurrent models only
4025
+ // list of slots to restore
4026
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
4027
+
4028
+ bool do_restore = false;
4029
+
4030
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4031
+ old_state.head = cache.head;
4032
+ old_state.n = cache.n;
4033
+ }
4034
+
4035
+ // saves a slot information for future restoration
4036
+ void save(const struct llama_kv_cache_slot_info & slot) {
4037
+ if (slot) {
4038
+ do_restore = true;
4039
+ if (slot.boundaries.first != slot.boundaries.second) {
4040
+ slot_boundaries.push_back(slot.boundaries);
4041
+ }
4042
+ }
4043
+ }
4044
+
4045
+ // must be explicitly called to restore the kv_cache state
4046
+ // and rollback changes from all llama_kv_cache_find_slot calls
4047
+ void restore(struct llama_kv_cache & cache) {
4048
+ if (do_restore) {
4049
+ cache.head = old_state.head;
4050
+ cache.n = old_state.n;
4051
+
4052
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4053
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
4054
+ } else {
4055
+ for (auto & slot : slot_boundaries) {
4056
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4057
+ }
4058
+ }
4059
+ }
4060
+ }
4061
+ };
4062
+
4001
4063
//
4002
4064
// model loading and saving
4003
4065
//
@@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
17181
17243
}
17182
17244
}
17183
17245
17184
- static void llama_graph_compute(
17246
+ // returns the result of ggml_backend_sched_graph_compute_async execution
17247
+ static enum ggml_status llama_graph_compute(
17185
17248
llama_context & lctx,
17186
17249
ggml_cgraph * gf,
17187
17250
int n_threads,
@@ -17196,15 +17259,20 @@ static void llama_graph_compute(
17196
17259
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
17197
17260
}
17198
17261
17199
- auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17200
- if (err != GGML_STATUS_SUCCESS) {
17201
- LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err );
17262
+ auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
17263
+ if (status != GGML_STATUS_SUCCESS) {
17264
+ LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status );
17202
17265
}
17203
17266
17204
17267
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
17268
+
17269
+ return status;
17205
17270
}
17206
17271
17207
17272
// decode a batch of tokens by evaluating the transformer
17273
+ // in case of unsuccessful decoding (error or warning),
17274
+ // the kv_cache state will be returned to its original state
17275
+ // (for non-recurrent models) or cleaned (for recurrent models)
17208
17276
//
17209
17277
// - lctx: llama context
17210
17278
// - batch: batch to evaluate
@@ -17254,6 +17322,7 @@ static int llama_decode_internal(
17254
17322
lctx.n_queued_tokens += n_tokens_all;
17255
17323
17256
17324
auto & kv_self = lctx.kv_self;
17325
+ llama_kv_slot_restorer kv_slot_restorer(kv_self);
17257
17326
17258
17327
const int64_t n_embd = hparams.n_embd;
17259
17328
const int64_t n_vocab = hparams.n_vocab;
@@ -17338,9 +17407,11 @@ static int llama_decode_internal(
17338
17407
kv_self.head = 0;
17339
17408
}
17340
17409
17341
- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17410
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17411
+ if (!slot) {
17342
17412
return 1;
17343
17413
}
17414
+ kv_slot_restorer.save(slot);
17344
17415
17345
17416
if (!kv_self.recurrent) {
17346
17417
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17387,7 +17458,19 @@ static int llama_decode_internal(
17387
17458
17388
17459
llama_set_inputs(lctx, ubatch);
17389
17460
17390
- llama_graph_compute(lctx, gf, n_threads, threadpool);
17461
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17462
+ if (compute_status != GGML_STATUS_SUCCESS) {
17463
+ kv_slot_restorer.restore(kv_self);
17464
+ switch (compute_status) {
17465
+ case GGML_STATUS_ABORTED:
17466
+ return 2;
17467
+ case GGML_STATUS_ALLOC_FAILED:
17468
+ return -2;
17469
+ case GGML_STATUS_FAILED:
17470
+ default:
17471
+ return -3;
17472
+ }
17473
+ }
17391
17474
17392
17475
// update the kv ring buffer
17393
17476
{
@@ -17624,7 +17707,18 @@ static int llama_encode_internal(
17624
17707
17625
17708
llama_set_inputs(lctx, ubatch);
17626
17709
17627
- llama_graph_compute(lctx, gf, n_threads, threadpool);
17710
+ const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17711
+ switch (compute_status) {
17712
+ case GGML_STATUS_SUCCESS:
17713
+ break;
17714
+ case GGML_STATUS_ABORTED:
17715
+ return 2;
17716
+ case GGML_STATUS_ALLOC_FAILED:
17717
+ return -2;
17718
+ case GGML_STATUS_FAILED:
17719
+ default:
17720
+ return -3;
17721
+ }
17628
17722
17629
17723
// extract embeddings
17630
17724
if (embd) {
0 commit comments