Skip to content

Commit 0d61871

Browse files
authored
perf: slight optimization on f16->f8 fragment layout swizzling (#453)
swap after dequantize.
1 parent fa38b5e commit 0d61871

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

include/flashinfer/attention/prefill.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,8 @@ __device__ __forceinline__ void compute_sfm_v(smem_t<swizzle_mode>* v_smem,
802802
}
803803
b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]);
804804
b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]);
805-
bfly_exch(b_frag_f8[0], b_frag_f8[1]);
806805
vec_cast<DTypeQ, DTypeKV>::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8);
806+
swap(b_frag[1], b_frag[2]);
807807
} else {
808808
v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag);
809809
}

include/flashinfer/frag_layout_swizzle.cuh

-7
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,4 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t
3939
return x;
4040
}
4141

42-
__device__ __forceinline__ void bfly_exch(uint32_t& a, uint32_t& b) {
43-
uint32_t tmp = __byte_perm(a, b, 0x5410);
44-
uint32_t tmp2 = __byte_perm(a, b, 0x7632);
45-
a = tmp;
46-
b = tmp2;
47-
}
48-
4942
#endif // FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_

include/flashinfer/utils.cuh

+6
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t
253253
return (x > y) ? x - y : 0U;
254254
}
255255

256+
__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
257+
uint32_t tmp = a;
258+
a = b;
259+
b = tmp;
260+
}
261+
256262
} // namespace flashinfer
257263

258264
#endif // FLASHINFER_UTILS_CUH_

0 commit comments

Comments
 (0)