Skip to content

Commit 92ac440

Browse files
authored
feat: allow the cascade kernels to be executed using varying sequence lenghts (#627)
The cascade kernels can take a dynamic sequence length in order to allow the number of tokens to vary when executed under CUDA graphs. This is the first step towards implementing CUDA graph support for arbitrary `qo_indptr` contents, as tracked by #626.
1 parent f5842b8 commit 92ac440

File tree

4 files changed

+128
-31
lines changed

4 files changed

+128
-31
lines changed

include/flashinfer/attention/cascade.cuh

+25-15
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
325325
/*!
326326
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
327327
* sets at each position might vary.
328+
*
329+
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
330+
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.
331+
*
328332
* \tparam vec_size The vector size used in the kernel.
329333
* \tparam bdx The blockDim.x used in the kernel.
330334
* \tparam bdy The blockDim.y used in the kernel.
@@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
336340
* \param indptr The start offsets of each position in the variable length array.
337341
* \param v_merged The merged v of index sets union. (n, h, d)
338342
* \param s_merged The merged logsumexp value of index sets union. (n, h)
343+
* \param max_seq_len The maximum sequence length supported by the kernel.
344+
* \param seq_len_ptr The current sequence length (number of positions populated in indptr).
339345
* \param num_heads The number of heads of v.
340346
* \param head_dim The dimension of each head.
341347
* \note s are logsumexp values with base 2.
342348
*/
343349
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
344350
typename DTypeO, typename IdType>
345-
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
346-
float* __restrict__ S, IdType* indptr,
347-
DTypeO* __restrict__ v_merged,
348-
float* __restrict__ s_merged,
349-
uint32_t seq_len, uint32_t num_heads) {
351+
__global__ void PersistentVariableLengthMergeStatesKernel(
352+
DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
353+
float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr,
354+
uint32_t num_heads) {
350355
uint32_t tx = threadIdx.x, ty = threadIdx.y;
351356
uint32_t cta_id = blockIdx.x;
352357
uint32_t num_ctas = gridDim.x;
358+
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
353359
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
354360
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
355361
constexpr uint32_t head_dim = vec_size * bdx;
@@ -437,10 +443,13 @@ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stage
437443
typename DTypeO, typename IdType>
438444
__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr,
439445
DTypeO* __restrict__ v_sum,
440-
uint32_t seq_len, uint32_t num_heads) {
446+
uint32_t max_seq_len,
447+
uint32_t* __restrict__ seq_len_ptr,
448+
uint32_t num_heads) {
441449
uint32_t tx = threadIdx.x, ty = threadIdx.y;
442450
uint32_t cta_id = blockIdx.x;
443451
uint32_t num_ctas = gridDim.x;
452+
const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
444453
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
445454
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
446455
constexpr uint32_t head_dim = vec_size * bdx;
@@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin
641650

642651
template <typename DTypeIn, typename DTypeO, typename IdType>
643652
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged,
644-
float* s_merged, uint32_t seq_len, uint32_t num_heads,
645-
uint32_t head_dim, cudaStream_t stream = nullptr) {
653+
float* s_merged, uint32_t max_seq_len, uint32_t* seq_len,
654+
uint32_t num_heads, uint32_t head_dim,
655+
cudaStream_t stream = nullptr) {
646656
int dev_id = 0;
647657
int num_sms = 0;
648658
int num_blocks_per_sm = 0;
@@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
661671
DTypeIn, DTypeO, IdType>;
662672
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
663673
num_threads, smem_size));
664-
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
674+
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));
665675

666676
dim3 nblks(num_sms * num_blocks_per_sm);
667677
dim3 nthrs(bdx, bdy);
668-
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
678+
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads};
669679
FLASHINFER_CUDA_CALL(
670680
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
671681
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
@@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
674684
}
675685

676686
template <typename DTypeIn, typename DTypeO, typename IdType>
677-
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len,
678-
uint32_t num_heads, uint32_t head_dim,
679-
cudaStream_t stream = nullptr) {
687+
cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum,
688+
uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads,
689+
uint32_t head_dim, cudaStream_t stream = nullptr) {
680690
int dev_id = 0;
681691
int num_sms = 0;
682692
int num_blocks_per_sm = 0;
@@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum
694704
DTypeIn, DTypeO, IdType>;
695705
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
696706
num_threads, smem_size));
697-
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
707+
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms));
698708

