7
7
# pyre-unsafe
8
8
9
9
import functools
10
+ import inspect
10
11
import logging
11
12
12
13
from typing import Optional
41
42
for num_ctas in [1 ]
42
43
]
43
44
44
- _NV_WS_CONFIGS = [
45
- triton .Config (
46
- {
47
- "BLOCK_SIZE_M" : block_size_m ,
48
- "BLOCK_SIZE_N" : block_size_n ,
49
- "BLOCK_SIZE_K" : block_size_k ,
50
- "NUM_CONSUMER_GROUPS" : max (1 , num_consumer_groups ),
51
- "USE_TMA_LOAD_ON_SCALES" : use_tma_load_on_scales ,
52
- "USE_TMA_STORE" : use_tma_store ,
53
- },
54
- num_stages = num_stages ,
55
- num_warps = num_warps ,
56
- num_ctas = num_ctas ,
57
- num_consumer_groups = num_consumer_groups ,
58
- num_buffers_warp_spec = num_stages ,
59
- )
60
- for block_size_m in [64 , 128 , 256 ]
61
- for block_size_n in [64 , 128 , 256 ]
62
- for block_size_k in [64 , 128 , 256 ]
63
- for num_stages in [2 , 3 , 4 ]
64
- for num_warps in [4 , 8 , 16 ]
65
- # TODO(shikaili): Resolve LLVM error.
66
- for num_ctas in [1 ]
67
- for num_consumer_groups in [0 , 2 ]
68
- for use_tma_load_on_scales in [True , False ]
69
- # TODO(shikaili): Resolve compatibility with ws.
70
- for use_tma_store in [False ]
71
- ]
45
+ _HAS_WS_SUPPORT = None
46
+
47
+
48
+ def _check_ws_support ():
49
+ if not hasattr (tl , "async_task" ):
50
+ return False
51
+ config_signature = inspect .signature (triton .Config ).parameters
52
+ if (
53
+ "num_consumer_groups" not in config_signature
54
+ or "num_buffers_warp_spec" not in config_signature
55
+ ):
56
+ return False
57
+ if not utils .HAS_TMA_DESC :
58
+ return False
59
+ return True
60
+
61
+
62
+ def _set_ws_support ():
63
+ global _HAS_WS_SUPPORT
64
+ if _HAS_WS_SUPPORT is None :
65
+ _HAS_WS_SUPPORT = _check_ws_support ()
66
+
67
+
68
+ _set_ws_support ()
69
+
70
+ if _HAS_WS_SUPPORT :
71
+ _NV_WS_CONFIGS = [
72
+ triton .Config (
73
+ {
74
+ "BLOCK_SIZE_M" : block_size_m ,
75
+ "BLOCK_SIZE_N" : block_size_n ,
76
+ "BLOCK_SIZE_K" : block_size_k ,
77
+ "NUM_CONSUMER_GROUPS" : max (1 , num_consumer_groups ),
78
+ "USE_TMA_LOAD_ON_SCALES" : use_tma_load_on_scales ,
79
+ "USE_TMA_STORE" : use_tma_store ,
80
+ },
81
+ num_stages = num_stages ,
82
+ num_warps = num_warps ,
83
+ num_ctas = num_ctas ,
84
+ num_consumer_groups = num_consumer_groups ,
85
+ num_buffers_warp_spec = num_stages ,
86
+ )
87
+ for block_size_m in [64 , 128 , 256 ]
88
+ for block_size_n in [64 , 128 , 256 ]
89
+ for block_size_k in [64 , 128 , 256 ]
90
+ for num_stages in [2 , 3 , 4 ]
91
+ for num_warps in [4 , 8 , 16 ]
92
+ # TODO(shikaili): Resolve LLVM error.
93
+ for num_ctas in [1 ]
94
+ for num_consumer_groups in [0 , 2 ]
95
+ for use_tma_load_on_scales in [True , False ]
96
+ # TODO(shikaili): Resolve compatibility with ws.
97
+ for use_tma_store in [False ]
98
+ ]
99
+ else :
100
+ _NV_WS_CONFIGS = _NV_CONFIGS
101
+
72
102
73
103
_AMD_CONFIGS = [
74
104
triton .Config (
@@ -880,15 +910,16 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
880
910
881
911
882
912
def _grouped_gemm (
913
+ * ,
883
914
x : torch .Tensor ,
884
915
w : torch .Tensor ,
885
916
m_sizes : torch .Tensor ,
886
- x_scale : Optional [torch .Tensor ] = None ,
887
- w_scale : Optional [torch .Tensor ] = None ,
888
- use_fast_accum : bool = False ,
889
- use_warp_specialization : bool = False ,
890
- output_tensor : Optional [torch .Tensor ] = None ,
891
- scatter_add_indices : Optional [torch .Tensor ] = None ,
917
+ x_scale : Optional [torch .Tensor ],
918
+ w_scale : Optional [torch .Tensor ],
919
+ use_fast_accum : bool ,
920
+ use_warp_specialization : bool ,
921
+ output_tensor : Optional [torch .Tensor ],
922
+ scatter_add_indices : Optional [torch .Tensor ],
892
923
) -> torch .Tensor :
893
924
894
925
USE_TMA_LOAD = not torch .version .hip
@@ -902,12 +933,19 @@ def _grouped_gemm(
902
933
USE_TMA_STORE = False
903
934
logging .warning ("TMA store is disabled as there is no TMA descriptor support!" )
904
935
936
+ # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
905
937
if use_warp_specialization and torch .version .hip :
906
938
logging .warning (
907
939
"Warp specialization is disabled as it is not supported on ROCm."
908
940
)
909
941
use_warp_specialization = False
910
942
943
+ if use_warp_specialization and not _HAS_WS_SUPPORT :
944
+ logging .warning (
945
+ "Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs."
946
+ )
947
+ use_warp_specialization = False
948
+
911
949
if use_warp_specialization :
912
950
assert utils .HAS_TMA_DESC
913
951
USE_TMA_STORE = True # Tuning decision
@@ -1063,14 +1101,16 @@ def grouped_gemm(
1063
1101
m_sizes : torch .Tensor ,
1064
1102
use_fast_accum : bool = True ,
1065
1103
* ,
1066
- _use_warp_specialization : bool = False ,
1104
+ _use_warp_specialization : bool = True ,
1067
1105
_output_tensor : Optional [torch .Tensor ] = None ,
1068
1106
_scatter_add_indices : Optional [torch .Tensor ] = None ,
1069
1107
) -> torch .Tensor :
1070
1108
return _grouped_gemm (
1071
- x ,
1072
- w ,
1073
- m_sizes ,
1109
+ x = x ,
1110
+ w = w ,
1111
+ m_sizes = m_sizes ,
1112
+ x_scale = None ,
1113
+ w_scale = None ,
1074
1114
use_fast_accum = use_fast_accum ,
1075
1115
use_warp_specialization = _use_warp_specialization ,
1076
1116
output_tensor = _output_tensor ,
@@ -1086,16 +1126,16 @@ def grouped_gemm_fp8_rowwise(
1086
1126
w_scale : torch .Tensor ,
1087
1127
use_fast_accum : bool = True ,
1088
1128
* ,
1089
- _use_warp_specialization : bool = False ,
1129
+ _use_warp_specialization : bool = True ,
1090
1130
_output_tensor : Optional [torch .Tensor ] = None ,
1091
1131
_scatter_add_indices : Optional [torch .Tensor ] = None ,
1092
1132
) -> torch .Tensor :
1093
1133
return _grouped_gemm (
1094
- x ,
1095
- w ,
1096
- m_sizes ,
1097
- x_scale ,
1098
- w_scale ,
1134
+ x = x ,
1135
+ w = w ,
1136
+ m_sizes = m_sizes ,
1137
+ x_scale = x_scale ,
1138
+ w_scale = w_scale ,
1099
1139
use_fast_accum = use_fast_accum ,
1100
1140
use_warp_specialization = _use_warp_specialization ,
1101
1141
output_tensor = _output_tensor ,
0 commit comments