Skip to content

Commit f1c0b68

Browse files
zhyncsyzh119
andauthored
feat: support bmm fp8 (#469)
`torch.bmm` doesn't support fp8 and `torch._scaled_mm` doesn't support 3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobock Thanks @yzh119 for assisting with debug. AType: fp8 e4m3, fp8 e5m2 BType: fp8 e4m3, fp8 e5m2 DType: bf16, fp16 Does not support both AType and BType fp8 e5m2. ref https://docs.nvidia.com/cuda/cublas/#cublasltmatmul ```python3 pytest python/tests/test_bmm_fp8.py ``` works on H100 ``` =================================================================================== test session starts =================================================================================== platform linux -- Python 3.12.4, pytest-8.3.2, pluggy-1.5.0 rootdir: /flashinfer collected 8 items python/tests/test_bmm_fp8.py ...s...s [100%] ============================================================================== 6 passed, 2 skipped in 2.16s =============================================================================== ``` --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent 2ba3f1c commit f1c0b68

File tree

10 files changed

+371
-9
lines changed

10 files changed

+371
-9
lines changed
File renamed without changes.

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ FlashInfer is a library for Large Language Models that provides high-performance
3333
api/python/sparse
3434
api/python/page
3535
api/python/sampling
36-
api/python/group_gemm
36+
api/python/gemm
3737
api/python/norm
3838
api/python/rope
3939
api/python/quantization

include/flashinfer/bmm_fp8.cuh

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_BMM_FP8_CUH_
17+
#define FLASHINFER_BMM_FP8_CUH_
18+
19+
#include <ATen/cuda/Exceptions.h>
20+
#include <c10/cuda/CUDACachingAllocator.h>
21+
#include <cublasLt.h>
22+
#include <cuda_fp8.h>
23+
#include <torch/extension.h>
24+
25+
#include <stdexcept>
26+
#include <type_traits>
27+
28+
namespace flashinfer {
29+
30+
namespace bmm_fp8 {
31+
32+
template <typename T, cublasStatus_t (*destructor)(T*)>
33+
struct CuBlasLtDeleter {
34+
void operator()(T* x) {
35+
if (x != nullptr) {
36+
TORCH_CUDABLAS_CHECK(destructor(x));
37+
}
38+
}
39+
};
40+
41+
template <typename T, cublasStatus_t (*destructor)(T*)>
42+
class CuBlasLtDescriptor {
43+
public:
44+
T* descriptor() const { return descriptor_.get(); }
45+
T* descriptor() { return descriptor_.get(); }
46+
47+
protected:
48+
std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
49+
};
50+
51+
class CuBlasLtMatmulDescriptor
52+
: public CuBlasLtDescriptor<cublasLtMatmulDescOpaque_t, &cublasLtMatmulDescDestroy> {
53+
public:
54+
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) {
55+
cublasLtMatmulDesc_t raw_descriptor = nullptr;
56+
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
57+
descriptor_.reset(raw_descriptor);
58+
}
59+
template <typename T>
60+
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
61+
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
62+
}
63+
};
64+
65+
class CuBlasLtMatrixLayout
66+
: public CuBlasLtDescriptor<cublasLtMatrixLayoutOpaque_t, &cublasLtMatrixLayoutDestroy> {
67+
public:
68+
CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld,
69+
bool t = false) {
70+
cublasLtMatrixLayout_t raw_descriptor = nullptr;
71+
TORCH_CUDABLAS_CHECK(
72+
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
73+
descriptor_.reset(raw_descriptor);
74+
}
75+
template <typename T>
76+
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
77+
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
78+
}
79+
};
80+
81+
class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPreferenceOpaque_t,
82+
&cublasLtMatmulPreferenceDestroy> {
83+
public:
84+
CuBlasLtMatmulPreference() {
85+
cublasLtMatmulPreference_t raw_descriptor = nullptr;
86+
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
87+
descriptor_.reset(raw_descriptor);
88+
}
89+
template <typename T>
90+
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
91+
TORCH_CUDABLAS_CHECK(
92+
::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
93+
}
94+
};
95+
96+
template <typename T>
97+
cudaDataType_t get_cuda_data_type() {
98+
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
99+
return CUDA_R_8F_E4M3;
100+
} else if constexpr (std::is_same_v<T, __nv_fp8_e5m2>) {
101+
return CUDA_R_8F_E5M2;
102+
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
103+
return CUDA_R_16BF;
104+
} else if constexpr (std::is_same_v<T, half>) {
105+
return CUDA_R_16F;
106+
} else {
107+
throw std::runtime_error("Unsupported type");
108+
}
109+
}
110+
111+
template <typename AT, typename BT, typename DT>
112+
void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k,
113+
const float* A_scale, const float* B_scale) {
114+
const void* A_scale_ptr = static_cast<const void*>(A_scale);
115+
const void* B_scale_ptr = static_cast<const void*>(B_scale);
116+
auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F);
117+
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T);
118+
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N);
119+
int8_t fast_accum = 1;
120+
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum);
121+
122+
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr);
123+
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr);
124+
125+
cudaDataType_t a_type = get_cuda_data_type<AT>();
126+
cudaDataType_t b_type = get_cuda_data_type<BT>();
127+
cudaDataType_t d_type = get_cuda_data_type<DT>();
128+
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
129+
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
130+
}
131+
132+
auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
133+
auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k);
134+
auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m);
135+
136+
if (batch_size > 1) {
137+
int64_t stride_a = m * k;
138+
int64_t stride_b = k * n;
139+
int64_t stride_d = m * n;
140+
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
141+
a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a);
142+
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
143+
b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b);
144+
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size);
145+
d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d);
146+
}
147+
148+
CuBlasLtMatmulPreference preference;
149+
size_t workspace_size = 1024 * 1024; // 1 MiB
150+
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size);
151+
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
152+
auto workspace = allocator.allocate(workspace_size);
153+
cublasLtMatmulHeuristicResult_t heuristic_result = {};
154+
int returned_result = 0;
155+
auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle();
156+
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
157+
lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(),
158+
d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result,
159+
&returned_result));
160+
if (returned_result == 0) {
161+
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
162+
}
163+
164+
const float alpha = 1.0f;
165+
const float beta = 0.0f;
166+
cublasStatus_t status = cublasLtMatmul(
167+
lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(),
168+
&beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo,
169+
workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream());
170+
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status));
171+
}
172+
173+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
174+
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
175+
int k, const float* A_scale, const float* B_scale);
176+
177+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
178+
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
179+
const float* A_scale, const float* B_scale);
180+
181+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
182+
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n,
183+
int k, const float* A_scale, const float* B_scale);
184+
185+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
186+
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k,
187+
const float* A_scale, const float* B_scale);
188+
189+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
190+
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
191+
int k, const float* A_scale, const float* B_scale);
192+
193+
template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
194+
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k,
195+
const float* A_scale, const float* B_scale);
196+
197+
} // namespace bmm_fp8
198+
} // namespace flashinfer
199+
200+
#endif // FLASHINFER_BMM_FP8_CUH_

