Skip to content

Commit 86ca89a

Browse files
authored
feat: fix the maximal grid dimension in prefill planning with CUDA graphs (#639)
Previously, differences in the contents of qo_indptr could lead to block sizes varying across CUDA graph invocations, leading to illegal memory accessed. This PR alters the calculation of the block size to find a reasonable maximum based on the longest sequence. The maximum token count is fixed in `plan` on the `Python` side and passed along to `scheduler.cuh` to derive the other parameters. While this ensures correctness under CUDA graphs, when CUDA graphs are enabled split-kv is now always used, potentially degrading performance if CUDA graphs are to be used with fixed `qo_indptr`. However, for varying `qo_indptr`, CUDA graphs deliver 4x performance improvements for prefill on models such as Llama 3.2-1B.
1 parent 5fe9f7d commit 86ca89a

12 files changed

+197
-100
lines changed

include/flashinfer/attention/scheduler.cuh

+70-45
Original file line numberDiff line numberDiff line change
@@ -419,21 +419,41 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
419419
return cudaSuccess;
420420
}
421421

422+
inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
423+
if (avg_packed_qo_len > 64 && head_dim < 256) {
424+
return 128;
425+
} else {
426+
auto compute_capacity = GetCudaComputeCapability();
427+
if (compute_capacity.first >= 8) {
428+
// Ampere or newer
429+
if (avg_packed_qo_len > 16) {
430+
// avg_packed_qo_len <= 64
431+
return 64;
432+
} else {
433+
// avg_packed_qo_len <= 16
434+
return 16;
435+
}
436+
} else {
437+
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
438+
return 64;
439+
}
440+
}
441+
}
442+
422443
template <typename IdType>
423-
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
424-
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
425-
uint32_t page_size, uint32_t max_batch_size_if_split,
426-
bool enable_cuda_graph) {
444+
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
445+
uint32_t total_num_rows, uint32_t max_seq_len,
446+
uint32_t batch_size, uint32_t num_qo_heads,
447+
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
448+
uint32_t max_batch_size_if_split, bool enable_cuda_graph) {
427449
std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
428450
merge_indptr.push_back(0);
429451
o_indptr.push_back(0);
430452

431453
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
432-
uint32_t total_num_rows = qo_indptr_h[batch_size];
433454

434-
// step 1: compute qo_chunk_size
455+
// step 1: determine packed_qo_len_arr and verify qo_indptr contents.
435456
std::vector<int64_t> packed_qo_len_arr(batch_size), kv_len_arr(batch_size);
436-
int64_t sum_packed_qo_len = 0;
437457
for (uint32_t i = 0; i < batch_size; ++i) {
438458
packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size);
439459
if (packed_qo_len_arr[i] < 0) {
@@ -449,41 +469,43 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
449469
<< kv_indptr_h[i] << " should be non-negative";
450470
FLASHINFER_ERROR(err_msg.str());
451471
}
452-
sum_packed_qo_len += packed_qo_len_arr[i];
453472
}
454-
int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
473+
474+
// step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
455475
uint32_t cta_tile_q;
456-
if (avg_packed_qo_len > 64 && head_dim < 256) {
457-
cta_tile_q = 128;
476+
uint32_t total_num_tiles_q;
477+
bool split_kv;
478+
int64_t kv_chunk_size, new_batch_size;
479+
if (enable_cuda_graph) {
480+
// When CUDA graphs are enabled, the lengths of sequences determined by
481+
// qo_indptr_h can vary. We assume that the dummy data based on which
482+
// the CUDA graph is created fixes the maximum number of tokens.
483+
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
484+
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);
485+
486+
// Find an upper bound for the number of tiles, derived from the total
487+
// number of rows and the batch size. The sum of qo lengths rounded
488+
// up to cta_tile_q will not exceed this number derived from the total
489+
// number of rows.
490+
total_num_tiles_q = ceil_div(total_num_rows, cta_tile_q) + batch_size;
491+
492+
split_kv = true;
493+
kv_chunk_size = max_batch_size_if_split;
494+
new_batch_size = max_batch_size_if_split;
458495
} else {
459-
auto compute_capacity = GetCudaComputeCapability();
460-
if (compute_capacity.first >= 8) {
461-
// Ampere or newer
462-
if (avg_packed_qo_len > 16) {
463-
// avg_packed_qo_len <= 64
464-
cta_tile_q = 64;
465-
} else {
466-
// avg_packed_qo_len <= 16
467-
cta_tile_q = 16;
468-
}
469-
} else {
470-
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
471-
cta_tile_q = 64;
496+
total_num_tiles_q = 0;
497+
int64_t sum_packed_qo_len = 0;
498+
for (uint32_t i = 0; i < batch_size; ++i) {
499+
total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
500+
sum_packed_qo_len += packed_qo_len_arr[i];
472501
}
473-
}
474502

475-
uint32_t total_num_tiles_q = 0;
476-
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
477-
total_num_tiles_q += ceil_div(packed_qo_len_arr[request_idx], cta_tile_q);
478-
}
503+
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
504+
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);
479505

