Skip to content

Commit 2cffc7b

Browse files
levendleefacebook-github-bot
authored andcommitted
Cleanup shuffling ops. (pytorch#4013)
Summary: Pull Request resolved: pytorch#4013 X-link: facebookresearch/FBGEMM#1101 - Sort the file structure. - Make sure compatible with `torch.compile`. - Move benchmark from test to dedicated benchmark script. Reviewed By: jianyuh Differential Revision: D73471757 fbshipit-source-id: 7067c9e50a4f5924f1ed3e02726b2e38ec8a28f6
1 parent 717642e commit 2cffc7b

File tree

4 files changed

+259
-71
lines changed

4 files changed

+259
-71
lines changed

fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py

+169-29
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,32 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import itertools
8-
from typing import List, Tuple
9+
from typing import List, Optional, Tuple
910

1011
import torch
1112
import triton # noqa: F401
1213
from fbgemm_gpu.experimental.gen_ai.moe import (
14+
combine_shuffling,
1315
gather_along_first_dim,
1416
gather_scale_dense_tokens,
1517
gather_scale_quant_dense_tokens,
1618
index_shuffling,
1719
scatter_add_along_first_dim,
20+
split_shuffling,
1821
)
1922
from triton.testing import do_bench, do_bench_cudagraph
2023

24+
_ACCELERATOR_TAG = torch.accelerator.current_accelerator()
25+
2126

2227
def bench_gather_along_first_dim(M: int, N: int, K: int) -> None:
23-
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
28+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
2429
if M == N:
25-
indices = torch.randperm(N, device="cuda", dtype=torch.int32)
30+
indices = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int32)
2631
else:
27-
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int32)
32+
indices = torch.randint(0, M, [N], device=_ACCELERATOR_TAG, dtype=torch.int32)
2833

2934
def fn():
3035
return gather_along_first_dim(src, indices)
@@ -51,12 +56,14 @@ def ref_fn():
5156

5257

5358
def bench_scatter_add_along_first_dim(M: int, N: int, K: int) -> None:
54-
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
55-
dst = torch.randn([N, K], device="cuda", dtype=torch.bfloat16).abs()
59+
src = torch.randn([M, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
60+
dst = torch.randn([N, K], device=_ACCELERATOR_TAG, dtype=torch.bfloat16).abs()
5661
if M == N:
57-
indices_1d = torch.randperm(N, device="cuda", dtype=torch.int64)
62+
indices_1d = torch.randperm(N, device=_ACCELERATOR_TAG, dtype=torch.int64)
5863
else:
59-
indices_1d = torch.randint(0, N, [M], device="cuda", dtype=torch.int64)
64+
indices_1d = torch.randint(
65+
0, N, [M], device=_ACCELERATOR_TAG, dtype=torch.int64
66+
)
6067

6168
indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)
6269

@@ -88,10 +95,10 @@ def ref_fn():
8895

8996

9097
def bench_gather_scale_dense_tokens(E: int, T: int, D: int, quantize: bool):
91-
x = torch.randn((T, D), dtype=torch.bfloat16, device="cuda").abs()
92-
expert_indices = torch.randint(0, E, (T,), device="cuda")
93-
token_indices = torch.randperm(T, device="cuda")
94-
scores = torch.rand((E, T), dtype=torch.bfloat16, device="cuda")
98+
x = torch.randn((T, D), dtype=torch.bfloat16, device=_ACCELERATOR_TAG).abs()
99+
expert_indices = torch.randint(0, E, (T,), device=_ACCELERATOR_TAG)
100+
token_indices = torch.randperm(T, device=_ACCELERATOR_TAG)
101+
scores = torch.rand((E, T), dtype=torch.bfloat16, device=_ACCELERATOR_TAG)
95102

96103
def torch_fn():
97104
shuffled_x = torch.index_select(x, dim=0, index=token_indices)
@@ -134,12 +141,13 @@ def triton_fn():
134141
)
135142

136143

