|
19 | 19 |
|
20 | 20 | using namespace flashinfer;
|
21 | 21 |
|
22 |
| -std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, |
23 |
| - torch::Tensor k_rope, torch::Tensor indptr, |
24 |
| - torch::Tensor offsets, bool interleave, float rope_scale, |
25 |
| - float rope_theta) { |
| 22 | +void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, |
| 23 | + torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, |
| 24 | + float rope_theta) { |
26 | 25 | CHECK_CUDA(q); // not necessarily contiguous
|
27 | 26 | CHECK_CUDA(k); // not necessarily contiguous
|
28 | 27 | CHECK_INPUT(indptr);
|
@@ -65,14 +64,11 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T
|
65 | 64 | std::string(cudaGetErrorString(status)));
|
66 | 65 | return true;
|
67 | 66 | });
|
68 |
| - |
69 |
| - return {q_rope, k_rope}; |
70 | 67 | }
|
71 | 68 |
|
72 |
| -std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, |
73 |
| - torch::Tensor q_rope, torch::Tensor k_rope, |
74 |
| - torch::Tensor pos_ids, bool interleave, |
75 |
| - float rope_scale, float rope_theta) { |
| 69 | +void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, |
| 70 | + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, |
| 71 | + float rope_scale, float rope_theta) { |
76 | 72 | CHECK_CUDA(q); // not necessarily contiguous
|
77 | 73 | CHECK_CUDA(k); // not necessarily contiguous
|
78 | 74 | CHECK_INPUT(pos_ids);
|
@@ -109,16 +105,12 @@ std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
|
109 | 105 | std::string(cudaGetErrorString(status)));
|
110 | 106 | return true;
|
111 | 107 | });
|
112 |
| - |
113 |
| - return {q_rope, k_rope}; |
114 | 108 | }
|
115 | 109 |
|
116 |
| -std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k, |
117 |
| - torch::Tensor q_rope, torch::Tensor k_rope, |
118 |
| - torch::Tensor indptr, torch::Tensor offsets, |
119 |
| - bool interleave, float rope_scale, float rope_theta, |
120 |
| - float low_freq_factor, float high_freq_factor, |
121 |
| - float old_context_length) { |
| 110 | +void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, |
| 111 | + torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, |
| 112 | + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, |
| 113 | + float high_freq_factor, float old_context_length) { |
122 | 114 | CHECK_CUDA(q); // not necessarily contiguous
|
123 | 115 | CHECK_CUDA(k); // not necessarily contiguous
|
124 | 116 | CHECK_INPUT(indptr);
|
@@ -162,16 +154,12 @@ std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
|
162 | 154 | std::string(cudaGetErrorString(status)));
|
163 | 155 | return true;
|
164 | 156 | });
|
165 |
| - |
166 |
| - return {q_rope, k_rope}; |
167 | 157 | }
|
168 | 158 |
|
169 |
| -std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, |
170 |
| - torch::Tensor q_rope, torch::Tensor k_rope, |
171 |
| - torch::Tensor pos_ids, bool interleave, |
172 |
| - float rope_scale, float rope_theta, |
173 |
| - float low_freq_factor, float high_freq_factor, |
174 |
| - float old_context_length) { |
| 159 | +void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, |
| 160 | + torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, |
| 161 | + float rope_scale, float rope_theta, float low_freq_factor, |
| 162 | + float high_freq_factor, float old_context_length) { |
175 | 163 | CHECK_CUDA(q); // not necessarily contiguous
|
176 | 164 | CHECK_CUDA(k); // not necessarily contiguous
|
177 | 165 | CHECK_INPUT(pos_ids);
|
@@ -209,6 +197,4 @@ std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te
|
209 | 197 | std::string(cudaGetErrorString(status)));
|
210 | 198 | return true;
|
211 | 199 | });
|
212 |
| - |
213 |
| - return {q_rope, k_rope}; |
214 | 200 | }
|
0 commit comments