Skip to content

Commit 4865e6d

Browse files
levendleefacebook-github-bot
authored andcommitted
Use cudaMemset/hipMemset to setup IndexShuffling kernel. (pytorch#4016)
Summary: Pull Request resolved: pytorch#4016 X-link: facebookresearch/FBGEMM#1104 It is too expensive to launch a ATen kernel to do setup. Use cudaMemset/hipMemset instead. Reviewed By: Alkaid-Benetnash Differential Revision: D73602755
1 parent 0f00a8a commit 4865e6d

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

+7-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
276276
at::Tensor shuffled_expert_indices = allocate_index_tensor(num_tokens);
277277
at::Tensor shuffled_token_indices = allocate_index_tensor(num_tokens);
278278

279-
counts.zero_();
279+
#ifdef USE_ROCM
280+
hipMemsetAsync(
281+
counts.data_ptr(), 0, counts.numel() * counts.dtype().itemsize());
282+
#else
283+
cudaMemsetAsync(
284+
counts.data_ptr(), 0, counts.numel() * counts.dtype().itemsize());
285+
#endif
280286

281287
// Avoid expensive `cudaGetDeviceProperties` call.
282288
if (num_sms < 0) {

0 commit comments

Comments
 (0)