Skip to content

Commit 83e541d

Browse files
authored
feat: support cached cos/sin in rope APIs (#585)
As requested in #530 , this PR implements the RoPE with cached cos/sin embeddings, which is more flexible in some use cases. In our previous RoPE implementations, cos/sin values are computed on-the-fly inside kernels with float32 instead using cached values. In this PR we found that if we use f16 cos/sin cache, the rope result will have a large discrepancy compared to our original implementation `flashinfer.apply_rope` (which stores cos/sin with fp32). So we require the `cos_cache` and `sin_cache` to use fp32 data type. cc @dreaming-panda @ByronHsu
1 parent 7557dc8 commit 83e541d

File tree

8 files changed

+465
-208
lines changed

8 files changed

+465
-208
lines changed

benchmarks/bench_append_paged_kv_cache.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import dataclasses
33
from typing import cast
44

5-
import flashinfer
65
import torch
76
from triton.testing import do_bench
87

8+
import flashinfer
9+
910

1011
@dataclasses.dataclass(kw_only=True)
1112
class ModelConfig:

include/flashinfer/pos_enc.cuh

+92
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <cmath>
2020
#include <cstdint>
21+
#include <iostream>
2122
#include <string>
2223

2324
#include "layout.cuh"
@@ -156,6 +157,55 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
156157
return vec;
157158
}
158159

160+
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
161+
typename IdType>
162+
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
163+
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache,
164+
float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
165+
uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h,
166+
size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h,
167+
size_t k_rope_stride_n, size_t k_rope_stride_h) {
168+
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
169+
const uint32_t bdy = blockDim.y;
170+
171+
vec_t<float, vec_size> cos, sin;
172+
if (bx * bdy + ty < nnz) {
173+
const uint32_t idx = bx * bdy + ty;
174+
const IdType pos = pos_ids[idx];
175+
176+
cos.load(cos_cache + pos * head_dim + tx * vec_size);
177+
sin.load(sin_cache + pos * head_dim + tx * vec_size);
178+
179+
#pragma unroll 1
180+
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
181+
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
182+
DType* q_rope_ptr =
183+
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
184+
vec_t<float, vec_size> q_vec;
185+
if constexpr (interleave) {
186+
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin);
187+
} else {
188+
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin);
189+
}
190+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
191+
}
192+
193+
#pragma unroll 1
194+
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
195+
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
196+
DType* k_rope_ptr =
197+
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
198+
vec_t<float, vec_size> k_vec;
199+
if constexpr (interleave) {
200+
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin);
201+
} else {
202+
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin);
203+
}
204+
k_vec.cast_store(k_rope_ptr + tx * vec_size);
205+
}
206+
}
207+
}
208+
159209
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
160210
typename IdType>
161211
__global__ void BatchQKApplyRotaryPosIdsKernel(
@@ -309,6 +359,48 @@ __global__ void BatchQKApplyRotaryKernel(
309359
__VA_ARGS__ \
310360
}
311361

362+
template <typename DType, typename IdType>
363+
cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
364+
DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache,
365+
IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
366+
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
367+
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
368+
bool interleave, cudaStream_t stream = nullptr) {
369+
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
370+
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
371+
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
372+
constexpr uint32_t bdx = HEAD_DIM / vec_size;
373+
uint32_t num_threads = std::max(128U, bdx);
374+
uint32_t bdy = num_threads / bdx;
375+
dim3 nblks((nnz + bdy - 1) / bdy);
376+
dim3 nthrs(bdx, bdy);
377+
auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
378+
DType, IdType>;
379+
void* args[] = {(void*)&q,
380+
(void*)&k,
381+
(void*)&q_rope,
382+
(void*)&k_rope,
383+
(void*)&cos_cache,
384+
(void*)&sin_cache,
385+
(void*)&pos_ids,
386+
(void*)&nnz,
387+
(void*)&num_qo_heads,
388+
(void*)&num_kv_heads,
389+
(void*)&q_stride_n,
390+
(void*)&q_stride_h,
391+
(void*)&k_stride_n,
392+
(void*)&k_stride_h,
393+
(void*)&q_rope_stride_n,
394+
(void*)&q_rope_stride_h,
395+
(void*)&k_rope_stride_n,
396+
(void*)&k_rope_stride_h};
397+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
398+
});
399+
});
400+
401+
return cudaSuccess;
402+
}
403+
312404
template <typename DType, typename IdType>
313405
cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope,
314406
IdType* __restrict__ pos_ids, uint32_t nnz,

python/csrc/flashinfer_rope_ops.cu

+7
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,17 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
3535
float rope_scale, float rope_theta, float low_freq_factor,
3636
float high_freq_factor, float old_context_length);
3737

