|
14 | 14 | limitations under the License.
|
15 | 15 | """
|
16 | 16 |
|
17 |
| -import itertools |
18 |
| - |
19 | 17 | batch_prefill_suffix = [
|
20 | 18 | "_plan.cu",
|
| 19 | + *[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], |
21 | 20 | "_ragged_run.cu",
|
| 21 | + *[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], |
22 | 22 | "_paged_run.cu",
|
23 | 23 | "_pybind.cc",
|
24 | 24 | ]
|
25 | 25 |
|
| 26 | + |
| 27 | +def ragged_prefill_inst_templ(mask_mode: str) -> str: |
| 28 | + return ( |
| 29 | + r"""#include <flashinfer/attention/prefill.cuh> |
| 30 | +#include <flashinfer/attention/prefill_params.cuh> |
| 31 | +#include <flashinfer/attention/variants.cuh> |
| 32 | +
|
| 33 | +namespace flashinfer { |
| 34 | +
|
| 35 | +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} |
| 36 | +using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; |
| 37 | +constexpr bool use_custom_mask = """ |
| 38 | + + mask_mode |
| 39 | + + r""" == MaskMode::kCustom; |
| 40 | +using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>; |
| 41 | +
|
| 42 | +template |
| 43 | +cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 44 | + + mask_mode |
| 45 | + + r""", RaggedAttentionVariant>( |
| 46 | + typename RaggedAttentionVariant::ParamsT params, |
| 47 | + typename RaggedAttentionVariant::DTypeO* tmp_v, |
| 48 | + float* tmp_s, cudaStream_t stream); |
| 49 | +
|
| 50 | +template |
| 51 | +cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 52 | + + mask_mode |
| 53 | + + r""", RaggedAttentionVariant>( |
| 54 | + typename RaggedAttentionVariant::ParamsT params, |
| 55 | + typename RaggedAttentionVariant::DTypeO* tmp_v, |
| 56 | + float* tmp_s, cudaStream_t stream); |
| 57 | +
|
| 58 | +template |
| 59 | +cudaError_t BatchPrefillWithRaggedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 60 | + + mask_mode |
| 61 | + + r""", RaggedAttentionVariant>( |
| 62 | + typename RaggedAttentionVariant::ParamsT params, |
| 63 | + typename RaggedAttentionVariant::DTypeO* tmp_v, |
| 64 | + float* tmp_s, cudaStream_t stream); |
| 65 | +} |
| 66 | +""" |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +def paged_prefill_inst_templ(mask_mode: str) -> str: |
| 71 | + return ( |
| 72 | + r"""#include <flashinfer/attention/prefill.cuh> |
| 73 | +#include <flashinfer/attention/prefill_params.cuh> |
| 74 | +#include <flashinfer/attention/variants.cuh> |
| 75 | +
|
| 76 | +namespace flashinfer { |
| 77 | +
|
| 78 | +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} |
| 79 | +using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; |
| 80 | +constexpr bool use_custom_mask = """ |
| 81 | + + mask_mode |
| 82 | + + r""" == MaskMode::kCustom; |
| 83 | +using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>; |
| 84 | +
|
| 85 | +template |
| 86 | +cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/16, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 87 | + + mask_mode |
| 88 | + + r""", PagedAttentionVariant>( |
| 89 | + typename PagedAttentionVariant::ParamsT params, |
| 90 | + typename PagedAttentionVariant::DTypeO* tmp_v, |
| 91 | + float* tmp_s, cudaStream_t stream); |
| 92 | +
|
| 93 | +template |
| 94 | +cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/64, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 95 | + + mask_mode |
| 96 | + + r""", PagedAttentionVariant>( |
| 97 | + typename PagedAttentionVariant::ParamsT params, |
| 98 | + typename PagedAttentionVariant::DTypeO* tmp_v, |
| 99 | + float* tmp_s, cudaStream_t stream); |
| 100 | +
|
| 101 | +template |
| 102 | +cudaError_t BatchPrefillWithPagedKVCacheDispatched</*cta_tile_q=*/128, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, """ |
| 103 | + + mask_mode |
| 104 | + + r""", PagedAttentionVariant>( |
| 105 | + typename PagedAttentionVariant::ParamsT params, |
| 106 | + typename PagedAttentionVariant::DTypeO* tmp_v, |
| 107 | + float* tmp_s, cudaStream_t stream); |
| 108 | +} |
| 109 | +""" |
| 110 | + ) |
| 111 | + |
| 112 | + |
26 | 113 | batch_prefill_templ = [
|
27 | 114 | r"""#include <flashinfer/attention/scheduler.cuh>
|
28 | 115 | #include "pytorch_extension_utils.h"
|
|
60 | 147 | return plan_info.ToVector();
|
61 | 148 | }
|
62 | 149 | """,
|
| 150 | + *[ |
| 151 | + ragged_prefill_inst_templ(mask_mode) |
| 152 | + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] |
| 153 | + ], |
63 | 154 | r"""
|
64 | 155 | #include <optional>
|
| 156 | +#include <flashinfer/pos_enc.cuh> |
65 | 157 | #include <flashinfer/attention/scheduler.cuh>
|
66 |
| -#include <flashinfer/attention/prefill.cuh> |
| 158 | +#include <flashinfer/attention/mask.cuh> |
67 | 159 | #include <flashinfer/attention/prefill_params.cuh>
|
68 | 160 | #include <flashinfer/attention/variants.cuh>
|
69 | 161 | #include "pytorch_extension_utils.h"
|
|
73 | 165 | {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
|
74 | 166 | using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
|
75 | 167 |
|
| 168 | +namespace flashinfer { |
| 169 | +
|
| 170 | +template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, |
| 171 | + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant> |
| 172 | +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params, |
| 173 | + typename AttentionVariant::DTypeO* tmp_v, |
| 174 | + float* tmp_s, cudaStream_t stream); |
| 175 | +
|
| 176 | +}; |
| 177 | +
|
76 | 178 | void BatchPrefillWithRaggedKVCacheRun(
|
77 | 179 | unsigned int mask_mode_code,
|
78 | 180 | at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
|
|
153 | 255 | constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
|
154 | 256 | using RaggedAttentionVariant = ComposedAttention<RaggedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
|
155 | 257 | DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
|
156 |
| - status = BatchPrefillWithRaggedKVCacheDispatched< |
| 258 | + status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< |
157 | 259 | CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, RaggedAttentionVariant>(
|
158 | 260 | params, tmp_v, tmp_s, stream);
|
159 | 261 | });
|
|
162 | 264 | TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status));
|
163 | 265 | }
|
164 | 266 | """,
|
| 267 | + *[ |
| 268 | + paged_prefill_inst_templ(mask_mode) |
| 269 | + for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] |
| 270 | + ], |
165 | 271 | r"""#include <optional>
|
| 272 | +#include <flashinfer/pos_enc.cuh> |
166 | 273 | #include <flashinfer/attention/scheduler.cuh>
|
167 |
| -#include <flashinfer/attention/prefill.cuh> |
| 274 | +#include <flashinfer/attention/mask.cuh> |
168 | 275 | #include <flashinfer/attention/prefill_params.cuh>
|
169 | 276 | #include <flashinfer/attention/variants.cuh>
|
170 | 277 | #include "pytorch_extension_utils.h"
|
|
174 | 281 | {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
|
175 | 282 | using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>;
|
176 | 283 |
|
| 284 | +namespace flashinfer { |
| 285 | +
|
| 286 | +template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, |
| 287 | + bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant> |
| 288 | +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, |
| 289 | + typename AttentionVariant::DTypeO* tmp_v, |
| 290 | + float* tmp_s, cudaStream_t stream); |
| 291 | +
|
| 292 | +}; |
| 293 | +
|
177 | 294 | void BatchPrefillWithPagedKVCacheRun(
|
178 | 295 | unsigned int mask_mode_code,
|
179 | 296 | at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
|
|
274 | 391 | constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom;
|
275 | 392 | using PagedAttentionVariant = ComposedAttention<PagedParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>;
|
276 | 393 | DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, {
|
277 |
| - status = BatchPrefillWithPagedKVCacheDispatched< |
| 394 | + status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< |
278 | 395 | CTA_TILE_Q, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, PagedAttentionVariant>(
|
279 | 396 | params, tmp_v, tmp_s, stream);
|
280 | 397 | });
|
|
0 commit comments