Skip to content

Commit 3613a5b

Browse files
authored
feat: JIT compilation (#507)
This PR implements the JIT compilation (#170 ) of flashinfer, after this PR, flashinfer will compile kernels just-in-time for different input data types and shapes, and cached the kernels at the disk, instead of pre-compile a set of kernels in the wheel. # Motivation The pip wheel size is exploding as we add support to more data types, more head dimensions, more attention variants and more kernel implementation. Pre-compile everything is not sustainable, and impedes development speed. This PR refactors the codebase to use torch's [JIT Compiling Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions) feature instead of pre-compile kernels in the wheel. ## Attention Variants We learned from [FlexAttention](https://pytorch.org/blog/flexattention/) and describes every attention variant as a template class, each instance of the struct can carry some closure variable defined in local memory or shared memory, below are two examples (logits soft cap and alibi attention, the programming interface is tentative and will be updated as we improve the programmability of the JIT template): ```cuda template <typename ParamsT> struct LogitsSoftCap { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return params.logits_soft_cap * math::log2e * float(math::tanh(logits)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; template <typename ParamsT> struct ALIBIAttention { using DTypeQ = typename ParamsT::DTypeQ; using DTypeKV = typename ParamsT::DTypeKV; using DTypeO = typename ParamsT::DTypeO; using IdType = typename ParamsT::IdType; uint32_t qo_len, kv_len; uint32_t window_left; __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { qo_len = params.get_qo_len(batch_idx); kv_len = params.get_kv_len(batch_idx); window_left = kv_len; } template <typename T> __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { return float(q) * params.sm_scale * math::log2e; } template <typename T> __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); } __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return true; } }; ``` User can customize their own `ParamsT` class and variants class to define their own attention variants, we hope such refactor will make the codebase more concise and extensive. # Roadmap After this PR, we will add support for: 1. PyPI wheels #153 2. fp8 tensor cores attention: #502 3. different head dimensions: #142 #454 #455 4. flashattention3 #369 5. multi-head latency attention #237 6. Generate ParamsT and Attention variants description from python dsl The development of this features have been blocked by the limitation of wheel size (binary size >= 2GB will trigger some linking issues), I hope this PR will make development easier in the future.
1 parent 2043692 commit 3613a5b

File tree

137 files changed

+6986
-6122
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+6986
-6122
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ src/dispatch.inc
1212
src/generated/
1313
python/csrc/generated/
1414
python/flashinfer/_build_meta.py
15+
python/flashinfer/jit/aot_config.py
16+
flashinfer-aot/csrc_aot/generated/
1517

1618
# Generated documentation files
1719
docs/generated

3rdparty/cutlass

Submodule cutlass updated 360 files

CMakeLists.txt

+114-128
Large diffs are not rendered by default.

cmake/config.cmake

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ set(FLASHINFER_FASTDEQUANT_TEST ON)
2424
set(FLASHINFER_DISTRIBUTED ON)
2525
# The following configurations can impact the binary
2626
# size of the generated library
27-
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
2827
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
2928
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
3029
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)

flashinfer-aot/3rdparty

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../3rdparty

flashinfer-aot/MANIFEST.in

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# sdist & wheel
2+
include version.txt
3+
recursive-include include *
4+
recursive-include csrc *
5+
recursive-include 3rdparty/cutlass *
6+
7+
# wheel-only
8+
exclude flashinfer/_build_meta.py
9+
10+
# Unneeded files
11+
prune */__pycache__
12+
global-exclude *.so

flashinfer-aot/csrc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../python/csrc

python/csrc/activation.cu flashinfer-aot/csrc_aot/activation.cu

+18-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,25 @@
1818

1919
#include <flashinfer/activation.cuh>
2020

21-
#include "flashinfer_ops.h"
2221
#include "pytorch_extension_utils.h"
2322

2423
using namespace flashinfer;
2524

25+
__device__ __forceinline__ float silu(const float& val) {
26+
return val / (1.0f + __expf(-val));
27+
}
28+
29+
__device__ __forceinline__ float gelu(const float& val) {
30+
constexpr float kAlpha = M_SQRT1_2;
31+
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
32+
}
33+
34+
__device__ __forceinline__ float gelu_tanh(const float& val) {
35+
const float cdf =
36+
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
37+
return val * cdf;
38+
}
39+
2640
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
2741
int d = input.size(-1) / 2;
2842
int64_t num_tokens = input.numel() / input.size(-1);
@@ -33,7 +47,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
3347
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
3448
uint32_t vec_size = 16 / sizeof(c_type);
3549
dim3 block(std::min(d / vec_size, 1024U));
36-
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::silu_kernel>
50+
flashinfer::activation::act_and_mul_kernel<c_type, silu>
3751
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
3852
static_cast<c_type*>(input.data_ptr()), d);
3953

@@ -51,7 +65,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
5165
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
5266
uint32_t vec_size = 16 / sizeof(c_type);
5367
dim3 block(std::min(d / vec_size, 1024U));
54-
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_tanh_kernel>
68+
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
5569
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
5670
static_cast<c_type*>(input.data_ptr()), d);
5771

