Skip to content

Commit 71ac6a4

Browse files
MarcelKochyhmtsaipratikvnupsj
committed
[dist] review updates:
- reduce tests - update docs - minor refactoring - fix binary search usage in cuda/hip - refactor - add tests for segmented array assertion Co-authored-by: Yu-Hsiang M. Tsai <[email protected]> Co-authored-by: Pratik Nayak <[email protected]> Co-authored-by: Tobias Ribizel <[email protected]>
1 parent 9588885 commit 71ac6a4

File tree

15 files changed

+192
-247
lines changed

15 files changed

+192
-247
lines changed

common/cuda_hip/distributed/index_map_kernels.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -302,20 +302,20 @@ void map_to_global(
302302
device_partition<const LocalIndexType, const GlobalIndexType> partition,
303303
device_segmented_array<const GlobalIndexType> remote_global_idxs,
304304
experimental::distributed::comm_index_type rank,
305-
const array<LocalIndexType>& local_ids,
305+
const array<LocalIndexType>& local_idxs,
306306
experimental::distributed::index_space is,
307-
array<GlobalIndexType>& global_ids)
307+
array<GlobalIndexType>& global_idxs)
308308
{
309309
auto range_bounds = partition.offsets_begin;
310310
auto starting_indices = partition.starting_indices_begin;
311311
const auto& ranges_by_part = partition.ranges_by_part;
312-
auto local_ids_it = local_ids.get_const_data();
313-
auto input_size = local_ids.get_size();
312+
auto local_idxs_it = local_idxs.get_const_data();
313+
auto input_size = local_idxs.get_size();
314314

315315
auto policy = thrust_policy(exec);
316316

317-
global_ids.resize_and_reset(local_ids.get_size());
318-
auto global_ids_it = global_ids.get_data();
317+
global_idxs.resize_and_reset(local_idxs.get_size());
318+
auto global_idxs_it = global_idxs.get_data();
319319

320320
auto map_local = [rank, ranges_by_part, range_bounds, starting_indices,
321321
partition] __device__(auto lid) {
@@ -330,11 +330,16 @@ void map_to_global(
330330
auto local_ranges_size =
331331
static_cast<int64>(local_ranges.end - local_ranges.begin);
332332

333-
auto it = binary_search(int64(0), local_ranges_size, [=](const auto i) {
334-
return starting_indices[local_ranges.begin[i]] >= lid;
335-
});
333+
// the binary search finds the first local range, such that the starting
334+
// index is larger than lid, thus lid is contained in the local range
335+
// before that one
336336
auto local_range_id =
337-
it != local_ranges_size ? it : max(int64(0), it - 1);
337+
binary_search(int64(0), local_ranges_size,
338+
[=](const auto i) {
339+
return starting_indices[local_ranges.begin[i]] >
340+
lid;
341+
}) -
342+
1;
338343
auto range_id = local_ranges.begin[local_range_id];
339344

340345
return static_cast<GlobalIndexType>(lid - starting_indices[range_id]) +
@@ -363,16 +368,16 @@ void map_to_global(
363368
};
364369

365370
if (is == experimental::distributed::index_space::local) {
366-
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
367-
global_ids_it, map_local);
371+
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
372+
global_idxs_it, map_local);
368373
}
369374
if (is == experimental::distributed::index_space::non_local) {
370-
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
371-
global_ids_it, map_non_local);
375+
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
376+
global_idxs_it, map_non_local);
372377
}
373378
if (is == experimental::distributed::index_space::combined) {
374-
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
375-
global_ids_it, map_combined);
379+
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
380+
global_idxs_it, map_combined);
376381
}
377382
}
378383

core/distributed/device_partition.hpp

+2-33
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

5-
#ifndef GINKGO_PARTITION_HPP
6-
#define GINKGO_PARTITION_HPP
5+
#pragma once
76

87
#include <ginkgo/core/distributed/partition.hpp>
98

