@@ -419,21 +419,41 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
419
419
return cudaSuccess;
420
420
}
421
421
422
+ inline uint32_t DetermineCtaTileQ (int64_t avg_packed_qo_len, uint32_t head_dim) {
423
+ if (avg_packed_qo_len > 64 && head_dim < 256 ) {
424
+ return 128 ;
425
+ } else {
426
+ auto compute_capacity = GetCudaComputeCapability ();
427
+ if (compute_capacity.first >= 8 ) {
428
+ // Ampere or newer
429
+ if (avg_packed_qo_len > 16 ) {
430
+ // avg_packed_qo_len <= 64
431
+ return 64 ;
432
+ } else {
433
+ // avg_packed_qo_len <= 16
434
+ return 16 ;
435
+ }
436
+ } else {
437
+ // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
438
+ return 64 ;
439
+ }
440
+ }
441
+ }
442
+
422
443
template <typename IdType>
423
- inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
424
- uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
425
- uint32_t page_size, uint32_t max_batch_size_if_split,
426
- bool enable_cuda_graph) {
444
+ inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h,
445
+ uint32_t total_num_rows, uint32_t max_seq_len,
446
+ uint32_t batch_size, uint32_t num_qo_heads,
447
+ uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
448
+ uint32_t max_batch_size_if_split, bool enable_cuda_graph) {
427
449
std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
428
450
merge_indptr.push_back (0 );
429
451
o_indptr.push_back (0 );
430
452
431
453
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
432
- uint32_t total_num_rows = qo_indptr_h[batch_size];
433
454
434
- // step 1: compute qo_chunk_size
455
+ // step 1: determine packed_qo_len_arr and verify qo_indptr contents.
435
456
std::vector<int64_t > packed_qo_len_arr (batch_size), kv_len_arr (batch_size);
436
- int64_t sum_packed_qo_len = 0 ;
437
457
for (uint32_t i = 0 ; i < batch_size; ++i) {
438
458
packed_qo_len_arr[i] = int64_t (qo_indptr_h[i + 1 ] - qo_indptr_h[i]) * int64_t (gqa_group_size);
439
459
if (packed_qo_len_arr[i] < 0 ) {
@@ -449,41 +469,43 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
449
469
<< kv_indptr_h[i] << " should be non-negative" ;
450
470
FLASHINFER_ERROR (err_msg.str ());
451
471
}
452
- sum_packed_qo_len += packed_qo_len_arr[i];
453
472
}
454
- int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
473
+
474
+ // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
455
475
uint32_t cta_tile_q;
456
- if (avg_packed_qo_len > 64 && head_dim < 256 ) {
457
- cta_tile_q = 128 ;
476
+ uint32_t total_num_tiles_q;
477
+ bool split_kv;
478
+ int64_t kv_chunk_size, new_batch_size;
479
+ if (enable_cuda_graph) {
480
+ // When CUDA graphs are enabled, the lengths of sequences determined by
481
+ // qo_indptr_h can vary. We assume that the dummy data based on which
482
+ // the CUDA graph is created fixes the maximum number of tokens.
483
+ uint64_t max_qo_len = uint64_t (max_seq_len) * gqa_group_size;
484
+ cta_tile_q = DetermineCtaTileQ (max_qo_len, head_dim);
485
+
486
+ // Find an upper bound for the number of tiles, derived from the total
487
+ // number of rows and the batch size. The sum of qo lengths rounded
488
+ // up to cta_tile_q will not exceed this number derived from the total
489
+ // number of rows.
490
+ total_num_tiles_q = ceil_div (total_num_rows, cta_tile_q) + batch_size;
491
+
492
+ split_kv = true ;
493
+ kv_chunk_size = max_batch_size_if_split;
494
+ new_batch_size = max_batch_size_if_split;
458
495
} else {
459
- auto compute_capacity = GetCudaComputeCapability ();
460
- if (compute_capacity.first >= 8 ) {
461
- // Ampere or newer
462
- if (avg_packed_qo_len > 16 ) {
463
- // avg_packed_qo_len <= 64
464
- cta_tile_q = 64 ;
465
- } else {
466
- // avg_packed_qo_len <= 16
467
- cta_tile_q = 16 ;
468
- }
469
- } else {
470
- // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
471
- cta_tile_q = 64 ;
496
+ total_num_tiles_q = 0 ;
497
+ int64_t sum_packed_qo_len = 0 ;
498
+ for (uint32_t i = 0 ; i < batch_size; ++i) {
499
+ total_num_tiles_q += ceil_div (packed_qo_len_arr[i], cta_tile_q);
500
+ sum_packed_qo_len += packed_qo_len_arr[i];
472
501
}
473
- }
474
502
475
- uint32_t total_num_tiles_q = 0 ;
476
- for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
477
- total_num_tiles_q += ceil_div (packed_qo_len_arr[request_idx], cta_tile_q);
478
- }
503
+ const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
504
+ cta_tile_q = DetermineCtaTileQ (avg_packed_qo_len, head_dim);
479
505
480
- // step 2: determine kv_chunk_size
481
- auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize (
482
- max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
483
- /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
484
-
485
- if (enable_cuda_graph) {
486
- split_kv = total_num_tiles_q < max_batch_size_if_split;
506
+ std::tie (split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize (
507
+ max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
508
+ /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
487
509
}
488
510
489
511
// step 3: split qo_indptr and kv_indptr
@@ -511,7 +533,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
511
533
kv_chunk_size *= page_size;
512
534
513
535
return std::make_tuple (split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
514
- total_num_rows, std::move (request_indices), std::move (qo_tile_indices),
536
+ std::move (request_indices), std::move (qo_tile_indices),
515
537
std::move (kv_tile_indices), std::move (merge_indptr), std::move (o_indptr));
516
538
}
517
539
@@ -597,9 +619,10 @@ template <typename IdType>
597
619
inline cudaError_t PrefillPlan (void * float_buffer, size_t float_workspace_size_in_bytes,
598
620
void * int_buffer, void * page_locked_int_buffer,
599
621
size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
600
- IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
601
- uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
602
- uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o,
622
+ IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
623
+ uint32_t max_seq_len, uint32_t batch_size, uint32_t num_qo_heads,
624
+ uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
625
+ bool enable_cuda_graph, uint32_t sizeof_dtype_o,
603
626
cudaStream_t stream) {
604
627
if (num_qo_heads % num_kv_heads != 0 ) {
605
628
std::ostringstream err_msg;
@@ -618,17 +641,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
618
641
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
619
642
620
643
// step 2: determine kv_chunk_size
621
- auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows ,
622
- request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
623
- o_indptr_vec] =
624
- PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads,
625
- head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
644
+ auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec ,
645
+ qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
646
+ PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size,
647
+ num_qo_heads, num_kv_heads, head_dim, page_size ,
648
+ max_batch_size_if_split, enable_cuda_graph);
626
649
plan_info.cta_tile_q = cta_tile_q;
627
650
plan_info.total_num_rows = total_num_rows;
628
651
629
652
plan_info.enable_cuda_graph = enable_cuda_graph;
630
653
size_t padded_batch_size =
631
654
enable_cuda_graph ? std::max (max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
655
+
632
656
plan_info.padded_batch_size = padded_batch_size;
633
657
plan_info.split_kv = split_kv;
634
658
@@ -679,6 +703,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
679
703
sizeof (IdType) * (plan_info.total_num_rows + 1 ), 16 , " batch_prefill_merge_indptr" );
680
704
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset (
681
705
sizeof (bool ) * padded_batch_size, 16 , " batch_prefill_block_valid_mask" );
706
+
682
707
IdType* merge_indptr_h =
683
708
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset );
684
709
bool * block_valid_mask_h =
0 commit comments