Skip to content

Commit e15f7c9

Browse files
authored
perf: fix the performance issue of append_paged_kv_cache (#588)
The performance of `append_paged_kv_cache` is terrible for small batch size, which is a known issue that we haven't fixed for a long time, this PR fixes it. This PR also adds support for non-contiguous append keys/values (which could be sliced from fused qkv matrix). We first call a triton kernel to convert `append_indptr` to `batch_indices` and `positions` (which is similar to [CSR2COO conversion](https://docs.nvidia.com/cuda/cusparse/#cusparse-t-csr2coo) in sparse matrix). After the conversion, we can use element parallelism instead of batch parallelism. It's also worth trying using triton for the second `AppendPagedKVCacheKernel` kernel, I think the performance should be fine. I'll leave it for future work. Some todo items: 1. add torch.compile support. After this PR (reference number can be found at #583 ): ```bash model: l1b seqlens: [1, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.006ms all_layers: 0.094ms throughput: 5.563GB/s model: l1b seqlens: [4993, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.014ms all_layers: 0.216ms throughput: 1514.280GB/s model: l1b seqlens: [5000] single_layer: 0.014ms all_layers: 0.216ms throughput: 1517.017GB/s model: l1b seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.014ms all_layers: 0.217ms throughput: 1510.863GB/s --- model: l3b seqlens: [1, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.006ms all_layers: 0.165ms throughput: 11.123GB/s model: l3b seqlens: [4993, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.021ms all_layers: 0.580ms throughput: 1975.732GB/s model: l3b seqlens: [5000] single_layer: 0.021ms all_layers: 0.586ms throughput: 1958.078GB/s model: l3b seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers: 0.581ms throughput: 1973.174GB/s --- model: l8b seqlens: [1, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.006ms all_layers: 0.185ms throughput: 11.321GB/s model: l8b seqlens: [4993, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.021ms all_layers: 0.661ms throughput: 1982.815GB/s model: l8b seqlens: [5000] single_layer: 0.021ms all_layers: 0.662ms throughput: 1980.227GB/s model: l8b seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers: 0.667ms throughput: 1964.861GB/s --- model: l70b-tp8 seqlens: [1, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.006ms all_layers: 0.457ms throughput: 1.434GB/s model: l70b-tp8 seqlens: [4993, 1, 1, 1, 1, 1, 1, 1] single_layer: 0.009ms all_layers: 0.710ms throughput: 576.866GB/s model: l70b-tp8 seqlens: [5000] single_layer: 0.009ms all_layers: 0.685ms throughput: 598.366GB/s model: l70b-tp8 seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.009ms all_layers: 0.690ms throughput: 593.453GB/s ``` cc @abcdabcd987
1 parent 1328693 commit e15f7c9

File tree

9 files changed

+285
-93
lines changed

9 files changed

+285
-93
lines changed

benchmarks/bench_append_paged_kv_cache.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,19 @@ def main():
9999
dtype=torch.int32,
100100
)
101101