@@ -34,34 +33,7 @@ struct device_partition {
3433

3534

3635
/**
37-
* Create device_segmented_array from a segmented_array.
38-
*/
39-
template <typename LocalIndexType, typename GlobalIndexType>
40-
constexpr device_partition<const LocalIndexType, const GlobalIndexType>
41-
to_device(
42-
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
43-
partition)
44-
{
45-
auto num_ranges = partition->get_num_ranges();
46-
auto num_parts = partition->get_num_parts();
47-
return {num_parts,
48-
partition->get_num_empty_parts(),
49-
partition->get_size(),
50-
partition->get_range_bounds(),
51-
partition->get_range_bounds() + num_ranges + 1,
52-
partition->get_range_starting_indices(),
53-
partition->get_range_starting_indices() + num_ranges,
54-
partition->get_part_sizes(),
55-
partition->get_part_sizes() + num_parts,
56-
partition->get_part_ids(),
57-
partition->get_part_ids() + num_parts,
58-
to_device(partition->get_ranges_by_part())};
59-
}
60-
61-
/**
62-
* Explicitly create a const version of device_segmented_array.
63-
*
64-
* This is mostly relevant for tests.
36+
* Explicitly create a const version of device_partition.
6537
*/
6638
template <typename LocalIndexType, typename GlobalIndexType>
6739
constexpr device_partition<const LocalIndexType, const GlobalIndexType>
@@ -87,6 +59,3 @@ to_device_const(
8759

8860

8961
} // namespace gko
90-
91-
92-
#endif // GINKGO_PARTITION_HPP

core/distributed/index_map.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,15 @@ array<LocalIndexType> index_map<LocalIndexType, GlobalIndexType>::map_to_local(
9393
template <typename LocalIndexType, typename GlobalIndexType>
9494
array<GlobalIndexType>
9595
index_map<LocalIndexType, GlobalIndexType>::map_to_global(
96-
const array<LocalIndexType>& local_ids, index_space index_space_v) const
96+
const array<LocalIndexType>& local_idxs, index_space index_space_v) const
9797
{
98-
array<GlobalIndexType> global_ids(exec_);
98+
array<GlobalIndexType> global_idxs(exec_);
9999

100100
exec_->run(index_map_kernels::make_map_to_global(
101-
to_device(partition_.get()), to_device(remote_global_idxs_), rank_,
102-
local_ids, index_space_v, global_ids));
101+
to_device_const(partition_.get()), to_device(remote_global_idxs_),
102+
rank_, local_idxs, index_space_v, global_idxs));
103103

104-
return global_ids;
104+
return global_idxs;
105105
}
106106

107107

core/distributed/index_map_kernels.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ namespace kernels {
5454
* space defined by is. The resulting indices are stored in local_ids.
5555
* The index map is defined by the input parameters:
5656
*
57-
* - partition: the global partition
57+
* - partition: the global partition
5858
* - remote_target_ids: the owning part ids of each segment of
5959
* remote_global_idxs
6060
* - remote_global_idxs: the remote global indices, segmented by the owning part
61-
* ids
61+
* ids, and each segment sorted
6262
* - rank: the part id of this process
6363
*
6464
* Any global index that is not in the specified local index space is mapped
@@ -81,7 +81,7 @@ namespace kernels {
8181
*
8282
* The relevant input parameter from the index map are:
8383
*
84-
* - partition: the global partition
84+
* - partition: the global partition
8585
* - remote_global_idxs: the remote global indices, segmented by the owning part
8686
* ids
8787
* - rank: the part id of this process
@@ -95,8 +95,8 @@ namespace kernels {
9595
device_partition<const _ltype, const _gtype> partition, \
9696
device_segmented_array<const _gtype> remote_global_idxs, \
9797
experimental::distributed::comm_index_type rank, \
98-
const array<_ltype>& local_ids, \
99-
experimental::distributed::index_space is, array<_gtype>& global_ids)
98+
const array<_ltype>& local_idxs, \
99+
experimental::distributed::index_space is, array<_gtype>& global_idxs)
100100

101101

102102
#define GKO_DECLARE_ALL_AS_TEMPLATES \

core/test/utils/assertions.hpp

+19-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <ginkgo/core/base/math.hpp>
2525
#include <ginkgo/core/base/mtx_io.hpp>
2626
#include <ginkgo/core/base/name_demangling.hpp>
27+
#include <ginkgo/core/base/segmented_array.hpp>
2728
#include <ginkgo/core/matrix/dense.hpp>
2829

2930
#include "core/base/batch_utilities.hpp"
@@ -1014,19 +1015,19 @@ ::testing::AssertionResult segmented_array_equal(
10141015
second.get_const_flat_data())
10151016
.copy_to_array();
10161017

1017-
auto buffer_result = array_equal(first_expression, second_expression,
1018-
view_first, view_second);
1019-
if (buffer_result == ::testing::AssertionFailure()) {
1020-
return buffer_result << "Buffers of the segmented arrays mismatch";
1021-
}
1022-
10231018
auto offsets_result =
10241019
array_equal(first_expression, second_expression, first.get_offsets(),
10251020
second.get_offsets());
10261021
if (offsets_result == ::testing::AssertionFailure()) {
10271022
return offsets_result << "Offsets of the segmented arrays mismatch";
10281023
}
10291024

1025+
auto buffer_result = array_equal(first_expression, second_expression,
1026+
view_first, view_second);
1027+
if (buffer_result == ::testing::AssertionFailure()) {
1028+
return buffer_result << "Buffers of the segmented arrays mismatch";
1029+
}
1030+
10301031
return ::testing::AssertionSuccess();
10311032
}
10321033

@@ -1414,6 +1415,18 @@ T* plain_ptr(T* ptr)
14141415
}
14151416

14161417

