Skip to content

Commit a059586

Browse files
authored
perf: speedup jit compilation of prefill attention kernels (#632)
Followup of #628, this PR splits prefill attention jit templates so that we compile different mask modes in different files. JIT compilation time of a prefill kernels of a certain configuration (shape, dtype etc) could be reduced to 10 seconds after this PR.
1 parent 5bf36ce commit a059586

File tree

2 files changed

+224
-18
lines changed

2 files changed

+224
-18
lines changed

python/flashinfer/jit/batch_prefill_templ.py

+123-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,102 @@
1414
limitations under the License.
1515
"""
1616

17-
import itertools
18-
1917
batch_prefill_suffix = [
2018
"_plan.cu",
19+
*[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
2120
"_ragged_run.cu",
21+
*[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
2222
"_paged_run.cu",
2323
"_pybind.cc",
2424
]
2525

26+
27+
def ragged_prefill_inst_templ(mask_mode: str) -> str:
28+
return (
29+
r"""#include <flashinfer/attention/prefill.cuh>
30+
#include <flashinfer/attention/prefill_params.cuh>
31+
#include <flashinfer/attention/variants.cuh>
32+
33+
namespace flashinfer {
34+
35+
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
36+
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
37+
constexpr bool use_custom_mask = """
38+
+ mask_mode
39+
+ r""" == MaskMode::kCustom;
40+
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
41+
42+
template
43+
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
44+
+ mask_mode
45+
+ r""", RaggedAttentionVariant>(
46+
typename RaggedAttentionVariant::ParamsT params,
47+
typename RaggedAttentionVariant::DTypeO* tmp_v,
48+
float* tmp_s, cudaStream_t stream);
49+
50+
template
51+
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
52+
+ mask_mode
53+
+ r""", RaggedAttentionVariant>(
54+
typename RaggedAttentionVariant::ParamsT params,
55+
typename RaggedAttentionVariant::DTypeO* tmp_v,
56+
float* tmp_s, cudaStream_t stream);
57+
58+
template
59+
cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
60+
+ mask_mode
61+
+ r""", RaggedAttentionVariant>(
62+
typename RaggedAttentionVariant::ParamsT params,
63+
typename RaggedAttentionVariant::DTypeO* tmp_v,
64+
float* tmp_s, cudaStream_t stream);
65+
}
66+
"""
67+
)
68+
69+
70+
def paged_prefill_inst_templ(mask_mode: str) -> str:
71+
return (
72+
r"""#include <flashinfer/attention/prefill.cuh>
73+
#include <flashinfer/attention/prefill_params.cuh>
74+
#include <flashinfer/attention/variants.cuh>
75+
76+
namespace flashinfer {
77+
78+
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
79+
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
80+
constexpr bool use_custom_mask = """
81+
+ mask_mode
82+
+ r""" == MaskMode::kCustom;
83+
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
84+
85+
template
86+
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
87+
+ mask_mode
88+
+ r""", PagedAttentionVariant>(
89+
typename PagedAttentionVariant::ParamsT params,
90+
typename PagedAttentionVariant::DTypeO* tmp_v,
91+
float* tmp_s, cudaStream_t stream);
92+
93+
template
94+
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
95+
+ mask_mode
96+
+ r""", PagedAttentionVariant>(
97+
typename PagedAttentionVariant::ParamsT params,
98+
typename PagedAttentionVariant::DTypeO* tmp_v,
99+
float* tmp_s, cudaStream_t stream);
100+
101+
template
102+
cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
103+
+ mask_mode
104+
+ r""", PagedAttentionVariant>(
105+
typename PagedAttentionVariant::ParamsT params,
106+
typename PagedAttentionVariant::DTypeO* tmp_v,
107+
float* tmp_s, cudaStream_t stream);
108+
}
109+
"""
110+
)
111+
112+
26113
batch_prefill_templ = [
27114
r"""#include <flashinfer/attention/scheduler.cuh>
28115
#include "pytorch_extension_utils.h"
@@ -60,10 +147,15 @@
60147
return plan_info.ToVector();
61148
}
62149
""",
150+
*[
151+
ragged_prefill_inst_templ(mask_mode)
152+
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
153+
],
63154
r"""
64155
#include <optional>
156+
#include <flashinfer/pos_enc.cuh>
65157
#include <flashinfer/attention/scheduler.cuh>
66-
#include <flashinfer/attention/prefill.cuh>
158+
#include <flashinfer/attention/mask.cuh>
67159
#include <flashinfer/attention/prefill_params.cuh>
68160
#include <flashinfer/attention/variants.cuh>
69161
#include "pytorch_extension_utils.h"
@@ -73,6 +165,16 @@
73165
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
74166
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
75167
168+
namespace flashinfer {
169+
170+
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
171+
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
172+
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params,
173+
typename AttentionVariant::DTypeO* tmp_v,
174+
float* tmp_s, cudaStream_t stream);
175+
176+
};
177+
76178
void BatchPrefillWithRaggedKVCacheRun(
77179
unsigned int mask_mode_code,
78180
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
@@ -153,7 +255,7 @@
153255
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
154256
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
155257
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
156-
status = BatchPrefillWithRaggedKVCacheDispatched<
258+
status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched<
157259
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
158260
params, tmp_v, tmp_s, stream);
159261
});
@@ -162,9 +264,14 @@
162264
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
163265
}
164266
""",
267+
*[
268+
paged_prefill_inst_templ(mask_mode)
269+
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
270+
],
165271
r"""#include <optional>
272+
#include <flashinfer/pos_enc.cuh>
166273
#include <flashinfer/attention/scheduler.cuh>
167-
#include <flashinfer/attention/prefill.cuh>
274+
#include <flashinfer/attention/mask.cuh>
168275
#include <flashinfer/attention/prefill_params.cuh>
169276
#include <flashinfer/attention/variants.cuh>
170277
#include "pytorch_extension_utils.h"
@@ -174,6 +281,16 @@
174281
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
175282
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
176283
284+
namespace flashinfer {
285+
286+
template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
287+
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
288+
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
289+
typename AttentionVariant::DTypeO* tmp_v,
290+
float* tmp_s, cudaStream_t stream);
291+
292+
};
293+
177294
void BatchPrefillWithPagedKVCacheRun(
178295
unsigned int mask_mode_code,
179296
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
@@ -274,7 +391,7 @@
274391
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
275392
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
276393
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
277-
status = BatchPrefillWithPagedKVCacheDispatched<
394+
status = flashinfer::BatchPrefillWithPagedKVCacheDispatched<
278395
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
279396
params, tmp_v, tmp_s, stream);
280397
});

