Skip to content

Commit d7b800b

Browse files
authored
llama : pad KV cache size (#4280)
* llama : pad KV cache size to 32 * metal : try to improve batched decoding
1 parent 5a7d312 commit d7b800b

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

ggml-metal.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ void ggml_metal_graph_compute(
10831083

10841084
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
10851085
// to the matrix-vector kernel
1086-
int ne11_mm_min = 1;
1086+
int ne11_mm_min = src0t == GGML_TYPE_F16 ? 1 : 16;
10871087

10881088
#if 0
10891089
// the numbers below are measured on M2 Ultra for 7B and 13B models

llama.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -5744,8 +5744,7 @@ static int llama_decode_internal(
57445744
// a heuristic, to avoid attending the full cache if it is not yet utilized
57455745
// after enough generations, the benefit from this heuristic disappears
57465746
// if we start defragmenting the cache, the benefit from this will be more important
5747-
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
5748-
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
5747+
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
57495748

57505749
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
57515750

0 commit comments

Comments
 (0)