Skip to content

Commit 9bf916f

Browse files
authored
feat: torch.compile and custom_op support (#554)
Follow up of #552. This PR adds torch library annotation to all FlashInfer kernels so that torch.compile can recognize the kernels. Most changes are tedious. I manually ran subsets of pytest test cases when I made these changes, but since there are too many of them and also some of them didn't pass even before I made the change, I cannot guarantee it's all working. To run tests with torch.compile, pass `FLASHINFER_TEST_TORCH_COMPILE=1` env. ```bash # With torch.compile FLASHINFER_TEST_TORCH_COMPILE=1 pytest -svx tests/test_norm.py # Without torch.compile pytest -svx tests/test_norm.py ``` Notable changes: * For the prefill and decode pybind, it used to return `Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]` depending on `return_lse`. This causes trouble for `torch.compile`. I changed the pybind interface to accept a `maybe_lse: Optional[torch.Tensor]` and only return one tensor. The allocation of the lse tensor is moved to Python side. The Python API does not change. * `chain_speculative_sampling` pybind: Move the allocation of `accepted` and `emitted` from C++ to Python. This is because `torch.compile` doesn't like returning input tensor as output tensor. The Python API does not change. Piggyback changes: * `BatchPrefillWithRaggedKVCacheWrapper.plan`: Bugfix qo_indptr not on CPU * `merge_state`: Fix typo in docs * Change `run_return_lse(...)` to `run(..., return_lse=True)` because torch.compile does not recognize `functools.partial`. * In tests, change `flashinfer.xxx()` to `flashinfer.<module>.xxx()` so that the monkeypatch works. Unsupported for torch.compile: * `flashinfer.quantization.segment_packbits`: Because it's data dependent. Untouched: * `sparse.py`: Tests didn't pass beforehand, so I skiped this. Also, it doesn't seem like need custom_op annotations, as it does not have CUDA kernels. Failed test cases: * batch_decode non contiguous kv: `test_batch_decode_with_paged_kv_cache[False-kv_dtype0-q_dtype0-True-0.0-NONE-NHD-128-4-4-1-54-12]`
1 parent 2989556 commit 9bf916f

24 files changed

+1291
-279
lines changed

flashinfer-aot/csrc_aot/batch_decode.cu

+9-11
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
8585
return plan_info.ToVector();
8686
}
8787

88-
std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
88+
torch::Tensor BatchDecodeWithPagedKVCacheRun(
8989
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
9090
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
9191
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
9292
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
9393
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
94-
float rope_scale, float rope_theta, bool return_lse) {
94+
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
9595
DecodePlanInfo plan_info;
9696
plan_info.FromVector(plan_info_vec);
9797
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
@@ -111,9 +111,11 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
111111

112112
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
113113
torch::Tensor o = torch::empty_like(q);
114-
torch::Tensor lse;
115-
if (return_lse) {
116-
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
114+
if (maybe_lse) {
115+
const auto& lse = *maybe_lse;
116+
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0));
117+
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
118+
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
117119
}
118120

119121
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
@@ -160,7 +162,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
160162
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
161163
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
162164
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
163-
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
165+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
164166
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
165167
logits_soft_cap, sm_scale, rope_scale, rope_theta);
166168

@@ -194,9 +196,5 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
194196
});
195197
});
196198

197-
if (return_lse) {
198-
return {o, lse};
199-
} else {
200-
return {o};
201-
}
199+
return o;
202200
}

flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
2929
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
3030
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);
3131

32-
std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
32+
torch::Tensor BatchDecodeWithPagedKVCacheRun(
3333
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
3434
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
3535
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
3636
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
3737
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
38-
float rope_scale, float rope_theta, bool return_lse);
38+
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);
3939

4040
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4141
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,

flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu

+7-6
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,37 @@
1515
*/
1616
#include <torch/extension.h>
1717

18-
std::vector<torch::Tensor> single_prefill_with_kv_cache(
18+
torch::Tensor single_prefill_with_kv_cache(
1919
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
2020
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
2121
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left,
22-
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
22+
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
23+
std::optional<torch::Tensor> maybe_lse);
2324

2425
std::vector<int64_t> BatchPrefillWithKVCachePlan(
2526
unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
2627
torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr,
2728
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
2829
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph);
2930

30-
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
31+
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
3132
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
3233
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
3334
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
3435
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
3536
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
3637
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
37-
bool return_lse);
38+
std::optional<torch::Tensor> maybe_lse);
3839

39-
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
40+
torch::Tensor BatchPrefillWithPagedKVCacheRun(
4041
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
4142
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
4243
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
4344
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
4445
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
4546
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
4647
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
47-
float rope_scale, float rope_theta, bool return_lse);
48+
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse);
4849

4950
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5051
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,

flashinfer-aot/csrc_aot/single_prefill.cu

+9-11
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
3232

3333
} // namespace flashinfer
3434

