Skip to content

Commit 0cc1a51

Browse files
xslingcnyzh119
andauthored
feat: specify gemm backend (#648)
Add optional `backend` api at gemm initialization. Usage: ```python # this will load cutlass_segment_gemm_sm90 kernel backend="sm90" segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend) ``` Supported values: `sm90`, `sm80`, `auto`; Default: `auto`. --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent 553ace5 commit 0cc1a51

File tree

4 files changed

+94
-72
lines changed

4 files changed

+94
-72
lines changed

include/flashinfer/gemm/group_gemm_sm90.cuh

-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
1717
#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
1818

19-
#include <sstream>
20-
2119
#include "../allocator.h"
2220
#include "../cutlass_utils.cuh"
2321
#include "../utils.cuh"

python/flashinfer/gemm.py

+75-65
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
2525
from .utils import (
2626
_get_cache_buf,
27-
get_compute_capability,
27+
determine_gemm_backend,
2828
get_cuda_stream,
2929
get_indptr,
3030
register_custom_op,
@@ -480,7 +480,9 @@ class SegmentGEMMWrapper:
480480
True
481481
"""
482482

483-
def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
483+
def __init__(
484+
self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
485+
) -> None:
484486
r"""Initialize the wrapper.
485487
486488
Parameters
@@ -493,6 +495,7 @@ def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
493495
(1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device
494496
)
495497
self._float_workspace_buffer = float_workspace_buffer
498+
self.backend = backend
496499

497500
def reset_workspace_buffer(
498501
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
@@ -584,75 +587,82 @@ def run(
584587
if weight_indices is None:
585588
# create an empty CPU tensor as placeholder
586589
weight_indices = torch.empty(0, dtype=torch.int64)
587-
major, _ = get_compute_capability(x.device)
588590
cumulative_batch_size = x.size(0)
589591
d_out = weights.size(1) if weight_column_major else weights.size(2)
590592
y = torch.zeros((cumulative_batch_size, d_out), dtype=x.dtype, device=x.device)
591593
empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device)
592594

593-
if major >= 9:
594-
(
595-
all_problems,
596-
x_data,
597-
w_data,
598-
y_data,
599-
x_stride_data,
600-
w_stride_data,
601-
y_stride_data,
602-
) = launch_compute_sm90_group_gemm_args(
603-
x,
604-
weights,
605-
y,
606-
weight_column_major,
607-
batch_size,
608-
seg_indptr,
609-
weight_indices,
610-
)
611-
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
612-
self._float_workspace_buffer,
613-
self._int_workspace_buffer,
614-
all_problems,
615-
x_data,
616-
w_data,
617-
y_data,
618-
x_stride_data,
619-
w_stride_data,
620-
y_stride_data,
621-
y, # for torch compile mutates_args
622-
empty_x_data, # for kernel type dispatch
623-
weight_column_major,
624-
)
595+
if self.backend == "auto":
596+
backend = determine_gemm_backend(x.device)
625597
else:
626-
(
627-
all_problems,
628-
x_data,
629-
w_data,
630-
y_data,
631-
x_ld_data,
632-
w_ld_data,
633-
y_ld_data,
634-
) = launch_compute_sm80_group_gemm_args(
635-
x,
636-
weights,
637-
y,
638-
weight_column_major,
639-
batch_size,
640-
seg_indptr,
641-
weight_indices,
642-
)
643-
get_gemm_module().cutlass_segment_gemm(
644-
self._int_workspace_buffer,
645-
all_problems,
646-
x_data,
647-
w_data,
648-
y_data,
649-
x_ld_data,
650-
w_ld_data,
651-
y_ld_data,
652-
y,
653-
empty_x_data,
654-
weight_column_major,
655-
)
598+
backend = self.backend
599+
600+
match backend:
601+
case "sm90":
602+
(
603+
all_problems,
604+
x_data,
605+
w_data,
606+
y_data,
607+
x_stride_data,
608+
w_stride_data,
609+
y_stride_data,
610+
) = launch_compute_sm90_group_gemm_args(
611+
x,
612+
weights,
613+
y,
614+
weight_column_major,
615+
batch_size,
616+
seg_indptr,
617+
weight_indices,
618+
)
619+
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
620+
self._float_workspace_buffer,
621+
self._int_workspace_buffer,
622+
all_problems,
623+
x_data,
624+
w_data,
625+
y_data,
626+
x_stride_data,
627+
w_stride_data,
628+
y_stride_data,
629+
y, # for torch compile mutates_args
630+
empty_x_data, # for kernel type dispatch
631+
weight_column_major,
632+
)
633+
case "sm80":
634+
(
635+
all_problems,
636+
x_data,
637+
w_data,
638+
y_data,
639+
x_ld_data,
640+
w_ld_data,
641+
y_ld_data,
642+
) = launch_compute_sm80_group_gemm_args(
643+
x,
644+
weights,
645+
y,
646+
weight_column_major,
647+
batch_size,
648+
seg_indptr,
649+
weight_indices,
650+
)
651+
get_gemm_module().cutlass_segment_gemm(
652+
self._int_workspace_buffer,
653+
all_problems,
654+
x_data,
655+
w_data,
656+
y_data,
657+
x_ld_data,
658+
w_ld_data,
659+
y_ld_data,
660+
y,
661+
empty_x_data,
662+
weight_column_major,
663+
)
664+
case _:
665+
raise ValueError(f"Unsupported gemm backend: {backend}")
656666
return y
657667

658668
forward = run

python/flashinfer/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,11 @@ def register_fake_op(
252252

253253
def get_cuda_stream(device: torch.device) -> int:
254254
return torch.cuda.current_stream(device).cuda_stream
255+
256+
257+
def determine_gemm_backend(device: torch.device) -> str:
258+
major, _ = get_compute_capability(device)
259+
if major >= 9:
260+
return "sm90"
261+
else:
262+
return "sm80"

tests/test_group_gemm.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
import flashinfer
21+
from flashinfer.utils import determine_gemm_backend
2122

2223
DTYPES = [torch.float16]
2324
CUDA_DEVICES = ["cuda:0"]
@@ -31,6 +32,7 @@
3132
@pytest.mark.parametrize("column_major", [False, True])
3233
@pytest.mark.parametrize("dtype", DTYPES)
3334
@pytest.mark.parametrize("device", CUDA_DEVICES)
35+
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
3436
def test_segment_gemm(
3537
batch_size,
3638
num_rows_per_batch,
@@ -40,12 +42,16 @@ def test_segment_gemm(
4042
column_major,
4143
dtype,
4244
device,
45+
backend,
4346
):
4447
if batch_size * num_rows_per_batch > 8192:
4548
pytest.skip("batch_size * num_rows_per_batch too large for test.")
49+
latest_supported_backend = determine_gemm_backend(torch.device(device))
50+
if backend == "sm90" and latest_supported_backend == "sm80":
51+
pytest.skip("sm90 backend not supported on this device.")
4652
torch.manual_seed(42)
4753
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device)
48-
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer)
54+
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend=backend)
4955
x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(device)
5056
if use_weight_indices:
5157
num_weights = 1024
@@ -99,7 +105,7 @@ def test_segment_gemm(
99105

100106

101107
if __name__ == "__main__":
102-
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0")
103-
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0")
104-
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0")
105-
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0")
108+
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto")
109+
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto")
110+
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto")
111+
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto")

0 commit comments

Comments
 (0)