|
| 1 | +/* |
| 2 | + * Copyright (c) 2024 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#ifndef FLASHINFER_BMM_FP8_CUH_ |
| 17 | +#define FLASHINFER_BMM_FP8_CUH_ |
| 18 | + |
| 19 | +#include <ATen/cuda/Exceptions.h> |
| 20 | +#include <c10/cuda/CUDACachingAllocator.h> |
| 21 | +#include <cublasLt.h> |
| 22 | +#include <cuda_fp8.h> |
| 23 | +#include <torch/extension.h> |
| 24 | + |
| 25 | +#include <stdexcept> |
| 26 | +#include <type_traits> |
| 27 | + |
| 28 | +namespace flashinfer { |
| 29 | + |
| 30 | +namespace bmm_fp8 { |
| 31 | + |
| 32 | +template <typename T, cublasStatus_t (*destructor)(T*)> |
| 33 | +struct CuBlasLtDeleter { |
| 34 | + void operator()(T* x) { |
| 35 | + if (x != nullptr) { |
| 36 | + TORCH_CUDABLAS_CHECK(destructor(x)); |
| 37 | + } |
| 38 | + } |
| 39 | +}; |
| 40 | + |
| 41 | +template <typename T, cublasStatus_t (*destructor)(T*)> |
| 42 | +class CuBlasLtDescriptor { |
| 43 | + public: |
| 44 | + T* descriptor() const { return descriptor_.get(); } |
| 45 | + T* descriptor() { return descriptor_.get(); } |
| 46 | + |
| 47 | + protected: |
| 48 | + std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_; |
| 49 | +}; |
| 50 | + |
| 51 | +class CuBlasLtMatmulDescriptor |
| 52 | + : public CuBlasLtDescriptor<cublasLtMatmulDescOpaque_t, &cublasLtMatmulDescDestroy> { |
| 53 | + public: |
| 54 | + CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) { |
| 55 | + cublasLtMatmulDesc_t raw_descriptor = nullptr; |
| 56 | + TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); |
| 57 | + descriptor_.reset(raw_descriptor); |
| 58 | + } |
| 59 | + template <typename T> |
| 60 | + inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) { |
| 61 | + TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); |
| 62 | + } |
| 63 | +}; |
| 64 | + |
| 65 | +class CuBlasLtMatrixLayout |
| 66 | + : public CuBlasLtDescriptor<cublasLtMatrixLayoutOpaque_t, &cublasLtMatrixLayoutDestroy> { |
| 67 | + public: |
| 68 | + CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld, |
| 69 | + bool t = false) { |
| 70 | + cublasLtMatrixLayout_t raw_descriptor = nullptr; |
| 71 | + TORCH_CUDABLAS_CHECK( |
| 72 | + cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); |
| 73 | + descriptor_.reset(raw_descriptor); |
| 74 | + } |
| 75 | + template <typename T> |
| 76 | + inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) { |
| 77 | + TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); |
| 78 | + } |
| 79 | +}; |
| 80 | + |
| 81 | +class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPreferenceOpaque_t, |
| 82 | + &cublasLtMatmulPreferenceDestroy> { |
| 83 | + public: |
| 84 | + CuBlasLtMatmulPreference() { |
| 85 | + cublasLtMatmulPreference_t raw_descriptor = nullptr; |
| 86 | + TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor)); |
| 87 | + descriptor_.reset(raw_descriptor); |
| 88 | + } |
| 89 | + template <typename T> |
| 90 | + inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) { |
| 91 | + TORCH_CUDABLAS_CHECK( |
| 92 | + ::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); |
| 93 | + } |
| 94 | +}; |
| 95 | + |
| 96 | +template <typename T> |
| 97 | +cudaDataType_t get_cuda_data_type() { |
| 98 | + if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) { |
| 99 | + return CUDA_R_8F_E4M3; |
| 100 | + } else if constexpr (std::is_same_v<T, __nv_fp8_e5m2>) { |
| 101 | + return CUDA_R_8F_E5M2; |
| 102 | + } else if constexpr (std::is_same_v<T, __nv_bfloat16>) { |
| 103 | + return CUDA_R_16BF; |
| 104 | + } else if constexpr (std::is_same_v<T, half>) { |
| 105 | + return CUDA_R_16F; |
| 106 | + } else { |
| 107 | + throw std::runtime_error("Unsupported type"); |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +template <typename AT, typename BT, typename DT> |
| 112 | +void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k, |
| 113 | + const float* A_scale, const float* B_scale) { |
| 114 | + const void* A_scale_ptr = static_cast<const void*>(A_scale); |
| 115 | + const void* B_scale_ptr = static_cast<const void*>(B_scale); |
| 116 | + auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); |
| 117 | + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); |
| 118 | + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); |
| 119 | + int8_t fast_accum = 1; |
| 120 | + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); |
| 121 | + |
| 122 | + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr); |
| 123 | + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr); |
| 124 | + |
| 125 | + cudaDataType_t a_type = get_cuda_data_type<AT>(); |
| 126 | + cudaDataType_t b_type = get_cuda_data_type<BT>(); |
| 127 | + cudaDataType_t d_type = get_cuda_data_type<DT>(); |
| 128 | + if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) { |
| 129 | + throw std::runtime_error("Unsupported combination: both A and B are e5m2"); |
| 130 | + } |
| 131 | + |
| 132 | + auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true); |
| 133 | + auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k); |
| 134 | + auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m); |
| 135 | + |
| 136 | + if (batch_size > 1) { |
| 137 | + int64_t stride_a = m * k; |
| 138 | + int64_t stride_b = k * n; |
| 139 | + int64_t stride_d = m * n; |
| 140 | + a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); |
| 141 | + a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a); |
| 142 | + b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); |
| 143 | + b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b); |
| 144 | + d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); |
| 145 | + d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d); |
| 146 | + } |
| 147 | + |
| 148 | + CuBlasLtMatmulPreference preference; |
| 149 | + size_t workspace_size = 1024 * 1024; // 1 MiB |
| 150 | + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size); |
| 151 | + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); |
| 152 | + auto workspace = allocator.allocate(workspace_size); |
| 153 | + cublasLtMatmulHeuristicResult_t heuristic_result = {}; |
| 154 | + int returned_result = 0; |
| 155 | + auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle(); |
| 156 | + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( |
| 157 | + lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(), |
| 158 | + d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result, |
| 159 | + &returned_result)); |
| 160 | + if (returned_result == 0) { |
| 161 | + TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); |
| 162 | + } |
| 163 | + |
| 164 | + const float alpha = 1.0f; |
| 165 | + const float beta = 0.0f; |
| 166 | + cublasStatus_t status = cublasLtMatmul( |
| 167 | + lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(), |
| 168 | + &beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo, |
| 169 | + workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream()); |
| 170 | + TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); |
| 171 | +} |
| 172 | + |
| 173 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>( |
| 174 | + const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, |
| 175 | + int k, const float* A_scale, const float* B_scale); |
| 176 | + |
| 177 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>( |
| 178 | + const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k, |
| 179 | + const float* A_scale, const float* B_scale); |
| 180 | + |
| 181 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>( |
| 182 | + const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n, |
| 183 | + int k, const float* A_scale, const float* B_scale); |
| 184 | + |
| 185 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>( |
| 186 | + const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k, |
| 187 | + const float* A_scale, const float* B_scale); |
| 188 | + |
| 189 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>( |
| 190 | + const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, |
| 191 | + int k, const float* A_scale, const float* B_scale); |
| 192 | + |
| 193 | +template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>( |
| 194 | + const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k, |
| 195 | + const float* A_scale, const float* B_scale); |
| 196 | + |
| 197 | +} // namespace bmm_fp8 |
| 198 | +} // namespace flashinfer |
| 199 | + |
| 200 | +#endif // FLASHINFER_BMM_FP8_CUH_ |
0 commit comments