Skip to content

Commit 5e86325

Browse files
fangfangssjfangfangssj
and
fangfangssj
authored
【Paddle Tensor 规范化第二期】paddle.svd support complex and 0-size (#72169)
* support complex * fix * fix * fix * fix * fix ci * rerun ci * fix * add test * fix ci --------- Co-authored-by: fangfangssj <[email protected]>
1 parent 03b4d4a commit 5e86325

File tree

7 files changed

+673
-191
lines changed

7 files changed

+673
-191
lines changed

paddle/phi/kernels/cpu/svd_grad_kernel.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/svd_grad_kernel_impl.h"
2020

21-
PD_REGISTER_KERNEL(
22-
svd_grad, CPU, ALL_LAYOUT, phi::SvdGradKernel, float, double) {}
21+
PD_REGISTER_KERNEL(svd_grad,
22+
CPU,
23+
ALL_LAYOUT,
24+
phi::SvdGradKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/svd_kernel.cc

+14-9
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ void SvdKernel(const Context& dev_ctx,
106106
int full = full_matrices;
107107
/*Create Tensors and output, set the dim ...*/
108108
auto numel = X.numel();
109+
if (numel == 0) {
110+
dev_ctx.template Alloc<T>(U);
111+
dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
112+
dev_ctx.template Alloc<T>(VH);
113+
return;
114+
}
109115
DenseTensor trans_x =
110116
::phi::TransposeLast2Dim<T>(dev_ctx, Conj<T, Context>(dev_ctx, X));
111117
auto x_dims = X.dims();
@@ -114,14 +120,6 @@ void SvdKernel(const Context& dev_ctx,
114120
// int k = std::min(rows, cols);
115121
// int col_u = full ? rows : k;
116122
// int col_v = full ? cols : k;
117-
PADDLE_ENFORCE_LT(
118-
0,
119-
rows,
120-
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
121-
PADDLE_ENFORCE_LT(
122-
0,
123-
cols,
124-
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
125123
auto* x_data = trans_x.data<T>();
126124
int batches = static_cast<int>(numel / (rows * cols));
127125
auto* U_out = dev_ctx.template Alloc<T>(U);
@@ -148,4 +146,11 @@ void SvdKernel(const Context& dev_ctx,
148146

149147
} // namespace phi
150148

151-
PD_REGISTER_KERNEL(svd, CPU, ALL_LAYOUT, phi::SvdKernel, float, double) {}
149+
PD_REGISTER_KERNEL(svd,
150+
CPU,
151+
ALL_LAYOUT,
152+
phi::SvdKernel,
153+
float,
154+
double,
155+
phi::dtype::complex<float>,
156+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/svd_grad_kernel.cu

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,11 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/svd_grad_kernel_impl.h"
2020

21-
PD_REGISTER_KERNEL(
22-
svd_grad, GPU, ALL_LAYOUT, phi::SvdGradKernel, float, double) {}
21+
PD_REGISTER_KERNEL(svd_grad,
22+
GPU,
23+
ALL_LAYOUT,
24+
phi::SvdGradKernel,
25+
float,
26+
double,
27+
phi::dtype::complex<float>,
28+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/svd_kernel.cu

+180-14
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/phi/backends/dynload/cusolver.h"
2121
#include "paddle/phi/common/memory_utils.h"
2222
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/kernels/complex_kernel.h"
2324
#include "paddle/phi/kernels/empty_kernel.h"
2425
#include "paddle/phi/kernels/funcs/complex_functors.h"
2526
#include "paddle/phi/kernels/transpose_kernel.h"
@@ -35,7 +36,7 @@ static void GesvdjBatched(const phi::GPUContext& dev_ctx,
3536
T* A,
3637
T* U,
3738
T* V,
38-
T* S,
39+
phi::dtype::Real<T>* S,
3940
int* info,
4041
int thin_UV = 1);
4142

@@ -201,13 +202,185 @@ void GesvdjBatched<double>(const phi::GPUContext& dev_ctx,
201202
phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
202203
}
203204

205+
template <>
206+
void GesvdjBatched<phi::dtype::complex<float>>(const phi::GPUContext& dev_ctx,
207+
int batchSize,
208+
int m,
209+
int n,
210+
int k,
211+
phi::dtype::complex<float>* A,
212+
phi::dtype::complex<float>* U,
213+
phi::dtype::complex<float>* V,
214+
float* S,
215+
int* info,
216+
int thin_UV) {
217+
/* compute singular vectors */
218+
const cusolverEigMode_t jobz =
219+
CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */
220+
gesvdjInfo_t gesvdj_params = NULL;
221+
int lda = m;
222+
int ldu = m;
223+
int ldt = n;
224+
int lwork = 0;
225+
auto handle = dev_ctx.cusolver_dn_handle();
226+
PADDLE_ENFORCE_GPU_SUCCESS(
227+
phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
228+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgesvdj_bufferSize(
229+
handle,
230+
jobz,
231+
thin_UV,
232+
m,
233+
n,
234+
reinterpret_cast<cuComplex*>(A),
235+
lda,
236+
S,
237+
reinterpret_cast<cuComplex*>(U),
238+
ldu,
239+
reinterpret_cast<cuComplex*>(V),
240+
ldt,
241+
&lwork,
242+
gesvdj_params));
243+
auto workspace = phi::memory_utils::Alloc(
244+
dev_ctx.GetPlace(),
245+
lwork * sizeof(phi::dtype::complex<float>),
246+
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
247+
phi::dtype::complex<float>* workspace_ptr =
248+
reinterpret_cast<phi::dtype::complex<float>*>(workspace->ptr());
249+
int stride_A = lda * n;
250+
int stride_U = ldu * (thin_UV ? k : m);
251+
int stride_V = ldt * (thin_UV ? k : n);
252+
for (int i = 0; i < batchSize; ++i) {
253+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgesvdj(
254+
handle,
255+
jobz,
256+
thin_UV,
257+
m,
258+
n,
259+
reinterpret_cast<cuComplex*>(A + stride_A * i),
260+
lda,
261+
reinterpret_cast<float*>(S + k * i),
262+
reinterpret_cast<cuComplex*>(U + stride_U * i),
263+
ldu,
264+
reinterpret_cast<cuComplex*>(V + stride_V * i),
265+
ldt,
266+
reinterpret_cast<cuComplex*>(workspace_ptr),
267+
lwork,
268+
info,
269+
gesvdj_params));
270+
// check the error info
271+
int error_info;
272+
memory_utils::Copy(phi::CPUPlace(),
273+
&error_info,
274+
dev_ctx.GetPlace(),
275+
info,
276+
sizeof(int),
277+
dev_ctx.stream());
278+
PADDLE_ENFORCE_EQ(
279+
error_info,
280+
0,
281+
common::errors::PreconditionNotMet(
282+
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
283+
}
284+
PADDLE_ENFORCE_GPU_SUCCESS(
285+
phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
286+
}
287+
288+
template <>
289+
void GesvdjBatched<phi::dtype::complex<double>>(const phi::GPUContext& dev_ctx,
290+
int batchSize,
291+
int m,
292+
int n,
293+
int k,
294+
phi::dtype::complex<double>* A,
295+
phi::dtype::complex<double>* U,
296+
phi::dtype::complex<double>* V,
297+
double* S,
298+
int* info,
299+
int thin_UV) {
300+
/* compute singular vectors */
301+
const cusolverEigMode_t jobz =
302+
CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */
303+
gesvdjInfo_t gesvdj_params = NULL;
304+
int lda = m;
305+
int ldu = m;
306+
int ldt = n;
307+
int lwork = 0;
308+
auto handle = dev_ctx.cusolver_dn_handle();
309+
PADDLE_ENFORCE_GPU_SUCCESS(
310+
phi::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
311+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgesvdj_bufferSize(
312+
handle,
313+
jobz,
314+
thin_UV,
315+
m,
316+
n,
317+
reinterpret_cast<cuDoubleComplex*>(A),
318+
lda,
319+
S,
320+
reinterpret_cast<cuDoubleComplex*>(U),
321+
ldu,
322+
reinterpret_cast<cuDoubleComplex*>(V),
323+
ldt,
324+
&lwork,
325+
gesvdj_params));
326+
auto workspace = phi::memory_utils::Alloc(
327+
dev_ctx.GetPlace(),
328+
lwork * sizeof(phi::dtype::complex<double>),
329+
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
330+
phi::dtype::complex<double>* workspace_ptr =
331+
reinterpret_cast<phi::dtype::complex<double>*>(workspace->ptr());
332+
int stride_A = lda * n;
333+
int stride_U = ldu * (thin_UV ? k : m);
334+
int stride_V = ldt * (thin_UV ? k : n);
335+
for (int i = 0; i < batchSize; ++i) {
336+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgesvdj(
337+
handle,
338+
jobz,
339+
thin_UV,
340+
m,
341+
n,
342+
reinterpret_cast<cuDoubleComplex*>(A + stride_A * i),
343+
lda,
344+
reinterpret_cast<double*>(S + k * i),
345+
reinterpret_cast<cuDoubleComplex*>(U + stride_U * i),
346+
ldu,
347+
reinterpret_cast<cuDoubleComplex*>(V + stride_V * i),
348+
ldt,
349+
reinterpret_cast<cuDoubleComplex*>(workspace_ptr),
350+
lwork,
351+
info,
352+
gesvdj_params));
353+
// check the error info
354+
int error_info;
355+
memory_utils::Copy(phi::CPUPlace(),
356+
&error_info,
357+
dev_ctx.GetPlace(),
358+
info,
359+
sizeof(int),
360+
dev_ctx.stream());
361+
PADDLE_ENFORCE_EQ(
362+
error_info,
363+
0,
364+
common::errors::PreconditionNotMet(
365+
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
366+
}
367+
PADDLE_ENFORCE_GPU_SUCCESS(
368+
phi::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
369+
}
370+
204371
template <typename T, typename Context>
205372
void SvdKernel(const Context& dev_ctx,
206373
const DenseTensor& X,
207374
bool full_matrices,
208375
DenseTensor* U,
209376
DenseTensor* S,
210377
DenseTensor* VH) {
378+
if (X.numel() == 0) {
379+
dev_ctx.template Alloc<T>(U);
380+
dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
381+
dev_ctx.template Alloc<T>(VH);
382+
return;
383+
}
211384
auto& dims = X.dims();
212385
int batch_count = 1;
213386
for (int i = 0; i < dims.size() - 2; i++) {
@@ -217,17 +390,8 @@ void SvdKernel(const Context& dev_ctx,
217390
int m = dims[rank - 2];
218391
int n = dims[rank - 1];
219392

220-
PADDLE_ENFORCE_LT(
221-
0,
222-
m,
223-
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
224-
PADDLE_ENFORCE_LT(
225-
0,
226-
n,
227-
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
228-
229-
auto* u_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
230-
auto* vh_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
393+
auto* u_data = dev_ctx.template Alloc<T>(U);
394+
auto* vh_data = dev_ctx.template Alloc<T>(VH);
231395
auto* s_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
232396
// NOTE:(@xiongkun03)
233397
// matrices are assumed to be stored in column-major order in cusolver
@@ -253,7 +417,7 @@ void SvdKernel(const Context& dev_ctx,
253417
auto UT_dim = U->dims();
254418
std::swap(UT_dim[rank - 1], UT_dim[rank - 2]); // Get the dim of UT_dim
255419
U->Resize(UT_dim); // U is entirely UT
256-
auto tmp_U = TransposeLast2Dim<T>(dev_ctx, *U);
420+
auto tmp_U = TransposeLast2Dim<T>(dev_ctx, Conj<T, Context>(dev_ctx, *U));
257421
U->ShareDataWith(tmp_U); // U becomse UT, aka VT;
258422
}
259423
} // namespace phi
@@ -263,6 +427,8 @@ PD_REGISTER_KERNEL(svd, // cuda_only
263427
ALL_LAYOUT,
264428
phi::SvdKernel,
265429
float,
266-
double) {}
430+
double,
431+
phi::dtype::complex<float>,
432+
phi::dtype::complex<double>) {}
267433

268434
#endif // not PADDLE_WITH_HIP

0 commit comments

Comments
 (0)