Skip to content

Commit eaf73fd

Browse files
authored
perf: accelerate JIT compilation speed (#618)
Current JIT compilation is slow because we rely on a huge header `<torch/extension.h>` which is too heavy for our use case. This PR refactors the codebase to only include necessary headers for pybind, and moves most of torch runtime API calls from C++ to python. The compilation time was reduced from 48 seconds to 18 seconds for lightweight operators such as norm.
1 parent dd3c836 commit eaf73fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2420
-2101
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ src/generated/
1313
python/csrc/generated/
1414
python/flashinfer/_build_meta.py
1515
python/flashinfer/jit/aot_config.py
16-
python/csrc_aot/generated/
16+
python/csrc-aot/generated/
1717

1818
# Package files
1919
python/flashinfer/data/

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ repos:
5252
- id: clang-format
5353
types_or: [c++, c, cuda]
5454
exclude: |
55-
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py python/csrc_aot/generated/.*)$
55+
(?x)^(3rdparty/.* src/generated/.* python/flashinfer/jit/aot_config.py)$
5656
5757
- repo: https://github.com/cheshirekow/cmake-format-precommit
5858
rev: v0.6.13

include/flashinfer/allocator.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
#include <memory>
2020
#include <sstream>
21-
#include <stdexcept>
21+
22+
#include "exception.h"
2223

2324
namespace flashinfer {
2425

@@ -44,7 +45,7 @@ struct AlignedAllocator {
4445
std::ostringstream oss;
4546
oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment "
4647
<< alignment << " in AlignedAllocator";
47-
throw std::runtime_error(oss.str());
48+
FLASHINFER_ERROR(oss.str());
4849
}
4950
return nullptr;
5051
}

include/flashinfer/attention/decode.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT
687687
if (nblks.x == 0 || nblks.y == 0) {
688688
std::ostringstream err_msg;
689689
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
690-
throw std::runtime_error(err_msg.str());
690+
FLASHINFER_ERROR(err_msg.str());
691691
}
692692
dim3 nthrs = dim3(bdx, bdy, bdz);
693693
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);

include/flashinfer/attention/prefill.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
13751375
err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal "
13761376
"to qo_len, got kv_len"
13771377
<< kv_len << " and qo_len " << qo_len;
1378-
throw std::invalid_argument(err_msg.str());
1378+
FLASHINFER_ERROR(err_msg.str());
13791379
}
13801380

13811381
const uint32_t group_size = num_qo_heads / num_kv_heads;
@@ -1442,7 +1442,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
14421442
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
14431443
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
14441444
" and report the issue to the developers.";
1445-
throw std::invalid_argument(err_msg.str());
1445+
FLASHINFER_ERROR(err_msg.str());
14461446
} else {
14471447
constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE;
14481448
constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16;
@@ -2165,7 +2165,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
21652165
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
21662166
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
21672167
" and report the issue to the developers.";
2168-
throw std::invalid_argument(err_msg.str());
2168+
FLASHINFER_ERROR(err_msg.str());
21692169
} else {
21702170
// TODO(Zihao): fix the following computation
21712171
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +
@@ -2267,7 +2267,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
22672267
<< " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV
22682268
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
22692269
" and report the issue to the developers.";
2270-
throw std::invalid_argument(err_msg.str());
2270+
FLASHINFER_ERROR(err_msg.str());
22712271
} else {
22722272
// TODO(Zihao): fix the following computation
22732273
uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) +

include/flashinfer/attention/scheduler.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ struct DecodePlanInfo {
333333
if (vec.size() != 10) {
334334
std::ostringstream err_msg;
335335
err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size();
336-
throw std::invalid_argument(err_msg.str());
336+
FLASHINFER_ERROR(err_msg.str());
337337
}
338338
padded_batch_size = vec[0];
339339
v_offset = vec[1];
@@ -440,14 +440,14 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
440440
std::ostringstream err_msg;
441441
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
442442
<< qo_indptr_h[i] << " should be non-negative";
443-
throw std::invalid_argument(err_msg.str());
443+
FLASHINFER_ERROR(err_msg.str());
444444
}
445445
kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]);
446446
if (kv_len_arr[i] < 0) {
447447
std::ostringstream err_msg;
448448
err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]"
449449
<< kv_indptr_h[i] << " should be non-negative";
450-
throw std::invalid_argument(err_msg.str());
450+
FLASHINFER_ERROR(err_msg.str());
451451
}
452452
sum_packed_qo_len += packed_qo_len_arr[i];
453453
}
@@ -570,7 +570,7 @@ struct PrefillPlanInfo {
570570
if (vec.size() != 14) {
571571
std::ostringstream err_msg;
572572
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
573-
throw std::invalid_argument(err_msg.str());
573+
FLASHINFER_ERROR(err_msg.str());
574574
}
575575
padded_batch_size = vec[0];
576576
total_num_rows = vec[1];
@@ -601,7 +601,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
601601
std::ostringstream err_msg;
602602
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
603603
<< num_kv_heads;
604-
throw std::invalid_argument(err_msg.str());
604+
FLASHINFER_ERROR(err_msg.str());
605605
}
606606

