Skip to content

Commit 5fe9f7d

Browse files
authored
feat: pass a dynamic token count to the cascade kernels (#635)
Under CUDA graph, if the graph is built with a maximal token count, the actual number of tokens from `qo_indptr` is passed on to the cascade kernels.
1 parent db9c48d commit 5fe9f7d

File tree

6 files changed

+220
-138
lines changed

6 files changed

+220
-138
lines changed

include/flashinfer/attention/prefill.cuh

+12-14
Original file line numberDiff line numberDiff line change
@@ -2121,7 +2121,6 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
21212121
const uint32_t num_qo_heads = params.num_qo_heads;
21222122
const uint32_t num_kv_heads = params.num_kv_heads;
21232123
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
2124-
const uint32_t total_num_rows = params.total_num_rows;
21252124
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
21262125
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
21272126
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
@@ -2198,13 +2197,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
21982197
FLASHINFER_CUDA_CALL(
21992198
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
22002199
if constexpr (AttentionVariant::use_softmax) {
2201-
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
2202-
total_num_rows, nullptr, num_qo_heads,
2203-
HEAD_DIM, stream));
2200+
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
2201+
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
2202+
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
22042203
} else {
2205-
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
2206-
total_num_rows, nullptr, num_qo_heads,
2207-
HEAD_DIM, stream));
2204+
FLASHINFER_CUDA_CALL(
2205+
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
2206+
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
22082207
}
22092208
}
22102209
}
@@ -2223,7 +2222,6 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
22232222
const uint32_t num_qo_heads = params.num_qo_heads;
22242223
const uint32_t num_kv_heads = params.paged_kv.num_heads;
22252224
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
2226-
const uint32_t total_num_rows = params.total_num_rows;
22272225
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
22282226
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
22292227
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
@@ -2300,13 +2298,13 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
23002298
FLASHINFER_CUDA_CALL(
23012299
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
23022300
if constexpr (AttentionVariant::use_softmax) {
2303-
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
2304-
total_num_rows, nullptr, num_qo_heads,
2305-
HEAD_DIM, stream));
2301+
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
2302+
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
2303+
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
23062304
} else {
2307-
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
2308-
total_num_rows, nullptr, num_qo_heads,
2309-
HEAD_DIM, stream));
2305+
FLASHINFER_CUDA_CALL(
2306+
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
2307+
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
23102308
}
23112309
}
23122310
}

include/flashinfer/attention/prefill_params.cuh

