Skip to content

Commit 1a6b17e

Browse files
authored
feat: add gemma_rmsnorm and gemma_fused_add_rmsnorm (#477)
for gemma2 cc @yzh119
1 parent 9ee26e7 commit 1a6b17e

File tree

7 files changed

+345
-1
lines changed

7 files changed

+345
-1
lines changed

include/flashinfer/norm.cuh

+184
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,190 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
212212
return cudaSuccess;
213213
}
214214

215+
template <uint32_t VEC_SIZE, typename T>
216+
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
217+
T* __restrict__ output, const uint32_t d, float eps) {
218+
const uint32_t bx = blockIdx.x;
219+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
220+
constexpr uint32_t warp_size = 32;
221+
const uint32_t num_warps = blockDim.y;
222+
const uint32_t thread_id = tx + ty * warp_size;
223+
const uint32_t num_threads = num_warps * warp_size;
224+
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
225+
extern __shared__ float smem[];
226+
227+
float sum_sq = 0.f;
228+
229+
for (uint32_t i = 0; i < rounds; i++) {
230+
vec_t<T, VEC_SIZE> input_vec;
231+
input_vec.fill(0.f);
232+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
233+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
234+
}
235+
#pragma unroll
236+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
237+
sum_sq += float(input_vec[j]) * float(input_vec[j]);
238+
}
239+
}
240+
241+
// first, warp reduce sum
242+
#pragma unroll
243+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
244+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
245+
}
246+
247+
smem[ty] = sum_sq;
248+
__syncthreads();
249+
// then, cross warp reduce sum using only the first warp
250+
if (ty == 0) {
251+
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
252+
#pragma unroll
253+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
254+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
255+
}
256+
smem[0] = sum_sq;
257+
}
258+
__syncthreads();
259+
260+
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
261+
262+
for (uint32_t i = 0; i < rounds; i++) {
263+
vec_t<T, VEC_SIZE> input_vec;
264+
vec_t<T, VEC_SIZE> weight_vec;
265+
vec_t<T, VEC_SIZE> output_vec;
266+
input_vec.fill(0.f);
267+
weight_vec.fill(0.f);
268+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
269+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
270+
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
271+
}
272+
#pragma unroll
273+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
274+
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
275+
}
276+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
277+
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
278+
}
279+
}
280+
}
281+
282+
template <typename T>
283+
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
284+
float eps = 1e-5, cudaStream_t stream = 0) {
285+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
286+
287+
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
288+
const uint32_t num_warps = ceil_div(block_size, 32);
289+
dim3 nblks(batch_size);
290+
dim3 nthrs(32, num_warps);
291+
const uint32_t smem_size = num_warps * sizeof(float);
292+
void* args[] = {&input, &weight, &output, &d, &eps};
293+
294+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
295+
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
296+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
297+
});
298+
return cudaSuccess;
299+
}
300+
301+
template <uint32_t VEC_SIZE, typename T>
302+
__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
303+
T* __restrict__ weight, const uint32_t d, float eps) {
304+
const uint32_t bx = blockIdx.x;
305+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
306+
constexpr uint32_t warp_size = 32;
307+
const uint32_t num_warps = blockDim.y;
308+
const uint32_t thread_id = tx + ty * warp_size;
309+
const uint32_t num_threads = num_warps * warp_size;
310+
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
311+
extern __shared__ float smem[];
312+
313+
float sum_sq = 0.f;
314+
315+
for (uint32_t i = 0; i < rounds; i++) {
316+
vec_t<T, VEC_SIZE> input_vec;
317+
input_vec.fill(0.f);
318+
vec_t<T, VEC_SIZE> residual_vec;
319+
residual_vec.fill(0.f);
320+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
321+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
322+
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
323+
}
324+
#pragma unroll
325+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
326+
float x = float(input_vec[j]);
327+
x += float(residual_vec[j]);
328+
sum_sq += x * x;
329+
residual_vec[j] = (T)x;
330+
}
331+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
332+
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
333+
}
334+
}
335+
336+
// first, warp reduce sum
337+
#pragma unroll
338+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
339+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
340+
}
341+
342+
smem[ty] = sum_sq;
343+
__syncthreads();
344+
// then, cross warp reduce sum using only the first warp
345+
if (ty == 0) {
346+
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
347+
#pragma unroll
348+
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
349+
sum_sq += math::shfl_xor_sync(sum_sq, offset);
350+
}
351+
smem[0] = sum_sq;
352+
}
353+
__syncthreads();
354+
355+
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
356+
357+
for (uint32_t i = 0; i < rounds; i++) {
358+
vec_t<T, VEC_SIZE> input_vec;
359+
vec_t<T, VEC_SIZE> weight_vec;
360+
vec_t<T, VEC_SIZE> residual_vec;
361+
input_vec.fill(0.f);
362+
weight_vec.fill(0.f);
363+
residual_vec.fill(0.f);
364+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
365+
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
366+
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
367+
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
368+
}
369+
#pragma unroll
370+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
371+
input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
372+
}
373+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
374+
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
375+
}
376+
}
377+
}
378+
379+
template <typename T>
380+
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
381+
float eps = 1e-5, cudaStream_t stream = 0) {
382+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
383+
384+
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
385+
const uint32_t num_warps = ceil_div(block_size, 32);
386+
dim3 nblks(batch_size);
387+
dim3 nthrs(32, num_warps);
388+
const uint32_t smem_size = num_warps * sizeof(float);
389+
void* args[] = {&input, &residual, &weight, &d, &eps};
390+
391+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
392+
auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
393+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
394+
});
395+
396+
return cudaSuccess;
397+
}
398+
215399
} // namespace norm
216400

