Skip to content

Commit c7dc921

Browse files
authored
feat: improve the precision of the FusedAddRMSNormKernel function (#587)
When `sizeof(T) == 2`, the sum of the read `input` and `residual` (float `x`) is split into two parts, high and low 16 bits, and saved to `input` and `residual` respectively. Later, `input` and `residual` are read out and combined to `x`, with the aim of improving the precision of the subsequent `x * rms_rcp` operation. Increase precision from 1e-2 to 1e-3.
1 parent d7300c4 commit c7dc921

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

include/flashinfer/norm.cuh

+4-5
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
143143
x += float(residual_vec[j]);
144144
sum_sq += x * x;
145145
residual_vec[j] = (T)x;
146+
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
146147
}
147148
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
148149
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -173,17 +174,15 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
173174
for (uint32_t i = 0; i < rounds; i++) {
174175
vec_t<T, VEC_SIZE> input_vec;
175176
vec_t<T, VEC_SIZE> weight_vec;
176-
vec_t<T, VEC_SIZE> residual_vec;
177177
input_vec.fill(0.f);
178178
weight_vec.fill(0.f);
179-
residual_vec.fill(0.f);
180179
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
181180
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
182-
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
183181
}
184182
#pragma unroll
185183
for (uint32_t j = 0; j < VEC_SIZE; j++) {
186-
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
184+
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
185+
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
187186
}
188187
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
189188
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -200,7 +199,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
200199
const uint32_t num_warps = ceil_div(block_size, 32);
201200
dim3 nblks(batch_size);
202201
dim3 nthrs(32, num_warps);
203-
const uint32_t smem_size = num_warps * sizeof(float);
202+
const uint32_t smem_size = (num_warps + d) * sizeof(float);
204203
void* args[] = {&input, &residual, &weight, &d, &eps};
205204

206205
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {

tests/test_norm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
9898
residual_fused = residual.clone()
9999
flashinfer.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
100100

101-
torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
102-
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)
101+
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
102+
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
103103

104104

105105
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])

0 commit comments

Comments
 (0)