Skip to content

Commit ff05155

Browse files
authored
perf: improve parallelism in RoPE with pos_ids (#609)
The previous kernel was not parallelised sufficiently well for low batch sizes. Similarly to the regular rotary kernel, now all qo/kv heads are split across separate blocks. In decode mode, the pos_ids kernel is now faster.
1 parent 32d9510 commit ff05155

File tree

1 file changed

+44
-48
lines changed

1 file changed

+44
-48
lines changed

include/flashinfer/pos_enc.cuh

+44-48
Original file line numberDiff line numberDiff line change
@@ -229,65 +229,61 @@ __global__ void BatchQKApplyRotaryPosIdsKernel(
229229
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
230230
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
231231
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
232-
const uint32_t bdy = blockDim.y;
233-
vec_t<float, vec_size> freq;
232+
233+
const uint32_t idx = bx * blockDim.y + ty;
234+
const uint32_t pos_idx = idx / (num_qo_heads + num_kv_heads);
235+
if (pos_idx >= nnz) {
236+
return;
237+
}
238+
239+
const IdType pos = pos_ids[pos_idx];
240+
241+
vec_t<float, vec_size> cos, sin;
234242
if (tx * vec_size < rotary_dim) {
235-
#pragma unroll
243+
#pragma unroll
236244
for (uint32_t i = 0; i < vec_size; ++i) {
245+
float freq;
237246
if constexpr (interleave) {
238-
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
247+
freq = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
239248
} else {
240-
freq[i] = __powf(rope_rcp_theta,
241-
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
249+
freq = __powf(rope_rcp_theta,
250+
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
242251
}
243252

244-
float smooth = freq[i] * smooth_a + smooth_b;
253+
float smooth = freq * smooth_a + smooth_b;
245254
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
246-
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
247-
}
248-
}
249-
250-
vec_t<float, vec_size> cos, sin;
255+
freq = (1 - smooth) * (freq * rope_rcp_scale) + smooth * freq;
251256

252-
if (bx * bdy + ty < nnz) {
253-
const uint32_t idx = bx * bdy + ty;
254-
const IdType pos = pos_ids[idx];
255-
256-
if (tx * vec_size < rotary_dim) {
257-
#pragma unroll
258-
for (uint32_t i = 0; i < vec_size; ++i) {
259-
float embed = float(pos) * freq[i];
260-
__sincosf(embed, &sin[i], &cos[i]);
261-
}
257+
const float embed = float(pos) * freq;
258+
__sincosf(embed, &sin[i], &cos[i]);
262259
}
260+
}
263261

264-
#pragma unroll 1
265-
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
266-
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
267-
DType* q_rope_ptr =
268-
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
269-
vec_t<float, vec_size> q_vec;
270-
if constexpr (interleave) {
271-
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
272-
} else {
273-
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
274-
}
275-
q_vec.cast_store(q_rope_ptr + tx * vec_size);
262+
const uint32_t head_idx = idx % (num_qo_heads + num_kv_heads);
263+
if (head_idx < num_qo_heads) {
264+
const uint32_t qo_head_idx = head_idx;
265+
DType* q_ptr = q + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_stride_n, q_stride_h);
266+
DType* q_rope_ptr =
267+
q_rope + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
268+
vec_t<float, vec_size> q_vec;
269+
if constexpr (interleave) {
270+
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
271+
} else {
272+
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
276273
}
277-
278-
#pragma unroll 1
279-
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
280-
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
281-
DType* k_rope_ptr =
282-
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
283-
vec_t<float, vec_size> k_vec;
284-
if constexpr (interleave) {
285-
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
286-
} else {
287-
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
288-
}
289-
k_vec.cast_store(k_rope_ptr + tx * vec_size);
274+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
275+
} else {
276+
const uint32_t kv_head_idx = head_idx - num_qo_heads;
277+
DType* k_ptr = k + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_stride_n, k_stride_h);
278+
DType* k_rope_ptr =
279+
k_rope + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
280+
vec_t<float, vec_size> k_vec;
281+
if constexpr (interleave) {
282+
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
283+
} else {
284+
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
290285
}
286+
k_vec.cast_store(k_rope_ptr + tx * vec_size);
291287
}
292288
}
293289

@@ -610,7 +606,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds(
610606
constexpr uint32_t bdx = HEAD_DIM / vec_size;
611607
uint32_t num_threads = std::max(128U, bdx);
612608
uint32_t bdy = num_threads / bdx;
613-
dim3 nblks((nnz + bdy - 1) / bdy);
609+
dim3 nblks((nnz + bdy - 1) / bdy * (num_qo_heads + num_kv_heads));
614610
dim3 nthrs(bdx, bdy);
615611
auto kernel =
616612
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;

0 commit comments

Comments
 (0)