Skip to content

Commit 52dab1d

Browse files
authored
feat: modify group-gemm stage number (#497)
The current group-gemm configuration raises the following error on NVIDIA 3090 : ```shell RuntimeError: cutlass group_gemm.initialize failed: Error Internal ``` Modify the stage of group-gemm to 4, reduce the size of dynamic smem, so that it can be called on GPUs like the 3090. Additionally, I also did a simple comparison on the A800. Modifying the stage to 4 can still slightly improve the performance of group-gemm. Refer to: https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_grouped_sm80.cu
1 parent 2de16b0 commit 52dab1d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

include/flashinfer/group_gemm/wrapper.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType*
8585
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
8686
cutlass::epilogue::thread::LinearCombination<DType, 8, float, float>, // Epilogue
8787
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator
88-
8 // Stages
88+
4 // Stages
8989
>::GemmKernel;
9090

9191
using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;

0 commit comments

Comments
 (0)