480-
// step 2: determine kv_chunk_size
481-
auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize(
482-
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
483-
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));
484-
485-
if (enable_cuda_graph) {
486-
split_kv = total_num_tiles_q < max_batch_size_if_split;
506+
std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize(
507+
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
508+
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));
487509
}
488510

489511
// step 3: split qo_indptr and kv_indptr
@@ -511,7 +533,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
511533
kv_chunk_size *= page_size;
512534

513535
return std::make_tuple(split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
514-
total_num_rows, std::move(request_indices), std::move(qo_tile_indices),
536+
std::move(request_indices), std::move(qo_tile_indices),
515537
std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr));
516538
}
517539

@@ -597,9 +619,10 @@ template <typename IdType>
597619
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
598620
void* int_buffer, void* page_locked_int_buffer,
599621
size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
600-
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
601-
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
602-
uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o,
622+
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
623+
uint32_t max_seq_len, uint32_t batch_size, uint32_t num_qo_heads,
624+
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
625+
bool enable_cuda_graph, uint32_t sizeof_dtype_o,
603626
cudaStream_t stream) {
604627
if (num_qo_heads % num_kv_heads != 0) {
605628
std::ostringstream err_msg;
@@ -618,17 +641,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
618641
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
619642

620643
// step 2: determine kv_chunk_size
621-
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows,
622-
request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
623-
o_indptr_vec] =
624-
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads,
625-
head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
644+
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
645+
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
646+
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size,
647+
num_qo_heads, num_kv_heads, head_dim, page_size,
648+
max_batch_size_if_split, enable_cuda_graph);
626649
plan_info.cta_tile_q = cta_tile_q;
627650
plan_info.total_num_rows = total_num_rows;
628651

629652
plan_info.enable_cuda_graph = enable_cuda_graph;
630653
size_t padded_batch_size =
631654
enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
655+
632656
plan_info.padded_batch_size = padded_batch_size;
633657
plan_info.split_kv = split_kv;
634658

@@ -679,6 +703,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
679703
sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr");
680704
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset(
681705
sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask");
706+
682707
IdType* merge_indptr_h =
683708
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset);
684709
bool* block_valid_mask_h =