607607
// step 0: get the number of SMs

include/flashinfer/exception.h

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_EXCEPTION_H_
17+
#define FLASHINFER_EXCEPTION_H_
18+
19+
#include <exception>
20+
#include <sstream>
21+
22+
namespace flashinfer {
23+
24+
class Error : public std::exception {
25+
private:
26+
std::string message_;
27+
28+
public:
29+
Error(const std::string& func, const std::string& file, int line, const std::string& message) {
30+
std::ostringstream oss;
31+
oss << "Error in function '" << func << "' "
32+
<< "at " << file << ":" << line << ": " << message;
33+
message_ = oss.str();
34+
}
35+
36+
virtual const char* what() const noexcept override { return message_.c_str(); }
37+
};
38+
39+
#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message)
40+
41+
#define FLASHINFER_CHECK(condition, message) \
42+
if (!(condition)) { \
43+
FLASHINFER_ERROR(message); \
44+
}
45+
46+
} // namespace flashinfer
47+
48+
#endif // FLASHINFER_EXCEPTION_H_

include/flashinfer/gemm/bmm_fp8.cuh

+11-9
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919
#include <cublasLt.h>
2020
#include <cuda_fp8.h>
2121

22-
#include <stdexcept>
22+
#include <iostream>
23+
#include <memory>
2324
#include <type_traits>
2425

25-
#define FLASHINFER_CUBLAS_CHECK(EXPR) \
26-
{ \
27-
cublasStatus_t e = (EXPR); \
28-
if (e != CUBLAS_STATUS_SUCCESS) { \
29-
throw std::runtime_error("CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
30-
} \
26+
#include "../exception.h"
27+
28+
#define FLASHINFER_CUBLAS_CHECK(EXPR) \
29+
{ \
30+
cublasStatus_t e = (EXPR); \
31+
FLASHINFER_CHECK(e == CUBLAS_STATUS_SUCCESS, \
32+
"CUBLAS Error: " + std::string(cublasGetStatusString(e))); \
3133
}
3234

3335
#ifndef NDEBUG
@@ -131,7 +133,7 @@ cudaDataType_t get_cuda_data_type() {
131133
} else if constexpr (std::is_same_v<T, half>) {
132134
return CUDA_R_16F;
133135
} else {
134-
throw std::runtime_error("Unsupported type");
136+
FLASHINFER_ERROR("Unsupported type");
135137
}
136138
}
137139

@@ -155,7 +157,7 @@ cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_
155157
cudaDataType_t b_type = get_cuda_data_type<BT>();
156158
cudaDataType_t d_type = get_cuda_data_type<DT>();
157159
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
158-
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
160+
FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2");
159161
}
160162

161163
auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);

include/flashinfer/gemm/group_gemm.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
7979
if (status != cutlass::Status::kSuccess) {
8080
std::ostringstream err_msg;
8181
err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status);
82-
throw std::runtime_error(err_msg.str());
82+
FLASHINFER_ERROR(err_msg.str());
8383
}
8484
status = gemm.run(stream);
8585
if (status != cutlass::Status::kSuccess) {
8686
std::ostringstream err_msg;
8787
err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status);
88-
throw std::runtime_error(err_msg.str());
88+
FLASHINFER_ERROR(err_msg.str());
8989
}
9090
});
9191