python/csrc/bmm_fp8.cu

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <ATen/cuda/CUDAContext.h>
18+
#include <torch/extension.h>
19+
20+
#include <flashinfer/bmm_fp8.cuh>
21+
22+
#include "flashinfer_ops.h"
23+
#include "pytorch_extension_utils.h"
24+
25+
using namespace flashinfer;
26+
27+
void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
28+
torch::Tensor& A_scale, torch::Tensor& B_scale) {
29+
TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
30+
TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
31+
TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor");
32+
TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A");
33+
TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B");
34+
TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D");
35+
TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match");
36+
TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes");
37+
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2),
38+
"Result tensor has incorrect shape");
39+
TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2,
40+
"A must be Float8_e4m3fn or Float8_e5m2");
41+
TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2,
42+
"B must be Float8_e4m3fn or Float8_e5m2");
43+
TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf,
44+
"D must be BFloat16 or Half");
45+
46+
TORCH_CHECK(A_scale.scalar_type() == torch::kFloat32 && B_scale.scalar_type() == torch::kFloat32,
47+
"A_scale and B_scale must be Float32");
48+
49+
auto batch_size = A.size(0);
50+
auto m = A.size(1);
51+
auto k = A.size(2);
52+
auto n = B.size(2);
53+
54+
// PyTorch is row major by default. cuBLASLt is column major by default.
55+
// We need row major D as expected.
56+
// A ^ T * B = D, so D ^ T = B ^ T * A
57+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
58+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
59+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
60+
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
61+
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
62+
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
63+
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()));
64+
return true;
65+
});
66+
});
67+
});
68+
}

python/csrc/flashinfer_ops.cu

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4848
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
4949
m.def("packbits", &packbits, "GPU packbits operator");
5050
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
51+
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
5152
py::class_<CutlassSegmentGEMMPyTorchWrapper>(m, "CutlassSegmentGEMMPyTorchWrapper")
5253
.def(py::init<torch::Tensor>())
5354
.def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer)

python/csrc/flashinfer_ops.h

+3
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
104104
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
105105
torch::Tensor output_indptr, const std::string& bitorder);
106106

107+
void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
108+
torch::Tensor& A_scale, torch::Tensor& B_scale);
109+
107110
class CutlassSegmentGEMMPyTorchWrapper {
108111
public:
109112
void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer);

python/flashinfer/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
limitations under the License.
1515
"""
1616

17+
from .activation import gelu_tanh_and_mul, silu_and_mul
1718
from .cascade import (
18-
MultiLevelCascadeAttentionWrapper,
1919
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
2020
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
21+
MultiLevelCascadeAttentionWrapper,
2122
merge_state,
2223
merge_state_in_place,
2324
merge_states,
@@ -27,8 +28,7 @@
2728
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
2829
single_decode_with_kv_cache,
2930
)
30-
from .activation import gelu_tanh_and_mul, silu_and_mul
31-
from .group_gemm import SegmentGEMMWrapper
31+
from .gemm import SegmentGEMMWrapper, bmm_fp8
3232
from .norm import fused_add_rmsnorm, rmsnorm
3333
from .page import append_paged_kv_cache
3434
from .prefill import (
@@ -46,15 +46,15 @@
4646
)
4747
from .sampling import (
4848
chain_speculative_sampling,
49+
min_p_sampling_from_probs,
4950
sampling_from_probs,
50-
top_k_renorm_prob,
5151
top_k_mask_logits,
52+
top_k_renorm_prob,
5253
top_k_sampling_from_probs,
53-
top_k_top_p_sampling_from_probs,
5454
top_k_top_p_sampling_from_logits,
55+
top_k_top_p_sampling_from_probs,
5556
top_p_renorm_prob,
5657
top_p_sampling_from_probs,
57-
min_p_sampling_from_probs,
5858
)
5959
from .sparse import BlockSparseAttentionWrapper
6060

0 commit comments

Comments
 (0)