Skip to content

Commit 4f40420

Browse files
authored
feat: support huggingface transformer style rope interface (#568)
Previously our rope apis assume the position indices of each request is contiguous, which is not appropriate for applications such as speculative decoding, this PR fixes the issue by supporting the huggingface transformer-style API which use `pos_ids` argument to specify positions. This PR implements parts of the feature of #530 , other requests are coming in later PRs. cc @dreaming-panda @abcdabcd987 @ByronHsu
1 parent cdc12c3 commit 4f40420

File tree

7 files changed

+521
-211
lines changed

7 files changed

+521
-211
lines changed

flashinfer-aot/csrc_aot/flashinfer_ops.cu

+23-16
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
6161
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
6262
unsigned int top_k_val);
6363

64-
torch::Tensor chain_speculative_sampling(
65-
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
66-
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
67-
torch::Tensor output_emitted_token_num, bool deterministic);
64+
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
65+
torch::Tensor uniform_samples, torch::Tensor target_probs,
66+
torch::Tensor output_accepted_token_num,
67+
torch::Tensor output_emitted_token_num,
68+
bool deterministic);
6869

6970
void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);
7071

@@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
8283

8384
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
8485

85-
void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86-
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
87-
88-
void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
89-
torch::Tensor offsets, bool interleave, float rope_scale,
90-
float rope_theta, float low_freq_factor, float high_freq_factor,
91-
float old_context_length);
92-
93-
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86+
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
87+
torch::Tensor k_rope, torch::Tensor indptr,
9488
torch::Tensor offsets, bool interleave, float rope_scale,
9589
float rope_theta);
9690

9791
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
92+
torch::Tensor q_rope, torch::Tensor k_rope,
9893
torch::Tensor indptr, torch::Tensor offsets,
9994
bool interleave, float rope_scale, float rope_theta,
10095
float low_freq_factor, float high_freq_factor,
10196
float old_context_length);
10297

98+
std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
99+
torch::Tensor q_rope, torch::Tensor k_rope,
100+
torch::Tensor pos_ids, bool interleave,
101+
float rope_scale, float rope_theta);
102+
103+
std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
104+
torch::Tensor q_rope, torch::Tensor k_rope,
105+
torch::Tensor pos_ids, bool interleave,
106+
float rope_scale, float rope_theta,
107+
float low_freq_factor, float high_freq_factor,
108+
float old_context_length);
109+
103110
torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
104111

105112
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
@@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
141148
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
142149
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
143150
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
144-
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
145-
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
146-
"Apply Llama 3.1 style RoPE in-place");
147151
m.def("apply_rope", &apply_rope, "Apply RoPE");
148152
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
153+
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
154+
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
155+
"Apply Llama 3.1 style RoPE with positional ids");
149156
m.def("packbits", &packbits, "GPU packbits operator");
150157
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
151158
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator");

0 commit comments

Comments
 (0)