699709
dim3 nblks(num_sms * num_blocks_per_sm);
700710
dim3 nthrs(bdx, bdy);
701-
void* args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads};
711+
void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads};
702712
FLASHINFER_CUDA_CALL(
703713
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
704714
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));

include/flashinfer/attention/decode.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -764,12 +764,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
764764
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
765765
if constexpr (AttentionVariant::use_softmax) {
766766
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
767-
params.paged_kv.batch_size, num_qo_heads,
768-
HEAD_DIM, stream));
767+
params.paged_kv.batch_size, nullptr,
768+
num_qo_heads, HEAD_DIM, stream));
769769
} else {
770770
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, o, params.o_indptr,
771-
params.paged_kv.batch_size, num_qo_heads,
772-
HEAD_DIM, stream));
771+
params.paged_kv.batch_size, nullptr,
772+
num_qo_heads, HEAD_DIM, stream));
773773
}
774774
}
775775
});
@@ -1087,8 +1087,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::
10871087
dim3 nthrs(bdx, bdy, bdz);
10881088
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
10891089
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
1090-
params.paged_kv.batch_size, num_qo_heads,
1091-
HEAD_DIM_CKV, stream));
1090+
params.paged_kv.batch_size, nullptr,
1091+
num_qo_heads, HEAD_DIM_CKV, stream));
10921092
}
10931093
});
10941094
return cudaSuccess;

include/flashinfer/attention/prefill.cuh

+10-8
Original file line numberDiff line numberDiff line change
@@ -2199,11 +2199,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
21992199
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
22002200
if constexpr (AttentionVariant::use_softmax) {
22012201
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
2202-
total_num_rows, num_qo_heads, HEAD_DIM,
2203-
stream));
2202+
total_num_rows, nullptr, num_qo_heads,
2203+
HEAD_DIM, stream));
22042204
} else {
2205-
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
2206-
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
2205+
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
2206+
total_num_rows, nullptr, num_qo_heads,
2207+
HEAD_DIM, stream));
22072208
}
22082209
}
22092210
}
@@ -2300,11 +2301,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
23002301
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
23012302
if constexpr (AttentionVariant::use_softmax) {
23022303
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
2303-
total_num_rows, num_qo_heads, HEAD_DIM,
2304-
stream));
2304+
total_num_rows, nullptr, num_qo_heads,
2305+
HEAD_DIM, stream));
23052306
} else {
2306-
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
2307-
tmp_v, params.merge_indptr, o, total_num_rows, num_qo_heads, HEAD_DIM, stream));
2307+
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
2308+
total_num_rows, nullptr, num_qo_heads,
2309+
HEAD_DIM, stream));
23082310
}
23092311
}
23102312
}

src/test_cascade.cu

+87-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
100100
thrust::raw_pointer_cast(S_ragged_device.data()),
101101
thrust::raw_pointer_cast(indptr_device.data()),
102102
thrust::raw_pointer_cast(V_merged_1_device.data()),
103-
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, num_heads,
104-
head_dim);
103+
thrust::raw_pointer_cast(S_merged_1_device.data()), seq_len, nullptr,
104+
num_heads, head_dim);
105105

106106
thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
107107
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);
@@ -133,6 +133,81 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
133133
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
134134
}
135135

