@@ -249,38 +249,34 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<DType, IdType> paged_k
249
249
* \param paged_kv The paged key-value cache
250
250
* \param key The key to be appended
251
251
* \param value The value to be appended
252
- * \param append_indptr The indptr array of the appended ragged tensor
252
+ * \param batch_indices The batch indices of elements to be appended
253
+ * \param positions The positions of elements to be appended
253
254
*/
254
255
template <uint32_t head_dim, uint32_t vec_size, typename DType, typename IdType>
255
- __global__ void AppendPagedKVCachePrefillKernel (paged_kv_t <DType, IdType> paged_kv,
256
- DType* __restrict__ key, DType* __restrict__ value,
257
- IdType* __restrict__ append_indptr) {
256
+ __global__ void AppendPagedKVCacheKernel (paged_kv_t <DType, IdType> paged_kv,
257
+ DType* __restrict__ append_key,
258
+ DType* __restrict__ append_value,
259
+ IdType* __restrict__ batch_indices,
260
+ IdType* __restrict__ positions, uint32_t nnz,
261
+ size_t append_k_stride_n, size_t append_k_stride_h,
262
+ size_t append_v_stride_n, size_t append_v_stride_h) {
258
263
uint32_t tx = threadIdx .x , ty = threadIdx .y ;
259
264
uint32_t num_heads = paged_kv.num_heads ;
260
- uint32_t batch_idx = blockIdx .x ;
261
265
uint32_t head_idx = ty;
262
-
263
- uint32_t seq_len =
264
- (paged_kv.indptr [batch_idx + 1 ] - paged_kv.indptr [batch_idx] - 1 ) * paged_kv.page_size +
265
- paged_kv.last_page_len [batch_idx];
266
- uint32_t append_seq_len = append_indptr[batch_idx + 1 ] - append_indptr[batch_idx];
267
- uint32_t append_start = seq_len - append_seq_len;
268
-
269
- #pragma unroll 2
270
- for (uint32_t j = 0 ; j < append_seq_len; ++j) {
271
- uint32_t page_seq_idx = j + append_start;
272
- uint32_t page_iter = paged_kv.indptr [batch_idx] + page_seq_idx / paged_kv.page_size ;
273
- uint32_t entry_idx = page_seq_idx % paged_kv.page_size ;
274
-
266
+ uint32_t cta_id = blockIdx .x ;
267
+ uint32_t num_ctas = gridDim .x ;
268
+
269
+ #pragma unroll 4
270
+ for (uint32_t i = cta_id; i < nnz; i += num_ctas) {
271
+ uint32_t page_iter, entry_idx;
272
+ paged_kv.page_size .divmod (paged_kv.indptr [batch_indices[i]] * paged_kv.page_size + positions[i],
273
+ page_iter, entry_idx);
275
274
DType* k_ptr = paged_kv.get_k_ptr (page_iter, head_idx, entry_idx, tx * vec_size);
276
275
DType* v_ptr = paged_kv.get_v_ptr (page_iter, head_idx, entry_idx, tx * vec_size);
277
276
vec_t <DType, vec_size>::memcpy (
278
- k_ptr,
279
- key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
280
-
277
+ k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size);
281
278
vec_t <DType, vec_size>::memcpy (
282
- v_ptr,
283
- value + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
279
+ v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size);
284
280
}
285
281
}
286
282
@@ -327,20 +323,36 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t<DType, IdType> paged_kv, DType*
327
323
* \return status Indicates whether CUDA calls are successful
328
324
*/
329
325
template <typename DType, typename IdType>
330
- cudaError_t AppendPagedKVCache (paged_kv_t <DType, IdType> paged_kv, DType* key, DType* value,
331
- IdType* append_indptr, cudaStream_t stream = nullptr ) {
326
+ cudaError_t AppendPagedKVCache (paged_kv_t <DType, IdType> paged_kv, DType* append_key,
327
+ DType* append_value, IdType* batch_indices, IdType* positions,
328
+ uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h,
329
+ size_t append_v_stride_n, size_t append_v_stride_h,
330
+ cudaStream_t stream = nullptr ) {
332
331
uint32_t head_dim = paged_kv.head_dim ;
333
- uint32_t batch_size = paged_kv.batch_size ;
334
332
uint32_t num_heads = paged_kv.num_heads ;
333
+ int dev_id = 0 ;
334
+ int num_sms = 0 ;
335
+ int num_blocks_per_sm = 0 ;
336
+ FLASHINFER_CUDA_CALL (cudaGetDevice (&dev_id));
337
+ FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
338
+
335
339
DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
336
340
constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
337
341
uint32_t bdx = HEAD_DIM / vec_size;
338
342
uint32_t bdy = num_heads;
339
- // NOTE(Zihao): could be slow for small batch size, will optimize later
340
- dim3 nblks (batch_size);
343
+ uint32_t num_threads = bdx * bdy;
344
+ uint32_t smem_size = 0 ;
345
+ auto kernel = AppendPagedKVCacheKernel<HEAD_DIM, vec_size, DType, IdType>;
346
+ FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&num_blocks_per_sm, kernel,
347
+ num_threads, smem_size));
348
+ num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (int (nnz), num_sms));
349
+ dim3 nblks (num_blocks_per_sm * num_sms);
341
350
dim3 nthrs (bdx, bdy);
342
- auto kernel = AppendPagedKVCachePrefillKernel<HEAD_DIM, vec_size, DType, IdType>;
343
- void * args[] = {(void *)&paged_kv, (void *)&key, (void *)&value, (void *)&append_indptr};
351
+
352
+ void * args[] = {(void *)&paged_kv, (void *)&append_key, (void *)&append_value,
353
+ (void *)&batch_indices, (void *)&positions, (void *)&nnz,
354
+ (void *)&append_k_stride_n, (void *)&append_k_stride_h, (void *)&append_v_stride_n,
355
+ (void *)&append_v_stride_h};
344
356
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
345
357
});
346
358
return cudaSuccess;
0 commit comments