Skip to content

Commit 7c397cb

Browse files
authored
perf: slight optimization on fragment layout swizzle (#458)
fuse two byte perm into one.
1 parent 85b4c77 commit 7c397cb

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

include/flashinfer/frag_layout_swizzle.cuh

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) {
2929
}
3030

3131
__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) {
32-
x = __byte_perm(x, x, 0x3120);
3332
uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4);
34-
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x5410 : 0x3276);
33+
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175);
3534
tmp = __shfl_xor_sync(0xffffffff, x, 0x8);
3635
x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276);
3736
tmp = __shfl_xor_sync(0xffffffff, x, 0x10);

include/flashinfer/vec_dtypes.cuh

-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
126126
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
127127
constexpr int MASK3 = MASK2 & 0x7fffffff;
128128
constexpr int MASK = MASK3 | (MASK3 >> 16);
129-
// Final MASK value: 0x7F007F00
130129
q = __byte_perm(q, q, 0x1302);
131130

132131
// Extract and shift FP8 values to FP16 format

0 commit comments

Comments
 (0)