137-
def bench_top1_index_shuffling(num_tokens: int, num_experts: int) -> None:
144+
def bench_top1_index_shuffling(T: int, E: int) -> None:
138145
torch.manual_seed(0)
139146

147+
num_rotating_buffers = max(2, triton.cdiv(1024 * 1024 * 1024, T * E * 2))
140148
scores_list: List[torch.Tensor] = [
141-
torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.bfloat16)
142-
for i in range(100)
149+
torch.randn(T, E, device=_ACCELERATOR_TAG, dtype=torch.bfloat16)
150+
for i in range(num_rotating_buffers)
143151
]
144152

145153
def fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -152,39 +160,171 @@ def ref_fn() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152160
expert_indices, _ = torch.sort(selected_expert_indices, dim=0)
153161
_ = (
154162
expert_indices[:, None]
155-
== torch.arange(num_experts, device=expert_indices.device)[None, :]
163+
== torch.arange(E, device=expert_indices.device)[None, :]
156164
).sum(dim=0)
157165

158-
fbgemm_time = do_bench_cudagraph(fn) * 1e3 / 100
159-
torch_time = do_bench_cudagraph(ref_fn) * 1e3 / 100
166+
fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
167+
torch_time = do_bench_cudagraph(ref_fn) * 1e3 / num_rotating_buffers
160168
print(
161-
f"Benchmark index_shuffling, num_tokens={num_tokens:4}, num_experts={num_experts:4}, "
169+
f"Benchmark index_shuffling, num_tokens={T:4}, num_experts={E:4}, "
162170
f"fbgemm_time={fbgemm_time:7.3f}us, torch_time={torch_time:7.3f}us"
163171
)
164172

165173

166-
def main():
174+
def bench_combine_or_split_shuffling(
175+
T: int,
176+
D: int,
177+
E: int,
178+
EP: bool,
179+
is_padded: bool,
180+
is_balanced: bool,
181+
is_combine_shuffling: bool,
182+
):
183+
torch.manual_seed(0)
184+
185+
assert E % EP == 0
186+
if is_padded:
187+
# graph. allgather
188+
input_num_tokens: int = EP * T
189+
input_num_experts: int = E
190+
output_num_experts: int = E // EP
191+
start_expert_index: int = 1
192+
end_expert_index: int = 1 + output_num_experts
193+
else:
194+
# eager. all2all
195+
input_num_tokens: int = T
196+
input_num_experts: int = E // EP
197+
output_num_experts: int = E // EP
198+
start_expert_index: int = 0
199+
end_expert_index: int = output_num_experts
200+
201+
tokens = torch.randn(
202+
input_num_tokens, D, device=_ACCELERATOR_TAG, dtype=torch.bfloat16
203+
)
204+
205+
if input_num_tokens < (EP * input_num_experts) != 0:
206+
return
207+
208+
input_num_tokens_per_expert: int = input_num_tokens // (EP * input_num_experts)
209+
token_counts: torch.Tensor = (
210+
torch.ones(
211+
[EP, input_num_experts],
212+
dtype=torch.int32,
213+
device=_ACCELERATOR_TAG,
214+
)
215+
* input_num_tokens_per_expert
216+
)
217+
if not is_balanced:
218+
for i in range(EP):
219+
token_counts[i, start_expert_index] -= input_num_tokens_per_expert
220+
token_counts[i, end_expert_index - 1] += input_num_tokens_per_expert
221+
222+
assert token_counts.sum().item() == input_num_tokens
223+
224+
num_rotating_buffers = triton.cdiv(1024 * 1024 * 1024, tokens.numel() * 2)
225+
token_list: List[torch.Tensor] = [
226+
tokens.clone() for _ in range(num_rotating_buffers)
227+
]
228+
token_count_list: List[torch.Tensor] = [
229+
token_counts.clone() for _ in range(num_rotating_buffers)
230+
]
231+
232+
def fn() -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
233+
for tokens, token_counts in zip(token_list, token_count_list):
234+
if is_combine_shuffling:
235+
combine_shuffling(
236+
tokens,
237+
token_counts,
238+
expert_start=start_expert_index,
239+
expert_end=end_expert_index,
240+
is_balanced=is_balanced,
241+
)
242+
else:
243+
split_shuffling(
244+
tokens,
245+
token_counts,
246+
expert_start=start_expert_index,
247+
expert_end=end_expert_index,
248+
is_balanced=is_balanced,
249+
)
250+
251+
fn()
252+
253+
output_num_tokens = 0
254+
for per_rank_counts in token_counts.tolist():
255+
for expert_index, per_expert_counts in enumerate(per_rank_counts):
256+
if expert_index >= start_expert_index and expert_index < end_expert_index:
257+
output_num_tokens += per_expert_counts
258+
259+
mem_bytes = output_num_tokens * D * 2 * 2
260+
fbgemm_time = do_bench_cudagraph(fn) * 1e3 / num_rotating_buffers
261+
fbgemm_bw = mem_bytes * 1e-9 / (fbgemm_time * 1e-6)
262+
263+
print(
264+
f"Benchmark {'combine_shuffling' if is_combine_shuffling else 'split_shuffling'}, "
265+
f"num_tokens={T:4}, dim={D:4}, num_experts={E:4}, expert_parallelism={EP:4}, output_num_tokens={output_num_tokens:4}, "
266+
f"{is_balanced=}, {is_padded=}, "
267+
f"fbgemm_time={fbgemm_time:7.3f}us, fbgemm_bw={fbgemm_bw:8.3f}GBytes/s."
268+
)
269+
270+
271+
def main(kernels: Optional[str]):
272+
if kernels is not None:
273+
kernels = kernels.split(",")
274+
275+
def should_bench_kernel(fn):
276+
return (fn is not None) and (kernels is None or fn.__name__ in kernels)
277+
167278
Es = [16, 128]
168279
Ts = [1, 128, 2048, 4096, 8192, 16384]
169280
Ds = [5120]
170281