python/csrc/batch_prefill.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ using namespace flashinfer;
4242
std::vector<int64_t> BatchPrefillWithKVCachePlan(
4343
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
4444
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
45-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
46-
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
45+
unsigned int total_num_rows, unsigned int max_seq_len, unsigned int batch_size,
46+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
47+
bool enable_cuda_graph, int64_t cuda_stream) {
4748
size_t float_workspace_size_in_bytes =
4849
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4950
size_t int_workspace_size_in_bytes =
@@ -58,8 +59,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
5859
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
5960
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6061
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
61-
kv_indptr.data_ptr<IdType>(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size,
62-
enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
62+
kv_indptr.data_ptr<IdType>(), total_num_rows, max_seq_len, batch_size, num_qo_heads,
63+
num_kv_heads, head_dim, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
6364

6465
TORCH_CHECK(status == cudaSuccess,
6566
"Failed to plan prefill with error: ", cudaGetErrorString(status));

python/csrc/flashinfer_ops.cu

+3-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,9 @@ void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at:
9999
std::vector<int64_t> BatchPrefillWithKVCachePlan(
100100
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
101101
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
102-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
103-
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
102+
unsigned total_num_rows, unsigned int max_seq_len, unsigned int batch_size,
103+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
104+
bool enable_cuda_graph, int64_t cuda_stream);
104105

105106
void BatchPrefillWithRaggedKVCacheRun(
106107
unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,

python/flashinfer/decode.py

+2
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,9 @@ def plan(
776776
self._pin_memory_int_workspace_buffer,
777777
qo_indptr_host,
778778
indptr_host,
779+
batch_size, # total_num_rows
779780
batch_size,
781+
1, # max_seq_len
780782
num_qo_heads,
781783
num_kv_heads,
782784
page_size,

python/flashinfer/jit/batch_prefill_templ.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
139139
at::Tensor page_locked_int_workspace_buffer,
140140
at::Tensor qo_indptr,
141141
at::Tensor kv_indptr,
142+
unsigned int total_num_rows,
143+
unsigned int max_seq_len,
142144
unsigned int batch_size,
143145
unsigned int num_qo_heads,
144146
unsigned int num_kv_heads,
@@ -156,8 +158,9 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
156158
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
157159
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
158160
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{dtype_idx}}>(),
159-
kv_indptr.data_ptr<{{dtype_idx}}>(), batch_size, num_qo_heads, num_kv_heads, {{head_dim}},
160-
page_size, enable_cuda_graph, sizeof({{dtype_o}}), stream);
161+
kv_indptr.data_ptr<{{dtype_idx}}>(), total_num_rows, max_seq_len,
162+
batch_size, num_qo_heads, num_kv_heads, {{head_dim}}, page_size,
163+
enable_cuda_graph, sizeof({{dtype_o}}), stream);
161164
162165
TORCH_CHECK(status == cudaSuccess,
163166
"Failed to plan prefill with error: ", cudaGetErrorString(status));
@@ -457,6 +460,8 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
457460
at::Tensor page_locked_int_workspace_buffer,
458461
at::Tensor qo_indptr,
459462
at::Tensor kv_indptr,
463+
unsigned int total_num_rows,
464+
unsigned int max_seq_len,
460465
unsigned int batch_size,
461466
unsigned int num_qo_heads,
462467
unsigned int num_kv_heads,

python/flashinfer/prefill.py

+59-8
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,8 @@ def __init__(
833833
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
834834
self._custom_mask_buf = custom_mask_buf
835835
self._qk_indptr_buf = qk_indptr_buf
836+
self._max_total_num_rows = None
837+
self._max_seq_len = None
836838

837839
@property
838840
def is_cuda_graph_enabled(self) -> bool:
@@ -993,7 +995,33 @@ def plan(
993995
bitorder="little",
994996
)
995997

998+
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
999+
qo_indptr_host = qo_indptr.to("cpu")
1000+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1001+
1002+
total_num_rows = qo_indptr_host[-1]
1003+
max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()
1004+
9961005
if self.is_cuda_graph_enabled:
1006+
if self._max_total_num_rows is None:
1007+
self._max_total_num_rows = total_num_rows
1008+
elif total_num_rows > self._max_total_num_rows:
1009+
raise ValueError(
1010+
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
1011+
"exceed the number of rows set during initialization {}.".format(
1012+
total_num_rows, self._max_total_num_rows
1013+
)
1014+
)
1015+
if self._max_seq_len is None:
1016+
self._max_seq_len = max_seq_len
1017+
elif max_seq_len > self._max_seq_len:
1018+
raise ValueError(
1019+
"The maximum sequence length in qo_indptr {} in cuda graph mode cannot "
1020+
"exceed the sequence length set during initialization {}.".format(
1021+
max_seq_len, self._max_seq_len
1022+
)
1023+
)
1024+
9971025
if batch_size != self._fixed_batch_size:
9981026
raise ValueError(
9991027
"The batch size should be fixed during the lifecycle of the wrapper in "
@@ -1049,10 +1077,6 @@ def plan(
10491077
self.device, non_blocking=non_blocking
10501078
)
10511079

1052-
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1053-
qo_indptr_host = qo_indptr.to("cpu")
1054-
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1055-
10561080
self._cached_q_data_type = q_data_type
10571081
self._cached_kv_data_type = kv_data_type
10581082
self._cached_module = get_batch_prefill_module(
@@ -1073,6 +1097,8 @@ def plan(
10731097
self._pin_memory_int_workspace_buffer,
10741098
qo_indptr_host,
10751099
paged_kv_indptr_host,
1100+
total_num_rows,
1101+
max_seq_len,
10761102
batch_size,
10771103
num_qo_heads,
10781104
num_kv_heads,
@@ -1463,6 +1489,7 @@ def __init__(
14631489
self._kv_indptr_buf = kv_indptr_buf
14641490
self._custom_mask_buf = custom_mask_buf
14651491
self._qk_indptr_buf = qk_indptr_buf
1492+
self._max_total_num_rows = None
14661493

14671494
@property
14681495
def is_cuda_graph_enabled(self) -> bool:
@@ -1610,7 +1637,33 @@ def plan(
16101637
bitorder="little",
16111638
)
16121639

1640+
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1641+
qo_indptr_host = qo_indptr.to("cpu")
1642+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1643+
1644+
total_num_rows = qo_indptr_host[-1]
1645+
max_seq_len = torch.max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()
1646+
16131647
if self.is_cuda_graph_enabled:
1648+
if self._max_total_num_rows is None:
1649+
self._max_total_num_rows = total_num_rows
1650+
elif total_num_rows > self._max_total_num_rows:
1651+
raise ValueError(
1652+
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
1653+
"exceed the number of rows set during initialization {}.".format(
1654+
total_num_rows, self._max_total_num_rows
1655+
)
1656+
)
1657+
if self._max_seq_len is None:
1658+
self._max_seq_len = max_seq_len
1659+
elif max_seq_len > self._max_seq_len:
1660+
raise ValueError(
1661+
"The maximum sequence length in qo_indptr {} in cuda graph mode cannot "
1662+
"exceed the sequence length set during initialization {}.".format(
1663+
max_seq_len, self._max_seq_len
1664+
)
1665+
)
1666+
16141667
if batch_size != self._fixed_batch_size:
16151668
raise ValueError(
16161669
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
@@ -1638,10 +1691,6 @@ def plan(
16381691
self._custom_mask_buf = packed_custom_mask.to(self.device)
16391692
self._qk_indptr_buf = qk_indptr.to(self.device)
16401693

1641-
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1642-
qo_indptr_host = qo_indptr.to("cpu")
1643-
kv_indptr_host = kv_indptr.to("cpu")
1644-
16451694
self._cached_q_data_type = q_data_type
16461695
self._cached_kv_data_type = kv_data_type
16471696
self._cached_module = get_batch_prefill_module(
@@ -1662,6 +1711,8 @@ def plan(
16621711
self._pin_memory_int_workspace_buffer,
16631712
qo_indptr_host,
16641713
kv_indptr_host,
1714+
total_num_rows,
1715+
max_seq_len,
16651716
batch_size,
16661717
num_qo_heads,
16671718
num_kv_heads,

src/bench_batch_decode.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
144144
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
145145
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
146146

147-
handler.Plan<T, int32_t>((void*)thrust::raw_pointer_cast(float_buffer.data()),
148-
float_workspace_size_in_bytes,
149-
(void*)thrust::raw_pointer_cast(int_buffer.data()),
150-
int_workspace_size_in_bytes, qo_indptr_h.data(), kv_indptr_host.data(),
151-
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
147+
handler.Plan<T, int32_t>(
148+
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
149+
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
150+
qo_indptr_h.data(), kv_indptr_host.data(), /*total_num_rows=*/batch_size, /*max_seq_len=*/1,
151+
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
152152

153153
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
154154
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(

0 commit comments

Comments
 (0)