|
1 |
| -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors |
| 1 | +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors |
2 | 2 | //
|
3 | 3 | // SPDX-License-Identifier: BSD-3-Clause
|
4 | 4 |
|
@@ -296,6 +296,95 @@ GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
|
296 | 296 | GKO_DECLARE_INDEX_MAP_MAP_TO_LOCAL);
|
297 | 297 |
|
298 | 298 |
|
| 299 | +template <typename LocalIndexType, typename GlobalIndexType> |
| 300 | +void map_to_global( |
| 301 | + std::shared_ptr<const DefaultExecutor> exec, |
| 302 | + device_partition<const LocalIndexType, const GlobalIndexType> partition, |
| 303 | + device_segmented_array<const GlobalIndexType> remote_global_idxs, |
| 304 | + experimental::distributed::comm_index_type rank, |
| 305 | + const array<LocalIndexType>& local_idxs, |
| 306 | + experimental::distributed::index_space is, |
| 307 | + array<GlobalIndexType>& global_idxs) |
| 308 | +{ |
| 309 | + auto range_bounds = partition.offsets_begin; |
| 310 | + auto starting_indices = partition.starting_indices_begin; |
| 311 | + const auto& ranges_by_part = partition.ranges_by_part; |
| 312 | + auto local_idxs_it = local_idxs.get_const_data(); |
| 313 | + auto input_size = local_idxs.get_size(); |
| 314 | + |
| 315 | + auto policy = thrust_policy(exec); |
| 316 | + |
| 317 | + global_idxs.resize_and_reset(local_idxs.get_size()); |
| 318 | + auto global_idxs_it = global_idxs.get_data(); |
| 319 | + |
| 320 | + auto map_local = [rank, ranges_by_part, range_bounds, starting_indices, |
| 321 | + partition] __device__(auto lid) { |
| 322 | + auto local_size = |
| 323 | + static_cast<LocalIndexType>(partition.part_sizes_begin[rank]); |
| 324 | + |
| 325 | + if (lid < 0 || lid >= local_size) { |
| 326 | + return invalid_index<GlobalIndexType>(); |
| 327 | + } |
| 328 | + |
| 329 | + auto local_ranges = ranges_by_part.get_segment(rank); |
| 330 | + auto local_ranges_size = |
| 331 | + static_cast<int64>(local_ranges.end - local_ranges.begin); |
| 332 | + |
| 333 | + // the binary search finds the first local range, such that the starting |
| 334 | + // index is larger than lid, thus lid is contained in the local range |
| 335 | + // before that one |
| 336 | + auto local_range_id = |
| 337 | + binary_search(int64(0), local_ranges_size, |
| 338 | + [=](const auto i) { |
| 339 | + return starting_indices[local_ranges.begin[i]] > |
| 340 | + lid; |
| 341 | + }) - |
| 342 | + 1; |
| 343 | + auto range_id = local_ranges.begin[local_range_id]; |
| 344 | + |
| 345 | + return static_cast<GlobalIndexType>(lid - starting_indices[range_id]) + |
| 346 | + range_bounds[range_id]; |
| 347 | + }; |
| 348 | + auto map_non_local = [remote_global_idxs] __device__(auto lid) { |
| 349 | + auto remote_size = static_cast<LocalIndexType>( |
| 350 | + remote_global_idxs.flat_end - remote_global_idxs.flat_begin); |
| 351 | + |
| 352 | + if (lid < 0 || lid >= remote_size) { |
| 353 | + return invalid_index<GlobalIndexType>(); |
| 354 | + } |
| 355 | + |
| 356 | + return remote_global_idxs.flat_begin[lid]; |
| 357 | + }; |
| 358 | + auto map_combined = [map_local, map_non_local, partition, |
| 359 | + rank] __device__(auto lid) { |
| 360 | + auto local_size = |
| 361 | + static_cast<LocalIndexType>(partition.part_sizes_begin[rank]); |
| 362 | + |
| 363 | + if (lid < local_size) { |
| 364 | + return map_local(lid); |
| 365 | + } else { |
| 366 | + return map_non_local(lid - local_size); |
| 367 | + } |
| 368 | + }; |
| 369 | + |
| 370 | + if (is == experimental::distributed::index_space::local) { |
| 371 | + thrust::transform(policy, local_idxs_it, local_idxs_it + input_size, |
| 372 | + global_idxs_it, map_local); |
| 373 | + } |
| 374 | + if (is == experimental::distributed::index_space::non_local) { |
| 375 | + thrust::transform(policy, local_idxs_it, local_idxs_it + input_size, |
| 376 | + global_idxs_it, map_non_local); |
| 377 | + } |
| 378 | + if (is == experimental::distributed::index_space::combined) { |
| 379 | + thrust::transform(policy, local_idxs_it, local_idxs_it + input_size, |
| 380 | + global_idxs_it, map_combined); |
| 381 | + } |
| 382 | +} |
| 383 | + |
| 384 | +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( |
| 385 | + GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL); |
| 386 | + |
| 387 | + |
299 | 388 | } // namespace index_map
|
300 | 389 | } // namespace GKO_DEVICE_NAMESPACE
|
301 | 390 | } // namespace kernels
|
|
0 commit comments