Skip to content

Commit 4908ff5

Browse files
MarcelKochyhmtsai
andcommitted
review updates:
- documentation - use templated function - fix test Co-authored-by: Yu-Hsiang M. Tsai <[email protected]>
1 parent b8c3f1f commit 4908ff5

File tree

4 files changed

+33
-41
lines changed

4 files changed

+33
-41
lines changed

core/matrix/dense.cpp

+16-22
Original file line numberDiff line numberDiff line change
@@ -1447,27 +1447,8 @@ void Dense<ValueType>::row_gather(ptr_param<const LinOp> alpha,
14471447

14481448

14491449
template <typename ValueType>
1450-
void Dense<ValueType>::row_scatter(const array<int32>* row_idxs,
1451-
ptr_param<LinOp> row_collection) const
1452-
{
1453-
gather_mixed_real_complex<ValueType>(
1454-
[&](auto dense) { this->row_scatter_impl(row_idxs, dense); },
1455-
row_collection.get());
1456-
}
1457-
1458-
1459-
template <typename ValueType>
1460-
void Dense<ValueType>::row_scatter(const array<int64>* row_idxs,
1461-
ptr_param<LinOp> row_collection) const
1462-
{
1463-
gather_mixed_real_complex<ValueType>(
1464-
[&](auto dense) { this->row_scatter_impl(row_idxs, dense); },
1465-
row_collection.get());
1466-
}
1467-
1468-
1469-
template <typename ValueType>
1470-
void Dense<ValueType>::row_scatter(const index_set<int32>* row_idxs,
1450+
template <typename IndexType>
1451+
void Dense<ValueType>::row_scatter(const array<IndexType>* row_idxs,
14711452
ptr_param<LinOp> row_collection) const
14721453
{
14731454
gather_mixed_real_complex<ValueType>(
@@ -1477,7 +1458,8 @@ void Dense<ValueType>::row_scatter(const index_set<int32>* row_idxs,
14771458

14781459

14791460
template <typename ValueType>
1480-
void Dense<ValueType>::row_scatter(const index_set<int64>* row_idxs,
1461+
template <typename IndexType>
1462+
void Dense<ValueType>::row_scatter(const index_set<IndexType>* row_idxs,
14811463
ptr_param<LinOp> row_collection) const
14821464
{
14831465
gather_mixed_real_complex<ValueType>(
@@ -1786,6 +1768,18 @@ std::unique_ptr<Dense<ValueType>> Dense<ValueType>::create_submatrix_impl(
17861768
#define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type>
17871769
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX);
17881770

1771+
#define GKO_DECLARE_DENSE_ROW_SCATTER_ARRAY(_vtype, _itype) \
1772+
void Dense<_vtype>::row_scatter(const array<_itype>* row_idxs, \
1773+
ptr_param<LinOp> row_collection) const
1774+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
1775+
GKO_DECLARE_DENSE_ROW_SCATTER_ARRAY);
1776+
1777+
#define GKO_DECLARE_DENSE_ROW_SCATTER_INDEX_SET(_vtype, _itype) \
1778+
void Dense<_vtype>::row_scatter(const index_set<_itype>* row_idxs, \
1779+
ptr_param<LinOp> row_collection) const
1780+
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
1781+
GKO_DECLARE_DENSE_ROW_SCATTER_INDEX_SET);
1782+
17891783

17901784
} // namespace matrix
17911785
} // namespace gko

include/ginkgo/core/base/index_set.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ struct temporary_clone_helper<const index_set<T>> {
471471
return std::make_unique<const index_set<T>>(std::move(exec), *ptr);
472472
}
473473
};
474+
475+
474476
} // namespace detail
475477

476478

include/ginkgo/core/matrix/dense.hpp

+12-16
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ class Dense
546546
/**
547547
* Copies this matrix into the given rows of the target matrix.
548548
*
549+
* @tparam IndexType the index type, either int32 or int 64
550+
*
549551
* @param scatter_indices row indices of the target matrix. It must
550552
* have the same number of indices as rows in
551553
* this matrix.
@@ -554,33 +556,27 @@ class Dense
554556
*
555557
* @warning scatter_indices may not contain duplicates, unless if
556558
* for indices `i, j` with `scatter_indices[i] ==
557-
* scatter_indices[j]` the rows `i, j` of this matrix are identical.
559+
* scatter_indices[j]` the rows `i, j` of this matrix are
560+
* identical.
558561
*/
559-
void row_scatter(const array<int64>* scatter_indices,
560-
ptr_param<LinOp> target) const;
561-
562-
/**
563-
* @copydoc row_scatter(const array<int64>*, ptr_param<LinOp>) const
564-
*/
565-
void row_scatter(const array<int32>* scatter_indices,
562+
template <typename IndexType>
563+
void row_scatter(const array<IndexType>* scatter_indices,
566564
ptr_param<LinOp> target) const;
567565

568566
/**
569567
* Copies this matrix into the given rows of the target matrix.
570568
*
569+
* @tparam IndexType the index type, either int32 or int 64
570+
*
571571
* @param scatter_indices row indices of the target matrix. It must
572572
* have the same number of indices as rows in
573573
* this matrix.
574574
* @param target matrix where the scattered rows are stored, i.e.
575-
* `target(scatter_indices[i], j) = this(i, j)`
576-
*/
577-
void row_scatter(const index_set<int64>* scatter_indices,
578-
ptr_param<LinOp> target) const;
579-
580-
/**
581-
* @copydoc row_scatter(const index_set<int64>*, ptr_param<LinOp>) const
575+
* `target(scatter_indices.get_global_index(i), j)
576+
* = this(i, j)`
582577
*/
583-
void row_scatter(const index_set<int32>* scatter_indices,
578+
template <typename IndexType>
579+
void row_scatter(const index_set<IndexType>* scatter_indices,
584580
ptr_param<LinOp> target) const;
585581

586582
std::unique_ptr<LinOp> column_permute(

test/matrix/dense_kernels.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class Dense : public CommonTestFixture {
163163
rscatter_idxs = std::unique_ptr<Arr>(
164164
new Arr{ref, tmp2.begin(), tmp2.begin() + u->get_size()[0]});
165165
rscatter_idxs_sub = std::unique_ptr<Arr>(
166-
new Arr{ref, tmp2.begin(), tmp2.begin() + u->get_size()[0]});
166+
new Arr{ref, tmp4.begin(), tmp4.begin() + u->get_size()[0]});
167167
}
168168

169169
template <typename ConvertedType, typename InputType>
@@ -1324,8 +1324,8 @@ TEST_F(Dense, CanScatterRowsIntoDenseCrossExecutor)
13241324
{
13251325
set_up_apply_data();
13261326

1327-
u->row_scatter(rgather_idxs.get(), x);
1328-
u->row_scatter(rgather_idxs.get(), dx);
1327+
u->row_scatter(rscatter_idxs.get(), x);
1328+
u->row_scatter(rscatter_idxs.get(), dx);
13291329

13301330
GKO_ASSERT_MTX_NEAR(x, dx, 0);
13311331
}

0 commit comments

Comments
 (0)