136+
template <typename T>
137+
void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) {
138+
ASSERT_LE(seq_len, max_seq_len);
139+
140+
const size_t num_heads = 4;
141+
const size_t head_dim = 64;
142+
const uint32_t max_num_index_sets = 512;
143+
144+
std::vector<int32_t> lengths(max_seq_len);
145+
utils::vec_randint_(lengths, 1, max_num_index_sets);
146+
std::vector<int32_t> indptr(max_seq_len + 1, 0);
147+
for (size_t i = 0; i < seq_len; ++i) {
148+
indptr[i + 1] = indptr[i] + lengths[i];
149+
}
150+
151+
uint32_t last_indptr = indptr[seq_len];
152+
std::vector<T> V_ragged_host(last_indptr * num_heads * head_dim);
153+
std::vector<float> S_ragged_host(last_indptr * num_heads);
154+
155+
utils::vec_normal_(V_ragged_host);
156+
utils::vec_uniform_(S_ragged_host, -10, 10);
157+
158+
thrust::device_vector<T> V_ragged_device(V_ragged_host);
159+
thrust::device_vector<float> S_ragged_device(S_ragged_host);
160+
thrust::device_vector<int32_t> indptr_device(indptr);
161+
thrust::device_vector<T> V_merged_0_device(max_seq_len * num_heads * head_dim);
162+
thrust::device_vector<T> V_merged_1_device(max_seq_len * num_heads * head_dim);
163+
thrust::device_vector<float> S_merged_0_device(max_seq_len * num_heads);
164+
thrust::device_vector<float> S_merged_1_device(max_seq_len * num_heads);
165+
thrust::device_vector<uint32_t> seq_len_device(
166+
std::vector<uint32_t>{static_cast<uint32_t>(seq_len)});
167+
168+
// Reference: use VariableLengthMergeStates on the precisely-sized input.
169+
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
170+
thrust::raw_pointer_cast(S_ragged_device.data()),
171+
thrust::raw_pointer_cast(indptr_device.data()),
172+
thrust::raw_pointer_cast(V_merged_0_device.data()),
173+
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, nullptr,
174+
num_heads, head_dim);
175+
// Expected: use VariableLengthMergeStates on a padded input
176+
VariableLengthMergeStates(thrust::raw_pointer_cast(V_ragged_device.data()),
177+
thrust::raw_pointer_cast(S_ragged_device.data()),
178+
thrust::raw_pointer_cast(indptr_device.data()),
179+
thrust::raw_pointer_cast(V_merged_1_device.data()),
180+
thrust::raw_pointer_cast(S_merged_1_device.data()), max_seq_len,
181+
thrust::raw_pointer_cast(seq_len_device.data()), num_heads, head_dim);
182+
183+
thrust::host_vector<T> V_merged_0_host(V_merged_0_device), V_merged_1_host(V_merged_1_device);
184+
thrust::host_vector<float> S_merged_0_host(S_merged_0_device), S_merged_1_host(S_merged_1_device);
185+
186+
// Compare results
187+
size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0;
188+
for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) {
189+
EXPECT_FALSE(std::isnan(float(V_merged_1_host[i]))) << "V_merged_1_host[" << i << "] is nan";
190+
num_V_result_errors_atol_1e_3_rtol_1e_3 +=
191+
(!utils::isclose(float(V_merged_0_host[i]), float(V_merged_1_host[i]), 1e-3, 1e-3));
192+
}
193+
for (size_t i = 0; i < seq_len * num_heads; ++i) {
194+
EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan";
195+
EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan";
196+
num_S_result_errors_atol_1e_3_rtol_1e_3 +=
197+
(!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3));
198+
}
199+
float V_result_accuracy =
200+
1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim);
201+
float S_result_accuracy =
202+
1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads);
203+
std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim
204+
<< ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy
205+
<< ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl;
206+
207+
EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed.";
208+
EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed.";
209+
}
210+
136211
template <typename T>
137212
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
138213
size_t head_dim, bool sparse_s) {
@@ -515,6 +590,12 @@ void TestVariableLengthMergeKernelCorrectness() {
515590
}
516591
}
517592

593+
template <typename T>
594+
void TestVariableLengthMergeKernelPaddedCorrectness() {
595+
_TestVariableLengthMergeKernelPaddedCorrectness<T>(8, 1);
596+
_TestVariableLengthMergeKernelPaddedCorrectness<T>(128, 77);
597+
}
598+
518599
template <typename T>
519600
void TestTwoLevelSinglePrefixCascadeDecodeCorrectness() {
520601
for (size_t batch_size : {1, 8, 16, 64, 128}) {
@@ -563,6 +644,10 @@ TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) {
563644
TestVariableLengthMergeKernelCorrectness<half>();
564645
}
565646

647+
TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) {
648+
TestVariableLengthMergeKernelPaddedCorrectness<half>();
649+
}
650+
566651
TEST(FlashInferCorrectnessTest, TwoLevelSinglePrefixCascadeDecodeTestFP16) {
567652
TestTwoLevelSinglePrefixCascadeDecodeCorrectness<half>();
568653
}

0 commit comments

Comments
 (0)