20
20
#include " paddle/phi/backends/dynload/cusolver.h"
21
21
#include " paddle/phi/common/memory_utils.h"
22
22
#include " paddle/phi/core/kernel_registry.h"
23
+ #include " paddle/phi/kernels/complex_kernel.h"
23
24
#include " paddle/phi/kernels/empty_kernel.h"
24
25
#include " paddle/phi/kernels/funcs/complex_functors.h"
25
26
#include " paddle/phi/kernels/transpose_kernel.h"
@@ -35,7 +36,7 @@ static void GesvdjBatched(const phi::GPUContext& dev_ctx,
35
36
T* A,
36
37
T* U,
37
38
T* V,
38
- T * S,
39
+ phi::dtype::Real<T> * S,
39
40
int * info,
40
41
int thin_UV = 1 );
41
42
@@ -201,13 +202,185 @@ void GesvdjBatched<double>(const phi::GPUContext& dev_ctx,
201
202
phi::dynload::cusolverDnDestroyGesvdjInfo (gesvdj_params));
202
203
}
203
204
205
+ template <>
206
+ void GesvdjBatched<phi::dtype::complex<float >>(const phi::GPUContext& dev_ctx,
207
+ int batchSize,
208
+ int m,
209
+ int n,
210
+ int k,
211
+ phi::dtype::complex<float >* A,
212
+ phi::dtype::complex<float >* U,
213
+ phi::dtype::complex<float >* V,
214
+ float * S,
215
+ int * info,
216
+ int thin_UV) {
217
+ /* compute singular vectors */
218
+ const cusolverEigMode_t jobz =
219
+ CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */
220
+ gesvdjInfo_t gesvdj_params = NULL ;
221
+ int lda = m;
222
+ int ldu = m;
223
+ int ldt = n;
224
+ int lwork = 0 ;
225
+ auto handle = dev_ctx.cusolver_dn_handle ();
226
+ PADDLE_ENFORCE_GPU_SUCCESS (
227
+ phi::dynload::cusolverDnCreateGesvdjInfo (&gesvdj_params));
228
+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCgesvdj_bufferSize (
229
+ handle,
230
+ jobz,
231
+ thin_UV,
232
+ m,
233
+ n,
234
+ reinterpret_cast <cuComplex*>(A),
235
+ lda,
236
+ S,
237
+ reinterpret_cast <cuComplex*>(U),
238
+ ldu,
239
+ reinterpret_cast <cuComplex*>(V),
240
+ ldt,
241
+ &lwork,
242
+ gesvdj_params));
243
+ auto workspace = phi::memory_utils::Alloc (
244
+ dev_ctx.GetPlace (),
245
+ lwork * sizeof (phi::dtype::complex<float >),
246
+ phi::Stream (reinterpret_cast <phi::StreamId>(dev_ctx.stream ())));
247
+ phi::dtype::complex<float >* workspace_ptr =
248
+ reinterpret_cast <phi::dtype::complex<float >*>(workspace->ptr ());
249
+ int stride_A = lda * n;
250
+ int stride_U = ldu * (thin_UV ? k : m);
251
+ int stride_V = ldt * (thin_UV ? k : n);
252
+ for (int i = 0 ; i < batchSize; ++i) {
253
+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnCgesvdj (
254
+ handle,
255
+ jobz,
256
+ thin_UV,
257
+ m,
258
+ n,
259
+ reinterpret_cast <cuComplex*>(A + stride_A * i),
260
+ lda,
261
+ reinterpret_cast <float *>(S + k * i),
262
+ reinterpret_cast <cuComplex*>(U + stride_U * i),
263
+ ldu,
264
+ reinterpret_cast <cuComplex*>(V + stride_V * i),
265
+ ldt,
266
+ reinterpret_cast <cuComplex*>(workspace_ptr),
267
+ lwork,
268
+ info,
269
+ gesvdj_params));
270
+ // check the error info
271
+ int error_info;
272
+ memory_utils::Copy (phi::CPUPlace (),
273
+ &error_info,
274
+ dev_ctx.GetPlace (),
275
+ info,
276
+ sizeof (int ),
277
+ dev_ctx.stream ());
278
+ PADDLE_ENFORCE_EQ (
279
+ error_info,
280
+ 0 ,
281
+ common::errors::PreconditionNotMet (
282
+ " For batch [%d]: CUSolver SVD is not zero. [%d]" , i, error_info));
283
+ }
284
+ PADDLE_ENFORCE_GPU_SUCCESS (
285
+ phi::dynload::cusolverDnDestroyGesvdjInfo (gesvdj_params));
286
+ }
287
+
288
+ template <>
289
+ void GesvdjBatched<phi::dtype::complex<double >>(const phi::GPUContext& dev_ctx,
290
+ int batchSize,
291
+ int m,
292
+ int n,
293
+ int k,
294
+ phi::dtype::complex<double >* A,
295
+ phi::dtype::complex<double >* U,
296
+ phi::dtype::complex<double >* V,
297
+ double * S,
298
+ int * info,
299
+ int thin_UV) {
300
+ /* compute singular vectors */
301
+ const cusolverEigMode_t jobz =
302
+ CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */
303
+ gesvdjInfo_t gesvdj_params = NULL ;
304
+ int lda = m;
305
+ int ldu = m;
306
+ int ldt = n;
307
+ int lwork = 0 ;
308
+ auto handle = dev_ctx.cusolver_dn_handle ();
309
+ PADDLE_ENFORCE_GPU_SUCCESS (
310
+ phi::dynload::cusolverDnCreateGesvdjInfo (&gesvdj_params));
311
+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZgesvdj_bufferSize (
312
+ handle,
313
+ jobz,
314
+ thin_UV,
315
+ m,
316
+ n,
317
+ reinterpret_cast <cuDoubleComplex*>(A),
318
+ lda,
319
+ S,
320
+ reinterpret_cast <cuDoubleComplex*>(U),
321
+ ldu,
322
+ reinterpret_cast <cuDoubleComplex*>(V),
323
+ ldt,
324
+ &lwork,
325
+ gesvdj_params));
326
+ auto workspace = phi::memory_utils::Alloc (
327
+ dev_ctx.GetPlace (),
328
+ lwork * sizeof (phi::dtype::complex<double >),
329
+ phi::Stream (reinterpret_cast <phi::StreamId>(dev_ctx.stream ())));
330
+ phi::dtype::complex<double >* workspace_ptr =
331
+ reinterpret_cast <phi::dtype::complex<double >*>(workspace->ptr ());
332
+ int stride_A = lda * n;
333
+ int stride_U = ldu * (thin_UV ? k : m);
334
+ int stride_V = ldt * (thin_UV ? k : n);
335
+ for (int i = 0 ; i < batchSize; ++i) {
336
+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cusolverDnZgesvdj (
337
+ handle,
338
+ jobz,
339
+ thin_UV,
340
+ m,
341
+ n,
342
+ reinterpret_cast <cuDoubleComplex*>(A + stride_A * i),
343
+ lda,
344
+ reinterpret_cast <double *>(S + k * i),
345
+ reinterpret_cast <cuDoubleComplex*>(U + stride_U * i),
346
+ ldu,
347
+ reinterpret_cast <cuDoubleComplex*>(V + stride_V * i),
348
+ ldt,
349
+ reinterpret_cast <cuDoubleComplex*>(workspace_ptr),
350
+ lwork,
351
+ info,
352
+ gesvdj_params));
353
+ // check the error info
354
+ int error_info;
355
+ memory_utils::Copy (phi::CPUPlace (),
356
+ &error_info,
357
+ dev_ctx.GetPlace (),
358
+ info,
359
+ sizeof (int ),
360
+ dev_ctx.stream ());
361
+ PADDLE_ENFORCE_EQ (
362
+ error_info,
363
+ 0 ,
364
+ common::errors::PreconditionNotMet (
365
+ " For batch [%d]: CUSolver SVD is not zero. [%d]" , i, error_info));
366
+ }
367
+ PADDLE_ENFORCE_GPU_SUCCESS (
368
+ phi::dynload::cusolverDnDestroyGesvdjInfo (gesvdj_params));
369
+ }
370
+
204
371
template <typename T, typename Context>
205
372
void SvdKernel (const Context& dev_ctx,
206
373
const DenseTensor& X,
207
374
bool full_matrices,
208
375
DenseTensor* U,
209
376
DenseTensor* S,
210
377
DenseTensor* VH) {
378
+ if (X.numel () == 0 ) {
379
+ dev_ctx.template Alloc <T>(U);
380
+ dev_ctx.template Alloc <phi::dtype::Real<T>>(S);
381
+ dev_ctx.template Alloc <T>(VH);
382
+ return ;
383
+ }
211
384
auto & dims = X.dims ();
212
385
int batch_count = 1 ;
213
386
for (int i = 0 ; i < dims.size () - 2 ; i++) {
@@ -217,17 +390,8 @@ void SvdKernel(const Context& dev_ctx,
217
390
int m = dims[rank - 2 ];
218
391
int n = dims[rank - 1 ];
219
392
220
- PADDLE_ENFORCE_LT (
221
- 0 ,
222
- m,
223
- errors::InvalidArgument (" The row of Input(X) should be greater than 0." ));
224
- PADDLE_ENFORCE_LT (
225
- 0 ,
226
- n,
227
- errors::InvalidArgument (" The col of Input(X) should be greater than 0." ));
228
-
229
- auto * u_data = dev_ctx.template Alloc <phi::dtype::Real<T>>(U);
230
- auto * vh_data = dev_ctx.template Alloc <phi::dtype::Real<T>>(VH);
393
+ auto * u_data = dev_ctx.template Alloc <T>(U);
394
+ auto * vh_data = dev_ctx.template Alloc <T>(VH);
231
395
auto * s_data = dev_ctx.template Alloc <phi::dtype::Real<T>>(S);
232
396
// NOTE:(@xiongkun03)
233
397
// matrices are assumed to be stored in column-major order in cusolver
@@ -253,7 +417,7 @@ void SvdKernel(const Context& dev_ctx,
253
417
auto UT_dim = U->dims ();
254
418
std::swap (UT_dim[rank - 1 ], UT_dim[rank - 2 ]); // Get the dim of UT_dim
255
419
U->Resize (UT_dim); // U is entirely UT
256
- auto tmp_U = TransposeLast2Dim<T>(dev_ctx, *U );
420
+ auto tmp_U = TransposeLast2Dim<T>(dev_ctx, Conj<T, Context>(dev_ctx, *U) );
257
421
U->ShareDataWith (tmp_U); // U becomse UT, aka VT;
258
422
}
259
423
} // namespace phi
@@ -263,6 +427,8 @@ PD_REGISTER_KERNEL(svd, // cuda_only
263
427
ALL_LAYOUT,
264
428
phi::SvdKernel,
265
429
float ,
266
- double ) {}
430
+ double ,
431
+ phi::dtype::complex<float >,
432
+ phi::dtype::complex<double >) {}
267
433
268
434
#endif // not PADDLE_WITH_HIP
0 commit comments