@@ -146,7 +146,7 @@ public:
146
146
// "for performance reasons"
147
147
// << abort(FatalError);
148
148
gko::experimental::EnableDistributedLinOp<
149
- RepartDistMatrix>::operator =(std::move ( other) );
149
+ RepartDistMatrix>::operator =(other);
150
150
this ->dist_mtx_ = other.dist_mtx_ ;
151
151
this ->local_sparsity_ = other.local_sparsity_ ;
152
152
this ->non_local_sparsity_ = other.non_local_sparsity_ ;
@@ -215,10 +215,30 @@ public:
215
215
local_sparsity_ = repart_loc_sparsity;
216
216
non_local_sparsity_ = repart_non_loc_sparsity;
217
217
218
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
219
+ << " build_localized_partition \n " ;
220
+ // << " dim " << local_sparsity_->dim[0] << " send idxs size "
221
+ // << dst_comm_pattern.send_idxs.size() << " target ids "
222
+ // << dst_comm_pattern.target_ids << " target sizes "
223
+ // << dst_comm_pattern.target_sizes << "\n";
224
+
218
225
auto localized_partition = local_part_type::build_from_blocked_recv (
219
226
exec, local_sparsity_->dim [0 ], dst_comm_pattern->send_idxs ,
220
227
dst_comm_pattern->target_ids , dst_comm_pattern->target_sizes );
221
228
229
+ std::cout << __FILE__ << " rank " << rank << " local sparsity size "
230
+ << local_sparsity_->size_ << " local sparsity dim ["
231
+ << local_sparsity_->dim [0 ] << " x" << local_sparsity_->dim [1 ]
232
+ << " ] non_local sparsity size " << non_local_sparsity_->size_
233
+ << " non local sparsity dim [" << non_local_sparsity_->dim [0 ]
234
+ << " x" << non_local_sparsity_->dim [1 ] << " ] target_ids "
235
+ << dst_comm_pattern->target_ids << " target_sizes "
236
+ << dst_comm_pattern->target_sizes << " target_send_idxs.size "
237
+ << dst_comm_pattern->send_idxs .size ()
238
+ << " non_local_sparsity.size " << non_local_sparsity_->size_
239
+ << " get_recv_indices "
240
+ << localized_partition->get_recv_indices ().get_num_elems ()
241
+ << " \n " ;
222
242
223
243
auto sparse_comm =
224
244
sparse_communicator::create (comm, localized_partition);
@@ -264,11 +284,11 @@ public:
264
284
non_local_sparsity_->row_idxs ,
265
285
non_local_sparsity_->col_idxs , non_local_coeffs),
266
286
sparse_comm);
267
-
268
-
269
287
update_impl (exec_handler, matrix_format, repartitioner, host_A, dist_A,
270
288
local_sparsity_, non_local_sparsity_, src_comm_pattern,
271
289
local_interfaces);
290
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
291
+ << " dnoe update impl \n " ;
272
292
273
293
auto ret = std::make_shared<RepartDistMatrix>(
274
294
exec, comm, repartitioner.get_repart_dim (), dist_A->get_size (),
@@ -305,6 +325,8 @@ public:
305
325
auto exec = exec_handler.get_ref_exec ();
306
326
auto device_exec = exec_handler.get_device_exec ();
307
327
auto ranks_per_gpu = repartitioner.get_ranks_per_gpu ();
328
+ bool requires_host_buffer = exec_handler.get_gko_force_host_buffer ();
329
+
308
330
label rank{repartitioner.get_rank (exec_handler)};
309
331
label owner_rank = repartitioner.get_owner_rank (exec_handler);
310
332
bool owner = repartitioner.is_owner (exec_handler);
@@ -314,8 +336,6 @@ public:
314
336
auto diag_comm_pattern = compute_send_recv_counts (
315
337
exec_handler, ranks_per_gpu, nrows, local_matrix_nnz,
316
338
local_matrix_nnz - nrows, 0 );
317
-
318
-
319
339
label upper_nnz = host_A->get_upper_nnz ();
320
340
auto upper_comm_pattern = compute_send_recv_counts (
321
341
exec_handler, ranks_per_gpu, upper_nnz, local_matrix_nnz, 0 ,
@@ -325,18 +345,29 @@ public:
325
345
local_matrix_nnz, upper_nnz, nrows);
326
346
327
347
scalar *local_ptr;
348
+ scalar *local_ptr_2;
349
+ label nnz=0 ;
328
350
329
351
// update main values
352
+ std::vector<scalar> loc_buffer;
330
353
if (owner) {
331
354
using Coo = gko::matrix::Coo<scalar, label>;
332
355
auto local_mtx = dist_A->get_local_matrix ();
333
356
357
+
334
358
std::shared_ptr<const Coo> local =
335
359
gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
336
360
dist_A->get_local_matrix ())
337
361
->get_combination ()
338
362
->get_operators ()[0 ]);
339
- local_ptr = const_cast <scalar *>(local->get_const_values ());
363
+ nnz = local->get_num_stored_elements ();
364
+ if (requires_host_buffer) {
365
+ loc_buffer.resize (nnz);
366
+ local_ptr = loc_buffer.data ();
367
+ local_ptr_2 = const_cast <scalar *>(local->get_const_values ());
368
+ } else {
369
+ local_ptr = const_cast <scalar *>(local->get_const_values ());
370
+ }
340
371
}
341
372
communicate_values (exec_handler, diag_comm_pattern, host_A->get_diag (),
342
373
local_ptr);
@@ -352,7 +383,26 @@ public:
352
383
communicate_values (exec_handler, lower_comm_pattern,
353
384
host_A->get_lower (), local_ptr);
354
385
}
355
-
386
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
387
+ << " done comm local mtx \n " ;
388
+
389
+ if (requires_host_buffer) {
390
+ auto host_buffer_view =
391
+ gko::array<scalar>::view (exec, nnz, local_ptr);
392
+ auto target_buffer_view =
393
+ gko::array<scalar>::view (device_exec, nnz, local_ptr_2);
394
+ target_buffer_view = host_buffer_view;
395
+ }
396
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
397
+ << " done copy to device \n " ;
398
+
399
+ if (requires_host_buffer) {
400
+ auto host_buffer_view =
401
+ gko::array<scalar>::view (exec, nnz, local_ptr);
402
+ auto target_buffer_view =
403
+ gko::array<scalar>::view (device_exec, nnz, local_ptr_2);
404
+ target_buffer_view = host_buffer_view;
405
+ }
356
406
// copy interface values
357
407
auto comm = *exec_handler.get_communicator ().get ();
358
408
if (owner) {
@@ -364,6 +414,8 @@ public:
364
414
label tag = 0 ;
365
415
label comm_rank, comm_size;
366
416
scalar *recv_buffer_ptr;
417
+ scalar *recv_buffer_ptr_2;
418
+ std::vector<scalar> host_recv_buffer;
367
419
label remain_host_interfaces = host_A->get_interface_size ();
368
420
for (auto [is_local, comm_rank] : local_interfaces) {
369
421
label &ctr = (is_local) ? loc_ctr : nloc_ctr;
@@ -383,19 +435,35 @@ public:
383
435
comm_size =
384
436
non_local_sparsity->interface_spans [ctr].length ();
385
437
}
386
- recv_buffer_ptr = const_cast <scalar *>(mtx->get_const_values ());
438
+
439
+ if (requires_host_buffer) {
440
+ host_recv_buffer.resize (comm_size);
441
+ recv_buffer_ptr = host_recv_buffer.data ();
442
+ recv_buffer_ptr_2 = const_cast <scalar *>(mtx->get_const_values ());
443
+ } else {
444
+ recv_buffer_ptr = const_cast <scalar *>(mtx->get_const_values ());
445
+ }
387
446
388
447
if (comm_rank != rank) {
389
- comm.recv (exec, recv_buffer_ptr, comm_size, comm_rank, tag);
448
+ comm.recv (device_exec, recv_buffer_ptr, comm_size, comm_rank, tag);
449
+ if (requires_host_buffer) {
450
+ auto host_buffer_view =
451
+ gko::array<scalar>::view (exec, comm_size, recv_buffer_ptr);
452
+ auto target_buffer_view =
453
+ gko::array<scalar>::view (device_exec, comm_size, recv_buffer_ptr_2);
454
+ target_buffer_view = host_buffer_view;
455
+ }
456
+
390
457
} else {
391
458
// if data is already on this rank
392
459
auto data_view = gko::array<scalar>::const_view (
393
460
exec, comm_size,
394
461
host_A->get_interface_data (host_interface_ctr));
395
462
396
463
// TODO FIXME this needs target executor
464
+ recv_buffer_ptr = const_cast <scalar *>(mtx->get_const_values ());
397
465
auto target_view = gko::array<scalar>::view (
398
- exec , comm_size, recv_buffer_ptr);
466
+ device_exec , comm_size, recv_buffer_ptr);
399
467
400
468
target_view = data_view;
401
469
@@ -409,7 +477,7 @@ public:
409
477
auto neg_one = gko::initialize<vec>({-1.0 }, exec);
410
478
auto interface_dense = vec::create (
411
479
exec, gko::dim<2 >{comm_size, 1 },
412
- gko::array<scalar>::view (exec , comm_size, recv_buffer_ptr),
480
+ gko::array<scalar>::view (device_exec , comm_size, recv_buffer_ptr),
413
481
1 );
414
482
415
483
interface_dense->scale (neg_one);
@@ -423,37 +491,43 @@ public:
423
491
label comm_size =
424
492
src_comm_pattern->target_sizes .get_const_data ()[i];
425
493
const scalar *send_buffer_ptr = host_A->get_interface_data (i);
426
- comm.send (exec , send_buffer_ptr, comm_size, owner_rank, tag);
494
+ comm.send (device_exec , send_buffer_ptr, comm_size, owner_rank, tag);
427
495
}
428
496
}
429
-
430
497
// reorder updated values
431
- if (owner) {
432
- // NOTE local sparsity size includes the interfaces
433
- using Coo = gko::matrix::Coo<scalar, label>;
434
- using dim_type = gko::dim<2 >::dimension_type;
435
- std::shared_ptr<const Coo> local =
436
- gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
437
- dist_A->get_local_matrix ())
438
- ->get_combination ()
439
- ->get_operators ()[0 ]);
440
- auto local_elements = local->get_num_stored_elements ();
441
- local_ptr = const_cast <scalar *>(local->get_const_values ());
442
- // TODO make sure this doesn't copy
443
- // create a non owning dense matrix of local_values
444
-
445
- auto row_collection = gko::share (gko::matrix::Dense<scalar>::create (
446
- exec, gko::dim<2 >{static_cast <dim_type>(local_elements), 1 },
447
- gko::array<scalar>::view (exec, local_elements, local_ptr), 1 ));
448
-
449
- auto mapping_view = gko::array<label>::view (
450
- exec, local_elements, local_sparsity->ldu_mapping .get_data ());
451
-
452
-
453
- // TODO this needs to copy ldu_mapping to the device
454
- auto dense_vec = row_collection->clone ();
455
- dense_vec->row_gather (&mapping_view, row_collection.get ());
456
- }
498
+ if (owner) {
499
+ // NOTE local sparsity size includes the interfaces
500
+ using Coo = gko::matrix::Coo<scalar, label>;
501
+ using dim_type = gko::dim<2 >::dimension_type;
502
+ std::shared_ptr<const Coo> local =
503
+ gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
504
+ dist_A->get_local_matrix ())
505
+ ->get_combination ()
506
+ ->get_operators ()[0 ]);
507
+ auto local_elements = local->get_num_stored_elements ();
508
+ local_ptr = const_cast <scalar *>(local->get_const_values ());
509
+ // TODO make sure this doesn't copy
510
+ // create a non owning dense matrix of local_values
511
+
512
+ auto row_collection = gko::share (gko::matrix::Dense<scalar>::create (
513
+ device_exec, gko::dim<2 >{static_cast <dim_type>(local_elements), 1 },
514
+ gko::array<scalar>::view (device_exec, local_elements, local_ptr), 1 ));
515
+ auto mapping_view = gko::array<label>::view (
516
+ exec, local_elements, local_sparsity->ldu_mapping .get_data ());
517
+
518
+
519
+ // TODO this needs to copy ldu_mapping to the device
520
+ auto dense_vec = row_collection->clone ();
521
+ // auto dense_vec = gko::share(gko::matrix::Dense<scalar>::create(exec, row_collection->get_size()));
522
+
523
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
524
+ << " reorder \n " ;
525
+ dense_vec->row_gather (&mapping_view, row_collection.get ());
526
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
527
+ << " reorder \n " ;
528
+ }
529
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
530
+ << " done reorder \n " ;
457
531
};
458
532
459
533
RepartDistMatrix (
0 commit comments