|
81 | 81 |
|
82 | 82 | {{ variant_decl }}
|
83 | 83 |
|
84 |
| -std::vector<torch::Tensor> single_prefill_with_kv_cache( |
85 |
| - torch::Tensor q, torch::Tensor k, torch::Tensor v, |
86 |
| - torch::Tensor tmp, unsigned int layout, int32_t window_left, bool return_lse{{ additional_func_params }}) { |
| 84 | +torch::Tensor single_prefill_with_kv_cache( |
| 85 | + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, |
| 86 | + torch::Tensor tmp, unsigned int layout, int32_t window_left, std::optional<torch::Tensor> maybe_lse{{ additional_func_params }}) { |
87 | 87 | auto device = q.device();
|
88 | 88 | unsigned int head_dim = q.size(2);
|
89 | 89 | unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
|
|
104 | 104 | }
|
105 | 105 | cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
|
106 | 106 | auto o = torch::empty_like(q, q.options());
|
107 |
| - torch::Tensor lse = torch::empty({0}); |
108 |
| - if (return_lse) { |
109 |
| - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); |
| 107 | + if (maybe_lse) { |
| 108 | + const auto& lse = *maybe_lse; |
| 109 | + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); |
| 110 | + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); |
| 111 | + TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32"); |
110 | 112 | }
|
111 | 113 |
|
112 | 114 | using ParamsT = SinglePrefillParams;
|
|
115 | 117 | static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
|
116 | 118 | static_cast<{{ dtype_kv }}*>(v.data_ptr()),
|
117 | 119 | static_cast<{{ dtype_o }}*>(o.data_ptr()),
|
118 |
| - /*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, |
| 120 | + /*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr), |
119 | 121 | num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
|
120 | 122 | kv_stride_n, kv_stride_h, head_dim, window_left{{ additional_params_data }});
|
121 | 123 |
|
122 |
| - cudaError_t status = |
123 |
| - SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, {{ mask_mode }}, AttentionVariant>( |
124 |
| - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); |
125 |
| - TORCH_CHECK(status == cudaSuccess, |
126 |
| - "SinglePrefillWithKVCache kernel launch failed, error: " + |
127 |
| - std::string(cudaGetErrorString(status))); |
| 124 | + MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code); |
128 | 125 |
|
129 |
| - if (return_lse) { |
130 |
| - return {o, lse}; |
131 |
| - } else { |
132 |
| - return {o}; |
133 |
| - } |
| 126 | + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { |
| 127 | + cudaError_t status = |
| 128 | + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, PosEncodingMode::kNone, false, MASK_MODE, AttentionVariant>( |
| 129 | + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); |
| 130 | + TORCH_CHECK(status == cudaSuccess, |
| 131 | + "SinglePrefillWithKVCache kernel launch failed, error: " + |
| 132 | + std::string(cudaGetErrorString(status))); |
| 133 | + }); |
| 134 | +
|
| 135 | + return o; |
134 | 136 | }
|
135 | 137 |
|
136 | 138 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
149 | 151 |
|
150 | 152 | using namespace flashinfer;
|
151 | 153 |
|
152 |
| -{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %} |
153 | 154 | {% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %}
|
154 | 155 | using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>;
|
155 |
| -using AttentionVariant = ComposedAttention<ParamsT, get_variant_code({{ use_custom_mask }}, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>; |
156 | 156 |
|
157 | 157 | torch::Tensor single_prefill_with_kv_cache(
|
| 158 | + unsigned int mask_mode_code, |
158 | 159 | torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_packed_custom_mask,
|
159 | 160 | torch::Tensor tmp, std::optional<torch::Tensor> maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
|
160 | 161 | float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
|
|
188 | 189 | ParamsT params(
|
189 | 190 | static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()),
|
190 | 191 | static_cast<{{ dtype_kv }}*>(v.data_ptr()),
|
191 |
| - {% if mask_mode == "MaskMode::kCustom" %}static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %}, |
| 192 | + /*custom_mask=*/(maybe_packed_custom_mask ? static_cast<uint8_t*>(maybe_packed_custom_mask->data_ptr()) : nullptr), |
192 | 193 | static_cast<{{ dtype_o }}*>(o.data_ptr()),
|
193 | 194 | /*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
|
194 | 195 | {% if use_alibi == "true" %}static_cast<float*>(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
|
195 | 196 | num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
|
196 | 197 | kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale,
|
197 | 198 | rope_scale, rope_theta);
|
198 | 199 |
|
199 |
| - cudaError_t status = |
200 |
| - SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, AttentionVariant>( |
201 |
| - params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); |
202 |
| - TORCH_CHECK(status == cudaSuccess, |
203 |
| - "SinglePrefillWithKVCache kernel launch failed, error: " + |
204 |
| - std::string(cudaGetErrorString(status))); |
| 200 | +
|
| 201 | + MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code); |
| 202 | +
|
| 203 | + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { |
| 204 | + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; |
| 205 | + using AttentionVariant = ComposedAttention<ParamsT, get_variant_code(use_custom_mask, {{ use_sliding_window }}, {{ use_logits_soft_cap }}, {{ use_alibi }})>; |
| 206 | + cudaError_t status = |
| 207 | + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, MASK_MODE, AttentionVariant>( |
| 208 | + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); |
| 209 | + TORCH_CHECK(status == cudaSuccess, |
| 210 | + "SinglePrefillWithKVCache kernel launch failed, error: " + |
| 211 | + std::string(cudaGetErrorString(status))); |
| 212 | + }); |
205 | 213 |
|
206 | 214 | return o;
|
207 | 215 | }
|
|
0 commit comments