35-
std::vector<torch::Tensor> single_prefill_with_kv_cache(
35+
torch::Tensor single_prefill_with_kv_cache(
3636
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
3737
std::optional<torch::Tensor> maybe_packed_custom_mask, torch::Tensor tmp,
3838
std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout,
3939
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
40-
bool return_lse) {
40+
std::optional<torch::Tensor> maybe_lse) {
4141
auto device = q.device();
4242
unsigned int head_dim = q.size(2);
4343
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
@@ -58,9 +58,11 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
5858
}
5959
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6060
auto o = torch::empty_like(q, q.options());
61-
torch::Tensor lse = torch::empty({0});
62-
if (return_lse) {
63-
lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32));
61+
if (maybe_lse) {
62+
const auto& lse = *maybe_lse;
63+
TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), q.size(0));
64+
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
65+
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
6466
}
6567

6668
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
@@ -90,7 +92,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
9092
? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr())
9193
: nullptr,
9294
static_cast<DTypeO*>(o.data_ptr()),
93-
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
95+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
9496
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len,
9597
q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left,
9698
logits_soft_cap, sm_scale, rope_scale, rope_theta);
@@ -109,9 +111,5 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
109111
});
110112
});
111113

112-
if (return_lse) {
113-
return {o, lse};
114-
} else {
115-
return {o};
116-
}
114+
return o;
117115
}

python/csrc/flashinfer_sampling_ops.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
4747
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
4848
unsigned int top_k_val);
4949

50-
std::vector<torch::Tensor> chain_speculative_sampling(
50+
torch::Tensor chain_speculative_sampling(
5151
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
52-
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
53-
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
52+
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
53+
torch::Tensor output_emitted_token_num, bool deterministic);
5454

5555
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5656
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");

python/csrc/sampling.cu

+6-16
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
314314
return mask_logits;
315315
}
316316

317-
std::vector<torch::Tensor> chain_speculative_sampling(
317+
torch::Tensor chain_speculative_sampling(
318318
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
319-
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
320-
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic) {
319+
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
320+
torch::Tensor output_emitted_token_num, bool deterministic) {
321321
CHECK_INPUT(draft_probs);
322322
CHECK_INPUT(draft_token_ids);
323323
CHECK_INPUT(uniform_samples);
@@ -339,6 +339,8 @@ std::vector<torch::Tensor> chain_speculative_sampling(
339339
CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1));
340340
CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1));
341341
CHECK_EQ(vocab_size, target_probs.size(2));
342+
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
343+
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
342344

343345
draft_probs = draft_probs.to(torch::kFloat32);
344346
draft_token_ids = draft_token_ids.to(torch::kInt32);
@@ -349,18 +351,6 @@ std::vector<torch::Tensor> chain_speculative_sampling(
349351
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
350352
torch::dtype(torch::kInt32).device(device));
351353

352-
bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value();
353-
bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value();
354-
auto output_accepted_token_num = maybe_output_accepted_token_num.value_or(
355-
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
356-
auto output_emitted_token_num = maybe_output_emitted_token_num.value_or(
357-
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
358-
if (has_output_accepted_token_num) {
359-
CHECK_EQ(has_output_emitted_token_num, true);
360-
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
361-
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
362-
}
363-
364354
cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
365355
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
366356
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
@@ -372,5 +362,5 @@ std::vector<torch::Tensor> chain_speculative_sampling(
372362
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
373363
std::string(cudaGetErrorString(status)));
374364

375-
return {output_token_ids, output_accepted_token_num, output_emitted_token_num};
365+
return output_token_ids;
376366
}

python/flashinfer/activation.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Optional
17+
from types import SimpleNamespace
18+
19+
import torch
20+
1821
from .jit import (
19-
load_cuda_ops,
2022
FLASHINFER_GEN_SRC_DIR,
2123
gen_act_and_mul_cu,
2224
has_prebuilt_ops,
25+
load_cuda_ops,
2326
)
24-
25-
import torch
26-
27+
from .utils import register_custom_op, register_fake_op
2728

2829
silu_def_cu_str = r"""
2930
__device__ __forceinline__ float silu(const float& val) {
@@ -73,15 +74,31 @@ def get_act_and_mul_module(act_func_name: str):
7374
if has_prebuilt_ops:
7475
from . import _kernels
7576

76-
_jit_modules[act_func_name] = _kernels
77+
module = _kernels
7778
else:
78-
_jit_modules[act_func_name] = compile_act_and_mul_module(
79+
module = compile_act_and_mul_module(
7980
act_func_name, act_func_def_str[act_func_name]
8081
)
82+
83+
# torch library for act_and_mul
84+
fname = f"{act_func_name}_and_mul"
85+
fn = getattr(module, fname)
86+
87+
@register_custom_op(f"flashinfer::{fname}", mutates_args=("out",))
88+
def _act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
89+
fn(out, input)
90+
91+
@register_fake_op(f"flashinfer::{fname}")
92+
def _fake_act_and_mul(out: torch.Tensor, input: torch.Tensor) -> None:
93+
pass
94+
95+
# Register the module
96+
_jit_modules[act_func_name] = SimpleNamespace(**{fname: _act_and_mul})
97+
8198
return _jit_modules[act_func_name]
8299

83100

84-
def _check_shape(input: torch.Tensor, output: torch.Tensor):
101+
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
85102
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
86103
assert (
87104
input.shape[:-1] == output.shape[:-1]

0 commit comments

Comments
 (0)