Skip to content

Commit 2043ca2

Browse files
authored
perf: reduce the read and write of shared memory in the FusedAddRMSNormKernel (#592)
Use `vec_t<float, VEC_SIZE> x_vec` to reduce the number of read and write operations to shared memory.
1 parent 1058d1e commit 2043ca2

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

benchmarks/bench_fused_add_rmsnorm.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import argparse
2+
from typing import cast
3+
4+
import torch
5+
from triton.testing import do_bench
6+
7+
import flashinfer
8+
9+
@torch.inference_mode()
10+
def main():
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989])
13+
parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192])
14+
parser.add_argument("--dtypes", nargs='+', choices=["float16", "bfloat16"], default=["float16"])
15+
args = parser.parse_args()
16+
17+
eps = 1e-6
18+
19+
# Loop over each combination of batch_size, hidden_size, and dtype
20+
for batch_size in args.batch_sizes:
21+
for hidden_size in args.hidden_sizes:
22+
for dtype_str in args.dtypes:
23+
dtype = getattr(torch, dtype_str)
24+
25+
# Define tensors with the correct dtype
26+
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
27+
residual = torch.randn_like(x)
28+
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
29+
30+
@torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}")
31+
def fn() -> None:
32+
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)
33+
34+
# Run benchmarking
35+
latency_ms = cast(float, do_bench(fn))
36+
throughput = (
37+
(x.numel() * x.element_size() * 2
38+
+ residual.numel() * residual.element_size() * 2
39+
+ weight.numel() * weight.element_size())
40+
/ (latency_ms * 1e-3)
41+
)
42+
print(
43+
f"batch_size: {batch_size:3},",
44+
f"hidden_size: {hidden_size:5},",
45+
f"dtype: {dtype_str:8},",
46+
f"latency: {latency_ms*1e3:2.0f}us,",
47+
f"throughput: {throughput*1e-9:7.3f}GB/s",
48+
)
49+
50+
print("---")
51+
52+
torch.cuda.profiler.stop()
53+
54+
if __name__ == "__main__":
55+
main()

include/flashinfer/norm.cuh

+8-3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
133133
input_vec.fill(0.f);
134134
vec_t<T, VEC_SIZE> residual_vec;
135135
residual_vec.fill(0.f);
136+
vec_t<float, VEC_SIZE> x_vec;
137+
x_vec.fill(0.f);
136138
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
137139
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
138140
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
@@ -143,10 +145,11 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
143145
x += float(residual_vec[j]);
144146
sum_sq += x * x;
145147
residual_vec[j] = (T)x;
146-
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
148+
x_vec[j] = x;
147149
}
148150
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
149151
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
152+
x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
150153
}
151154
}
152155

@@ -174,15 +177,17 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
174177
for (uint32_t i = 0; i < rounds; i++) {
175178
vec_t<T, VEC_SIZE> input_vec;
176179
vec_t<T, VEC_SIZE> weight_vec;
180+
vec_t<float, VEC_SIZE> x_vec;
177181
input_vec.fill(0.f);
178182
weight_vec.fill(0.f);
183+
x_vec.fill(0.f);
179184
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
180185
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
186+
x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
181187
}
182188
#pragma unroll
183189
for (uint32_t j = 0; j < VEC_SIZE; j++) {
184-
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
185-
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
190+
input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]);
186191
}
187192
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
188193
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);

0 commit comments

Comments
 (0)