Skip to content

Commit 595cf60

Browse files
authored
perf: fix prefill kernel performance degradation (step 1) (#602)
The prefill attention kernel performance has degraded significantly in recent releases (since v0.1.2), especially on A100 when `causal=True`, this is mainly because we add new attention variants (which increases register usage thus incurs register spilling) and move some parameters from compile-time to runtime. This PR alleviates the issue by caching some of the variables regarding GQA group size. In the next PR, we will support another mode `kv_head_major` in addition to `qo_head_major`, to further accelerate GQA prefill with query size >= 64. cc @AKKamath
1 parent 3dd9405 commit 595cf60

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

include/flashinfer/attention/prefill.cuh

+26-12
Original file line numberDiff line numberDiff line change
@@ -623,18 +623,25 @@ __device__ __forceinline__ void logits_transform(const typename AttentionVariant
623623
const uint_fastdiv group_size,
624624
DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) {
625625
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
626+
uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2];
627+
#pragma unroll
628+
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
629+
#pragma unroll
630+
for (uint32_t j = 0; j < 2; ++j) {
631+
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]);
632+
}
633+
}
634+
626635
#pragma unroll
627636
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
628637
#pragma unroll
629638
for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) {
630639
#pragma unroll
631640
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
632-
uint32_t q, r;
633-
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q,
634-
r);
635-
const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) +
636-
8 * (reg_id / 4) + reg_id % 2;
637-
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
641+
const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 +
642+
2 * (lane_idx % 4) +
643+
8 * (reg_id / 4) + reg_id % 2;
644+
const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2];
638645
s_frag[fq][fkv][reg_id] = variant.LogitsTransform(
639646
params, s_frag[fq][fkv][reg_id], batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx);
640647
}
@@ -652,18 +659,25 @@ __device__ __forceinline__ void logits_mask(const typename AttentionVariant::Par
652659
const uint_fastdiv group_size,
653660
DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) {
654661
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
662+
uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2];
663+
#pragma unroll
664+
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
665+
#pragma unroll
666+
for (uint32_t j = 0; j < 2; ++j) {
667+
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]);
668+
}
669+
}
670+
655671
#pragma unroll
656672
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
657673
#pragma unroll
658674
for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) {
659675
#pragma unroll
660676
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
661-
uint32_t q, r;
662-
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q,
663-
r);
664-
const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) +
665-
8 * (reg_id / 4) + reg_id % 2;
666-
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
677+
const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 +
678+
2 * (lane_idx % 4) +
679+
8 * (reg_id / 4) + reg_id % 2;
680+
const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2];
667681
const bool mask =
668682
(!(MASK_MODE == MaskMode::kCausal
669683
? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end))

0 commit comments

Comments
 (0)