171-
for E, T, D in itertools.product(Es, Ts, Ds):
172-
bench_gather_scale_dense_tokens(E, T, D, quantize=False)
282+
# Gather/Scatter
283+
if should_bench_kernel(gather_scale_dense_tokens):
284+
for E, T, D in itertools.product(Es, Ts, Ds):
285+
bench_gather_scale_dense_tokens(E, T, D, quantize=False)
173286

174-
for E, T, D in itertools.product(Es, Ts, Ds):
175-
bench_gather_scale_dense_tokens(E, T, D, quantize=True)
287+
if should_bench_kernel(gather_scale_quant_dense_tokens):
288+
for E, T, D in itertools.product(Es, Ts, Ds):
289+
bench_gather_scale_dense_tokens(E, T, D, quantize=True)
176290

177-
if gather_along_first_dim is not None:
291+
if should_bench_kernel(gather_along_first_dim):
178292
for T, D in itertools.product(Ts, Ds):
179293
bench_gather_along_first_dim(T, T, D)
180294

181-
if scatter_add_along_first_dim is not None:
295+
if should_bench_kernel(scatter_add_along_first_dim):
182296
for T, D in itertools.product(Ts, Ds):
183297
bench_scatter_add_along_first_dim(T, T, D)
184298

185-
for T, E in itertools.product(Ts, Es):
186-
bench_top1_index_shuffling(T, E)
299+
# Shuffling
300+
if should_bench_kernel(index_shuffling):
301+
for T, E in itertools.product(Ts, Es):
302+
bench_top1_index_shuffling(T, E)
303+
304+
EPs = [2, 16]
305+
Ts = [32, 128, 2048, 4096, 8192, 16384]
306+
padded = [True, False]
307+
balanced = [True, False]
308+
309+
if should_bench_kernel(combine_shuffling):
310+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
311+
bench_combine_or_split_shuffling(
312+
T, D, E, EP, p, b, is_combine_shuffling=True
313+
)
314+
315+
if should_bench_kernel(split_shuffling):
316+
for T, D, E, EP, p, b in itertools.product(Ts, Ds, Es, EPs, padded, balanced):
317+
bench_combine_or_split_shuffling(
318+
T, D, E, EP, p, b, is_combine_shuffling=False
319+
)
187320

