Skip to content

Commit 6d64f90

Browse files
levendleefacebook-github-bot
authored andcommitted
Autodetect Triton WS support. (pytorch#4009)
Summary: Pull Request resolved: pytorch#4009 X-link: facebookresearch/FBGEMM#1096 By default, enable warp specialization if the Triton version of the Python environment has warp specialization support. Otherwise, disable the feature to avoid API errors. Reviewed By: htyu Differential Revision: D73518121 fbshipit-source-id: c50ed6312b3705e7a083ef2ef8d58fe432a63430
1 parent 9651fb2 commit 6d64f90

File tree

2 files changed

+104
-69
lines changed

2 files changed

+104
-69
lines changed

fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py

+20-25
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,19 @@
88
# pyre-ignore-all-errors[53]
99

1010
import logging
11-
import os
1211
import unittest
1312
from typing import Tuple
1413

1514
import torch
1615

17-
try:
18-
# pyre-ignore[21]
19-
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
20-
from fbgemm_gpu import open_source
21-
except Exception:
22-
open_source: bool = False
23-
24-
25-
if not open_source and torch.cuda.is_available():
26-
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
27-
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
28-
grouped_gemm,
29-
grouped_gemm_fp8_rowwise,
30-
)
16+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
17+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
18+
_HAS_WS_SUPPORT,
19+
grouped_gemm,
20+
grouped_gemm_fp8_rowwise,
21+
)
3122

3223

