@@ -212,6 +212,190 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
212
212
return cudaSuccess;
213
213
}
214
214
215
+ template <uint32_t VEC_SIZE, typename T>
216
+ __global__ void GemmaRMSNormKernel (T* __restrict__ input, T* __restrict__ weight,
217
+ T* __restrict__ output, const uint32_t d, float eps) {
218
+ const uint32_t bx = blockIdx .x ;
219
+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
220
+ constexpr uint32_t warp_size = 32 ;
221
+ const uint32_t num_warps = blockDim .y ;
222
+ const uint32_t thread_id = tx + ty * warp_size;
223
+ const uint32_t num_threads = num_warps * warp_size;
224
+ const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
225
+ extern __shared__ float smem[];
226
+
227
+ float sum_sq = 0 .f ;
228
+
229
+ for (uint32_t i = 0 ; i < rounds; i++) {
230
+ vec_t <T, VEC_SIZE> input_vec;
231
+ input_vec.fill (0 .f );
232
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
233
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
234
+ }
235
+ #pragma unroll
236
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
237
+ sum_sq += float (input_vec[j]) * float (input_vec[j]);
238
+ }
239
+ }
240
+
241
+ // first, warp reduce sum
242
+ #pragma unroll
243
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
244
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
245
+ }
246
+
247
+ smem[ty] = sum_sq;
248
+ __syncthreads ();
249
+ // then, cross warp reduce sum using only the first warp
250
+ if (ty == 0 ) {
251
+ sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
252
+ #pragma unroll
253
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
254
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
255
+ }
256
+ smem[0 ] = sum_sq;
257
+ }
258
+ __syncthreads ();
259
+
260
+ float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
261
+
262
+ for (uint32_t i = 0 ; i < rounds; i++) {
263
+ vec_t <T, VEC_SIZE> input_vec;
264
+ vec_t <T, VEC_SIZE> weight_vec;
265
+ vec_t <T, VEC_SIZE> output_vec;
266
+ input_vec.fill (0 .f );
267
+ weight_vec.fill (0 .f );
268
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
269
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
270
+ weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
271
+ }
272
+ #pragma unroll
273
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
274
+ output_vec[j] = float (input_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
275
+ }
276
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
277
+ output_vec.store (output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
278
+ }
279
+ }
280
+ }
281
+
282
+ template <typename T>
283
+ cudaError_t GemmaRMSNorm (T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
284
+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
285
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
286
+
287
+ const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
288
+ const uint32_t num_warps = ceil_div (block_size, 32 );
289
+ dim3 nblks (batch_size);
290
+ dim3 nthrs (32 , num_warps);
291
+ const uint32_t smem_size = num_warps * sizeof (float );
292
+ void * args[] = {&input, &weight, &output, &d, &eps};
293
+
294
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
295
+ auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
296
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
297
+ });
298
+ return cudaSuccess;
299
+ }
300
+
301
+ template <uint32_t VEC_SIZE, typename T>
302
+ __global__ void GemmaFusedAddRMSNormKernel (T* __restrict__ input, T* __restrict__ residual,
303
+ T* __restrict__ weight, const uint32_t d, float eps) {
304
+ const uint32_t bx = blockIdx .x ;
305
+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
306
+ constexpr uint32_t warp_size = 32 ;
307
+ const uint32_t num_warps = blockDim .y ;
308
+ const uint32_t thread_id = tx + ty * warp_size;
309
+ const uint32_t num_threads = num_warps * warp_size;
310
+ const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
311
+ extern __shared__ float smem[];
312
+
313
+ float sum_sq = 0 .f ;
314
+
315
+ for (uint32_t i = 0 ; i < rounds; i++) {
316
+ vec_t <T, VEC_SIZE> input_vec;
317
+ input_vec.fill (0 .f );
318
+ vec_t <T, VEC_SIZE> residual_vec;
319
+ residual_vec.fill (0 .f );
320
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
321
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
322
+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
323
+ }
324
+ #pragma unroll
325
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
326
+ float x = float (input_vec[j]);
327
+ x += float (residual_vec[j]);
328
+ sum_sq += x * x;
329
+ residual_vec[j] = (T)x;
330
+ }
331
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
332
+ residual_vec.store (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
333
+ }
334
+ }
335
+
336
+ // first, warp reduce sum
337
+ #pragma unroll
338
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
339
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
340
+ }
341
+
342
+ smem[ty] = sum_sq;
343
+ __syncthreads ();
344
+ // then, cross warp reduce sum using only the first warp
345
+ if (ty == 0 ) {
346
+ sum_sq = (tx < num_warps) ? smem[tx] : 0 .f ;
347
+ #pragma unroll
348
+ for (uint32_t offset = warp_size / 2 ; offset > 0 ; offset /= 2 ) {
349
+ sum_sq += math::shfl_xor_sync (sum_sq, offset);
350
+ }
351
+ smem[0 ] = sum_sq;
352
+ }
353
+ __syncthreads ();
354
+
355
+ float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
356
+
357
+ for (uint32_t i = 0 ; i < rounds; i++) {
358
+ vec_t <T, VEC_SIZE> input_vec;
359
+ vec_t <T, VEC_SIZE> weight_vec;
360
+ vec_t <T, VEC_SIZE> residual_vec;
361
+ input_vec.fill (0 .f );
362
+ weight_vec.fill (0 .f );
363
+ residual_vec.fill (0 .f );
364
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
365
+ input_vec.load (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
366
+ weight_vec.load (weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
367
+ residual_vec.load (residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
368
+ }
369
+ #pragma unroll
370
+ for (uint32_t j = 0 ; j < VEC_SIZE; j++) {
371
+ input_vec[j] = float (residual_vec[j]) * rms_rcp * (1 .0f + float (weight_vec[j]));
372
+ }
373
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
374
+ input_vec.store (input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
375
+ }
376
+ }
377
+ }
378
+
379
+ template <typename T>
380
+ cudaError_t GemmaFusedAddRMSNorm (T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
381
+ float eps = 1e-5 , cudaStream_t stream = 0 ) {
382
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
383
+
384
+ const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
385
+ const uint32_t num_warps = ceil_div (block_size, 32 );
386
+ dim3 nblks (batch_size);
387
+ dim3 nthrs (32 , num_warps);
388
+ const uint32_t smem_size = num_warps * sizeof (float );
389
+ void * args[] = {&input, &residual, &weight, &d, &eps};
390
+
391
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
392
+ auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
393
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
394
+ });
395
+
396
+ return cudaSuccess;
397
+ }
398
+
215
399
} // namespace norm
216
400
217
401
} // namespace flashinfer
0 commit comments