Skip to content

Commit 8cb941b

Browse files
levendleefacebook-github-bot
authored andcommitted
Use cudaMemsetAsync/hipMemsetAsync to setup IndexShuffling kernel. (pytorch#4016)
Summary: Pull Request resolved: pytorch#4016 X-link: facebookresearch/FBGEMM#1104 It is too expensive to launch a ATen kernel to do setup. Use cudaMemsetAsync/hipMemsetAsync instead. Note we need to use cudaMemsetAsync/hipMemsetAsync on current stream to be compatible with CUDAGraph capture. Reviewed By: Alkaid-Benetnash Differential Revision: D73602755
1 parent 0f00a8a commit 8cb941b

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

+13-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,19 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
276276
at::Tensor shuffled_expert_indices = allocate_index_tensor(num_tokens);
277277
at::Tensor shuffled_token_indices = allocate_index_tensor(num_tokens);
278278

279-
counts.zero_();
279+
#ifdef USE_ROCM
280+
hipMemsetAsync(
281+
counts.data_ptr(),
282+
0,
283+
counts.numel() * counts.dtype().itemsize(),
284+
at::cuda::getCurrentCUDAStream());
285+
#else
286+
cudaMemsetAsync(
287+
counts.data_ptr(),
288+
0,
289+
counts.numel() * counts.dtype().itemsize(),
290+
at::cuda::getCurrentCUDAStream());
291+
#endif
280292

281293
// Avoid expensive `cudaGetDeviceProperties` call.
282294
if (num_sms < 0) {

0 commit comments

Comments
 (0)