@@ -229,65 +229,61 @@ __global__ void BatchQKApplyRotaryPosIdsKernel(
229
229
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
230
230
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
231
231
uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
232
- const uint32_t bdy = blockDim .y ;
233
- vec_t <float , vec_size> freq;
232
+
233
+ const uint32_t idx = bx * blockDim .y + ty;
234
+ const uint32_t pos_idx = idx / (num_qo_heads + num_kv_heads);
235
+ if (pos_idx >= nnz) {
236
+ return ;
237
+ }
238
+
239
+ const IdType pos = pos_ids[pos_idx];
240
+
241
+ vec_t <float , vec_size> cos , sin ;
234
242
if (tx * vec_size < rotary_dim) {
235
- #pragma unroll
243
+ #pragma unroll
236
244
for (uint32_t i = 0 ; i < vec_size; ++i) {
245
+ float freq;
237
246
if constexpr (interleave) {
238
- freq[i] = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (rotary_dim));
247
+ freq = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (rotary_dim));
239
248
} else {
240
- freq[i] = __powf (rope_rcp_theta,
241
- float (2 * ((tx * vec_size + i) % (rotary_dim / 2 ))) / float (rotary_dim));
249
+ freq = __powf (rope_rcp_theta,
250
+ float (2 * ((tx * vec_size + i) % (rotary_dim / 2 ))) / float (rotary_dim));
242
251
}
243
252
244
- float smooth = freq[i] * smooth_a + smooth_b;
253
+ float smooth = freq * smooth_a + smooth_b;
245
254
smooth = max (0 .0f , min (1 .0f , smooth)); // clamp to [0, 1]
246
- freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
247
- }
248
- }
249
-
250
- vec_t <float , vec_size> cos , sin ;
255
+ freq = (1 - smooth) * (freq * rope_rcp_scale) + smooth * freq;
251
256
252
- if (bx * bdy + ty < nnz) {
253
- const uint32_t idx = bx * bdy + ty;
254
- const IdType pos = pos_ids[idx];
255
-
256
- if (tx * vec_size < rotary_dim) {
257
- #pragma unroll
258
- for (uint32_t i = 0 ; i < vec_size; ++i) {
259
- float embed = float (pos) * freq[i];
260
- __sincosf (embed, &sin [i], &cos [i]);
261
- }
257
+ const float embed = float (pos) * freq;
258
+ __sincosf (embed, &sin [i], &cos [i]);
262
259
}
260
+ }
263
261
264
- #pragma unroll 1
265
- for (uint32_t qo_head_idx = 0 ; qo_head_idx < num_qo_heads; ++qo_head_idx) {
266
- DType* q_ptr = q + get_elem_offset_impl (idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
267
- DType* q_rope_ptr =
268
- q_rope + get_elem_offset_impl (idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
269
- vec_t <float , vec_size> q_vec;
270
- if constexpr (interleave) {
271
- q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos , sin , rotary_dim);
272
- } else {
273
- q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos , sin , rotary_dim);
274
- }
275
- q_vec.cast_store (q_rope_ptr + tx * vec_size);
262
+ const uint32_t head_idx = idx % (num_qo_heads + num_kv_heads);
263
+ if (head_idx < num_qo_heads) {
264
+ const uint32_t qo_head_idx = head_idx;
265
+ DType* q_ptr = q + get_elem_offset_impl (pos_idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
266
+ DType* q_rope_ptr =
267
+ q_rope + get_elem_offset_impl (pos_idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
268
+ vec_t <float , vec_size> q_vec;
269
+ if constexpr (interleave) {
270
+ q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos , sin , rotary_dim);
271
+ } else {
272
+ q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos , sin , rotary_dim);
276
273
}
277
-
278
- #pragma unroll 1
279
- for (uint32_t kv_head_idx = 0 ; kv_head_idx < num_kv_heads; ++kv_head_idx) {
280
- DType* k_ptr = k + get_elem_offset_impl (idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
281
- DType* k_rope_ptr =
282
- k_rope + get_elem_offset_impl (idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
283
- vec_t <float , vec_size> k_vec;
284
- if constexpr (interleave) {
285
- k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos , sin , rotary_dim);
286
- } else {
287
- k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos , sin , rotary_dim);
288
- }
289
- k_vec.cast_store (k_rope_ptr + tx * vec_size);
274
+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
275
+ } else {
276
+ const uint32_t kv_head_idx = head_idx - num_qo_heads;
277
+ DType* k_ptr = k + get_elem_offset_impl (pos_idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
278
+ DType* k_rope_ptr =
279
+ k_rope + get_elem_offset_impl (pos_idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
280
+ vec_t <float , vec_size> k_vec;
281
+ if constexpr (interleave) {
282
+ k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos , sin , rotary_dim);
283
+ } else {
284
+ k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos , sin , rotary_dim);
290
285
}
286
+ k_vec.cast_store (k_rope_ptr + tx * vec_size);
291
287
}
292
288
}
293
289
@@ -610,7 +606,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds(
610
606
constexpr uint32_t bdx = HEAD_DIM / vec_size;
611
607
uint32_t num_threads = std::max (128U , bdx);
612
608
uint32_t bdy = num_threads / bdx;
613
- dim3 nblks ((nnz + bdy - 1 ) / bdy);
609
+ dim3 nblks ((nnz + bdy - 1 ) / bdy * (num_qo_heads + num_kv_heads) );
614
610
dim3 nthrs (bdx, bdy);
615
611
auto kernel =
616
612
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
0 commit comments