Skip to content

Commit fe4f898

Browse files
authored
feat: simplify prefill JIT compilation (#605)
Compile all three mask modes (causal/non-causal/custom) altogether instead of compiling them one-by-one.
1 parent bb67144 commit fe4f898

10 files changed

+190
-122
lines changed

python/flashinfer/decode.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,12 @@ def single_decode_with_kv_cache(
400400
q.dtype,
401401
head_dim,
402402
PosEncodingMode[pos_encoding_mode].value,
403-
MaskMode.NON_CAUSAL.value,
404403
window_left != -1, # use_sliding_window
405404
logits_soft_cap > 0, # use_logits_soft_cap
406405
False, # allow_fp16_qk_reduction
407406
)
408407
.run(
408+
MaskMode.NON_CAUSAL.value,
409409
q.unsqueeze(0),
410410
k,
411411
v,
@@ -418,7 +418,7 @@ def single_decode_with_kv_cache(
418418
sm_scale,
419419
rope_scale,
420420
rope_theta,
421-
False, # return_lse
421+
None, # maybe_lse
422422
)[0]
423423
.squeeze(0)
424424
)
@@ -743,8 +743,10 @@ def plan(
743743

744744
indptr_host = indptr.to("cpu")
745745
if data_type is not None:
746-
q_data_type = data_type
747-
kv_data_type = data_type
746+
if q_data_type is None:
747+
q_data_type = data_type
748+
if kv_data_type is None:
749+
kv_data_type = data_type
748750

749751
q_data_type = canonicalize_torch_dtype(q_data_type)
750752
if kv_data_type is None:
@@ -761,7 +763,6 @@ def plan(
761763
indptr.dtype,
762764
head_dim,
763765
PosEncodingMode[pos_encoding_mode].value,
764-
MaskMode.NON_CAUSAL.value,
765766
window_left != -1, # use_sliding_window
766767
logits_soft_cap > 0, # use_logits_soft_cap
767768
False, # allow_fp16_qk_reduction
@@ -938,6 +939,7 @@ def run(
938939

939940
if self.use_tensor_cores:
940941
out = self._cached_module.paged_run(
942+
MaskMode.NON_CAUSAL.value,
941943
self._float_workspace_buffer,
942944
self._int_workspace_buffer,
943945
self._plan_info,

python/flashinfer/jit/attention.py

-11
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .utils import (
3434
dtype_map,
3535
filename_safe_dtype_map,
36-
mask_mode_literal,
3736
pos_encoding_mode_literal,
3837
write_if_different,
3938
)
@@ -216,7 +215,6 @@ def get_single_prefill_cu_str(
216215
dtype_o: torch.dtype,
217216
head_dim: int,
218217
pos_encoding_mode: int,
219-
mask_mode: int,
220218
use_sliding_window: bool,
221219
use_logits_soft_cap: bool,
222220
use_fp16_qk_reduction: bool,
@@ -228,7 +226,6 @@ def get_single_prefill_cu_str(
228226
dtype_o=dtype_map[dtype_o],
229227
head_dim=head_dim,
230228
pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode],
231-
mask_mode=mask_mode_literal[mask_mode],
232229
use_sliding_window="true" if use_sliding_window else "false",
233230
use_logits_soft_cap="true" if use_logits_soft_cap else "false",
234231
use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false",
@@ -241,7 +238,6 @@ def get_single_prefill_uri(
241238
dtype_o: torch.dtype,
242239
head_dim: int,
243240
pos_encoding_mode: int,
244-
mask_mode: int,
245241
use_sliding_window: bool,
246242
use_logits_soft_cap: bool,
247243
use_fp16_qk_reduction: bool,
@@ -252,7 +248,6 @@ def get_single_prefill_uri(
252248
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
253249
f"head_dim_{head_dim}_"
254250
f"posenc_{pos_encoding_mode}_"
255-
f"mask_{mask_mode}_"
256251
f"use_swa_{use_sliding_window}_"
257252
f"use_logits_cap_{use_logits_soft_cap}_"
258253
f"f16qk_{use_fp16_qk_reduction}"
@@ -280,7 +275,6 @@ def get_batch_prefill_cu_str(
280275
dtype_idx: torch.dtype,
281276
head_dim: int,
282277
pos_encoding_mode: int,
283-
mask_mode: int,
284278
use_sliding_window: bool,
285279
use_logits_soft_cap: bool,
286280
use_fp16_qk_reduction: bool,
@@ -293,7 +287,6 @@ def get_batch_prefill_cu_str(
293287
dtype_idx=dtype_map[dtype_idx],
294288
head_dim=head_dim,
295289
pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode],
296-
mask_mode=mask_mode_literal[mask_mode],
297290
use_sliding_window="true" if use_sliding_window else "false",
298291
use_logits_soft_cap="true" if use_logits_soft_cap else "false",
299292
use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false",
@@ -307,7 +300,6 @@ def get_batch_prefill_uri(
307300
dtype_idx: torch.dtype,
308301
head_dim: int,
309302
pos_encoding_mode: int,
310-
mask_mode: int,
311303
use_sliding_window: bool,
312304
use_logits_soft_cap: bool,
313305
use_fp16_qk_reduction: bool,
@@ -319,7 +311,6 @@ def get_batch_prefill_uri(
319311
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
320312
f"head_dim_{head_dim}_"
321313
f"posenc_{pos_encoding_mode}_"
322-
f"mask_{mask_mode}_"
323314
f"use_swa_{use_sliding_window}_"
324315
f"use_logits_cap_{use_logits_soft_cap}_"
325316
f"f16qk_{use_fp16_qk_reduction}"
@@ -424,7 +415,6 @@ def get_customize_single_prefill_cu_str(
424415
dtype_kv: torch.dtype,
425416
dtype_o: torch.dtype,
426417
head_dim: int,
427-
mask_mode: int,
428418
additional_input_tensor_var_names: List[str],
429419
additional_input_tensor_var_types: List[str],
430420
additional_input_scalar_var_names: List[str],
@@ -489,7 +479,6 @@ def get_customize_single_prefill_cu_str(
489479
dtype_kv=dtype_map[dtype_kv],
490480
dtype_o=dtype_map[dtype_o],
491481
head_dim=head_dim,
492-
mask_mode=mask_mode_literal[mask_mode],
493482
additional_params_decl=additional_params_decl,
494483
additional_params=additional_params,
495484
additional_params_init=additional_params_init,

python/flashinfer/jit/batch_prefill_templ.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,9 @@
2525
2626
using namespace flashinfer;
2727
28-
{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %}
2928
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
3029
using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
31-
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
3230
using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
33-
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
3431
3532
std::vector<int64_t> BatchPrefillWithKVCachePlan(
3633
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
@@ -68,6 +65,7 @@
6865
}
6966
7067
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
68+
unsigned int mask_mode_code,
7169
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
7270
std::vector<int64_t> plan_info_vec,
7371
torch::Tensor q, torch::Tensor k, torch::Tensor v,
@@ -109,10 +107,10 @@
109107
RaggedParamsT params(
110108
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
111109
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
112-
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %},
110+
/*custom_mask=*/(maybe_custom_mask ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr()) : nullptr),
113111
static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()),
114112
static_cast<{{ dtype_idx }}*>(kv_indptr.data_ptr()),
115-
{% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %},
113+
/*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr),
116114
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
117115
static_cast<{{ dtype_o }}*>(o.data_ptr()),
118116
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
@@ -141,10 +139,16 @@
141139
142140
cudaError_t status = cudaSuccess;
143141
144-
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
145-
status = BatchPrefillWithRaggedKVCacheDispatched<
146-
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, RaggedAttentionVariant>(
147-
params, tmp_v, tmp_s, torch_current_stream);
142+
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
143+
144+
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
145+
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
146+
using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
147+
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
148+
status = BatchPrefillWithRaggedKVCacheDispatched<
149+
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
150+
params, tmp_v, tmp_s, torch_current_stream);
151+
});
148152
});
149153
150154
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
@@ -153,6 +157,7 @@
153157
}
154158
155159
torch::Tensor BatchPrefillWithPagedKVCacheRun(
160+
unsigned int mask_mode_code,
156161
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
157162
std::vector<int64_t> plan_info_vec,
158163
torch::Tensor q,
@@ -215,9 +220,9 @@
215220
216221
PagedParamsT params(
217222
static_cast<{{ dtype_q }}*>(q.data_ptr()), paged_kv,
218-
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %},
223+
/*custom_mask=*/(maybe_custom_mask ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr()) : nullptr),
219224
static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()),
220-
{% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %},
225+
/*qk_indptr=*/(maybe_qk_indptr ? static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()) : nullptr),
221226
/*q_offset=*/nullptr,
222227
static_cast<{{ dtype_o }}*>(o.data_ptr()),
223228
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
@@ -245,10 +250,16 @@
245250
246251
cudaError_t status = cudaSuccess;
247252
248-
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
249-
status = BatchPrefillWithPagedKVCacheDispatched<
250-
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, PagedAttentionVariant>(
251-
params, tmp_v, tmp_s, torch_current_stream);
253+
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
254+
255+
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
256+
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
257+
using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
258+
DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
259+
status = BatchPrefillWithPagedKVCacheDispatched<
260+
CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
261+
params, tmp_v, tmp_s, torch_current_stream);
262+
});
252263
});
253264
254265
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status));

python/flashinfer/jit/single_prefill_templ.py

+35-27
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@
8181
8282
{{ variant_decl }}
8383
84-
std::vector<torch::Tensor> single_prefill_with_kv_cache(
85-
torch::Tensor q, torch::Tensor k, torch::Tensor v,
86-
torch::Tensor tmp, unsigned int layout, int32_t window_left, bool return_lse{{ additional_func_params }}) {
84+
torch::Tensor single_prefill_with_kv_cache(
85+
unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v,
86+
torch::Tensor tmp, unsigned int layout, int32_t window_left, std::optional<torch::Tensor> maybe_lse{{ additional_func_params }}) {
8787
auto device = q.device();
8888
unsigned int head_dim = q.size(2);
8989
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
@@ -104,9 +104,11 @@
104104
}
105105
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
106106
auto o = torch::empty_like(q, q.options());
107-
torch::Tensor lse = torch::empty({0});
108-
if (return_lse) {
109-
lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32));
107+
if (maybe_lse) {
108+
const auto& lse = *maybe_lse;
109+
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
110+
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
111+
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
110112
}
111113
112114
using ParamsT = SinglePrefillParams;
@@ -115,22 +117,22 @@
115117
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
116118
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
117119
static_cast<{{ dtype_o }}*>(o.data_ptr()),
118-
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
120+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
119121
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
120122
kv_stride_n, kv_stride_h, head_dim, window_left{{ additional_params_data }});
121123
122-
cudaError_t status =
123-
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, {{ mask_mode }}, AttentionVariant>(
124-
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
125-
TORCH_CHECK(status == cudaSuccess,
126-
"SinglePrefillWithKVCache kernel launch failed, error: " +
127-
std::string(cudaGetErrorString(status)));
124+
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
128125
129-
if (return_lse) {
130-
return {o, lse};
131-
} else {
132-
return {o};
133-
}
126+
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
127+
cudaError_t status =
128+
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, MASK_MODE, AttentionVariant>(
129+
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
130+
TORCH_CHECK(status == cudaSuccess,
131+
"SinglePrefillWithKVCache kernel launch failed, error: " +
132+
std::string(cudaGetErrorString(status)));
133+
});
134+
135+
return o;
134136
}
135137
136138
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -149,12 +151,11 @@
149151
150152
using namespace flashinfer;
151153
152-
{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %}
153154
{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
154155
using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
155-
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
156156
157157
torch::Tensor single_prefill_with_kv_cache(
158+
unsigned int mask_mode_code,
158159
torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_packed_custom_mask,
159160
torch::Tensor tmp, std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
160161
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
@@ -188,20 +189,27 @@
188189
ParamsT params(
189190
static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
190191
static_cast<{{ dtype_kv }}*>(v.data_ptr()),
191-
{% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %},
192+
/*custom_mask=*/(maybe_packed_custom_mask ? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()) : nullptr),
192193
static_cast<{{ dtype_o }}*>(o.data_ptr()),
193194
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
194195
{% if use_alibi == "true" %}static_cast<float*>(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
195196
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
196197
kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale,
197198
rope_scale, rope_theta);
198199
199-
cudaError_t status =
200-
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, AttentionVariant>(
201-
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
202-
TORCH_CHECK(status == cudaSuccess,
203-
"SinglePrefillWithKVCache kernel launch failed, error: " +
204-
std::string(cudaGetErrorString(status)));
200+
201+
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
202+
203+
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {
204+
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
205+
using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
206+
cudaError_t status =
207+
SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, AttentionVariant>(
208+
params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream);
209+
TORCH_CHECK(status == cudaSuccess,
210+
"SinglePrefillWithKVCache kernel launch failed, error: " +
211+
std::string(cudaGetErrorString(status)));
212+
});
205213
206214
return o;
207215
}

0 commit comments

Comments
 (0)