38+
void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
39+
torch::Tensor k_rope, torch::Tensor cos_cache,
40+
torch::Tensor sin_cache, torch::Tensor pos_ids,
41+
bool interleave);
42+
3843
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3944
m.def("apply_rope", &apply_rope, "Apply RoPE");
4045
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
4146
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
4247
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
4348
"Apply Llama 3.1 style RoPE with positional ids");
49+
m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache,
50+
"Apply RoPE with positional ids and cosine/sine cache");
4451
}

python/csrc/rope.cu

+58-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using namespace flashinfer;
2222
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
2323
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
2424
float rope_theta) {
25-
CHECK_CUDA(q); // not necessarily contiguous
26-
CHECK_CUDA(k); // not necessarily contiguous
25+
CHECK_LAST_DIM_CONTIGUOUS(q);
26+
CHECK_LAST_DIM_CONTIGUOUS(k);
2727
CHECK_INPUT(indptr);
2828
CHECK_INPUT(offsets);
2929

@@ -69,8 +69,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
6969
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
7070
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
7171
float rope_scale, float rope_theta) {
72-
CHECK_CUDA(q); // not necessarily contiguous
73-
CHECK_CUDA(k); // not necessarily contiguous
72+
CHECK_LAST_DIM_CONTIGUOUS(q);
73+
CHECK_LAST_DIM_CONTIGUOUS(k);
7474
CHECK_INPUT(pos_ids);
7575

7676
auto device = q.device();
@@ -107,6 +107,60 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
107107
});
108108
}
109109

110+
void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
111+
torch::Tensor k_rope, torch::Tensor cos_cache,
112+
torch::Tensor sin_cache, torch::Tensor pos_ids,
113+
bool interleave) {
114+
CHECK_LAST_DIM_CONTIGUOUS(q);
115+
CHECK_LAST_DIM_CONTIGUOUS(k);
116+
CHECK_INPUT(cos_cache);
117+
CHECK_INPUT(sin_cache);
118+
CHECK_INPUT(pos_ids);
119+
auto device = q.device();
120+
CHECK_EQ(k.device(), device);
121+
CHECK_EQ(cos_cache.device(), device);
122+
CHECK_EQ(sin_cache.device(), device);
123+
CHECK_EQ(pos_ids.device(), device);
124+
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
125+
CHECK_DIM(3, k); // k: (nnz, H_K, D)
126+
CHECK_DIM(2, cos_cache); // cos_cache: (max_seq_len, D)
127+
CHECK_DIM(2, sin_cache); // sin_cache: (max_seq_len, D)
128+
CHECK_EQ(q.size(0), k.size(0));
129+
CHECK_EQ(q.size(2), k.size(2));
130+
CHECK_EQ(cos_cache.size(1), q.size(2));
131+
CHECK_EQ(sin_cache.size(1), q.size(2));
132+
CHECK_EQ(cos_cache.dtype(), torch::kFloat32);
133+
CHECK_EQ(sin_cache.dtype(), torch::kFloat32);
134+
unsigned int num_qo_heads = q.size(1);
135+
unsigned int num_kv_heads = k.size(1);
136+
unsigned int head_dim = q.size(2);
137+
unsigned int nnz = q.size(0);
138+
size_t q_stride_n = q.stride(0);
139+
size_t q_stride_h = q.stride(1);
140+
size_t k_stride_n = k.stride(0);
141+
size_t k_stride_h = k.stride(1);
142+
size_t q_rope_stride_n = q_rope.stride(0);
143+
size_t q_rope_stride_h = q_rope.stride(1);
144+
size_t k_rope_stride_n = k_rope.stride(0);
145+
size_t k_rope_stride_h = k_rope.stride(1);
146+
pos_ids = pos_ids.to(torch::kInt32);
147+
148+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
149+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
150+
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
151+
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
152+
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
153+
static_cast<float*>(cos_cache.data_ptr()), static_cast<float*>(sin_cache.data_ptr()),
154+
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim,
155+
q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h,
156+
k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream);
157+
TORCH_CHECK(status == cudaSuccess,
158+
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " +
159+
std::string(cudaGetErrorString(status)));
160+
return true;
161+
});
162+
}
163+
110164
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
111165
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
112166
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,

python/flashinfer/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,18 @@
6060
from .quantization import segment_packbits as segment_packbits
6161
from .rope import apply_llama31_rope as apply_llama31_rope
6262
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
63+
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
64+
from .rope import (
65+
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
66+
)
6367
from .rope import apply_rope as apply_rope
6468
from .rope import apply_rope_inplace as apply_rope_inplace
6569
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
6670
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
71+
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
72+
from .rope import (
73+
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
74+
)
6775
from .sampling import chain_speculative_sampling as chain_speculative_sampling
6876
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
6977
from .sampling import sampling_from_probs as sampling_from_probs

0 commit comments

Comments
 (0)