Skip to content

Commit 9448ba9

Browse files
authored
Performance: Add CUDA Aware MPI (#5930)
* add CUDA-aware MPI * update docs
1 parent b06a163 commit 9448ba9

File tree

6 files changed

+49
-0
lines changed

6 files changed

+49
-0
lines changed

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
4242
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
4343
option(USE_DSP "Enable DSP usage." OFF)
4444
option(USE_CUDA_ON_DCU "Enable CUDA on DCU" OFF)
45+
option(USE_CUDA_MPI "Enable CUDA-aware MPI" OFF)
4546

4647
# enable json support
4748
if(ENABLE_RAPIDJSON)
@@ -132,6 +133,10 @@ if (USE_CUDA_ON_DCU)
132133
add_compile_definitions(__CUDA_ON_DCU)
133134
endif()
134135

136+
if (USE_CUDA_MPI)
137+
add_compile_definitions(__CUDA_MPI)
138+
endif()
139+
135140
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
136141

137142
if(ENABLE_COVERAGE)

docs/advanced/install.md

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ To build NVIDIA GPU support for ABACUS, define `USE_CUDA` flag. You can also spe
115115
cmake -B build -DUSE_CUDA=1 -DCMAKE_CUDA_COMPILER=${path to cuda toolkit}/bin/nvcc
116116
```
117117

118+
If you are confident that your MPI supports CUDA Aware, you can add `-DUSE_CUDA_MPI=ON`. In this case, the program will directly communicate data with the CUDA hardware, rather than transferring it to the CPU first before communication. But note that if CUDA Aware is not supported, adding `-DUSE_CUDA_MPI=ON` will cause the program to throw an error.
119+
118120
## Build math library from source
119121

120122
> Note: We recommend using the latest available compiler sets, since they offer faster implementations of math functions.

source/module_base/para_gemm.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "kernels/math_kernel_op.h"
44
#include "parallel_device.h"
5+
#include "module_base/timer.h"
56
namespace ModuleBase
67
{
78
template <typename T, typename Device>
@@ -109,6 +110,7 @@ void PGemmCN<T, Device>::set_dimension(
109110
template <typename T, typename Device>
110111
void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C)
111112
{
113+
ModuleBase::timer::tick("PGemmCN", "multiply");
112114
#ifdef __MPI
113115
if (this->col_nproc > 1)
114116
{
@@ -126,6 +128,7 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
126128
{
127129
multiply_single(alpha, A, B, beta, C);
128130
}
131+
ModuleBase::timer::tick("PGemmCN", "multiply");
129132
}
130133

131134
template <typename T, typename Device>
@@ -154,10 +157,12 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
154157

155158
std::vector<T> B_tmp(max_colA * LDA);
156159
std::vector<T> isend_tmp;
160+
#ifndef __CUDA_MPI
157161
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
158162
{
159163
isend_tmp.resize(max_colA * LDA);
160164
}
165+
#endif
161166
for (int ip = 0; ip < col_nproc; ip++)
162167
{
163168
if (col_rank != ip)
@@ -244,6 +249,13 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
244249

245250
if (this->gatherC)
246251
{
252+
#ifdef __CUDA_MPI
253+
if (this->row_nproc > 1)
254+
{
255+
Parallel_Common::reduce_data(C_local, size_C_local, row_world);
256+
}
257+
Parallel_Common::gatherv_data(C_local, size_C_local, C, recv_counts.data(), displs.data(), col_world);
258+
#else
247259
T* Cglobal_cpu = nullptr;
248260
T* Clocal_cpu = C_tmp.data();
249261
std::vector<T> cpu_tmp;
@@ -277,6 +289,7 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
277289
{
278290
syncmem_h2d_op()(C, Cglobal_cpu, size_C_global);
279291
}
292+
#endif
280293
}
281294
else
282295
{

source/module_base/parallel_device.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ void gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::comple
9999
MPI_Allgatherv(sendbuf, sendcount, MPI_COMPLEX, recvbuf, recvcounts, displs, MPI_COMPLEX, comm);
100100
}
101101

102+
#ifndef __CUDA_MPI
102103
template <typename T>
103104
struct object_cpu_point<T, base_device::DEVICE_GPU>
104105
{
@@ -171,6 +172,7 @@ template struct object_cpu_point<float, base_device::DEVICE_CPU>;
171172
template struct object_cpu_point<float, base_device::DEVICE_GPU>;
172173
template struct object_cpu_point<std::complex<float>, base_device::DEVICE_CPU>;
173174
template struct object_cpu_point<std::complex<float>, base_device::DEVICE_GPU>;
175+
#endif
174176

175177
} // namespace Parallel_Common
176178
#endif

source/module_base/parallel_device.h

+22
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void gatherv_data(const std::complex<double>* sendbuf, int sendcount, std::compl
3232
void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
3333
void gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::complex<float>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
3434

35+
#ifndef __CUDA_MPI
3536
template<typename T, typename Device>
3637
struct object_cpu_point
3738
{
@@ -41,6 +42,7 @@ struct object_cpu_point
4142
void sync_d2h(T* object_cpu, const T* object, const int& n);
4243
void sync_h2d(T* object, const T* object_cpu, const int& n);
4344
};
45+
#endif
4446

4547
/**
4648
* @brief send data in Device
@@ -49,11 +51,15 @@ struct object_cpu_point
4951
template <typename T, typename Device>
5052
void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr)
5153
{
54+
#ifdef __CUDA_MPI
55+
send_data(object, count, dest, tag, comm);
56+
#else
5257
object_cpu_point<T,Device> o;
5358
T* object_cpu = o.get(object, count, tmp_space);
5459
o.sync_d2h(object_cpu, object, count);
5560
send_data(object_cpu, count, dest, tag, comm);
5661
o.del(object_cpu);
62+
#endif
5763
return;
5864
}
5965

@@ -65,11 +71,15 @@ void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T*
6571
template <typename T, typename Device>
6672
void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space)
6773
{
74+
#ifdef __CUDA_MPI
75+
isend_data(object, count, dest, tag, comm, request);
76+
#else
6877
object_cpu_point<T,Device> o;
6978
T* object_cpu = o.get(object, count, send_space);
7079
o.sync_d2h(object_cpu, object, count);
7180
isend_data(object_cpu, count, dest, tag, comm, request);
7281
o.del(object_cpu);
82+
#endif
7383
return;
7484
}
7585

@@ -80,11 +90,15 @@ void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MP
8090
template <typename T, typename Device>
8191
void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status, T* tmp_space = nullptr)
8292
{
93+
#ifdef __CUDA_MPI
94+
recv_data(object, count, source, tag, comm, status);
95+
#else
8396
object_cpu_point<T,Device> o;
8497
T* object_cpu = o.get(object, count, tmp_space);
8598
recv_data(object_cpu, count, source, tag, comm, status);
8699
o.sync_h2d(object, object_cpu, count);
87100
o.del(object_cpu);
101+
#endif
88102
return;
89103
}
90104

@@ -102,24 +116,32 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta
102116
template <typename T, typename Device>
103117
void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
104118
{
119+
#ifdef __CUDA_MPI
120+
bcast_data(object, n, comm);
121+
#else
105122
object_cpu_point<T,Device> o;
106123
T* object_cpu = o.get(object, n, tmp_space);
107124
o.sync_d2h(object_cpu, object, n);
108125
bcast_data(object_cpu, n, comm);
109126
o.sync_h2d(object, object_cpu, n);
110127
o.del(object_cpu);
128+
#endif
111129
return;
112130
}
113131

114132
template <typename T, typename Device>
115133
void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
116134
{
135+
#ifdef __CUDA_MPI
136+
reduce_data(object, n, comm);
137+
#else
117138
object_cpu_point<T,Device> o;
118139
T* object_cpu = o.get(object, n, tmp_space);
119140
o.sync_d2h(object_cpu, object, n);
120141
reduce_data(object_cpu, n, comm);
121142
o.sync_h2d(object, object_cpu, n);
122143
o.del(object_cpu);
144+
#endif
123145
return;
124146
}
125147
}

source/module_hsolver/para_linear_transform.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "para_linear_transform.h"
2+
#include "module_base/timer.h"
23

34
#include <algorithm>
45
#include <vector>
@@ -54,6 +55,7 @@ void PLinearTransform<T, Device>::set_dimension(const int nrowA,
5455
template <typename T, typename Device>
5556
void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, const T beta, T* B)
5657
{
58+
ModuleBase::timer::tick("PLinearTransform", "act");
5759
const Device* ctx = {};
5860
#ifdef __MPI
5961
if (nproc_col > 1)
@@ -65,7 +67,9 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
6567
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
6668
{
6769
A_tmp_device = nullptr;
70+
#ifndef __CUDA_MPI
6871
isend_tmp.resize(max_colA * LDA);
72+
#endif
6973
resmem_dev_op()(A_tmp_device, max_colA * LDA);
7074
}
7175
T* B_tmp = nullptr;
@@ -168,6 +172,7 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
168172
B,
169173
LDA);
170174
}
175+
ModuleBase::timer::tick("PLinearTransform", "act");
171176
};
172177

173178
template struct PLinearTransform<double, base_device::DEVICE_CPU>;

0 commit comments

Comments
 (0)