include/flashinfer/gemm/group_gemm_sm90.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_si
7373
sizeof(DTypeIn) == 1) {
7474
std::ostringstream err_msg;
7575
err_msg << "Row-major layout is not supported for fp8 data type";
76-
throw std::runtime_error(err_msg.str());
76+
FLASHINFER_ERROR(err_msg.str());
7777
} else {
7878
using LayoutA = cutlass::layout::RowMajor;
7979
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

include/flashinfer/math.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <cuda_fp16.h>
2020
#include <cuda_runtime.h>
2121

22+
#include <cstdint>
23+
2224
namespace flashinfer {
2325
namespace math {
2426

include/flashinfer/utils.cuh

+11-11
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
#include <cstdint>
2525
#include <iostream>
26-
#include <sstream>
27-
#include <stdexcept>
2826
#include <vector>
2927

28+
#include "exception.h"
29+
3030
#define STR_HELPER(x) #x
3131
#define STR(x) STR_HELPER(x)
3232

@@ -57,7 +57,7 @@
5757

5858
#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
5959
if (allow_fp16_qk_reduction) { \
60-
throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
60+
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
6161
} else { \
6262
constexpr bool ALLOW_FP16_QK_REDUCTION = false; \
6363
__VA_ARGS__ \
@@ -73,7 +73,7 @@
7373
} else { \
7474
std::ostringstream err_msg; \
7575
err_msg << "Unsupported num_frags_q: " << num_frags_q; \
76-
throw std::invalid_argument(err_msg.str()); \
76+
FLASHINFER_ERROR(err_msg.str()); \
7777
}
7878

7979
#define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \
@@ -92,7 +92,7 @@
9292
} else { \
9393
std::ostringstream err_msg; \
9494
err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \
95-
throw std::invalid_argument(err_msg.str()); \
95+
FLASHINFER_ERROR(err_msg.str()); \
9696
}
9797

9898
#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
@@ -115,7 +115,7 @@
115115
default: { \
116116
std::ostringstream err_msg; \
117117
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
118-
throw std::invalid_argument(err_msg.str()); \
118+
FLASHINFER_ERROR(err_msg.str()); \
119119
} \
120120
}
121121

@@ -138,7 +138,7 @@
138138
} else { \
139139
std::ostringstream err_msg; \
140140
err_msg << "Unsupported group_size: " << group_size; \
141-
throw std::invalid_argument(err_msg.str()); \
141+
FLASHINFER_ERROR(err_msg.str()); \
142142
}
143143

144144
#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
@@ -161,7 +161,7 @@
161161
default: { \
162162
std::ostringstream err_msg; \
163163
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
164-
throw std::invalid_argument(err_msg.str()); \
164+
FLASHINFER_ERROR(err_msg.str()); \
165165
} \
166166
}
167167

@@ -190,7 +190,7 @@
190190
default: { \
191191
std::ostringstream err_msg; \
192192
err_msg << "Unsupported head_dim: " << head_dim; \
193-
throw std::invalid_argument(err_msg.str()); \
193+
FLASHINFER_ERROR(err_msg.str()); \
194194
} \
195195
}
196196

@@ -214,7 +214,7 @@
214214
default: { \
215215
std::ostringstream err_msg; \
216216
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
217-
throw std::invalid_argument(err_msg.str()); \
217+
FLASHINFER_ERROR(err_msg.str()); \
218218
} \
219219
}
220220

@@ -248,7 +248,7 @@
248248
default: { \
249249
std::ostringstream err_msg; \
250250
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
251-
throw std::invalid_argument(err_msg.str()); \
251+
FLASHINFER_ERROR(err_msg.str()); \
252252
} \
253253
}
254254

python/aot_MANIFEST.in

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
prune */__pycache__
44
prune csrc
5-
prune csrc_aot
65
exclude aot_setup.py
76
exclude setup.py
87

python/aot_setup.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
6464

6565

6666
def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]:
67-
path = root / "python" / "csrc_aot" / "generated"
67+
path = root / "python" / "csrc" / "generated"
6868
path.mkdir(parents=True, exist_ok=True)
6969

7070
head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")
@@ -423,12 +423,12 @@ def ln(src: str, dst: str, is_dir: bool = False) -> None:
423423
"csrc/quantization.cu",
424424
"csrc/rope.cu",
425425
"csrc/sampling.cu",
426-
"csrc_aot/activation.cu",
427-
"csrc_aot/batch_decode.cu",
428-
"csrc_aot/batch_prefill.cu",
429-
"csrc_aot/flashinfer_ops.cu",
430-
"csrc_aot/single_decode.cu",
431-
"csrc_aot/single_prefill.cu",
426+
"csrc/activation.cu",
427+
"csrc/batch_decode.cu",
428+
"csrc/batch_prefill.cu",
429+
"csrc/single_decode.cu",
430+
"csrc/single_prefill.cu",
431+
"csrc/flashinfer_ops.cu",
432432
]
433433
+ files_decode
434434
+ files_prefill,

0 commit comments

Comments
 (0)