Skip to content

Commit 47f4b10

Browse files
Bangsheng Tangfacebook-github-bot
Bangsheng Tang
authored andcommitted
add AMD specific includes in cuda_prelude.h (pytorch#3614)
Summary: X-link: facebookresearch/FBGEMM#691 as title Reviewed By: q10 Differential Revision: D68638427
1 parent 1aff241 commit 47f4b10

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh

+15
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,31 @@
99
#pragma once
1010

1111
#include <ATen/ATen.h>
12+
1213
#include <cuda.h>
14+
15+
#ifdef __HIP_PLATFORM_AMD__
16+
#include <ATen/cuda/CUDAGeneratorImpl.h>
17+
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
18+
19+
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> // @manual
20+
#else
1321
#include <ATen/cuda/CUDAGraphsUtils.cuh>
22+
#endif
1423
#include <cassert>
1524

1625
namespace {
1726

1827
inline int get_device_sm_cnt_() {
28+
#ifdef __HIP_PLATFORM_AMD__
29+
hipDeviceProp_t deviceProp;
30+
hipGetDeviceProperties(&deviceProp, c10::hip::current_device());
31+
return deviceProp.multiProcessorCount;
32+
#else
1933
cudaDeviceProp* deviceProp =
2034
at::cuda::getDeviceProperties(c10::cuda::current_device());
2135
return deviceProp->multiProcessorCount;
36+
#endif
2237
}
2338

2439
} // namespace

0 commit comments

Comments
 (0)