Skip to content

Commit ae31729

Browse files
levendleefacebook-github-bot
authored andcommitted
Use cudaMemset/hipMemset to setup IndexShuffling kernel. (pytorch#4016)
Summary: X-link: facebookresearch/FBGEMM#1104 It is too expensive to launch a ATen kernel to do setup. Use cudaMemset/hipMemset instead. Differential Revision: D73602755
1 parent f5b4267 commit ae31729

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

+5-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,11 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
276276
at::Tensor shuffled_expert_indices = allocate_index_tensor(num_tokens);
277277
at::Tensor shuffled_token_indices = allocate_index_tensor(num_tokens);
278278

279-
counts.zero_();
279+
#ifdef USE_ROCM
280+
hipMemset(counts.data_ptr(), 0, counts.numel() * counts.dtype().itemsize());
281+
#else
282+
cudaMemset(counts.data_ptr(), 0, counts.numel() * counts.dtype().itemsize());
283+
#endif
280284

281285
// Avoid expensive `cudaGetDeviceProperties` call.
282286
if (num_sms < 0) {

0 commit comments

Comments
 (0)