217401
} // namespace flashinfer

python/csrc/flashinfer_ops.cu

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3939
"Speculative sampling from sequence of probabilities");
4040
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
4141
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
42+
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization");
43+
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm,
44+
"Gemma Fused add root mean square normalization");
4245
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
4346
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
4447
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");

python/csrc/flashinfer_ops.h

+5
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
7777
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
7878
double eps);
7979

80+
torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
81+
82+
void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
83+
double eps);
84+
8085
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
8186

8287
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

python/csrc/norm.cu

+54
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,57 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso
7373
return true;
7474
});
7575
}
76+
77+
torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
78+
CHECK_INPUT(input);
79+
CHECK_INPUT(weight);
80+
auto device = input.device();
81+
CHECK_EQ(weight.device(), device);
82+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
83+
CHECK_DIM(1, weight); // weight: (hidden_size)
84+
CHECK_EQ(input.size(1), weight.size(0));
85+
unsigned int batch_size = input.size(0);
86+
unsigned int hidden_size = input.size(1);
87+
88+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
89+
auto output = torch::empty_like(input);
90+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
91+
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
92+
static_cast<c_type*>(weight.data_ptr()),
93+
static_cast<c_type*>(output.data_ptr()), batch_size,
94+
hidden_size, eps, torch_current_stream);
95+
TORCH_CHECK(status == cudaSuccess,
96+
"GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
97+
return true;
98+
});
99+
return output;
100+
}
101+
102+
void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
103+
double eps) {
104+
CHECK_INPUT(input);
105+
CHECK_INPUT(residual);
106+
CHECK_INPUT(weight);
107+
auto device = input.device();
108+
CHECK_EQ(residual.device(), device);
109+
CHECK_EQ(weight.device(), device);
110+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
111+
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
112+
CHECK_DIM(1, weight); // weight: (hidden_size)
113+
CHECK_EQ(input.size(0), residual.size(0));
114+
CHECK_EQ(input.size(1), residual.size(1));
115+
CHECK_EQ(input.size(1), weight.size(0));
116+
unsigned int batch_size = input.size(0);
117+
unsigned int hidden_size = input.size(1);
118+
119+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
120+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
121+
cudaError_t status = norm::GemmaFusedAddRMSNorm(
122+
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
123+
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps,
124+
torch_current_stream);
125+
TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " +
126+
std::string(cudaGetErrorString(status)));
127+
return true;
128+
});
129+
}

