|
| 1 | +/* |
| 2 | + * Copyright (c) 2023 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <torch/extension.h> |
| 17 | + |
| 18 | +#include <flashinfer/attention/decode_params.cuh> |
| 19 | +#include <flashinfer/attention/scheduler.cuh> |
| 20 | +#include <flashinfer/attention/variants.cuh> |
| 21 | +#include <optional> |
| 22 | + |
| 23 | +#include "pytorch_extension_utils.h" |
| 24 | + |
| 25 | +namespace flashinfer { |
| 26 | + |
| 27 | +template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant> |
| 28 | +cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, |
| 29 | + typename AttentionVariant::DTypeO* tmp_v, |
| 30 | + float* tmp_s, cudaStream_t stream); |
| 31 | + |
| 32 | +} // namespace flashinfer |
| 33 | + |
| 34 | +std::vector<int64_t> BatchDecodeWithPagedKVCachePlan( |
| 35 | + bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, |
| 36 | + torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, |
| 37 | + torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, |
| 38 | + torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, |
| 39 | + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { |
| 40 | + size_t float_workspace_size_in_bytes = |
| 41 | + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); |
| 42 | + size_t int_workspace_size_in_bytes = |
| 43 | + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); |
| 44 | + auto device = float_workspace_buffer.device(); |
| 45 | + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); |
| 46 | + indptr = indptr.to(torch::kCPU); |
| 47 | + |
| 48 | + DecodePlanInfo plan_info; |
| 49 | + |
| 50 | + using IdType = int32_t; |
| 51 | + // check indptr has idtype int32 |
| 52 | + TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32"); |
| 53 | + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; |
| 54 | + |
| 55 | + auto q_scalar_type = empty_q_data.scalar_type(); |
| 56 | + auto kv_scalar_type = empty_kv_data.scalar_type(); |
| 57 | + |
| 58 | + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { |
| 59 | + using DTypeQ = q_type; |
| 60 | + using DTypeKV = kv_type; |
| 61 | + using DTypeO = DTypeQ; |
| 62 | + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { |
| 63 | + return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { |
| 64 | + using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>; |
| 65 | + using AttentionVariant = |
| 66 | + ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false, |
| 67 | + /*use_sliding_window=*/true, |
| 68 | + USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>; |
| 69 | + |
| 70 | + cudaError_t status = DecodePlan<HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>( |
| 71 | + static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, |
| 72 | + static_cast<void*>(int_workspace_buffer.data_ptr()), |
| 73 | + static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()), |
| 74 | + int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()), |
| 75 | + batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph, |
| 76 | + /*stream=*/torch_current_stream); |
| 77 | + |
| 78 | + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", |
| 79 | + cudaGetErrorString(status)); |
| 80 | + return true; |
| 81 | + }); |
| 82 | + }); |
| 83 | + }); |
| 84 | + |
| 85 | + return plan_info.ToVector(); |
| 86 | +} |
| 87 | + |
| 88 | +std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun( |
| 89 | + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, |
| 90 | + std::vector<int64_t> plan_info_vec, torch::Tensor q, |
| 91 | + std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache, |
| 92 | + std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr, |
| 93 | + torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, |
| 94 | + std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left, |
| 95 | + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { |
| 96 | + DecodePlanInfo plan_info; |
| 97 | + plan_info.FromVector(plan_info_vec); |
| 98 | + QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code); |
| 99 | + bool paged_kv_defined = paged_kv_cache.has_value(); |
| 100 | + auto device = q.device(); |
| 101 | + int64_t batch_size = q.size(0); |
| 102 | + int64_t num_qo_heads = q.size(1); |
| 103 | + int64_t num_kv_heads, page_size; |
| 104 | + if (paged_kv_defined) { |
| 105 | + if (kv_layout == QKVLayout::kHND) { |
| 106 | + num_kv_heads = paged_kv_cache->size(2); |
| 107 | + page_size = paged_kv_cache->size(3); |
| 108 | + } else { |
| 109 | + page_size = paged_kv_cache->size(2); |
| 110 | + num_kv_heads = paged_kv_cache->size(3); |
| 111 | + } |
| 112 | + } else { |
| 113 | + if (kv_layout == QKVLayout::kHND) { |
| 114 | + num_kv_heads = paged_k_cache->size(1); |
| 115 | + page_size = paged_k_cache->size(2); |
| 116 | + } else { |
| 117 | + page_size = paged_k_cache->size(1); |
| 118 | + num_kv_heads = paged_k_cache->size(2); |
| 119 | + } |
| 120 | + } |
| 121 | + uint32_t head_dim = q.size(2); |
| 122 | + |
| 123 | + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); |
| 124 | + torch::Tensor o = torch::empty_like(q); |
| 125 | + torch::Tensor lse; |
| 126 | + if (return_lse) { |
| 127 | + lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); |
| 128 | + } |
| 129 | + |
| 130 | + TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); |
| 131 | + |
| 132 | + void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr()); |
| 133 | + void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr()); |
| 134 | + |
| 135 | + using IdType = int32_t; |
| 136 | + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; |
| 137 | + |
| 138 | + // get q_scalar_type and kv_scalar_type |
| 139 | + auto q_scalar_type = q.scalar_type(); |
| 140 | + auto kv_scalar_type = |
| 141 | + paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); |
| 142 | + |
| 143 | + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { |
| 144 | + using DTypeQ = q_type; |
| 145 | + using DTypeKV = kv_type; |
| 146 | + using DTypeO = DTypeQ; |
| 147 | + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { |
| 148 | + return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { |
| 149 | + using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>; |
| 150 | + using AttentionVariant = |
| 151 | + ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false, |
| 152 | + /*use_sliding_window=*/true, |
| 153 | + USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>; |
| 154 | + |
| 155 | + paged_kv_t<DTypeKV, IdType> paged_kv( |
| 156 | + num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, |
| 157 | + static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() |
| 158 | + : nullptr), |
| 159 | + static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), |
| 160 | + static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), |
| 161 | + static_cast<IdType*>(paged_kv_indices.data_ptr()), |
| 162 | + static_cast<IdType*>(paged_kv_indptr.data_ptr()), |
| 163 | + static_cast<IdType*>(paged_kv_last_page_len.data_ptr())); |
| 164 | + ParamsT params(static_cast<DTypeQ*>(q.data_ptr()), |
| 165 | + /*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()), |
| 166 | + /*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr), |
| 167 | + /*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, |
| 168 | + sm_scale, rope_scale, rope_theta); |
| 169 | + |
| 170 | + DTypeO* tmp_v = nullptr; |
| 171 | + float* tmp_s = nullptr; |
| 172 | + params.request_indices = |
| 173 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset); |
| 174 | + params.kv_tile_indices = |
| 175 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset); |
| 176 | + params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset); |
| 177 | + params.kv_chunk_size_ptr = |
| 178 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset); |
| 179 | + if (plan_info.split_kv) { |
| 180 | + tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset); |
| 181 | + tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset); |
| 182 | + if (plan_info.enable_cuda_graph) { |
| 183 | + params.block_valid_mask = |
| 184 | + GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset); |
| 185 | + } |
| 186 | + } |
| 187 | + params.padded_batch_size = plan_info.padded_batch_size; |
| 188 | + |
| 189 | + cudaError_t status = |
| 190 | + flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE, |
| 191 | + AttentionVariant>( |
| 192 | + params, tmp_v, tmp_s, /*stream=*/torch_current_stream); |
| 193 | + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", |
| 194 | + cudaGetErrorString(status)); |
| 195 | + return true; |
| 196 | + }); |
| 197 | + }); |
| 198 | + }); |
| 199 | + |
| 200 | + if (return_lse) { |
| 201 | + return {o, lse}; |
| 202 | + } else { |
| 203 | + return {o}; |
| 204 | + } |
| 205 | +} |
0 commit comments