@@ -215,10 +215,30 @@ public:
215
215
local_sparsity_ = repart_loc_sparsity;
216
216
non_local_sparsity_ = repart_non_loc_sparsity;
217
217
218
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
219
+ << " build_localized_partition \n " ;
220
+ // << " dim " << local_sparsity_->dim[0] << " send idxs size "
221
+ // << dst_comm_pattern.send_idxs.size() << " target ids "
222
+ // << dst_comm_pattern.target_ids << " target sizes "
223
+ // << dst_comm_pattern.target_sizes << "\n";
224
+
218
225
auto localized_partition = local_part_type::build_from_blocked_recv (
219
226
exec, local_sparsity_->dim [0 ], dst_comm_pattern->send_idxs ,
220
227
dst_comm_pattern->target_ids , dst_comm_pattern->target_sizes );
221
228
229
+ std::cout << __FILE__ << " rank " << rank << " local sparsity size "
230
+ << local_sparsity_->size_ << " local sparsity dim ["
231
+ << local_sparsity_->dim [0 ] << " x" << local_sparsity_->dim [1 ]
232
+ << " ] non_local sparsity size " << non_local_sparsity_->size_
233
+ << " non local sparsity dim [" << non_local_sparsity_->dim [0 ]
234
+ << " x" << non_local_sparsity_->dim [1 ] << " ] target_ids "
235
+ << dst_comm_pattern->target_ids << " target_sizes "
236
+ << dst_comm_pattern->target_sizes << " target_send_idxs.size "
237
+ << dst_comm_pattern->send_idxs .size ()
238
+ << " non_local_sparsity.size " << non_local_sparsity_->size_
239
+ << " get_recv_indices "
240
+ << localized_partition->get_recv_indices ().get_num_elems ()
241
+ << " \n " ;
222
242
223
243
auto sparse_comm =
224
244
sparse_communicator::create (comm, localized_partition);
@@ -264,11 +284,15 @@ public:
264
284
non_local_sparsity_->row_idxs ,
265
285
non_local_sparsity_->col_idxs , non_local_coeffs),
266
286
sparse_comm);
287
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
288
+ << " dnoe read distributed \n " ;
267
289
268
290
269
291
update_impl (exec_handler, matrix_format, repartitioner, host_A, dist_A,
270
292
local_sparsity_, non_local_sparsity_, src_comm_pattern,
271
293
local_interfaces);
294
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
295
+ << " dnoe update impl \n " ;
272
296
273
297
auto ret = std::make_shared<RepartDistMatrix>(
274
298
exec, comm, repartitioner.get_repart_dim (), dist_A->get_size (),
@@ -305,6 +329,8 @@ public:
305
329
auto exec = exec_handler.get_ref_exec ();
306
330
auto device_exec = exec_handler.get_device_exec ();
307
331
auto ranks_per_gpu = repartitioner.get_ranks_per_gpu ();
332
+ bool requires_host_buffer = exec_handler.get_gko_force_host_buffer ();
333
+
308
334
label rank{repartitioner.get_rank (exec_handler)};
309
335
label owner_rank = repartitioner.get_owner_rank (exec_handler);
310
336
bool owner = repartitioner.is_owner (exec_handler);
@@ -314,29 +340,43 @@ public:
314
340
auto diag_comm_pattern = compute_send_recv_counts (
315
341
exec_handler, ranks_per_gpu, nrows, local_matrix_nnz,
316
342
local_matrix_nnz - nrows, 0 );
343
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
344
+ << " diag comm pattern \n " ;
317
345
318
346
319
347
label upper_nnz = host_A->get_upper_nnz ();
320
348
auto upper_comm_pattern = compute_send_recv_counts (
321
349
exec_handler, ranks_per_gpu, upper_nnz, local_matrix_nnz, 0 ,
322
350
local_matrix_nnz - upper_nnz);
351
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
352
+ << " upper comm pattern \n " ;
323
353
auto lower_comm_pattern =
324
354
compute_send_recv_counts (exec_handler, ranks_per_gpu, upper_nnz,
325
355
local_matrix_nnz, upper_nnz, nrows);
326
356
327
357
scalar *local_ptr;
358
+ scalar *local_ptr_2;
359
+ label nnz=0 ;
328
360
329
361
// update main values
362
+ std::vector<scalar> loc_buffer;
330
363
if (owner) {
331
364
using Coo = gko::matrix::Coo<scalar, label>;
332
365
auto local_mtx = dist_A->get_local_matrix ();
333
366
367
+
334
368
std::shared_ptr<const Coo> local =
335
369
gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
336
370
dist_A->get_local_matrix ())
337
371
->get_combination ()
338
372
->get_operators ()[0 ]);
339
- local_ptr = const_cast <scalar *>(local->get_const_values ());
373
+ if (requires_host_buffer) {
374
+ loc_buffer.resize (local->get_num_stored_elements ());
375
+ local_ptr = loc_buffer.data ();
376
+ local_ptr_2 = const_cast <scalar *>(local->get_const_values ());
377
+ } else {
378
+ local_ptr = const_cast <scalar *>(local->get_const_values ());
379
+ }
340
380
}
341
381
communicate_values (exec_handler, diag_comm_pattern, host_A->get_diag (),
342
382
local_ptr);
@@ -352,6 +392,18 @@ public:
352
392
communicate_values (exec_handler, lower_comm_pattern,
353
393
host_A->get_lower (), local_ptr);
354
394
}
395
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
396
+ << " done comm local mtx \n " ;
397
+
398
+ if (requires_host_buffer) {
399
+ auto host_buffer_view =
400
+ gko::array<scalar>::view (exec, nnz, local_ptr);
401
+ auto target_buffer_view =
402
+ gko::array<scalar>::view (device_exec, nnz, local_ptr_2);
403
+ target_buffer_view = host_buffer_view;
404
+ }
405
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
406
+ << " done copy to device \n " ;
355
407
356
408
// copy interface values
357
409
auto comm = *exec_handler.get_communicator ().get ();
@@ -364,6 +416,7 @@ public:
364
416
label tag = 0 ;
365
417
label comm_rank, comm_size;
366
418
scalar *recv_buffer_ptr;
419
+ std::vector<scalar> host_recv_buffer;
367
420
label remain_host_interfaces = host_A->get_interface_size ();
368
421
for (auto [is_local, comm_rank] : local_interfaces) {
369
422
label &ctr = (is_local) ? loc_ctr : nloc_ctr;
@@ -383,9 +436,18 @@ public:
383
436
comm_size =
384
437
non_local_sparsity->interface_spans [ctr].length ();
385
438
}
386
- recv_buffer_ptr = const_cast <scalar *>(mtx->get_const_values ());
439
+
440
+ if (requires_host_buffer) {
441
+ host_recv_buffer.resize (comm_size);
442
+ recv_buffer_ptr = host_recv_buffer.data ();
443
+ } else {
444
+ recv_buffer_ptr = const_cast <scalar *>(mtx->get_const_values ());
445
+ }
387
446
388
447
if (comm_rank != rank) {
448
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
449
+ << " comm_rank " << comm_rank << " rank " << rank << " \n " ;
450
+
389
451
comm.recv (exec, recv_buffer_ptr, comm_size, comm_rank, tag);
390
452
} else {
391
453
// if data is already on this rank
@@ -427,33 +489,47 @@ public:
427
489
}
428
490
}
429
491
492
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
493
+ << " reorder \n " ;
430
494
// reorder updated values
431
- if (owner) {
432
- // NOTE local sparsity size includes the interfaces
433
- using Coo = gko::matrix::Coo<scalar, label>;
434
- using dim_type = gko::dim<2 >::dimension_type;
435
- std::shared_ptr<const Coo> local =
436
- gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
437
- dist_A->get_local_matrix ())
438
- ->get_combination ()
439
- ->get_operators ()[0 ]);
440
- auto local_elements = local->get_num_stored_elements ();
441
- local_ptr = const_cast <scalar *>(local->get_const_values ());
442
- // TODO make sure this doesn't copy
443
- // create a non owning dense matrix of local_values
444
-
445
- auto row_collection = gko::share (gko::matrix::Dense<scalar>::create (
446
- exec, gko::dim<2 >{static_cast <dim_type>(local_elements), 1 },
447
- gko::array<scalar>::view (exec, local_elements, local_ptr), 1 ));
448
-
449
- auto mapping_view = gko::array<label>::view (
450
- exec, local_elements, local_sparsity->ldu_mapping .get_data ());
451
-
452
-
453
- // TODO this needs to copy ldu_mapping to the device
454
- auto dense_vec = row_collection->clone ();
455
- dense_vec->row_gather (&mapping_view, row_collection.get ());
456
- }
495
+ if (owner) {
496
+ // NOTE local sparsity size includes the interfaces
497
+ using Coo = gko::matrix::Coo<scalar, label>;
498
+ using dim_type = gko::dim<2 >::dimension_type;
499
+ std::shared_ptr<const Coo> local =
500
+ gko::as<Coo>(gko::as<CombinationMatrix<scalar, label, Coo>>(
501
+ dist_A->get_local_matrix ())
502
+ ->get_combination ()
503
+ ->get_operators ()[0 ]);
504
+ auto local_elements = local->get_num_stored_elements ();
505
+ local_ptr = const_cast <scalar *>(local->get_const_values ());
506
+ // TODO make sure this doesn't copy
507
+ // create a non owning dense matrix of local_values
508
+
509
+ auto row_collection = gko::share (gko::matrix::Dense<scalar>::create (
510
+ device_exec, gko::dim<2 >{static_cast <dim_type>(local_elements), 1 },
511
+ gko::array<scalar>::view (device_exec, local_elements, local_ptr), 1 ));
512
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank << " local_elements " << local_elements
513
+ << " reorder \n " ;
514
+
515
+ auto mapping_view = gko::array<label>::view (
516
+ exec, local_elements, local_sparsity->ldu_mapping .get_data ());
517
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
518
+ << " reorder \n " ;
519
+
520
+
521
+ // TODO this needs to copy ldu_mapping to the device
522
+ auto dense_vec = row_collection->clone ();
523
+ // auto dense_vec = gko::share(gko::matrix::Dense<scalar>::create(exec, row_collection->get_size()));
524
+
525
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
526
+ << " reorder \n " ;
527
+ dense_vec->row_gather (&mapping_view, row_collection.get ());
528
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
529
+ << " reorder \n " ;
530
+ }
531
+ std::cout << __FILE__ << " :" << __LINE__ << " rank " << rank
532
+ << " done reorder \n " ;
457
533
};
458
534
459
535
RepartDistMatrix (
0 commit comments