4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import argparse
7
8
import itertools
8
- from typing import List , Tuple
9
+ from typing import List , Optional , Tuple
9
10
10
11
import torch
11
12
import triton # noqa: F401
12
13
from fbgemm_gpu .experimental .gen_ai .moe import (
14
+ combine_shuffling ,
13
15
gather_along_first_dim ,
14
16
gather_scale_dense_tokens ,
15
17
gather_scale_quant_dense_tokens ,
16
18
index_shuffling ,
17
19
scatter_add_along_first_dim ,
20
+ split_shuffling ,
18
21
)
19
22
from triton .testing import do_bench , do_bench_cudagraph
20
23
24
+ _ACCELERATOR_TAG = torch .accelerator .current_accelerator ()
25
+
21
26
22
27
def bench_gather_along_first_dim (M : int , N : int , K : int ) -> None :
23
- src = torch .randn ([M , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
28
+ src = torch .randn ([M , K ], device = _ACCELERATOR_TAG , dtype = torch .bfloat16 ).abs ()
24
29
if M == N :
25
- indices = torch .randperm (N , device = "cuda" , dtype = torch .int32 )
30
+ indices = torch .randperm (N , device = _ACCELERATOR_TAG , dtype = torch .int32 )
26
31
else :
27
- indices = torch .randint (0 , M , [N ], device = "cuda" , dtype = torch .int32 )
32
+ indices = torch .randint (0 , M , [N ], device = _ACCELERATOR_TAG , dtype = torch .int32 )
28
33
29
34
def fn ():
30
35
return gather_along_first_dim (src , indices )
@@ -51,12 +56,14 @@ def ref_fn():
51
56
52
57
53
58
def bench_scatter_add_along_first_dim (M : int , N : int , K : int ) -> None :
54
- src = torch .randn ([M , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
55
- dst = torch .randn ([N , K ], device = "cuda" , dtype = torch .bfloat16 ).abs ()
59
+ src = torch .randn ([M , K ], device = _ACCELERATOR_TAG , dtype = torch .bfloat16 ).abs ()
60
+ dst = torch .randn ([N , K ], device = _ACCELERATOR_TAG , dtype = torch .bfloat16 ).abs ()
56
61
if M == N :
57
- indices_1d = torch .randperm (N , device = "cuda" , dtype = torch .int64 )
62
+ indices_1d = torch .randperm (N , device = _ACCELERATOR_TAG , dtype = torch .int64 )
58
63
else :
59
- indices_1d = torch .randint (0 , N , [M ], device = "cuda" , dtype = torch .int64 )
64
+ indices_1d = torch .randint (
65
+ 0 , N , [M ], device = _ACCELERATOR_TAG , dtype = torch .int64
66
+ )
60
67
61
68
indices_2d = indices_1d .to (torch .int64 ).unsqueeze (1 ).expand (- 1 , K )
62
69
@@ -88,10 +95,10 @@ def ref_fn():
88
95
89
96
90
97
def bench_gather_scale_dense_tokens (E : int , T : int , D : int , quantize : bool ):
91
- x = torch .randn ((T , D ), dtype = torch .bfloat16 , device = "cuda" ).abs ()
92
- expert_indices = torch .randint (0 , E , (T ,), device = "cuda" )
93
- token_indices = torch .randperm (T , device = "cuda" )
94
- scores = torch .rand ((E , T ), dtype = torch .bfloat16 , device = "cuda" )
98
+ x = torch .randn ((T , D ), dtype = torch .bfloat16 , device = _ACCELERATOR_TAG ).abs ()
99
+ expert_indices = torch .randint (0 , E , (T ,), device = _ACCELERATOR_TAG )
100
+ token_indices = torch .randperm (T , device = _ACCELERATOR_TAG )
101
+ scores = torch .rand ((E , T ), dtype = torch .bfloat16 , device = _ACCELERATOR_TAG )
95
102
96
103
def torch_fn ():
97
104
shuffled_x = torch .index_select (x , dim = 0 , index = token_indices )
@@ -134,12 +141,13 @@ def triton_fn():
134
141
)
135
142
136
143
137
- def bench_top1_index_shuffling (num_tokens : int , num_experts : int ) -> None :
144
+ def bench_top1_index_shuffling (T : int , E : int ) -> None :
138
145
torch .manual_seed (0 )
139
146
147
+ num_rotating_buffers = max (2 , triton .cdiv (1024 * 1024 * 1024 , T * E * 2 ))
140
148
scores_list : List [torch .Tensor ] = [
141
- torch .randn (num_tokens , num_experts , device = "cuda" , dtype = torch .bfloat16 )
142
- for i in range (100 )
149
+ torch .randn (T , E , device = _ACCELERATOR_TAG , dtype = torch .bfloat16 )
150
+ for i in range (num_rotating_buffers )
143
151
]
144
152
145
153
def fn () -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
@@ -152,39 +160,171 @@ def ref_fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152
160
expert_indices , _ = torch .sort (selected_expert_indices , dim = 0 )
153
161
_ = (
154
162
expert_indices [:, None ]
155
- == torch .arange (num_experts , device = expert_indices .device )[None , :]
163
+ == torch .arange (E , device = expert_indices .device )[None , :]
156
164
).sum (dim = 0 )
157
165
158
- fbgemm_time = do_bench_cudagraph (fn ) * 1e3 / 100
159
- torch_time = do_bench_cudagraph (ref_fn ) * 1e3 / 100
166
+ fbgemm_time = do_bench_cudagraph (fn ) * 1e3 / num_rotating_buffers
167
+ torch_time = do_bench_cudagraph (ref_fn ) * 1e3 / num_rotating_buffers
160
168
print (
161
- f"Benchmark index_shuffling, num_tokens={ num_tokens :4} , num_experts={ num_experts :4} , "
169
+ f"Benchmark index_shuffling, num_tokens={ T :4} , num_experts={ E :4} , "
162
170
f"fbgemm_time={ fbgemm_time :7.3f} us, torch_time={ torch_time :7.3f} us"
163
171
)
164
172
165
173
166
- def main ():
174
+ def bench_combine_or_split_shuffling (
175
+ T : int ,
176
+ D : int ,
177
+ E : int ,
178
+ EP : bool ,
179
+ is_padded : bool ,
180
+ is_balanced : bool ,
181
+ is_combine_shuffling : bool ,
182
+ ):
183
+ torch .manual_seed (0 )
184
+
185
+ assert E % EP == 0
186
+ if is_padded :
187
+ # graph. allgather
188
+ input_num_tokens : int = EP * T
189
+ input_num_experts : int = E
190
+ output_num_experts : int = E // EP
191
+ start_expert_index : int = 1
192
+ end_expert_index : int = 1 + output_num_experts
193
+ else :
194
+ # eager. all2all
195
+ input_num_tokens : int = T
196
+ input_num_experts : int = E // EP
197
+ output_num_experts : int = E // EP
198
+ start_expert_index : int = 0
199
+ end_expert_index : int = output_num_experts
200
+
201
+ tokens = torch .randn (
202
+ input_num_tokens , D , device = _ACCELERATOR_TAG , dtype = torch .bfloat16
203
+ )
204
+
205
+ if input_num_tokens < (EP * input_num_experts ) != 0 :
206
+ return
207
+
208
+ input_num_tokens_per_expert : int = input_num_tokens // (EP * input_num_experts )
209
+ token_counts : torch .Tensor = (
210
+ torch .ones (
211
+ [EP , input_num_experts ],
212
+ dtype = torch .int32 ,
213
+ device = _ACCELERATOR_TAG ,
214
+ )
215
+ * input_num_tokens_per_expert
216
+ )
217
+ if not is_balanced :
218
+ for i in range (EP ):
219
+ token_counts [i , start_expert_index ] -= input_num_tokens_per_expert
220
+ token_counts [i , end_expert_index - 1 ] += input_num_tokens_per_expert
221
+
222
+ assert token_counts .sum ().item () == input_num_tokens
223
+
224
+ num_rotating_buffers = triton .cdiv (1024 * 1024 * 1024 , tokens .numel () * 2 )
225
+ token_list : List [torch .Tensor ] = [
226
+ tokens .clone () for _ in range (num_rotating_buffers )
227
+ ]
228
+ token_count_list : List [torch .Tensor ] = [
229
+ token_counts .clone () for _ in range (num_rotating_buffers )
230
+ ]
231
+
232
+ def fn () -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
233
+ for tokens , token_counts in zip (token_list , token_count_list ):
234
+ if is_combine_shuffling :
235
+ combine_shuffling (
236
+ tokens ,
237
+ token_counts ,
238
+ expert_start = start_expert_index ,
239
+ expert_end = end_expert_index ,
240
+ is_balanced = is_balanced ,
241
+ )
242
+ else :
243
+ split_shuffling (
244
+ tokens ,
245
+ token_counts ,
246
+ expert_start = start_expert_index ,
247
+ expert_end = end_expert_index ,
248
+ is_balanced = is_balanced ,
249
+ )
250
+
251
+ fn ()
252
+
253
+ output_num_tokens = 0
254
+ for per_rank_counts in token_counts .tolist ():
255
+ for expert_index , per_expert_counts in enumerate (per_rank_counts ):
256
+ if expert_index >= start_expert_index and expert_index < end_expert_index :
257
+ output_num_tokens += per_expert_counts
258
+
259
+ mem_bytes = output_num_tokens * D * 2 * 2
260
+ fbgemm_time = do_bench_cudagraph (fn ) * 1e3 / num_rotating_buffers
261
+ fbgemm_bw = mem_bytes * 1e-9 / (fbgemm_time * 1e-6 )
262
+
263
+ print (
264
+ f"Benchmark { 'combine_shuffling' if is_combine_shuffling else 'split_shuffling' } , "
265
+ f"num_tokens={ T :4} , dim={ D :4} , num_experts={ E :4} , expert_parallelism={ EP :4} , output_num_tokens={ output_num_tokens :4} , "
266
+ f"{ is_balanced = } , { is_padded = } , "
267
+ f"fbgemm_time={ fbgemm_time :7.3f} us, fbgemm_bw={ fbgemm_bw :8.3f} GBytes/s."
268
+ )
269
+
270
+
271
+ def main (kernels : Optional [str ]):
272
+ if kernels is not None :
273
+ kernels = kernels .split ("," )
274
+
275
+ def should_bench_kernel (fn ):
276
+ return (fn is not None ) and (kernels is None or fn .__name__ in kernels )
277
+
167
278
Es = [16 , 128 ]
168
279
Ts = [1 , 128 , 2048 , 4096 , 8192 , 16384 ]
169
280
Ds = [5120 ]
170
281
171
- for E , T , D in itertools .product (Es , Ts , Ds ):
172
- bench_gather_scale_dense_tokens (E , T , D , quantize = False )
282
+ # Gather/Scatter
283
+ if should_bench_kernel (gather_scale_dense_tokens ):
284
+ for E , T , D in itertools .product (Es , Ts , Ds ):
285
+ bench_gather_scale_dense_tokens (E , T , D , quantize = False )
173
286
174
- for E , T , D in itertools .product (Es , Ts , Ds ):
175
- bench_gather_scale_dense_tokens (E , T , D , quantize = True )
287
+ if should_bench_kernel (gather_scale_quant_dense_tokens ):
288
+ for E , T , D in itertools .product (Es , Ts , Ds ):
289
+ bench_gather_scale_dense_tokens (E , T , D , quantize = True )
176
290
177
- if gather_along_first_dim is not None :
291
+ if should_bench_kernel ( gather_along_first_dim ) :
178
292
for T , D in itertools .product (Ts , Ds ):
179
293
bench_gather_along_first_dim (T , T , D )
180
294
181
- if scatter_add_along_first_dim is not None :
295
+ if should_bench_kernel ( scatter_add_along_first_dim ) :
182
296
for T , D in itertools .product (Ts , Ds ):
183
297
bench_scatter_add_along_first_dim (T , T , D )
184
298
185
- for T , E in itertools .product (Ts , Es ):
186
- bench_top1_index_shuffling (T , E )
299
+ # Shuffling
300
+ if should_bench_kernel (index_shuffling ):
301
+ for T , E in itertools .product (Ts , Es ):
302
+ bench_top1_index_shuffling (T , E )
303
+
304
+ EPs = [2 , 16 ]
305
+ Ts = [32 , 128 , 2048 , 4096 , 8192 , 16384 ]
306
+ padded = [True , False ]
307
+ balanced = [True , False ]
308
+
309
+ if should_bench_kernel (combine_shuffling ):
310
+ for T , D , E , EP , p , b in itertools .product (Ts , Ds , Es , EPs , padded , balanced ):
311
+ bench_combine_or_split_shuffling (
312
+ T , D , E , EP , p , b , is_combine_shuffling = True
313
+ )
314
+
315
+ if should_bench_kernel (split_shuffling ):
316
+ for T , D , E , EP , p , b in itertools .product (Ts , Ds , Es , EPs , padded , balanced ):
317
+ bench_combine_or_split_shuffling (
318
+ T , D , E , EP , p , b , is_combine_shuffling = False
319
+ )
187
320
188
321
189
322
if __name__ == "__main__" :
190
- main ()
323
+ parser = argparse .ArgumentParser ()
324
+ parser .add_argument (
325
+ "--kernels" ,
326
+ default = None ,
327
+ help = "Comma separated list of kernels to benchmark. Defaults to all kernels." ,
328
+ )
329
+ args = parser .parse_args ()
330
+ main (args .kernels )
0 commit comments