Skip to content

Commit 3e104bc

Browse files
authored
feat: torch custom_op fix for rope (#569)
Fix after changes made in #568 torch.compile doesn't like returning input arguments. So, change the return type of pybind fns to `void`, given that it's already an inplace op. PyTorch Library annotation is applied to a wrapper function for each pybind op. Python API doesn't change. Both inplace and non-inplace versions calls the annotated wrapper function.
1 parent 4f40420 commit 3e104bc

File tree

5 files changed

+346
-147
lines changed

5 files changed

+346
-147
lines changed

flashinfer-aot/csrc_aot/flashinfer_ops.cu

+17-23
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,23 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
8383

8484
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
8585

86-
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
87-
torch::Tensor k_rope, torch::Tensor indptr,
88-
torch::Tensor offsets, bool interleave, float rope_scale,
89-
float rope_theta);
90-
91-
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
92-
torch::Tensor q_rope, torch::Tensor k_rope,
93-
torch::Tensor indptr, torch::Tensor offsets,
94-
bool interleave, float rope_scale, float rope_theta,
95-
float low_freq_factor, float high_freq_factor,
96-
float old_context_length);
97-
98-
std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
99-
torch::Tensor q_rope, torch::Tensor k_rope,
100-
torch::Tensor pos_ids, bool interleave,
101-
float rope_scale, float rope_theta);
102-
103-
std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
104-
torch::Tensor q_rope, torch::Tensor k_rope,
105-
torch::Tensor pos_ids, bool interleave,
106-
float rope_scale, float rope_theta,
107-
float low_freq_factor, float high_freq_factor,
108-
float old_context_length);
86+
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
87+
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
88+
float rope_theta);
89+
90+
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
91+
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
92+
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
93+
float high_freq_factor, float old_context_length);
94+
95+
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
96+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
97+
float rope_scale, float rope_theta);
98+
99+
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
100+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
101+
float rope_scale, float rope_theta, float low_freq_factor,
102+
float high_freq_factor, float old_context_length);
109103

110104
torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
111105

python/csrc/flashinfer_rope_ops.cu

+14-20
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,23 @@
1717

1818
#include <vector>
1919

20-
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
21-
torch::Tensor k_rope, torch::Tensor indptr,
22-
torch::Tensor offsets, bool interleave, float rope_scale,
23-
float rope_theta);
20+
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
21+
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
22+
float rope_theta);
2423

25-
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
26-
torch::Tensor q_rope, torch::Tensor k_rope,
27-
torch::Tensor indptr, torch::Tensor offsets,
28-
bool interleave, float rope_scale, float rope_theta,
29-
float low_freq_factor, float high_freq_factor,
30-
float old_context_length);
24+
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
25+
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
26+
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
27+
float high_freq_factor, float old_context_length);
3128

32-
std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
33-
torch::Tensor q_rope, torch::Tensor k_rope,
34-
torch::Tensor pos_ids, bool interleave,
35-
float rope_scale, float rope_theta);
29+
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
30+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
31+
float rope_scale, float rope_theta);
3632

37-
std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
38-
torch::Tensor q_rope, torch::Tensor k_rope,
39-
torch::Tensor pos_ids, bool interleave,
40-
float rope_scale, float rope_theta,
41-
float low_freq_factor, float high_freq_factor,
42-
float old_context_length);
33+
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
34+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
35+
float rope_scale, float rope_theta, float low_freq_factor,
36+
float high_freq_factor, float old_context_length);
4337

4438
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4539
m.def("apply_rope", &apply_rope, "Apply RoPE");

python/csrc/rope.cu

+14-28
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919

2020
using namespace flashinfer;
2121

22-
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
23-
torch::Tensor k_rope, torch::Tensor indptr,
24-
torch::Tensor offsets, bool interleave, float rope_scale,
25-
float rope_theta) {
22+
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
23+
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
24+
float rope_theta) {
2625
CHECK_CUDA(q); // not necessarily contiguous
2726
CHECK_CUDA(k); // not necessarily contiguous
2827
CHECK_INPUT(indptr);
@@ -65,14 +64,11 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T
6564
std::string(cudaGetErrorString(status)));
6665
return true;
6766
});
68-
69-
return {q_rope, k_rope};
7067
}
7168

72-
std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
73-
torch::Tensor q_rope, torch::Tensor k_rope,
74-
torch::Tensor pos_ids, bool interleave,
75-
float rope_scale, float rope_theta) {
69+
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
70+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
71+
float rope_scale, float rope_theta) {
7672
CHECK_CUDA(q); // not necessarily contiguous
7773
CHECK_CUDA(k); // not necessarily contiguous
7874
CHECK_INPUT(pos_ids);
@@ -109,16 +105,12 @@ std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
109105
std::string(cudaGetErrorString(status)));
110106
return true;
111107
});
112-
113-
return {q_rope, k_rope};
114108
}
115109

116-
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
117-
torch::Tensor q_rope, torch::Tensor k_rope,
118-
torch::Tensor indptr, torch::Tensor offsets,
119-
bool interleave, float rope_scale, float rope_theta,
120-
float low_freq_factor, float high_freq_factor,
121-
float old_context_length) {
110+
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
111+
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
112+
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
113+
float high_freq_factor, float old_context_length) {
122114
CHECK_CUDA(q); // not necessarily contiguous
123115
CHECK_CUDA(k); // not necessarily contiguous
124116
CHECK_INPUT(indptr);
@@ -162,16 +154,12 @@ std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
162154
std::string(cudaGetErrorString(status)));
163155
return true;
164156
});
165-
166-
return {q_rope, k_rope};
167157
}
168158

169-
std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
170-
torch::Tensor q_rope, torch::Tensor k_rope,
171-
torch::Tensor pos_ids, bool interleave,
172-
float rope_scale, float rope_theta,
173-
float low_freq_factor, float high_freq_factor,
174-
float old_context_length) {
159+
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
160+
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
161+
float rope_scale, float rope_theta, float low_freq_factor,
162+
float high_freq_factor, float old_context_length) {
175163
CHECK_CUDA(q); // not necessarily contiguous
176164
CHECK_CUDA(k); // not necessarily contiguous
177165
CHECK_INPUT(pos_ids);
@@ -209,6 +197,4 @@ std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te
209197
std::string(cudaGetErrorString(status)));
210198
return true;
211199
});
212-
213-
return {q_rope, k_rope};
214200
}

0 commit comments

Comments
 (0)