@@ -69,7 +83,7 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) {
6983
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
7084
uint32_t vec_size = 16 / sizeof(c_type);
7185
dim3 block(std::min(d / vec_size, 1024U));
72-
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_kernel>
86+
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
7387
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
7488
static_cast<c_type*>(input.data_ptr()), d);
7589

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <torch/extension.h>
17+
18+
#include <flashinfer/attention/decode_params.cuh>
19+
#include <flashinfer/attention/scheduler.cuh>
20+
#include <flashinfer/attention/variants.cuh>
21+
#include <optional>
22+
23+
#include "pytorch_extension_utils.h"
24+
25+
namespace flashinfer {
26+
27+
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant>
28+
cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
29+
typename AttentionVariant::DTypeO* tmp_v,
30+
float* tmp_s, cudaStream_t stream);
31+
32+
} // namespace flashinfer
33+
34+
std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
35+
bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data,
36+
torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer,
37+
torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer,
38+
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
39+
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) {
40+
size_t float_workspace_size_in_bytes =
41+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
42+
size_t int_workspace_size_in_bytes =
43+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
44+
auto device = float_workspace_buffer.device();
45+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
46+
indptr = indptr.to(torch::kCPU);
47+
48+
DecodePlanInfo plan_info;
49+
50+
using IdType = int32_t;
51+
// check indptr has idtype int32
52+
TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32");
53+
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
54+
55+
auto q_scalar_type = empty_q_data.scalar_type();
56+
auto kv_scalar_type = empty_kv_data.scalar_type();
57+
58+
DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
59+
using DTypeQ = q_type;
60+
using DTypeKV = kv_type;
61+
using DTypeO = DTypeQ;
62+
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
63+
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
64+
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
65+
using AttentionVariant =
66+
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
67+
/*use_sliding_window=*/true,
68+
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;
69+
70+
cudaError_t status = DecodePlan<HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>(
71+
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
72+
static_cast<void*>(int_workspace_buffer.data_ptr()),
73+
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
74+
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
75+
batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph,
76+
/*stream=*/torch_current_stream);
77+
78+
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
79+
cudaGetErrorString(status));
80+
return true;
81+
});
82+
});
83+
});
84+
85+
return plan_info.ToVector();
86+
}
87+
88+
std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
89+
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
90+
std::vector<int64_t> plan_info_vec, torch::Tensor q,
91+
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
92+
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
93+
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
94+
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
95+
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
96+
DecodePlanInfo plan_info;
97+
plan_info.FromVector(plan_info_vec);
98+
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
99+
bool paged_kv_defined = paged_kv_cache.has_value();
100+
auto device = q.device();
101+
int64_t batch_size = q.size(0);
102+
int64_t num_qo_heads = q.size(1);
103+
int64_t num_kv_heads, page_size;
104+
if (paged_kv_defined) {
105+
if (kv_layout == QKVLayout::kHND) {
106+
num_kv_heads = paged_kv_cache->size(2);
107+
page_size = paged_kv_cache->size(3);
108+
} else {
109+
page_size = paged_kv_cache->size(2);
110+
num_kv_heads = paged_kv_cache->size(3);
111+
}
112+
} else {
113+
if (kv_layout == QKVLayout::kHND) {
114+
num_kv_heads = paged_k_cache->size(1);
115+
page_size = paged_k_cache->size(2);
116+
} else {
117+
page_size = paged_k_cache->size(1);
118+
num_kv_heads = paged_k_cache->size(2);
119+
}
120+
}
121+
uint32_t head_dim = q.size(2);
122+
123+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
124+
torch::Tensor o = torch::empty_like(q);
125+
torch::Tensor lse;
126+
if (return_lse) {
127+
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
128+
}
129+
130+
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
131+
132+
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
133+
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());
134+
135+
using IdType = int32_t;
136+
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
137+
138+
// get q_scalar_type and kv_scalar_type
139+
auto q_scalar_type = q.scalar_type();
140+
auto kv_scalar_type =
141+
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
142+
143+
DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
144+
using DTypeQ = q_type;
145+
using DTypeKV = kv_type;
146+
using DTypeO = DTypeQ;
147+
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
148+
return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] {
149+
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
150+
using AttentionVariant =
151+
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
152+
/*use_sliding_window=*/true,
153+
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;
154+
155+
paged_kv_t<DTypeKV, IdType> paged_kv(
156+
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
157+
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
158+
: nullptr),
159+
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr),
160+
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr),
161+
static_cast<IdType*>(paged_kv_indices.data_ptr()),
162+
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
163+
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
164+
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
165+
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
166+
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
167+
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
168+
sm_scale, rope_scale, rope_theta);
169+
170+
DTypeO* tmp_v = nullptr;
171+
float* tmp_s = nullptr;
172+
params.request_indices =
173+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
174+
params.kv_tile_indices =
175+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
176+
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
177+
params.kv_chunk_size_ptr =
178+
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset);
179+
if (plan_info.split_kv) {
180+
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
181+
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
182+
if (plan_info.enable_cuda_graph) {
183+
params.block_valid_mask =
184+
GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset);
185+
}
186+
}
187+
params.padded_batch_size = plan_info.padded_batch_size;
188+
189+
cudaError_t status =
190+
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE,
191+
AttentionVariant>(
192+
params, tmp_v, tmp_s, /*stream=*/torch_current_stream);
193+
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
194+
cudaGetErrorString(status));
195+
return true;
196+
});
197+
});
198+
});
199+
200+
if (return_lse) {
201+
return {o, lse};
202+
} else {
203+
return {o};
204+
}
205+
}

0 commit comments

Comments
 (0)