python/flashinfer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
single_decode_with_kv_cache,
3030
)
3131
from .gemm import SegmentGEMMWrapper, bmm_fp8
32-
from .norm import fused_add_rmsnorm, rmsnorm
32+
from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
3333
from .page import append_paged_kv_cache
3434
from .prefill import (
3535
BatchPrefillWithPagedKVCacheWrapper,

python/flashinfer/norm.py

+39
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,42 @@ def fused_add_rmsnorm(
6969
Epsilon for numerical stability.
7070
"""
7171
_kernels.fused_add_rmsnorm(input, residual, weight, eps)
72+
73+
74+
def gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
75+
r"""Gemma Root mean square normalization.
76+
77+
Parameters
78+
----------
79+
input: torch.Tensor
80+
Input tensor, shape (batch_size, hidden_size).
81+
weight: torch.Tensor
82+
Weight tensor, shape (hidden_size,).
83+
eps: float
84+
Epsilon for numerical stability.
85+
86+
Returns
87+
-------
88+
output: torch.Tensor
89+
Gemma Normalized tensor, shape (batch_size, hidden_size).
90+
"""
91+
return _kernels.gemma_rmsnorm(input, weight, eps)
92+
93+
94+
def gemma_fused_add_rmsnorm(
95+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
96+
):
97+
r"""Gemma Fused add root mean square normalization.
98+
99+
Parameters
100+
----------
101+
input: torch.Tensor
102+
Input tensor, shape (batch_size, hidden_size).
103+
residual: torch.Tensor
104+
Residual tensor, shape (batch_size, hidden_size).
105+
weight: torch.Tensor
106+
Weight tensor, shape (hidden_size,).
107+
eps: float
108+
Epsilon for numerical stability.
109+
"""
110+
_kernels.gemma_fused_add_rmsnorm(input, residual, weight, eps)

python/tests/test_norm.py

+59
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,28 @@ def _norm(x):
2929
return output * w
3030

3131

32+
def gemma_rms_norm(x, w, eps=1e-6):
33+
orig_dtype = x.dtype
34+
x = x.float()
35+
variance = x.pow(2).mean(dim=-1, keepdim=True)
36+
x = x * torch.rsqrt(variance + eps)
37+
x = x * (1.0 + w)
38+
x = x.to(orig_dtype)
39+
return x
40+
41+
42+
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
43+
orig_dtype = x.dtype
44+
x = x + residual
45+
residual = x
46+
x = x.float()
47+
variance = x.pow(2).mean(dim=-1, keepdim=True)
48+
x = x * torch.rsqrt(variance + eps)
49+
x = x * (1.0 + w)
50+
x = x.to(orig_dtype)
51+
return x, residual
52+
53+
3254
def fused_add_rms_norm(x, residual, weight, eps):
3355
orig_dtype = x.dtype
3456
x = x.to(torch.float32)
@@ -76,3 +98,40 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
7698

7799
torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
78100
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)
101+
102+
103+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
104+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
105+
@pytest.mark.parametrize("dtype", [torch.float16])
106+
def test_gemma_norm(batch_size, hidden_size, dtype):
107+
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
108+
w = torch.randn(hidden_size).to(0).to(dtype)
109+
110+
y_ref = gemma_rms_norm(x, w)
111+
y = flashinfer.norm.gemma_rmsnorm(x, w)
112+
113+
numpy.testing.assert_allclose(
114+
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
115+
)
116+
117+
118+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
119+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
120+
@pytest.mark.parametrize("dtype", [torch.float16])
121+
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
122+
eps = 1e-6
123+
124+
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
125+
residual = torch.randn_like(x)
126+
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
127+
128+
x_native, residual_native = gemma_fused_add_rms_norm(
129+
x.clone(), residual.clone(), weight, eps
130+
)
131+
132+
x_fused = x.clone()
133+
residual_fused = residual.clone()
134+
flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
135+
136+
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
137+
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)