Skip to content

Commit 794bdda

Browse files
xslingcnyzh119
andauthored
feat: support sm90 cutlass group gemm (#509)
Co-authored-by: Zihao Ye <[email protected]>
1 parent 20265d6 commit 794bdda

14 files changed

+528
-56
lines changed

flashinfer-aot/csrc_aot/flashinfer_ops.cu

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#pragma once
1716
#include <torch/extension.h>
1817

1918
void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <torch/extension.h>
17+
18+
19+
torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr,
20+
torch::Tensor weight_indices, torch::Tensor x,
21+
torch::Tensor weight, unsigned int batch_size,
22+
bool weight_column_major);
23+
24+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25+
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90");
26+
}

flashinfer-aot/setup.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs) -> None:
355355
include_dirs = [
356356
str(root.resolve() / "include"),
357357
str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm
358+
str(root.resolve() / "3rdparty" / "cutlass" / "tools" / "util" / "include"),
358359
]
359360
extra_compile_args = {
360361
"cxx": [
@@ -371,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None:
371372
"-use_fast_math",
372373
],
373374
}
375+
extra_compile_args_sm90 = extra_compile_args.copy()
376+
extra_compile_args_sm90["nvcc"].extend(
377+
"-gencode arch=compute_90a,code=sm_90a".split()
378+
)
374379
ext_modules = []
375380
ext_modules.append(
376381
torch_cpp_ext.CUDAExtension(
@@ -385,12 +390,23 @@ def __init__(self, *args, **kwargs) -> None:
385390
"csrc/quantization.cu",
386391
"csrc/group_gemm.cu",
387392
"csrc/bmm_fp8.cu",
388-
"csrc_aot/flashinfer_ops.cu",
393+
"csrc_aot/flashinfer_ops.cu"
389394
],
390395
include_dirs=include_dirs,
391396
extra_compile_args=extra_compile_args,
392397
)
393398
)
399+
ext_modules.append(
400+
torch_cpp_ext.CUDAExtension(
401+
name="flashinfer._kernels_sm90",
402+
sources=[
403+
"csrc/group_gemm_sm90.cu",
404+
"csrc_aot/flashinfer_sm90_ops.cu",
405+
],
406+
include_dirs=include_dirs,
407+
extra_compile_args=extra_compile_args_sm90,
408+
)
409+
)
394410
ext_modules.append(
395411
torch_cpp_ext.CUDAExtension(
396412
name="flashinfer._decode_kernels",

include/flashinfer/gemm/group_gemm.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
5353

5454
// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
5555
// so I just use the kernel function directly, need to investigate more.
56-
auto compute_args_kernel = compute_cutlass_group_gemm_args<DType>;
56+
auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args<DType, DType>;
5757
compute_args_kernel<<<batch_size, 1, 0, stream>>>(
5858
problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w,
5959
(DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major);
@@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe
116116

117117
} // namespace flashinfer
118118

119-
#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
119+
#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_

include/flashinfer/gemm/group_gemm_cutlass.cuh

+45-12
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
1717
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
1818

19+
#include <cuda_bf16.h>
20+
#include <cuda_fp16.h>
21+
#include <cuda_fp8.h>
22+
1923
#include "cutlass/cutlass.h"
2024
#include "cutlass/gemm/device/gemm_grouped.h"
2125
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
2226
#include "cutlass/layout/matrix.h"
2327
#include "cutlass/numeric_types.h"
28+
#include "cutlass/util/packed_stride.hpp"
2429

2530
namespace flashinfer {
2631

@@ -41,21 +46,49 @@ struct cutlass_dtype<nv_bfloat16> {
4146
using type = cutlass::bfloat16_t;
4247
};
4348

44-
template <typename T>
45-
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
46-
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
47-
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
48-
int64_t* w_indices, size_t d_in, size_t d_out,
49-
bool w_column_major) {
49+
template <>
50+
struct cutlass_dtype<__nv_fp8_e4m3> {
51+
using type = cutlass::float_e4m3_t;
52+
};
53+
54+
template <>
55+
struct cutlass_dtype<__nv_fp8_e5m2> {
56+
using type = cutlass::float_e5m2_t;
57+
};
58+
59+
template <typename DTypeIn, typename DTypeOut>
60+
__global__ void compute_sm80_cutlass_group_gemm_args(
61+
cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
62+
int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
63+
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
5064
int i = blockIdx.x;
5165
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
5266
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
53-
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
54-
ptr_x[i] = x + xy_indptr[i] * d_in;
55-
ptr_y[i] = y + xy_indptr[i] * d_out;
56-
ld_x[i] = k; // m * k
57-
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
58-
ld_y[i] = n; // m * n
67+
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
68+
x_ptr[i] = x + xy_indptr[i] * k;
69+
y_ptr[i] = y + xy_indptr[i] * n;
70+
x_ld[i] = k; // m * k
71+
w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
72+
y_ld[i] = n; // m * n
73+
}
74+
75+
template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
76+
typename StrideB, typename StrideCD>
77+
__global__ void compute_sm90_cutlass_group_gemm_args(
78+
ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
79+
StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
80+
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
81+
int i = blockIdx.x;
82+
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
83+
all_problems[i] = ProblemShape(m, n, k);
84+
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
85+
x_ptr[i] = x + xy_indptr[i] * k;
86+
y_ptr[i] = y + xy_indptr[i] * n;
87+
88+
x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
89+
w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1})
90+
: cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
91+
y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1});
5992
}
6093

6194
} // namespace group_gemm

0 commit comments

Comments
 (0)