1418+
/**
1419+
* Checks if two `gko::segmented_array`s are equal.
1420+
*
1421+
* Both the flat array buffer and the offsets of both arrays are tested
1422+
* for equality.
1423+
*
1424+
* Has to be called from within a google test unit test.
1425+
* Internally calls gko::test::assertions::segmented_array_equal().
1426+
*
1427+
* @param _array1 first segmented array
1428+
* @param _array2 second segmented array
1429+
*/
14171430
#define GKO_ASSERT_SEGMENTED_ARRAY_EQ(_array1, _array2) \
14181431
{ \
14191432
ASSERT_PRED_FORMAT2(::gko::test::assertions::segmented_array_equal, \

core/test/utils/assertions_test.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -8,6 +8,7 @@
88

99
#include <gtest/gtest.h>
1010

11+
#include <ginkgo/core/base/segmented_array.hpp>
1112
#include <ginkgo/core/matrix/csr.hpp>
1213
#include <ginkgo/core/matrix/dense.hpp>
1314

@@ -218,4 +219,52 @@ TEST_F(ArraysNear, CanUseShortNotation)
218219
}
219220

220221

222+
class SegmentedArraysEqual : public ::testing::Test {
223+
protected:
224+
using array = gko::array<double>;
225+
using iarray = gko::array<gko::int64>;
226+
using segmented_array = gko::segmented_array<double>;
227+
228+
std::shared_ptr<gko::Executor> exec = gko::ReferenceExecutor::create();
229+
230+
segmented_array arr1 = segmented_array::create_from_sizes(
231+
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {2, 1, 2}});
232+
segmented_array arr2 = segmented_array::create_from_sizes(
233+
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {2, 1, 2}});
234+
segmented_array arr3 = segmented_array::create_from_sizes(
235+
array{exec, {1, 2, 3, 5, 6}}, iarray{exec, {2, 1, 2}});
236+
segmented_array arr4 = segmented_array::create_from_sizes(
237+
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {3, 2}});
238+
segmented_array arr5 = segmented_array::create_from_sizes(
239+
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {1, 2, 2}});
240+
};
241+
242+
243+
TEST_F(SegmentedArraysEqual, SucceedsIfEqual)
244+
{
245+
GKO_ASSERT_SEGMENTED_ARRAY_EQ(arr1, arr2);
246+
}
247+
248+
249+
TEST_F(SegmentedArraysEqual, FailsIfValuesDifferent)
250+
{
251+
ASSERT_PRED_FORMAT2(!::gko::test::assertions::segmented_array_equal, arr1,
252+
arr3);
253+
}
254+
255+
256+
TEST_F(SegmentedArraysEqual, FailsIfOffsetsDifferent1)
257+
{
258+
ASSERT_PRED_FORMAT2(!::gko::test::assertions::segmented_array_equal, arr1,
259+
arr4);
260+
}
261+
262+
263+
TEST_F(SegmentedArraysEqual, FailsIfOffsetsDifferent2)
264+
{
265+
ASSERT_PRED_FORMAT2(!::gko::test::assertions::segmented_array_equal, arr1,
266+
arr5);
267+
}
268+
269+
221270
} // namespace

dpcpp/distributed/index_map_kernels.dp.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ void map_to_global(
5050
device_partition<const LocalIndexType, const GlobalIndexType> partition,
5151
device_segmented_array<const GlobalIndexType> remote_global_idxs,
5252
experimental::distributed::comm_index_type rank,
53-
const array<LocalIndexType>& local_ids,
53+
const array<LocalIndexType>& local_idxs,
5454
experimental::distributed::index_space is,
55-
array<GlobalIndexType>& global_ids) GKO_NOT_IMPLEMENTED;
55+
array<GlobalIndexType>& global_idxs) GKO_NOT_IMPLEMENTED;
5656

5757
GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
5858
GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL);

dpcpp/distributed/partition_kernels.dp.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,7 @@ void build_ranges_by_part(std::shared_ptr<const DefaultExecutor> exec,
140140

141141
range_ids.resize_and_reset(num_ranges);
142142
auto range_ids_ptr = range_ids.get_data();
143-
// fill range_ids with 0,...,num_ranges - 1
144-
run_kernel(
145-
exec, [] GKO_KERNEL(auto i, auto rid) { rid[i] = i; }, num_ranges,
146-
range_ids_ptr);
143+
components::fill_seq_array(exec, range_ids_ptr, num_ranges);
147144

148145
oneapi::dpl::stable_sort(policy, range_ids_ptr, range_ids_ptr + num_ranges,
149146
[range_parts](const auto rid_a, const auto rid_b) {

include/ginkgo/core/distributed/index_map.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ struct index_map {
8585
/**
8686
* Maps local indices to global indices
8787
*
88-
* @param local_ids the local indices to map
88+
* @param local_idxs the local indices to map
8989
* @param index_space_v the index space in which the passed-in local
9090
* indices are defined
9191
*
9292
* @return the mapped global indices. Any local index, that is not in the
9393
* specified index space is mapped to invalid_index
9494
*/
95-
array<GlobalIndexType> map_to_global(const array<LocalIndexType>& local_ids,
96-
index_space index_space_v) const;
95+
array<GlobalIndexType> map_to_global(
96+
const array<LocalIndexType>& local_idxs,
97+
index_space index_space_v) const;
9798

9899
/**
99100
* \brief get size of index_space::local

0 commit comments

Comments
 (0)