102+
batch_indices, positions = flashinfer.get_batch_indices_positions(
103+
x_indptr,
104+
flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len),
105+
k.shape[0],
106+
)
107+
102108
@torch.cuda.nvtx.range(f"model={model_name}, seqlens={seqlens}")
103109
def fn():
104110
flashinfer.append_paged_kv_cache(
105111
k,
106112
v,
107-
x_indptr,
113+
batch_indices,
114+
positions,
108115
layer_buf,
109116
kv_indices,
110117
kv_indptr,

include/flashinfer/page.cuh

+42-30
Original file line numberDiff line numberDiff line change
@@ -249,38 +249,34 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<DType, IdType> paged_k
249249
* \param paged_kv The paged key-value cache
250250
* \param key The key to be appended
251251
* \param value The value to be appended
252-
* \param append_indptr The indptr array of the appended ragged tensor
252+
* \param batch_indices The batch indices of elements to be appended
253+
* \param positions The positions of elements to be appended
253254
*/
254255
template <uint32_t head_dim, uint32_t vec_size, typename DType, typename IdType>
255-
__global__ void AppendPagedKVCachePrefillKernel(paged_kv_t<DType, IdType> paged_kv,
256-
DType* __restrict__ key, DType* __restrict__ value,
257-
IdType* __restrict__ append_indptr) {
256+
__global__ void AppendPagedKVCacheKernel(paged_kv_t<DType, IdType> paged_kv,
257+
DType* __restrict__ append_key,
258+
DType* __restrict__ append_value,
259+
IdType* __restrict__ batch_indices,
260+
IdType* __restrict__ positions, uint32_t nnz,
261+
size_t append_k_stride_n, size_t append_k_stride_h,
262+
size_t append_v_stride_n, size_t append_v_stride_h) {
258263
uint32_t tx = threadIdx.x, ty = threadIdx.y;
259264
uint32_t num_heads = paged_kv.num_heads;
260-
uint32_t batch_idx = blockIdx.x;
261265
uint32_t head_idx = ty;
262-
263-
uint32_t seq_len =
264-
(paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size +
265-
paged_kv.last_page_len[batch_idx];
266-
uint32_t append_seq_len = append_indptr[batch_idx + 1] - append_indptr[batch_idx];
267-
uint32_t append_start = seq_len - append_seq_len;
268-
269-
#pragma unroll 2
270-
for (uint32_t j = 0; j < append_seq_len; ++j) {
271-
uint32_t page_seq_idx = j + append_start;
272-
uint32_t page_iter = paged_kv.indptr[batch_idx] + page_seq_idx / paged_kv.page_size;
273-
uint32_t entry_idx = page_seq_idx % paged_kv.page_size;
274-
266+
uint32_t cta_id = blockIdx.x;
267+
uint32_t num_ctas = gridDim.x;
268+
269+
#pragma unroll 4
270+
for (uint32_t i = cta_id; i < nnz; i += num_ctas) {
271+
uint32_t page_iter, entry_idx;
272+
paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i],
273+
page_iter, entry_idx);
275274
DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
276275
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
277276
vec_t<DType, vec_size>::memcpy(
278-
k_ptr,
279-
key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
280-
277+
k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size);
281278
vec_t<DType, vec_size>::memcpy(
282-
v_ptr,
283-
value + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
279+
v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size);
284280
}
285281
}
286282

@@ -327,20 +323,36 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t<DType, IdType> paged_kv, DType*
327323
* \return status Indicates whether CUDA calls are successful
328324
*/
329325
template <typename DType, typename IdType>
330-
cudaError_t AppendPagedKVCache(paged_kv_t<DType, IdType> paged_kv, DType* key, DType* value,
331-
IdType* append_indptr, cudaStream_t stream = nullptr) {
326+
cudaError_t AppendPagedKVCache(paged_kv_t<DType, IdType> paged_kv, DType* append_key,
327+
DType* append_value, IdType* batch_indices, IdType* positions,
328+
uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h,
329+
size_t append_v_stride_n, size_t append_v_stride_h,
330+
cudaStream_t stream = nullptr) {
332331
uint32_t head_dim = paged_kv.head_dim;
333-
uint32_t batch_size = paged_kv.batch_size;
334332
uint32_t num_heads = paged_kv.num_heads;
333+
int dev_id = 0;
334+
int num_sms = 0;
335+
int num_blocks_per_sm = 0;
336+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
337+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
338+
335339
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
336340
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
337341
uint32_t bdx = HEAD_DIM / vec_size;
338342
uint32_t bdy = num_heads;
339-
// NOTE(Zihao): could be slow for small batch size, will optimize later
340-
dim3 nblks(batch_size);
343+
uint32_t num_threads = bdx * bdy;
344+
uint32_t smem_size = 0;
345+
auto kernel = AppendPagedKVCacheKernel<HEAD_DIM, vec_size, DType, IdType>;
346+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
347+
num_threads, smem_size));
348+
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms));
349+
dim3 nblks(num_blocks_per_sm * num_sms);
341350
dim3 nthrs(bdx, bdy);
342-
auto kernel = AppendPagedKVCachePrefillKernel<HEAD_DIM, vec_size, DType, IdType>;
343-
void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, (void*)&append_indptr};
351+
352+
void* args[] = {(void*)&paged_kv, (void*)&append_key, (void*)&append_value,
353+
(void*)&batch_indices, (void*)&positions, (void*)&nnz,
354+
(void*)&append_k_stride_n, (void*)&append_k_stride_h, (void*)&append_v_stride_n,
355+
(void*)&append_v_stride_h};
344356
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
345357
});
346358
return cudaSuccess;

python/csrc/flashinfer_page_ops.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
#include <torch/extension.h>
1717

1818
void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
19-
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
20-
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
21-
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
22-
unsigned int layout);
19+
torch::Tensor batch_indices, torch::Tensor positions,
20+
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
21+
torch::Tensor kv_indices, torch::Tensor kv_indptr,
22+
torch::Tensor kv_last_page_len, unsigned int layout);
2323

2424
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2525
m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator");

python/csrc/page.cu

+30-16
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
using namespace flashinfer;
2121