python/flashinfer/jit/single_prefill_templ.py

+101-12
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,12 @@
1515
"""
1616

1717
single_prefill_suffix = [
18+
*[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]],
1819
".cu",
1920
"_pybind.cc",
2021
]
2122

22-
customizable_single_prefill_templ = [
23-
r"""
24-
#include <optional>
25-
#include <flashinfer/attention/prefill.cuh>
26-
#include "pytorch_extension_utils.h"
27-
28-
using namespace flashinfer;
29-
30-
23+
customizable_struct_templ = r"""
3124
struct SinglePrefillParams {
3225
using DTypeQ = {{ dtype_q }};
3326
using DTypeKV = {{ dtype_kv }};
@@ -82,10 +75,63 @@
8275
return kv_len;
8376
}
8477
};
78+
"""
79+
80+
81+
def customizable_single_prefill_inst_templ(mask_mode: str) -> str:
82+
return (
83+
r"""#include <flashinfer/attention/prefill.cuh>
84+
85+
using namespace flashinfer;
86+
"""
87+
+ customizable_struct_templ
88+
+ r"""{{ variant_decl }}
89+
using ParamsT = SinglePrefillParams;
90+
using AttentionVariant = {{ variant_name }}<ParamsT>;
91+
92+
namespace flashinfer {
93+
94+
template
95+
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, """
96+
f"{mask_mode}"
97+
r""", AttentionVariant>(
98+
typename AttentionVariant::ParamsT params,
99+
typename AttentionVariant::DTypeO* tmp,
100+
cudaStream_t stream);
101+
102+
};
103+
"""
104+
)
105+
106+
107+
customizable_single_prefill_templ = [
108+
*[
109+
customizable_single_prefill_inst_templ(mask_mode)
110+
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
111+
],
112+
r"""
113+
#include <optional>
114+
#include <flashinfer/pos_enc.cuh>
115+
#include <flashinfer/attention/mask.cuh>
116+
#include "pytorch_extension_utils.h"
85117
118+
using namespace flashinfer;
86119
120+
"""
121+
+ customizable_struct_templ
122+
+ r"""
87123
{{ variant_decl }}
88124
125+
namespace flashinfer {
126+
127+
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
128+
MaskMode MASK_MODE, typename AttentionVariant>
129+
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
130+
typename AttentionVariant::DTypeO* tmp,
131+
cudaStream_t stream);
132+
133+
}
134+
89135
at::Tensor single_prefill_with_kv_cache(
90136
unsigned int mask_mode_code, at::Tensor q, at::Tensor k, at::Tensor v,
91137
at::Tensor tmp, at::Tensor o, unsigned int layout, int32_t window_left,
@@ -155,10 +201,43 @@
155201
""",
156202
]
157203

204+
205+
def single_prefill_inst_templ(mask_mode: str) -> str:
206+
return (
207+
r"""#include <flashinfer/attention/prefill.cuh>
208+
#include <flashinfer/attention/prefill_params.cuh>
209+
#include <flashinfer/attention/variants.cuh>
210+
211+
namespace flashinfer {
212+
213+
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
214+
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
215+
constexpr bool use_custom_mask = """
216+
f"{mask_mode}"
217+
r"""== MaskMode::kCustom;
218+
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
219+
220+
template
221+
cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """
222+
f"{mask_mode}"
223+
r""", AttentionVariant>(
224+
typename AttentionVariant::ParamsT params,
225+
typename AttentionVariant::DTypeO* tmp,
226+
cudaStream_t stream);
227+
228+
}
229+
"""
230+
)
231+
232+
158233
single_prefill_templ = [
159-
r"""
160-
#include <optional>
161-
#include <flashinfer/attention/prefill.cuh>
234+
*[
235+
single_prefill_inst_templ(mask_mode)
236+
for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"]
237+
],
238+
r"""#include <optional>
239+
#include <flashinfer/pos_enc.cuh>
240+
#include <flashinfer/attention/mask.cuh>
162241
#include <flashinfer/attention/variants.cuh>
163242
#include <flashinfer/attention/prefill_params.cuh>
164243
#include "pytorch_extension_utils.h"
@@ -168,6 +247,16 @@
168247
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
169248
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
170249
250+
namespace flashinfer {
251+
252+
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
253+
MaskMode MASK_MODE, typename AttentionVariant>
254+
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params,
255+
typename AttentionVariant::DTypeO* tmp,
256+
cudaStream_t stream);
257+
258+
}
259+
171260
void single_prefill_with_kv_cache(
172261
unsigned int mask_mode_code,
173262
at::Tensor q, at::Tensor k, at::Tensor v, std::optional<at::Tensor> maybe_packed_custom_mask,

0 commit comments

Comments
 (0)