Skip to content

Commit 90e42a7

Browse files
authored
fix: batch decode kernel redundant store output to gmem (#505)
Hi, this is a minor fix, when bdz is greater than 1, there would be redundant store to gmem operations for some warps. We may also check 'if (tx == 0)' when storing lse value, but since bdx is 32 most of the time, I think that would be fine. Co-authored-by: tsu-bin <[email protected]>
1 parent 33ef957 commit 90e42a7

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

include/flashinfer/attention/decode.cuh

+6-4
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
575575
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
576576
st.normalize();
577577

578-
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
579-
// write lse
580-
if (lse != nullptr) {
581-
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
578+
if (tz == 0) {
579+
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
580+
// write lse
581+
if (lse != nullptr) {
582+
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
583+
}
582584
}
583585
}
584586

0 commit comments

Comments
 (0)