Skip to content

Commit eeee38e

Browse files
levendleefacebook-github-bot
authored andcommitted
Use cudaMemsetAsync to setup IndexShuffling kernel. (#4016)
Summary: Pull Request resolved: #4016 X-link: facebookresearch/FBGEMM#1104 It is too expensive to launch a ATen kernel to do setup. Use cudaMemsetAsync instead. hipMemsetAsync is somehow more expensive than launching a kernel. Avoid doing so for now. Reviewed By: Alkaid-Benetnash Differential Revision: D73602755 fbshipit-source-id: 6cdb2b5d489a7fd9c1a2791fcc1c26f2e6dd63ec
1 parent 0485fcf commit eeee38e

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

+16-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
276276
at::Tensor shuffled_expert_indices = allocate_index_tensor(num_tokens);
277277
at::Tensor shuffled_token_indices = allocate_index_tensor(num_tokens);
278278

279+
#ifdef USE_ROCM
279280
counts.zero_();
281+
// TODO(shikaili): hipMetsetAsync is more expensive than ATen set zero.
282+
/*
283+
hipMemsetAsync(
284+
counts.data_ptr(),
285+
0,
286+
counts.numel() * counts.dtype().itemsize(),
287+
at::cuda::getCurrentCUDAStream());
288+
*/
289+
#else
290+
cudaMemsetAsync(
291+
counts.data_ptr(),
292+
0,
293+
counts.numel() * counts.dtype().itemsize(),
294+
at::cuda::getCurrentCUDAStream());
295+
#endif
280296

281297
// Avoid expensive `cudaGetDeviceProperties` call.
282298
if (num_sms < 0) {
@@ -298,7 +314,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
298314
kernel = (void*)index_shuffling_kernel<DataType, IndexType, E, B>; \
299315
smem_size = sizeof(SharedStorage<DataType, IndexType, E, B>);
300316

301-
int num_tokens_per_tile;
302317
if (num_experts == 16) {
303318
DISPATCH(16, kNumTokensPerTile);
304319
} else {

0 commit comments

Comments
 (0)