|
15 | 15 | */
|
16 | 16 | #include <torch/extension.h>
|
17 | 17 |
|
18 |
| -std::vector<torch::Tensor> single_prefill_with_kv_cache( |
| 18 | +torch::Tensor single_prefill_with_kv_cache( |
19 | 19 | unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
|
20 | 20 | std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
|
21 | 21 | std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left,
|
22 |
| - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); |
| 22 | + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, |
| 23 | + std::optional<torch::Tensor> maybe_lse); |
23 | 24 |
|
24 | 25 | std::vector<int64_t> BatchPrefillWithKVCachePlan(
|
25 | 26 | unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
|
26 | 27 | torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr,
|
27 | 28 | torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
|
28 | 29 | unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);
|
29 | 30 |
|
30 |
| -std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun( |
| 31 | +torch::Tensor BatchPrefillWithRaggedKVCacheRun( |
31 | 32 | unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
|
32 | 33 | torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
|
33 | 34 | torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
|
34 | 35 | std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
|
35 | 36 | torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
|
36 | 37 | int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
|
37 |
| - bool return_lse); |
| 38 | + std::optional<torch::Tensor> maybe_lse); |
38 | 39 |
|
39 |
| -std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun( |
| 40 | +torch::Tensor BatchPrefillWithPagedKVCacheRun( |
40 | 41 | unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
|
41 | 42 | torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
|
42 | 43 | torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
|
43 | 44 | std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
|
44 | 45 | torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
|
45 | 46 | torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
|
46 | 47 | unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
|
47 |
| - float rope_scale, float rope_theta, bool return_lse); |
| 48 | + float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse); |
48 | 49 |
|
49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
50 | 51 | m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
|
|
0 commit comments