Skip to content

Commit be6bf5b

Browse files
authored
perf: use persistent kernel for merging attention states (#459)
As observed by @MasterJH5574 , there are cases where our `VariableLengthMergeStatesKernel` launches a lot of CTAs (>=10k) while most of the CTAs only work on small number of merges, this PR fixes the issue by using a persistent kernel. There is still load imbalance issue, and I plan to resolve it inside scheduler. I'll leave it for later PRs.
1 parent 048560d commit be6bf5b

File tree

1 file changed

+85
-66
lines changed

1 file changed

+85
-66
lines changed

include/flashinfer/attention/cascade.cuh

+85-66
Original file line numberDiff line numberDiff line change
@@ -301,85 +301,94 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
301301
*/
302302
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
303303
typename DTypeOut, typename IdType>
304-
__global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S,
305-
IdType* indptr, DTypeOut* __restrict__ v_merged,
306-
float* __restrict__ s_merged, uint32_t num_heads) {
304+
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
305+
float* __restrict__ S, IdType* indptr,
306+
DTypeOut* __restrict__ v_merged,
307+
float* __restrict__ s_merged,
308+
uint32_t seq_len, uint32_t num_heads) {
307309
uint32_t tx = threadIdx.x, ty = threadIdx.y;
308-
uint32_t pos = blockIdx.x;
309-
uint32_t head_idx = blockIdx.y;
310-
state_t<vec_size> st;
310+
uint32_t cta_id = blockIdx.x;
311+
uint32_t num_ctas = gridDim.x;
312+
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
311313
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
312314
constexpr uint32_t head_dim = vec_size * bdx;
313-
314315
extern __shared__ uint8_t smem[];
315316
DTypeIn* v_smem = (DTypeIn*)smem;
316317
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
317-
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
318318

319-
if (num_index_sets == 0) {
320-
vec_t<DTypeOut, vec_size> v;
321-
v.fill(DTypeOut(0));
322-
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
323-
if (s_merged != nullptr) {
324-
s_merged[pos * num_heads + head_idx] = -5e4;
319+
#pragma unroll 1
320+
for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) {
321+
uint32_t pos = i / num_heads;
322+
uint32_t head_idx = i % num_heads;
323+
state_t<vec_size> st;
324+
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
325+
326+
if (num_index_sets == 0) {
327+
vec_t<DTypeOut, vec_size> v;
328+
v.fill(DTypeOut(0));
329+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
330+
if (s_merged != nullptr) {
331+
s_merged[pos * num_heads + head_idx] = -5e4;
332+
}
333+
continue;
325334
}
326-
return;
327-
}
328335

329-
if (num_index_sets == 1) {
330-
vec_t<DTypeOut, vec_size> v;
331-
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
332-
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
333-
if (s_merged != nullptr) {
334-
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
336+
if (num_index_sets == 1) {
337+
vec_t<DTypeOut, vec_size> v;
338+
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
339+
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
340+
if (s_merged != nullptr) {
341+
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
342+
}
343+
continue;
335344
}
336-
}
337345

338346
#pragma unroll
339-
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
340-
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
341-
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
342-
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
343-
(iter * bdy + ty) < num_index_sets);
344-
cp_async::commit_group();
345-
}
347+
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
348+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
349+
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
350+
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
351+
(iter * bdy + ty) < num_index_sets);
352+
cp_async::commit_group();
353+
}
346354
#pragma unroll 4
347-
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
348-
if (iter % bdx == 0) {
349-
s_smem[ty * bdx + tx] =
350-
iter * bdy + (ty * bdx + tx) < num_index_sets
351-
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
352-
: 0.f;
355+
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
356+
if (iter % bdx == 0) {
357+
s_smem[ty * bdx + tx] =
358+
iter * bdy + (ty * bdx + tx) < num_index_sets
359+
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
360+
: 0.f;
361+
__syncthreads();
362+
}
363+
cp_async::wait_group<num_smem_stages - 1>();
353364
__syncthreads();
365+
vec_t<float, vec_size> v;
366+
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
367+
if (iter * bdy + ty < num_index_sets) {
368+
float s = s_smem[(iter % bdx) * bdy + ty];
369+
st.merge(v, s, 1);
370+
}
371+
__syncthreads();
372+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
373+
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
374+
V +
375+
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
376+
head_dim +
377+
tx * vec_size,
378+
(iter + num_smem_stages) * bdy + ty < num_index_sets);
379+
cp_async::commit_group();
354380
}
355-
cp_async::wait_group<num_smem_stages - 1>();
356-
__syncthreads();
357-
vec_t<float, vec_size> v;
358-
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
359-
if (iter * bdy + ty < num_index_sets) {
360-
float s = s_smem[(iter % bdx) * bdy + ty];
361-
st.merge(v, s, 1);
362-
}
381+
cp_async::wait_group<0>();
363382
__syncthreads();
364-
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
365-
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
366-
V +
367-
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
368-
head_dim +
369-
tx * vec_size,
370-
(iter + num_smem_stages) * bdy + ty < num_index_sets);
371-
cp_async::commit_group();
372-
}
373-
cp_async::wait_group<0>();
374-
__syncthreads();
375383

376-
st.normalize();
377-
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
378-
st.normalize();
384+
st.normalize();
385+
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
386+
st.normalize();
379387

380-
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
381-
if (s_merged != nullptr) {
382-
s_merged[pos * num_heads + head_idx] = st.get_lse();
388+
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
389+
if (s_merged != nullptr) {
390+
s_merged[pos * num_heads + head_idx] = st.get_lse();
391+
}
383392
}
384393
}
385394

@@ -502,19 +511,29 @@ template <typename DTypeIn, typename DTypeOut, typename IdType>
502511
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged,
503512
float* s_merged, uint32_t seq_len, uint32_t num_heads,
504513
uint32_t head_dim, cudaStream_t stream = nullptr) {
514+
int dev_id = 0;
515+
int num_sms = 0;
516+
int num_blocks_per_sm = 0;
517+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
518+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
519+
505520
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
506521
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
507522
constexpr uint32_t bdx = HEAD_DIM / vec_size;
508523
constexpr uint32_t num_threads = 128;
509524
constexpr uint32_t bdy = num_threads / bdx;
510-
dim3 nblks(seq_len, num_heads);
511-
dim3 nthrs(bdx, bdy);
512525
constexpr uint32_t num_smem_stages = 4;
513-
auto kernel = VariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
514-
DTypeOut, IdType>;
515-
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads};
516526
uint32_t smem_size =
517527
num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float);
528+
auto kernel = PersistentVariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages,
529+
DTypeIn, DTypeOut, IdType>;
530+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
531+
num_threads, smem_size));
532+
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));
533+
534+
dim3 nblks(num_sms * num_blocks_per_sm);
535+
dim3 nthrs(bdx, bdy);
536+
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
518537
FLASHINFER_CUDA_CALL(
519538
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
520539
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));

0 commit comments

Comments
 (0)