@@ -301,85 +301,94 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
301
301
*/
302
302
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
303
303
typename DTypeOut, typename IdType>
304
- __global__ void VariableLengthMergeStatesKernel (DTypeIn* __restrict__ V, float * __restrict__ S,
305
- IdType* indptr, DTypeOut* __restrict__ v_merged,
306
- float * __restrict__ s_merged, uint32_t num_heads) {
304
+ __global__ void PersistentVariableLengthMergeStatesKernel (DTypeIn* __restrict__ V,
305
+ float * __restrict__ S, IdType* indptr,
306
+ DTypeOut* __restrict__ v_merged,
307
+ float * __restrict__ s_merged,
308
+ uint32_t seq_len, uint32_t num_heads) {
307
309
uint32_t tx = threadIdx .x , ty = threadIdx .y ;
308
- uint32_t pos = blockIdx .x ;
309
- uint32_t head_idx = blockIdx . y ;
310
- state_t <vec_size> st ;
310
+ uint32_t cta_id = blockIdx .x ;
311
+ uint32_t num_ctas = gridDim . x ;
312
+ uint32_t num_iters = ceil_div (seq_len * num_heads, num_ctas) ;
311
313
constexpr uint32_t vec_bits = sizeof (DTypeIn) * vec_size * 8 ;
312
314
constexpr uint32_t head_dim = vec_size * bdx;
313
-
314
315
extern __shared__ uint8_t smem[];
315
316
DTypeIn* v_smem = (DTypeIn*)smem;
316
317
float * s_smem = (float *)(smem + num_smem_stages * bdy * head_dim * sizeof (DTypeIn));
317
- const uint32_t num_index_sets = indptr[pos + 1 ] - indptr[pos];
318
318
319
- if (num_index_sets == 0 ) {
320
- vec_t <DTypeOut, vec_size> v;
321
- v.fill (DTypeOut (0 ));
322
- v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
323
- if (s_merged != nullptr ) {
324
- s_merged[pos * num_heads + head_idx] = -5e4 ;
319
+ #pragma unroll 1
320
+ for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) {
321
+ uint32_t pos = i / num_heads;
322
+ uint32_t head_idx = i % num_heads;
323
+ state_t <vec_size> st;
324
+ const uint32_t num_index_sets = indptr[pos + 1 ] - indptr[pos];
325
+
326
+ if (num_index_sets == 0 ) {
327
+ vec_t <DTypeOut, vec_size> v;
328
+ v.fill (DTypeOut (0 ));
329
+ v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
330
+ if (s_merged != nullptr ) {
331
+ s_merged[pos * num_heads + head_idx] = -5e4 ;
332
+ }
333
+ continue ;
325
334
}
326
- return ;
327
- }
328
335
329
- if (num_index_sets == 1 ) {
330
- vec_t <DTypeOut, vec_size> v;
331
- v.cast_load (V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
332
- v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
333
- if (s_merged != nullptr ) {
334
- s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
336
+ if (num_index_sets == 1 ) {
337
+ vec_t <DTypeOut, vec_size> v;
338
+ v.cast_load (V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
339
+ v.store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
340
+ if (s_merged != nullptr ) {
341
+ s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
342
+ }
343
+ continue ;
335
344
}
336
- }
337
345
338
346
#pragma unroll
339
- for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
340
- cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
341
- v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
342
- V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
343
- (iter * bdy + ty) < num_index_sets);
344
- cp_async::commit_group ();
345
- }
347
+ for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
348
+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
349
+ v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
350
+ V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
351
+ (iter * bdy + ty) < num_index_sets);
352
+ cp_async::commit_group ();
353
+ }
346
354
#pragma unroll 4
347
- for (uint32_t iter = 0 ; iter < ceil_div (num_index_sets, bdy); ++iter) {
348
- if (iter % bdx == 0 ) {
349
- s_smem[ty * bdx + tx] =
350
- iter * bdy + (ty * bdx + tx) < num_index_sets
351
- ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
352
- : 0 .f ;
355
+ for (uint32_t iter = 0 ; iter < ceil_div (num_index_sets, bdy); ++iter) {
356
+ if (iter % bdx == 0 ) {
357
+ s_smem[ty * bdx + tx] =
358
+ iter * bdy + (ty * bdx + tx) < num_index_sets
359
+ ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
360
+ : 0 .f ;
361
+ __syncthreads ();
362
+ }
363
+ cp_async::wait_group<num_smem_stages - 1 >();
353
364
__syncthreads ();
365
+ vec_t <float , vec_size> v;
366
+ v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
367
+ if (iter * bdy + ty < num_index_sets) {
368
+ float s = s_smem[(iter % bdx) * bdy + ty];
369
+ st.merge (v, s, 1 );
370
+ }
371
+ __syncthreads ();
372
+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
373
+ v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
374
+ V +
375
+ ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
376
+ head_dim +
377
+ tx * vec_size,
378
+ (iter + num_smem_stages) * bdy + ty < num_index_sets);
379
+ cp_async::commit_group ();
354
380
}
355
- cp_async::wait_group<num_smem_stages - 1 >();
356
- __syncthreads ();
357
- vec_t <float , vec_size> v;
358
- v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
359
- if (iter * bdy + ty < num_index_sets) {
360
- float s = s_smem[(iter % bdx) * bdy + ty];
361
- st.merge (v, s, 1 );
362
- }
381
+ cp_async::wait_group<0 >();
363
382
__syncthreads ();
364
- cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
365
- v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
366
- V +
367
- ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
368
- head_dim +
369
- tx * vec_size,
370
- (iter + num_smem_stages) * bdy + ty < num_index_sets);
371
- cp_async::commit_group ();
372
- }
373
- cp_async::wait_group<0 >();
374
- __syncthreads ();
375
383
376
- st.normalize ();
377
- threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
378
- st.normalize ();
384
+ st.normalize ();
385
+ threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
386
+ st.normalize ();
379
387
380
- st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
381
- if (s_merged != nullptr ) {
382
- s_merged[pos * num_heads + head_idx] = st.get_lse ();
388
+ st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
389
+ if (s_merged != nullptr ) {
390
+ s_merged[pos * num_heads + head_idx] = st.get_lse ();
391
+ }
383
392
}
384
393
}
385
394
@@ -502,19 +511,29 @@ template <typename DTypeIn, typename DTypeOut, typename IdType>
502
511
cudaError_t VariableLengthMergeStates (DTypeIn* v, float * s, IdType* indptr, DTypeOut* v_merged,
503
512
float * s_merged, uint32_t seq_len, uint32_t num_heads,
504
513
uint32_t head_dim, cudaStream_t stream = nullptr ) {
514
+ int dev_id = 0 ;
515
+ int num_sms = 0 ;
516
+ int num_blocks_per_sm = 0 ;
517
+ FLASHINFER_CUDA_CALL (cudaGetDevice (&dev_id));
518
+ FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
519
+
505
520
DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
506
521
constexpr uint32_t vec_size = std::max (16U / sizeof (DTypeIn), HEAD_DIM / 32U );
507
522
constexpr uint32_t bdx = HEAD_DIM / vec_size;
508
523
constexpr uint32_t num_threads = 128 ;
509
524
constexpr uint32_t bdy = num_threads / bdx;
510
- dim3 nblks (seq_len, num_heads);
511
- dim3 nthrs (bdx, bdy);
512
525
constexpr uint32_t num_smem_stages = 4 ;
513
- auto kernel = VariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
514
- DTypeOut, IdType>;
515
- void * args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads};
516
526
uint32_t smem_size =
517
527
num_smem_stages * bdy * head_dim * sizeof (DTypeIn) + num_threads * sizeof (float );
528
+ auto kernel = PersistentVariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages,
529
+ DTypeIn, DTypeOut, IdType>;
530
+ FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&num_blocks_per_sm, kernel,
531
+ num_threads, smem_size));
532
+ num_blocks_per_sm = min (num_blocks_per_sm, ceil_div (seq_len * num_heads, num_sms));
533
+
534
+ dim3 nblks (num_sms * num_blocks_per_sm);
535
+ dim3 nthrs (bdx, bdy);
536
+ void * args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
518
537
FLASHINFER_CUDA_CALL (
519
538
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
520
539
FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
0 commit comments