16
16
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
17
17
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
18
18
19
+ #include < cuda_bf16.h>
20
+ #include < cuda_fp16.h>
21
+ #include < cuda_fp8.h>
22
+
19
23
#include " cutlass/cutlass.h"
20
24
#include " cutlass/gemm/device/gemm_grouped.h"
21
25
#include " cutlass/gemm/kernel/default_gemm_grouped.h"
22
26
#include " cutlass/layout/matrix.h"
23
27
#include " cutlass/numeric_types.h"
28
+ #include " cutlass/util/packed_stride.hpp"
24
29
25
30
namespace flashinfer {
26
31
@@ -41,21 +46,49 @@ struct cutlass_dtype<nv_bfloat16> {
41
46
using type = cutlass::bfloat16_t ;
42
47
};
43
48
44
- template <typename T>
45
- __global__ void compute_cutlass_group_gemm_args (cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
46
- T** ptr_w, T** ptr_y, int64_t * ld_x, int64_t * ld_w,
47
- int64_t * ld_y, T* x, T* w, T* y, int64_t * xy_indptr,
48
- int64_t * w_indices, size_t d_in, size_t d_out,
49
- bool w_column_major) {
49
+ template <>
50
+ struct cutlass_dtype <__nv_fp8_e4m3> {
51
+ using type = cutlass::float_e4m3_t ;
52
+ };
53
+
54
+ template <>
55
+ struct cutlass_dtype <__nv_fp8_e5m2> {
56
+ using type = cutlass::float_e5m2_t ;
57
+ };
58
+
59
+ template <typename DTypeIn, typename DTypeOut>
60
+ __global__ void compute_sm80_cutlass_group_gemm_args (
61
+ cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
62
+ int64_t * x_ld, int64_t * w_ld, int64_t * y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
63
+ int64_t * xy_indptr, int64_t * w_indices, size_t d_in, size_t d_out, bool w_column_major) {
50
64
int i = blockIdx .x ;
51
65
int m = xy_indptr[i + 1 ] - xy_indptr[i], k = d_in, n = d_out;
52
66
all_problems[i] = cutlass::gemm::GemmCoord (m, n, k);
53
- ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
54
- ptr_x[i] = x + xy_indptr[i] * d_in;
55
- ptr_y[i] = y + xy_indptr[i] * d_out;
56
- ld_x[i] = k; // m * k
57
- ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
58
- ld_y[i] = n; // m * n
67
+ w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
68
+ x_ptr[i] = x + xy_indptr[i] * k;
69
+ y_ptr[i] = y + xy_indptr[i] * n;
70
+ x_ld[i] = k; // m * k
71
+ w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
72
+ y_ld[i] = n; // m * n
73
+ }
74
+
75
+ template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
76
+ typename StrideB, typename StrideCD>
77
+ __global__ void compute_sm90_cutlass_group_gemm_args (
78
+ ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
79
+ StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
80
+ int64_t * xy_indptr, int64_t * w_indices, size_t d_in, size_t d_out, bool w_column_major) {
81
+ int i = blockIdx .x ;
82
+ int m = xy_indptr[i + 1 ] - xy_indptr[i], k = d_in, n = d_out;
83
+ all_problems[i] = ProblemShape (m, n, k);
84
+ w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
85
+ x_ptr[i] = x + xy_indptr[i] * k;
86
+ y_ptr[i] = y + xy_indptr[i] * n;
87
+
88
+ x_stride[i] = cutlass::make_cute_packed_stride (StrideA{}, {m, k, 1 });
89
+ w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride (StrideB{}, {k, n, 1 })
90
+ : cutlass::make_cute_packed_stride (StrideB{}, {n, k, 1 });
91
+ y_stride[i] = cutlass::make_cute_packed_stride (StrideCD{}, {m, n, 1 });
59
92
}
60
93
61
94
} // namespace group_gemm
0 commit comments