@@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
61
61
torch::Tensor top_k_mask_logits (torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
62
62
unsigned int top_k_val);
63
63
64
- torch::Tensor chain_speculative_sampling (
65
- torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
66
- torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
67
- torch::Tensor output_emitted_token_num, bool deterministic);
64
+ torch::Tensor chain_speculative_sampling (torch::Tensor draft_probs, torch::Tensor draft_token_ids,
65
+ torch::Tensor uniform_samples, torch::Tensor target_probs,
66
+ torch::Tensor output_accepted_token_num,
67
+ torch::Tensor output_emitted_token_num,
68
+ bool deterministic);
68
69
69
70
void rmsnorm (torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);
70
71
@@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
82
83
83
84
void gelu_and_mul (torch::Tensor& out, torch::Tensor& input);
84
85
85
- void apply_rope_inplace (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86
- torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
87
-
88
- void apply_llama31_rope_inplace (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
89
- torch::Tensor offsets, bool interleave, float rope_scale,
90
- float rope_theta, float low_freq_factor, float high_freq_factor,
91
- float old_context_length);
92
-
93
- std::vector<torch::Tensor> apply_rope (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86
+ std::vector<torch::Tensor> apply_rope (torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
87
+ torch::Tensor k_rope, torch::Tensor indptr,
94
88
torch::Tensor offsets, bool interleave, float rope_scale,
95
89
float rope_theta);
96
90
97
91
std::vector<torch::Tensor> apply_llama31_rope (torch::Tensor q, torch::Tensor k,
92
+ torch::Tensor q_rope, torch::Tensor k_rope,
98
93
torch::Tensor indptr, torch::Tensor offsets,
99
94
bool interleave, float rope_scale, float rope_theta,
100
95
float low_freq_factor, float high_freq_factor,
101
96
float old_context_length);
102
97
98
+ std::vector<torch::Tensor> apply_rope_pos_ids (torch::Tensor q, torch::Tensor k,
99
+ torch::Tensor q_rope, torch::Tensor k_rope,
100
+ torch::Tensor pos_ids, bool interleave,
101
+ float rope_scale, float rope_theta);
102
+
103
+ std::vector<torch::Tensor> apply_llama31_rope_pos_ids (torch::Tensor q, torch::Tensor k,
104
+ torch::Tensor q_rope, torch::Tensor k_rope,
105
+ torch::Tensor pos_ids, bool interleave,
106
+ float rope_scale, float rope_theta,
107
+ float low_freq_factor, float high_freq_factor,
108
+ float old_context_length);
109
+
103
110
torch::Tensor packbits (torch::Tensor x, const std::string& bitorder);
104
111
105
112
torch::Tensor segment_packbits (torch::Tensor x, torch::Tensor input_indptr,
@@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
141
148
m.def (" silu_and_mul" , &silu_and_mul, " Fused SiLU and Mul" );
142
149
m.def (" gelu_tanh_and_mul" , &gelu_tanh_and_mul, " Fused GeLU Tanh and Mul" );
143
150
m.def (" gelu_and_mul" , &gelu_and_mul, " Fused GeLU and Mul" );
144
- m.def (" apply_rope_inplace" , &apply_rope_inplace, " Apply RoPE in-place" );
145
- m.def (" apply_llama31_rope_inplace" , &apply_llama31_rope_inplace,
146
- " Apply Llama 3.1 style RoPE in-place" );
147
151
m.def (" apply_rope" , &apply_rope, " Apply RoPE" );
148
152
m.def (" apply_llama31_rope" , &apply_llama31_rope, " Apply Llama 3.1 style RoPE" );
153
+ m.def (" apply_rope_pos_ids" , &apply_rope_pos_ids, " Apply RoPE with positional ids" );
154
+ m.def (" apply_llama31_rope_pos_ids" , &apply_llama31_rope_pos_ids,
155
+ " Apply Llama 3.1 style RoPE with positional ids" );
149
156
m.def (" packbits" , &packbits, " GPU packbits operator" );
150
157
m.def (" segment_packbits" , &segment_packbits, " GPU segment packbits operator" );
151
158
m.def (" cutlass_segment_gemm" , &CutlassSegmentGEMM, " Cutlass Segment GEMM operator" );
0 commit comments