|
18 | 18 |
|
19 | 19 | #include <cmath>
|
20 | 20 | #include <cstdint>
|
| 21 | +#include <iostream> |
21 | 22 | #include <string>
|
22 | 23 |
|
23 | 24 | #include "layout.cuh"
|
@@ -156,6 +157,55 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
|
156 | 157 | return vec;
|
157 | 158 | }
|
158 | 159 |
|
| 160 | +template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType, |
| 161 | + typename IdType> |
| 162 | +__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel( |
| 163 | + DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache, |
| 164 | + float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz, |
| 165 | + uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, |
| 166 | + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, |
| 167 | + size_t k_rope_stride_n, size_t k_rope_stride_h) { |
| 168 | + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; |
| 169 | + const uint32_t bdy = blockDim.y; |
| 170 | + |
| 171 | + vec_t<float, vec_size> cos, sin; |
| 172 | + if (bx * bdy + ty < nnz) { |
| 173 | + const uint32_t idx = bx * bdy + ty; |
| 174 | + const IdType pos = pos_ids[idx]; |
| 175 | + |
| 176 | + cos.load(cos_cache + pos * head_dim + tx * vec_size); |
| 177 | + sin.load(sin_cache + pos * head_dim + tx * vec_size); |
| 178 | + |
| 179 | +#pragma unroll 1 |
| 180 | + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { |
| 181 | + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); |
| 182 | + DType* q_rope_ptr = |
| 183 | + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); |
| 184 | + vec_t<float, vec_size> q_vec; |
| 185 | + if constexpr (interleave) { |
| 186 | + q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin); |
| 187 | + } else { |
| 188 | + q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin); |
| 189 | + } |
| 190 | + q_vec.cast_store(q_rope_ptr + tx * vec_size); |
| 191 | + } |
| 192 | + |
| 193 | +#pragma unroll 1 |
| 194 | + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { |
| 195 | + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); |
| 196 | + DType* k_rope_ptr = |
| 197 | + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); |
| 198 | + vec_t<float, vec_size> k_vec; |
| 199 | + if constexpr (interleave) { |
| 200 | + k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin); |
| 201 | + } else { |
| 202 | + k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin); |
| 203 | + } |
| 204 | + k_vec.cast_store(k_rope_ptr + tx * vec_size); |
| 205 | + } |
| 206 | + } |
| 207 | +} |
| 208 | + |
159 | 209 | template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
160 | 210 | typename IdType>
|
161 | 211 | __global__ void BatchQKApplyRotaryPosIdsKernel(
|
@@ -309,6 +359,48 @@ __global__ void BatchQKApplyRotaryKernel(
|
309 | 359 | __VA_ARGS__ \
|
310 | 360 | }
|
311 | 361 |
|
| 362 | +template <typename DType, typename IdType> |
| 363 | +cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( |
| 364 | + DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache, |
| 365 | + IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, |
| 366 | + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, |
| 367 | + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, |
| 368 | + bool interleave, cudaStream_t stream = nullptr) { |
| 369 | + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { |
| 370 | + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { |
| 371 | + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); |
| 372 | + constexpr uint32_t bdx = HEAD_DIM / vec_size; |
| 373 | + uint32_t num_threads = std::max(128U, bdx); |
| 374 | + uint32_t bdy = num_threads / bdx; |
| 375 | + dim3 nblks((nnz + bdy - 1) / bdy); |
| 376 | + dim3 nthrs(bdx, bdy); |
| 377 | + auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, |
| 378 | + DType, IdType>; |
| 379 | + void* args[] = {(void*)&q, |
| 380 | + (void*)&k, |
| 381 | + (void*)&q_rope, |
| 382 | + (void*)&k_rope, |
| 383 | + (void*)&cos_cache, |
| 384 | + (void*)&sin_cache, |
| 385 | + (void*)&pos_ids, |
| 386 | + (void*)&nnz, |
| 387 | + (void*)&num_qo_heads, |
| 388 | + (void*)&num_kv_heads, |
| 389 | + (void*)&q_stride_n, |
| 390 | + (void*)&q_stride_h, |
| 391 | + (void*)&k_stride_n, |
| 392 | + (void*)&k_stride_h, |
| 393 | + (void*)&q_rope_stride_n, |
| 394 | + (void*)&q_rope_stride_h, |
| 395 | + (void*)&k_rope_stride_n, |
| 396 | + (void*)&k_rope_stride_h}; |
| 397 | + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); |
| 398 | + }); |
| 399 | + }); |
| 400 | + |
| 401 | + return cudaSuccess; |
| 402 | +} |
| 403 | + |
312 | 404 | template <typename DType, typename IdType>
|
313 | 405 | cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope,
|
314 | 406 | IdType* __restrict__ pos_ids, uint32_t nnz,
|
|
0 commit comments