@@ -325,6 +325,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
325
325
/* !
326
326
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
327
327
* sets at each position might vary.
328
+ *
329
+ * For CUDA graph support, the kernel can be built with a maximum sequence length and executed
330
+ * using a truncated, dynamic sequence length passed through `seq_len_ptr`.
331
+ *
328
332
* \tparam vec_size The vector size used in the kernel.
329
333
* \tparam bdx The blockDim.x used in the kernel.
330
334
* \tparam bdy The blockDim.y used in the kernel.
@@ -336,20 +340,22 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
336
340
* \param indptr The start offsets of each position in the variable length array.
337
341
* \param v_merged The merged v of index sets union. (n, h, d)
338
342
* \param s_merged The merged logsumexp value of index sets union. (n, h)
343
+ * \param max_seq_len The maximum sequence length supported by the kernel.
344
+ * \param seq_len_ptr The current sequence length (number of positions populated in indptr).
339
345
* \param num_heads The number of heads of v.
340
346
* \param head_dim The dimension of each head.
341
347
* \note s are logsumexp values with base 2.
342
348
*/
343
349
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
344
350
typename DTypeO, typename IdType>
345
- __global__ void PersistentVariableLengthMergeStatesKernel (DTypeIn* __restrict__ V,
346
- float * __restrict__ S, IdType* indptr,
347
- DTypeO* __restrict__ v_merged,
348
- float * __restrict__ s_merged,
349
- uint32_t seq_len, uint32_t num_heads) {
351
+ __global__ void PersistentVariableLengthMergeStatesKernel (
352
+ DTypeIn* __restrict__ V, float * __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged,
353
+ float * __restrict__ s_merged, uint32_t max_seq_len, uint32_t * __restrict__ seq_len_ptr,
354
+ uint32_t num_heads) {
350
355
uint32_t tx = threadIdx .x , ty = threadIdx .y ;
351
356
uint32_t cta_id = blockIdx .x ;
352
357
uint32_t num_ctas = gridDim .x ;
358
+ const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
353
359
uint32_t num_iters = ceil_div (seq_len * num_heads, num_ctas);
354
360
constexpr uint32_t vec_bits = sizeof (DTypeIn) * vec_size * 8 ;
355
361
constexpr uint32_t head_dim = vec_size * bdx;
@@ -437,10 +443,13 @@ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stage
437
443
typename DTypeO, typename IdType>
438
444
__global__ void PersistentVariableLengthAttentionSumKernel (DTypeIn* __restrict__ V, IdType* indptr,
439
445
DTypeO* __restrict__ v_sum,
440
- uint32_t seq_len, uint32_t num_heads) {
446
+ uint32_t max_seq_len,
447
+ uint32_t * __restrict__ seq_len_ptr,
448
+ uint32_t num_heads) {
441
449
uint32_t tx = threadIdx .x , ty = threadIdx .y ;
442
450
uint32_t cta_id = blockIdx .x ;
443
451
uint32_t num_ctas = gridDim .x ;
452
+ const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len;
444
453
uint32_t num_iters = ceil_div (seq_len * num_heads, num_ctas);
445
454
constexpr uint32_t vec_bits = sizeof (DTypeIn) * vec_size * 8 ;
446
455
constexpr uint32_t head_dim = vec_size * bdx;
@@ -641,8 +650,9 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin
641
650
642
651
template <typename DTypeIn, typename DTypeO, typename IdType>
643
652
cudaError_t VariableLengthMergeStates (DTypeIn* v, float * s, IdType* indptr, DTypeO* v_merged,
644
- float * s_merged, uint32_t seq_len, uint32_t num_heads,
645
- uint32_t head_dim, cudaStream_t stream = nullptr ) {
653
+ float * s_merged, uint32_t max_seq_len, uint32_t * seq_len,
654
+ uint32_t num_heads, uint32_t head_dim,
655
+ cudaStream_t stream = nullptr ) {
646
656
int dev_id = 0 ;
647
657
int num_sms = 0 ;
648
658
int num_blocks_per_sm = 0 ;
@@ -661,11 +671,11 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
661
671
DTypeIn, DTypeO, IdType>;
662
672
FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&num_blocks_per_sm, kernel,
663
673
num_threads, smem_size));
664
- num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (seq_len * num_heads, num_sms));
674
+ num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (max_seq_len * num_heads, num_sms));
665
675
666
676
dim3 nblks (num_sms * num_blocks_per_sm);
667
677
dim3 nthrs (bdx, bdy);
668
- void * args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
678
+ void * args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, & seq_len, &num_heads};
669
679
FLASHINFER_CUDA_CALL (
670
680
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
671
681
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
@@ -674,9 +684,9 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp
674
684
}
675
685
676
686
template <typename DTypeIn, typename DTypeO, typename IdType>
677
- cudaError_t VariableLengthAttentionSum (DTypeIn* v, IdType* indptr, DTypeO* v_sum, uint32_t seq_len,
678
- uint32_t num_heads , uint32_t head_dim ,
679
- cudaStream_t stream = nullptr ) {
687
+ cudaError_t VariableLengthAttentionSum (DTypeIn* v, IdType* indptr, DTypeO* v_sum,
688
+ uint32_t max_seq_len , uint32_t * seq_len, uint32_t num_heads ,
689
+ uint32_t head_dim, cudaStream_t stream = nullptr ) {
680
690
int dev_id = 0 ;
681
691
int num_sms = 0 ;
682
692
int num_blocks_per_sm = 0 ;
@@ -694,11 +704,11 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum
694
704
DTypeIn, DTypeO, IdType>;
695
705
FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&num_blocks_per_sm, kernel,
696
706
num_threads, smem_size));
697
- num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (seq_len * num_heads, num_sms));
707
+ num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (max_seq_len * num_heads, num_sms));
698
708
699
709
dim3 nblks (num_sms * num_blocks_per_sm);
700
710
dim3 nthrs (bdx, bdy);
701
- void * args[] = {&v, &indptr, &v_sum, &seq_len, &num_heads};
711
+ void * args[] = {&v, &indptr, &v_sum, &max_seq_len, & seq_len, &num_heads};
702
712
FLASHINFER_CUDA_CALL (
703
713
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
704
714
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
0 commit comments