@@ -302,20 +302,20 @@ void map_to_global(
302
302
device_partition<const LocalIndexType, const GlobalIndexType> partition,
303
303
device_segmented_array<const GlobalIndexType> remote_global_idxs,
304
304
experimental::distributed::comm_index_type rank,
305
- const array<LocalIndexType>& local_ids ,
305
+ const array<LocalIndexType>& local_idxs ,
306
306
experimental::distributed::index_space is,
307
- array<GlobalIndexType>& global_ids )
307
+ array<GlobalIndexType>& global_idxs )
308
308
{
309
309
auto range_bounds = partition.offsets_begin ;
310
310
auto starting_indices = partition.starting_indices_begin ;
311
311
const auto & ranges_by_part = partition.ranges_by_part ;
312
- auto local_ids_it = local_ids .get_const_data ();
313
- auto input_size = local_ids .get_size ();
312
+ auto local_idxs_it = local_idxs .get_const_data ();
313
+ auto input_size = local_idxs .get_size ();
314
314
315
315
auto policy = thrust_policy (exec);
316
316
317
- global_ids .resize_and_reset (local_ids .get_size ());
318
- auto global_ids_it = global_ids .get_data ();
317
+ global_idxs .resize_and_reset (local_idxs .get_size ());
318
+ auto global_idxs_it = global_idxs .get_data ();
319
319
320
320
auto map_local = [rank, ranges_by_part, range_bounds, starting_indices,
321
321
partition] __device__ (auto lid) {
@@ -330,11 +330,16 @@ void map_to_global(
330
330
auto local_ranges_size =
331
331
static_cast <int64>(local_ranges.end - local_ranges.begin );
332
332
333
- auto it = binary_search ( int64 ( 0 ), local_ranges_size, [=]( const auto i) {
334
- return starting_indices[local_ranges. begin [i]] >= lid;
335
- });
333
+ // the binary search finds the first local range, such that the starting
334
+ // index is larger than lid, thus lid is contained in the local range
335
+ // before that one
336
336
auto local_range_id =
337
- it != local_ranges_size ? it : max (int64 (0 ), it - 1 );
337
+ binary_search (int64 (0 ), local_ranges_size,
338
+ [=](const auto i) {
339
+ return starting_indices[local_ranges.begin [i]] >
340
+ lid;
341
+ }) -
342
+ 1 ;
338
343
auto range_id = local_ranges.begin [local_range_id];
339
344
340
345
return static_cast <GlobalIndexType>(lid - starting_indices[range_id]) +
@@ -363,16 +368,16 @@ void map_to_global(
363
368
};
364
369
365
370
if (is == experimental::distributed::index_space::local) {
366
- thrust::transform (policy, local_ids_it, local_ids_it + input_size,
367
- global_ids_it , map_local);
371
+ thrust::transform (policy, local_idxs_it, local_idxs_it + input_size,
372
+ global_idxs_it , map_local);
368
373
}
369
374
if (is == experimental::distributed::index_space::non_local) {
370
- thrust::transform (policy, local_ids_it, local_ids_it + input_size,
371
- global_ids_it , map_non_local);
375
+ thrust::transform (policy, local_idxs_it, local_idxs_it + input_size,
376
+ global_idxs_it , map_non_local);
372
377
}
373
378
if (is == experimental::distributed::index_space::combined) {
374
- thrust::transform (policy, local_ids_it, local_ids_it + input_size,
375
- global_ids_it , map_combined);
379
+ thrust::transform (policy, local_idxs_it, local_idxs_it + input_size,
380
+ global_idxs_it , map_combined);
376
381
}
377
382
}
378
383
0 commit comments