33-
@unittest.skipIf(
34-
open_source,
35-
"TypeError: __init__() got an unexpected keyword argument 'num_consumer_groups'",
36-
)
3724
@unittest.skipIf(
3825
not torch.cuda.is_available(),
3926
"Skip when CUDA is not available",
@@ -42,10 +29,11 @@ class TestGroupedGEMM(unittest.TestCase):
4229
def setUp(self) -> None:
4330
torch.manual_seed(0)
4431

45-
# pyre-ignore [56]
46-
@unittest.skipIf(
47-
os.getenv("GITHUB_ENV") is not None,
48-
"""This test fails on the GitHub runners: module 'triton.language' has no attribute 'async_task'""",
32+
@unittest.skipIf( # pyre-ignore [56]
33+
(not torch.cuda.is_available())
34+
or (torch.version.hip is None)
35+
and (torch.cuda.get_device_properties(0).major < 9),
36+
"Skip FP8 test on architectures before SM90.",
4937
)
5038
def test_grouped_gemm_fp8_rowwise(self) -> None:
5139
def _test_grouped_gemm_fp8_rowwise(
@@ -154,6 +142,8 @@ def msg(s: str) -> str:
154142
for fast_accu in (True, False):
155143
for ws in (True, False):
156144
for fuse_scatter_add in (True, False):
145+
if ws and not _HAS_WS_SUPPORT:
146+
continue
157147
if not ws and fuse_scatter_add:
158148
continue
159149
logging.info(
@@ -168,9 +158,12 @@ def msg(s: str) -> str:
168158
fuse_scatter_add=fuse_scatter_add,
169159
)
170160

161+
# TODO(shikaili): Re-enable the test for SM80 after fixing TMA issues.
171162
@unittest.skipIf( # pyre-ignore [56]
172-
os.getenv("GITHUB_ENV") is not None,
173-
"""This test fails on the GitHub runners: "type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')""",
163+
(not torch.cuda.is_available())
164+
or (torch.version.hip is None)
165+
and (torch.cuda.get_device_properties(0).major < 9),
166+
"Skip BF16 test on architectures before SM90.",
174167
)
175168
def test_grouped_gemm_bf16(self) -> None:
176169
def _test_grouped_gemm_bf16(
@@ -251,6 +244,8 @@ def msg(s: str) -> str:
251244
for M in (0, 64, 512, 1000000):
252245
for ws in (True, False):
253246
for fuse_scatter_add in (True, False):
247+
if ws and not _HAS_WS_SUPPORT:
248+
continue
254249
if not ws and fuse_scatter_add:
255250
continue
256251
logging.info(

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

+84-44
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import functools
10+
import inspect
1011
import logging
1112

1213
from typing import Optional
@@ -41,34 +42,63 @@
4142
for num_ctas in [1]
4243
]
4344

44-
_NV_WS_CONFIGS = [
45-
triton.Config(
46-
{
47-
"BLOCK_SIZE_M": block_size_m,
48-
"BLOCK_SIZE_N": block_size_n,
49-
"BLOCK_SIZE_K": block_size_k,
50-
"NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
51-
"USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
52-
"USE_TMA_STORE": use_tma_store,
53-
},
54-
num_stages=num_stages,
55-
num_warps=num_warps,
56-
num_ctas=num_ctas,
57-
num_consumer_groups=num_consumer_groups,
58-
num_buffers_warp_spec=num_stages,
59-
)
60-
for block_size_m in [64, 128, 256]
61-
for block_size_n in [64, 128, 256]
62-
for block_size_k in [64, 128, 256]
63-
for num_stages in [2, 3, 4]
64-
for num_warps in [4, 8, 16]
65-
# TODO(shikaili): Resolve LLVM error.
66-
for num_ctas in [1]
67-
for num_consumer_groups in [0, 2]
68-
for use_tma_load_on_scales in [True, False]
69-
# TODO(shikaili): Resolve compatibility with ws.
70-
for use_tma_store in [False]
71-
]
45+
_HAS_WS_SUPPORT = None
46+
47+
48+
def _check_ws_support():
49+
if not hasattr(tl, "async_task"):
50+
return False
51+
config_signature = inspect.signature(triton.Config).parameters
52+
if (
53+
"num_consumer_groups" not in config_signature
54+
or "num_buffers_warp_spec" not in config_signature
55+
):
56+
return False
57+
if not utils.HAS_TMA_DESC:
58+
return False
59+
return True
60+
61+
62+
def _set_ws_support():
63+
global _HAS_WS_SUPPORT
64+
if _HAS_WS_SUPPORT is None:
65+
_HAS_WS_SUPPORT = _check_ws_support()
66+
67+
68+
_set_ws_support()
69+
70+
if _HAS_WS_SUPPORT:
71+
_NV_WS_CONFIGS = [
72+
triton.Config(
73+
{
74+
"BLOCK_SIZE_M": block_size_m,
75+
"BLOCK_SIZE_N": block_size_n,
76+
"BLOCK_SIZE_K": block_size_k,
77+
"NUM_CONSUMER_GROUPS": max(1, num_consumer_groups),
78+
"USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales,
79+
"USE_TMA_STORE": use_tma_store,
80+
},
81+
num_stages=num_stages,
82+
num_warps=num_warps,
83+
num_ctas=num_ctas,
84+
num_consumer_groups=num_consumer_groups,
85+
num_buffers_warp_spec=num_stages,
86+
)
87+
for block_size_m in [64, 128, 256]
88+
for block_size_n in [64, 128, 256]
89+
for block_size_k in [64, 128, 256]
90+
for num_stages in [2, 3, 4]
91+
for num_warps in [4, 8, 16]
92+
# TODO(shikaili): Resolve LLVM error.
93+
for num_ctas in [1]
94+
for num_consumer_groups in [0, 2]
95+
for use_tma_load_on_scales in [True, False]
96+
# TODO(shikaili): Resolve compatibility with ws.
97+
for use_tma_store in [False]
98+
]
99+
else:
100+
_NV_WS_CONFIGS = _NV_CONFIGS
101+
72102

73103
_AMD_CONFIGS = [
74104
triton.Config(
@@ -880,15 +910,16 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
880910

881911

882912
def _grouped_gemm(
913+
*,
883914
x: torch.Tensor,
884915
w: torch.Tensor,
885916
m_sizes: torch.Tensor,
886-
x_scale: Optional[torch.Tensor] = None,
887-
w_scale: Optional[torch.Tensor] = None,
888-
use_fast_accum: bool = False,
889-
use_warp_specialization: bool = False,
890-
output_tensor: Optional[torch.Tensor] = None,
891-
scatter_add_indices: Optional[torch.Tensor] = None,
917+
x_scale: Optional[torch.Tensor],
918+
w_scale: Optional[torch.Tensor],
919+
use_fast_accum: bool,
920+
use_warp_specialization: bool,
921+
output_tensor: Optional[torch.Tensor],
922+
scatter_add_indices: Optional[torch.Tensor],
892923
) -> torch.Tensor:
893924

894925
USE_TMA_LOAD = not torch.version.hip
@@ -902,12 +933,19 @@ def _grouped_gemm(
902933
USE_TMA_STORE = False
903934
logging.warning("TMA store is disabled as there is no TMA descriptor support!")
904935

936+
# TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
905937
if use_warp_specialization and torch.version.hip:
906938
logging.warning(
907939
"Warp specialization is disabled as it is not supported on ROCm."
908940
)
909941
use_warp_specialization = False
910942

943+
if use_warp_specialization and not _HAS_WS_SUPPORT:
944+
logging.warning(
945+
"Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
946+
)
947+
use_warp_specialization = False
948+
911949
if use_warp_specialization:
912950
assert utils.HAS_TMA_DESC
913951
USE_TMA_STORE = True # Tuning decision
@@ -1063,14 +1101,16 @@ def grouped_gemm(
10631101
m_sizes: torch.Tensor,
10641102
use_fast_accum: bool = True,
10651103
*,
1066-
_use_warp_specialization: bool = False,
1104+
_use_warp_specialization: bool = True,
10671105
_output_tensor: Optional[torch.Tensor] = None,
10681106
_scatter_add_indices: Optional[torch.Tensor] = None,
10691107
) -> torch.Tensor:
10701108
return _grouped_gemm(
1071-
x,
1072-
w,
1073-
m_sizes,
1109+
x=x,
1110+
w=w,
1111+
m_sizes=m_sizes,
1112+
x_scale=None,
1113+
w_scale=None,
10741114
use_fast_accum=use_fast_accum,
10751115
use_warp_specialization=_use_warp_specialization,
10761116
output_tensor=_output_tensor,
@@ -1086,16 +1126,16 @@ def grouped_gemm_fp8_rowwise(
10861126
w_scale: torch.Tensor,
10871127
use_fast_accum: bool = True,
10881128
*,
1089-
_use_warp_specialization: bool = False,
1129+
_use_warp_specialization: bool = True,
10901130
_output_tensor: Optional[torch.Tensor] = None,
10911131
_scatter_add_indices: Optional[torch.Tensor] = None,
10921132
) -> torch.Tensor:
10931133
return _grouped_gemm(
1094-
x,
1095-
w,
1096-
m_sizes,
1097-
x_scale,
1098-
w_scale,
1134+
x=x,
1135+
w=w,
1136+
m_sizes=m_sizes,
1137+
x_scale=x_scale,
1138+
w_scale=w_scale,
10991139
use_fast_accum=use_fast_accum,
11001140
use_warp_specialization=_use_warp_specialization,
11011141
output_tensor=_output_tensor,

0 commit comments

Comments
 (0)