188321

189322
if __name__ == "__main__":
190-
main()
323+
parser = argparse.ArgumentParser()
324+
parser.add_argument(
325+
"--kernels",
326+
default=None,
327+
help="Comma separated list of kernels to benchmark. Defaults to all kernels.",
328+
)
329+
args = parser.parse_args()
330+
main(args.kernels)

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py

+87
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import triton.language as tl
1414

1515

16+
# Function APIs
1617
def combine_shuffling(
1718
tokens: torch.Tensor,
1819
token_counts: torch.Tensor,
@@ -115,6 +116,92 @@ def _combine_or_split_shuffling(
115116
return output_tokens
116117

117118

119+
# Torch Custom Op Registrations
120+
_COMBINE_SHUFFLING_OP_NAME = "fbgemm::combine_shuffling"
121+
122+
torch.library.define(
123+
"fbgemm::combine_shuffling",
124+
"(Tensor tokens, Tensor token_counts, int expert_start, int expert_end, bool is_balanced) -> (Tensor, Tensor)",
125+
)
126+
127+
128+
@torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "Meta")
129+
def combine_shuffling_meta(
130+
tokens,
131+
token_counts,
132+
expert_start,
133+
expert_end,
134+
is_balanced,
135+
):
136+
_, E = token_counts.shape
137+
if expert_start is None:
138+
expert_start = 0
139+
if expert_end is None:
140+
expert_end = E
141+
142+
EG: int = expert_end - expert_start
143+
output_tokens = torch.empty_like(tokens)
144+
output_token_counts = torch.empty(
145+
EG + 1, dtype=token_counts.dtype, device=token_counts.device
146+
)
147+
return output_tokens, output_token_counts
148+
149+
150+
@torch.library.impl(_COMBINE_SHUFFLING_OP_NAME, "CUDA")
151+
def combine_shuffling_cuda(
152+
tokens,
153+
token_counts,
154+
expert_start=None,
155+
expert_end=None,
156+
is_balanced=False,
157+
):
158+
return combine_shuffling(
159+
tokens,
160+
token_counts,
161+
expert_start,
162+
expert_end,
163+
is_balanced,
164+
)
165+
166+
167+
_SPLIT_SHUFFLING_OP_NAME = "fbgemm::split_shuffling"
168+
169+
torch.library.define(
170+
"fbgemm::split_shuffling",
171+
"(Tensor tokens, Tensor token_counts, int expert_start, int expert_end, bool is_balanced) -> Tensor",
172+
)
173+
174+
175+
@torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "Meta")
176+
def split_shuffling_meta(
177+
tokens,
178+
token_counts,
179+
expert_start,
180+
expert_end,
181+
is_balanced,
182+
):
183+
output_tokens = torch.empty_like(tokens)
184+
return output_tokens
185+
186+
187+
@torch.library.impl(_SPLIT_SHUFFLING_OP_NAME, "CUDA")
188+
def split_shuffling_cuda(
189+
tokens,
190+
token_counts,
191+
expert_start=None,
192+
expert_end=None,
193+
is_balanced=False,
194+
):
195+
return split_shuffling(
196+
tokens,
197+
token_counts,
198+
expert_start,
199+
expert_end,
200+
is_balanced,
201+
)
202+
203+
204+
# Kernel Implementations
118205
_NV_CONFIGS = [
119206
triton.Config(
120207
{

fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-ignore-all-errors[16,21,53,56]
99

1010
import logging
11-
import os
1211
import unittest
1312
from typing import Tuple
1413

@@ -26,7 +25,6 @@
2625
logger: logging.Logger = logging.getLogger()
2726
logger.setLevel(logging.INFO)
2827

29-
_BENCHMARK_IN_TEST: bool = os.environ.get("BENCHMARK_IN_TEST", "0") == "1"
3028
_MAX_SAMPLES: int = 100
3129

3230

0 commit comments

Comments
 (0)