+8-4
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ struct BatchPrefillRaggedParams {
136136
IdType* o_indptr;
137137
IdType* kv_chunk_size_ptr;
138138
bool* block_valid_mask;
139-
uint32_t total_num_rows;
139+
uint32_t max_total_num_rows;
140+
uint32_t* total_num_rows;
140141
uint32_t padded_batch_size;
141142
bool partition_kv;
142143

@@ -178,7 +179,8 @@ struct BatchPrefillRaggedParams {
178179
o_indptr(nullptr),
179180
kv_chunk_size_ptr(nullptr),
180181
block_valid_mask(nullptr),
181-
total_num_rows(0),
182+
max_total_num_rows(0),
183+
total_num_rows(nullptr),
182184
padded_batch_size(0),
183185
partition_kv(false) {}
184186

@@ -227,7 +229,8 @@ struct BatchPrefillPagedParams {
227229
IdType* o_indptr;
228230
bool* block_valid_mask;
229231
IdType* kv_chunk_size_ptr;
230-
uint32_t total_num_rows;
232+
uint32_t max_total_num_rows;
233+
uint32_t* total_num_rows;
231234
uint32_t padded_batch_size;
232235
bool partition_kv;
233236

@@ -261,7 +264,8 @@ struct BatchPrefillPagedParams {
261264
o_indptr(nullptr),
262265
block_valid_mask(nullptr),
263266
kv_chunk_size_ptr(nullptr),
264-
total_num_rows(0),
267+
max_total_num_rows(0),
268+
total_num_rows(nullptr),
265269
padded_batch_size(0),
266270
partition_kv(false) {}
267271

include/flashinfer/attention/scheduler.cuh

+25-13
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
518518
struct PrefillPlanInfo {
519519
int64_t padded_batch_size;
520520
int64_t total_num_rows;
521+
int64_t total_num_rows_offset;
521522
int64_t cta_tile_q;
522523
int64_t request_indices_offset;
523524
int64_t qo_tile_indices_offset;
@@ -534,6 +535,7 @@ struct PrefillPlanInfo {
534535
PrefillPlanInfo()
535536
: padded_batch_size(0),
536537
total_num_rows(0),
538+
total_num_rows_offset(0),
537539
cta_tile_q(0),
538540
request_indices_offset(0),
539541
qo_tile_indices_offset(0),
@@ -551,6 +553,7 @@ struct PrefillPlanInfo {
551553
std::vector<int64_t> ToVector() const {
552554
return {padded_batch_size,
553555
total_num_rows,
556+
total_num_rows_offset,
554557
cta_tile_q,
555558
request_indices_offset,
556559
qo_tile_indices_offset,
@@ -567,25 +570,26 @@ struct PrefillPlanInfo {
567570

568571
// From std::vector<int64_t> to PrefillPlanInfo
569572
void FromVector(const std::vector<int64_t>& vec) {
570-
if (vec.size() != 14) {
573+
if (vec.size() != 15) {
571574
std::ostringstream err_msg;
572575
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
573576
FLASHINFER_ERROR(err_msg.str());
574577
}
575578
padded_batch_size = vec[0];
576579
total_num_rows = vec[1];
577-
cta_tile_q = vec[2];
578-
request_indices_offset = vec[3];
579-
qo_tile_indices_offset = vec[4];
580-
kv_tile_indices_offset = vec[5];
581-
merge_indptr_offset = vec[6];
582-
o_indptr_offset = vec[7];
583-
kv_chunk_size_ptr_offset = vec[8];
584-
v_offset = vec[9];
585-
s_offset = vec[10];
586-
block_valid_mask_offset = vec[11];
587-
enable_cuda_graph = vec[12];
588-
split_kv = vec[13];
580+
total_num_rows_offset = vec[2];
581+
cta_tile_q = vec[3];
582+
request_indices_offset = vec[4];
583+
qo_tile_indices_offset = vec[5];
584+
kv_tile_indices_offset = vec[6];
585+
merge_indptr_offset = vec[7];
586+
o_indptr_offset = vec[8];
587+
kv_chunk_size_ptr_offset = vec[9];
588+
v_offset = vec[10];
589+
s_offset = vec[11];
590+
block_valid_mask_offset = vec[12];
591+
enable_cuda_graph = vec[13];
592+
split_kv = vec[14];
589593
}
590594
};
591595

@@ -640,6 +644,14 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
640644
plan_info.kv_chunk_size_ptr_offset =
641645
int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
642646

647+
if (plan_info.enable_cuda_graph) {
648+
plan_info.total_num_rows_offset =
649+
int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows");
650+
uint32_t* total_num_rows_h =
651+
GetPtrFromBaseOffset<uint32_t>(page_locked_int_buffer, plan_info.total_num_rows_offset);
652+
*total_num_rows_h = qo_indptr_h[batch_size];
653+
}
654+
643655
IdType* request_indices_h =
644656
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.request_indices_offset);
645657
IdType* qo_tile_indices_h =

python/csrc/batch_prefill.cu

+10-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,12 @@ void BatchPrefillWithRaggedKVCacheRun(
160160
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
161161
}
162162
}
163-
params.total_num_rows = plan_info.total_num_rows;
164163
params.padded_batch_size = plan_info.padded_batch_size;
164+
params.max_total_num_rows = plan_info.total_num_rows;
165+
if (plan_info.enable_cuda_graph) {
166+
params.total_num_rows =
167+
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
168+
}
165169

166170
cudaError_t status = cudaSuccess;
167171

@@ -290,8 +294,12 @@ void BatchPrefillWithPagedKVCacheRun(
290294
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
291295
}
292296
}
293-
params.total_num_rows = plan_info.total_num_rows;
294297
params.padded_batch_size = plan_info.padded_batch_size;
298+
params.max_total_num_rows = plan_info.total_num_rows;
299+
if (plan_info.enable_cuda_graph) {
300+
params.total_num_rows =
301+
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
302+
}
295303

296304
cudaError_t status = cudaSuccess;
297305

0 commit comments

Comments
 (0)