2222
void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
23-
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
24-
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
25-
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
26-
unsigned int layout) {
27-
CHECK_INPUT(append_key);
28-
CHECK_INPUT(append_value);
29-
CHECK_INPUT(append_indptr);
23+
torch::Tensor batch_indices, torch::Tensor positions,
24+
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
25+
torch::Tensor kv_indices, torch::Tensor kv_indptr,
26+
torch::Tensor kv_last_page_len, unsigned int layout) {
27+
CHECK_LAST_DIM_CONTIGUOUS(append_key);
28+
CHECK_LAST_DIM_CONTIGUOUS(append_value);
29+
CHECK_INPUT(batch_indices);
30+
CHECK_INPUT(positions);
3031
// NOTE(Zihao): doesn't have to be contiguous
3132
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_k_cache);
3233
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_v_cache);
@@ -35,20 +36,24 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
3536
CHECK_INPUT(kv_last_page_len);
3637
CHECK_DIM(3, append_key);
3738
CHECK_DIM(3, append_value);
38-
CHECK_DIM(1, append_indptr);
39+
CHECK_DIM(1, batch_indices);
40+
CHECK_DIM(1, positions);
3941
CHECK_DIM(4, paged_k_cache);
4042
CHECK_DIM(4, paged_v_cache);
4143
CHECK_DIM(1, kv_indices);
4244
CHECK_DIM(1, kv_indptr);
4345
CHECK_DIM(1, kv_last_page_len);
46+
unsigned int nnz = append_key.size(0);
4447
unsigned int batch_size = kv_last_page_len.size(0);
45-
CHECK_EQ(append_indptr.size(0), batch_size + 1);
4648
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
47-
CHECK_EQ(append_indptr.scalar_type(), torch::kInt32);
49+
CHECK_EQ(batch_indices.size(0), nnz);
50+
CHECK_EQ(positions.size(0), nnz);
51+
CHECK_EQ(batch_indices.scalar_type(), torch::kInt32);
52+
CHECK_EQ(positions.scalar_type(), torch::kInt32);
4853
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);
4954
CHECK_EQ(kv_indices.scalar_type(), torch::kInt32);
5055
CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32);
51-
auto device = append_indptr.device();
56+
auto device = append_key.device();
5257
CHECK_EQ(append_key.device(), device);
5358
CHECK_EQ(append_value.device(), device);
5459
CHECK_EQ(paged_k_cache.device(), device);
@@ -76,10 +81,17 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
7681
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
7782
kv_cache_strides = k_strides.data();
7883

84+
auto append_k_strides = append_key.strides();
85+
auto append_k_stride_n = append_k_strides[0];
86+
auto append_k_stride_h = append_k_strides[1];
87+
auto append_v_strides = append_value.strides();
88+
auto append_v_stride_n = append_v_strides[0];
89+
auto append_v_stride_h = append_v_strides[1];
90+
7991
CHECK_EQ(append_key.size(1), num_heads);
8092
CHECK_EQ(append_key.size(2), head_dim);
8193
CHECK_EQ(append_value.size(1), num_heads);
82-
CHECK_EQ(append_key.size(2), head_dim);
94+
CHECK_EQ(append_value.size(2), head_dim);
8395

8496
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
8597

@@ -92,10 +104,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
92104
static_cast<c_type*>(paged_v_cache.data_ptr()), kv_cache_strides,
93105
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
94106
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
95-
cudaError_t status =
96-
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
97-
static_cast<c_type*>(append_value.data_ptr()),
98-
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
107+
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
108+
static_cast<c_type*>(append_value.data_ptr()),
109+
static_cast<int32_t*>(batch_indices.data_ptr()),
110+
static_cast<int32_t*>(positions.data_ptr()), nnz,
111+
append_k_stride_n, append_k_stride_h, append_v_stride_n,
112+
append_v_stride_h, torch_current_stream);
99113
TORCH_CHECK(status == cudaSuccess,
100114
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
101115
return true;

python/csrc_aot/flashinfer_ops.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc
8080
//========== page ==========
8181

8282
void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
83-
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
84-
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
85-
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
86-
unsigned int layout);
83+
torch::Tensor batch_indices, torch::Tensor positions,
84+
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
85+
torch::Tensor kv_indices, torch::Tensor kv_indptr,
86+
torch::Tensor kv_last_page_len, unsigned int layout);
8787

8888
//========== prefill ==========
8989

python/flashinfer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from .norm import gemma_rmsnorm as gemma_rmsnorm
4747
from .norm import rmsnorm as rmsnorm
4848
from .page import append_paged_kv_cache as append_paged_kv_cache
49+
from .page import get_batch_indices_positions as get_batch_indices_positions
50+
from .page import get_seq_lens as get_seq_lens
4951
from .prefill import (
5052
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
5153
)

0 commit comments

Comments
 (0)