Skip to content

Commit 23688e2

Browse files
authored
Merge reusable permutation and transpose
This adds a new permute_reuse and transpose_reuse interface to enable reusable operations, i.e. operations that precompute information about their symbolic component, speeding up future applications of value updates. Related PR: #1338
2 parents 25b59ec + e21e7ab commit 23688e2

File tree

5 files changed

+383
-24
lines changed

5 files changed

+383
-24
lines changed

core/matrix/csr.cpp

+138-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -26,6 +26,7 @@
2626
#include "core/components/absolute_array_kernels.hpp"
2727
#include "core/components/fill_array_kernels.hpp"
2828
#include "core/components/format_conversion_kernels.hpp"
29+
#include "core/components/precision_conversion_kernels.hpp"
2930
#include "core/components/prefix_sum_kernels.hpp"
3031
#include "core/matrix/csr_kernels.hpp"
3132
#include "core/matrix/ell_kernels.hpp"
@@ -48,6 +49,7 @@ GKO_REGISTER_OPERATION(spgeam, csr::spgeam);
4849
GKO_REGISTER_OPERATION(convert_idxs_to_ptrs, components::convert_idxs_to_ptrs);
4950
GKO_REGISTER_OPERATION(convert_ptrs_to_idxs, components::convert_ptrs_to_idxs);
5051
GKO_REGISTER_OPERATION(fill_in_dense, csr::fill_in_dense);
52+
GKO_REGISTER_OPERATION(fill_seq_array, components::fill_seq_array);
5153
GKO_REGISTER_OPERATION(compute_slice_sets, sellp::compute_slice_sets);
5254
GKO_REGISTER_OPERATION(convert_to_sellp, csr::convert_to_sellp);
5355
GKO_REGISTER_OPERATION(compute_max_row_nnz, ell::compute_max_row_nnz);
@@ -83,6 +85,7 @@ GKO_REGISTER_OPERATION(is_sorted_by_column_index,
8385
csr::is_sorted_by_column_index);
8486
GKO_REGISTER_OPERATION(extract_diagonal, csr::extract_diagonal);
8587
GKO_REGISTER_OPERATION(fill_array, components::fill_array);
88+
GKO_REGISTER_OPERATION(convert_precision, components::convert_precision);
8689
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
8790
components::prefix_sum_nonnegative);
8891
GKO_REGISTER_OPERATION(inplace_absolute_array,
@@ -618,6 +621,92 @@ void Csr<ValueType, IndexType>::write(mat_data& data) const
618621
}
619622

620623

624+
template <typename ValueType, typename IndexType, typename TransformClosure>
625+
std::pair<std::unique_ptr<Csr<ValueType, IndexType>>,
626+
typename Csr<ValueType, IndexType>::permuting_reuse_info>
627+
transform_reusable(const Csr<ValueType, IndexType>* input, gko::dim<2> out_size,
628+
size_type nnz, TransformClosure closure)
629+
{
630+
using FloatIndexType =
631+
std::conditional_t<std::is_same_v<IndexType, int32>, float, double>;
632+
static_assert(sizeof(FloatIndexType) == sizeof(IndexType));
633+
static_assert(alignof(FloatIndexType) == alignof(IndexType));
634+
auto exec = input->get_executor();
635+
auto in_size = input->get_size();
636+
auto transformed = Csr<ValueType, IndexType>::create(exec, out_size, nnz);
637+
// transform matrix with integer values from 0 to nnz - 1 reinterpret_cast
638+
// as float
639+
array<IndexType> iota_values{exec, nnz};
640+
exec->run(csr::make_fill_seq_array(iota_values.get_data(), nnz));
641+
auto iota_float_view = make_array_view(
642+
exec, nnz, reinterpret_cast<FloatIndexType*>(iota_values.get_data()));
643+
auto iota_mtx = Csr<FloatIndexType, IndexType>::create_const(
644+
exec, input->get_size(), iota_float_view.as_const_view(),
645+
make_const_array_view(exec, nnz, input->get_const_col_idxs()),
646+
make_const_array_view(exec, in_size[0] + 1,
647+
input->get_const_row_ptrs()),
648+
std::make_shared<typename Csr<FloatIndexType, IndexType>::sparselib>());
649+
auto transformed_iota = closure(iota_mtx.get());
650+
exec->copy(out_size[0] + 1, transformed_iota->get_const_row_ptrs(),
651+
transformed->get_row_ptrs());
652+
exec->copy(nnz, transformed_iota->get_const_col_idxs(),
653+
transformed->get_col_idxs());
654+
exec->copy(nnz,
655+
reinterpret_cast<const IndexType*>(
656+
transformed_iota->get_const_values()),
657+
iota_values.get_data());
658+
auto transform_permutation =
659+
Permutation<IndexType>::create(exec, std::move(iota_values));
660+
transformed->set_strategy(input->get_strategy());
661+
// permute values into output matrix
662+
input->create_const_value_view()->permute(transform_permutation,
663+
transformed->create_value_view(),
664+
permute_mode::rows);
665+
666+
return std::make_pair(
667+
std::move(transformed),
668+
typename Csr<ValueType, IndexType>::permuting_reuse_info{
669+
std::move(transform_permutation)});
670+
}
671+
672+
673+
template <typename ValueType, typename IndexType>
674+
Csr<ValueType, IndexType>::permuting_reuse_info::permuting_reuse_info()
675+
: permuting_reuse_info{nullptr}
676+
{}
677+
678+
679+
template <typename ValueType, typename IndexType>
680+
Csr<ValueType, IndexType>::permuting_reuse_info::permuting_reuse_info(
681+
std::unique_ptr<Permutation<index_type>> value_permutation)
682+
: value_permutation{std::move(value_permutation)}
683+
{}
684+
685+
686+
template <typename ValueType, typename IndexType>
687+
void Csr<ValueType, IndexType>::permuting_reuse_info::update_values(
688+
ptr_param<const Csr> input, ptr_param<Csr> output) const
689+
{
690+
if (!value_permutation) {
691+
GKO_NOT_SUPPORTED(value_permutation);
692+
}
693+
input->create_const_value_view()->permute(
694+
value_permutation, output->create_value_view(), permute_mode::rows);
695+
}
696+
697+
698+
template <typename ValueType, typename IndexType>
699+
auto Csr<ValueType, IndexType>::transpose_reuse() const
700+
-> std::pair<std::unique_ptr<Csr>, Csr::permuting_reuse_info>
701+
{
702+
return transform_reusable(
703+
this, gko::transpose(this->get_size()), this->get_num_stored_elements(),
704+
[](auto mtx) {
705+
return as<gko::detail::pointee<decltype(mtx)>>(mtx->transpose());
706+
});
707+
}
708+
709+
621710
template <typename ValueType, typename IndexType>
622711
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::transpose() const
623712
{
@@ -733,6 +822,31 @@ std::unique_ptr<Csr<ValueType, IndexType>> Csr<ValueType, IndexType>::permute(
733822
}
734823

735824

825+
template <typename ValueType, typename IndexType>
826+
auto Csr<ValueType, IndexType>::permute_reuse(
827+
ptr_param<const Permutation<index_type>> permutation,
828+
permute_mode mode) const
829+
-> std::pair<std::unique_ptr<Csr>, permuting_reuse_info>
830+
{
831+
return transform_reusable(
832+
this, this->get_size(), this->get_num_stored_elements(),
833+
[&](auto mtx) { return mtx->permute(permutation, mode); });
834+
}
835+
836+
837+
template <typename ValueType, typename IndexType>
838+
auto Csr<ValueType, IndexType>::permute_reuse(
839+
ptr_param<const Permutation<index_type>> row_permutation,
840+
ptr_param<const Permutation<index_type>> column_permutation,
841+
bool invert) const -> std::pair<std::unique_ptr<Csr>, permuting_reuse_info>
842+
{
843+
return transform_reusable(
844+
this, this->get_size(), this->get_num_stored_elements(), [&](auto mtx) {
845+
return mtx->permute(row_permutation, column_permutation, invert);
846+
});
847+
}
848+
849+
736850
template <typename ValueType, typename IndexType>
737851
std::unique_ptr<Csr<ValueType, IndexType>>
738852
Csr<ValueType, IndexType>::scale_permute(
@@ -986,6 +1100,29 @@ Csr<ValueType, IndexType>::create_submatrix(
9861100
}
9871101

9881102

1103+
template <typename ValueType, typename IndexType>
1104+
std::unique_ptr<Dense<ValueType>> Csr<ValueType, IndexType>::create_value_view()
1105+
{
1106+
const auto nnz = this->get_num_stored_elements();
1107+
const auto exec = this->get_executor();
1108+
return Dense<ValueType>::create(
1109+
exec, gko::dim<2>{nnz, 1},
1110+
make_array_view(exec, nnz, this->get_values()), 1);
1111+
}
1112+
1113+
1114+
template <typename ValueType, typename IndexType>
1115+
std::unique_ptr<const Dense<ValueType>>
1116+
Csr<ValueType, IndexType>::create_const_value_view() const
1117+
{
1118+
const auto nnz = this->get_num_stored_elements();
1119+
const auto exec = this->get_executor();
1120+
return Dense<ValueType>::create_const(
1121+
exec, gko::dim<2>{nnz, 1},
1122+
make_const_array_view(exec, nnz, this->get_const_values()), 1);
1123+
}
1124+
1125+
9891126
template <typename ValueType, typename IndexType>
9901127
std::unique_ptr<Diagonal<ValueType>>
9911128
Csr<ValueType, IndexType>::extract_diagonal() const

core/test/matrix/csr.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -8,6 +8,7 @@
88
#include <ginkgo/core/matrix/csr.hpp>
99

1010
#include "core/test/utils.hpp"
11+
#include "ginkgo/core/base/exception.hpp"
1112

1213

1314
namespace {
@@ -419,4 +420,13 @@ TYPED_TEST(Csr, GeneratesCorrectMatrixData)
419420
}
420421

421422

423+
TYPED_TEST(Csr, PermutingReuseInfoDefaultUpdateException)
424+
{
425+
using Mtx = typename TestFixture::Mtx;
426+
typename Mtx::permuting_reuse_info reuse;
427+
428+
ASSERT_THROW(reuse.update_values(this->mtx, this->mtx), gko::NotSupported);
429+
}
430+
431+
422432
} // namespace

include/ginkgo/core/base/precision_dispatch.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void precision_dispatch(Function fn, Args*... linops)
8787
* Calls the given function with the given LinOps temporarily converted to
8888
* matrix::Dense<ValueType>* as parameters.
8989
* If ValueType is real and both input vectors are complex, uses
90-
* matrix::Dense::get_real_view() to convert them into real matrices after
90+
* matrix::Dense::create_real_view() to convert them into real matrices after
9191
* precision conversion.
9292
*
9393
* @see precision_dispatch()
@@ -121,7 +121,7 @@ void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
121121
* Calls the given function with the given LinOps temporarily converted to
122122
* matrix::Dense<ValueType>* as parameters.
123123
* If ValueType is real and both `in` and `out` are complex, uses
124-
* matrix::Dense::get_real_view() to convert them into real matrices after
124+
* matrix::Dense::create_real_view() to convert them into real matrices after
125125
* precision conversion.
126126
*
127127
* @see precision_dispatch()

include/ginkgo/core/matrix/csr.hpp

+102-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -48,6 +48,9 @@ class Fbcsr;
4848
template <typename ValueType, typename IndexType>
4949
class CsrBuilder;
5050

51+
template <typename IndexType>
52+
class Permutation;
53+
5154

5255
namespace detail {
5356

@@ -754,6 +757,45 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
754757

755758
std::unique_ptr<LinOp> conj_transpose() const override;
756759

760+
/**
761+
* A struct describing a transformation of the matrix that reorders the
762+
* values of the matrix into the transformed matrix.
763+
*/
764+
struct permuting_reuse_info {
765+
/** Creates an empty reuse info. */
766+
explicit permuting_reuse_info();
767+
768+
/** Creates a reuse info structure from its value permutation. */
769+
permuting_reuse_info(
770+
std::unique_ptr<Permutation<index_type>> value_permutation);
771+
772+
/**
773+
* Propagates the values from an input matrix to the transformed matrix.
774+
* The output matrix needs to have been computed using the
775+
* transformation that was also used to generate this reuse data.
776+
* Internally, this permutes the input value vector into the output
777+
* value vector.
778+
*/
779+
void update_values(ptr_param<const Csr> input,
780+
ptr_param<Csr> output) const;
781+
782+
std::unique_ptr<Permutation<IndexType>> value_permutation;
783+
};
784+
785+
/**
786+
* Computes the necessary data to update a transposed matrix from its
787+
* original matrix.
788+
* ```
789+
* auto [transposed, reuse] = matrix->transpose_reuse();
790+
* change_values(matrix);
791+
* reuse->update_values(matrix, transposed);
792+
* ```
793+
* @return the reuse info struct that can be used to update values in the
794+
* transposed matrix.
795+
*/
796+
std::pair<std::unique_ptr<Csr>, permuting_reuse_info> transpose_reuse()
797+
const;
798+
757799
/**
758800
* Creates a permuted copy $A'$ of this matrix $A$ with the given
759801
* permutation $P$. By default, this computes a symmetric permutation
@@ -790,6 +832,53 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
790832
ptr_param<const Permutation<index_type>> column_permutation,
791833
bool invert = false) const;
792834

835+
/**
836+
* Computes the operations necessary to propagate changed values from a
837+
* matrix A to a permuted matrix.
838+
* The semantics of this function match those of
839+
* permute(ptr_param<const Permutation<index_type>>, permute_mode).
840+
* Updating values works as follows:
841+
* ```
842+
* auto [permuted, reuse] = matrix->permute_reuse(permutation, mode);
843+
* change_values(matrix);
844+
* reuse->update_values(matrix, permuted);
845+
* ```
846+
* @param permutation The input permutation.
847+
* @param mode The permutation mode. If permute_mode::inverse is set, we
848+
* use the inverse permutation $P^{-1}$ instead of $P$.
849+
* If permute_mode::rows is set, the rows will be permuted.
850+
* If permute_mode::columns is set, the columns will be
851+
* permuted.
852+
* @return an std::pair consisting of the permuted matrix and the reuse info
853+
* that can be used to update values in the permuted matrix.
854+
*/
855+
std::pair<std::unique_ptr<Csr>, permuting_reuse_info> permute_reuse(
856+
ptr_param<const Permutation<index_type>> permutation,
857+
permute_mode mode = permute_mode::symmetric) const;
858+
859+
/**
860+
* Computes the operations necessary to propagate changed values from a
861+
* matrix A to a permuted matrix.
862+
* The semantics of this function match those of
863+
* permute(ptr_param<const Permutation<index_type>>, ptr_param<const
864+
* Permutation<index_type>>, bool). Updating values works as follows:
865+
* ```
866+
* auto [permuted, reuse] = matrix->permute_reuse(row_perm, col_perm, inv);
867+
* change_values(matrix);
868+
* reuse->update_values(matrix, permuted);
869+
* ```
870+
* @param row_permutation The permutation $P$ to apply to the rows
871+
* @param column_permutation The permutation $Q$ to apply to the columns
872+
* @param invert If set to `false`, uses the input permutations, otherwise
873+
* uses their inverses $P^{-1}, Q^{-1}$
874+
* @return an std::pair consisting of the permuted matrix and the reuse info
875+
* that can be used to update values in the permuted matrix.
876+
*/
877+
std::pair<std::unique_ptr<Csr>, permuting_reuse_info> permute_reuse(
878+
ptr_param<const Permutation<index_type>> row_permutation,
879+
ptr_param<const Permutation<index_type>> column_permutation,
880+
bool invert = false) const;
881+
793882
/**
794883
* Creates a scaled and permuted copy of this matrix.
795884
* For an explanation of the permutation modes, see
@@ -878,6 +967,18 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
878967
return values_.get_const_data();
879968
}
880969

970+
/**
971+
* Creates a Dense view of the value array of this matrix as a column
972+
* vector of dimensions nnz x 1.
973+
*/
974+
std::unique_ptr<Dense<ValueType>> create_value_view();
975+
976+
/**
977+
* Creates a const Dense view of the value array of this matrix as a column
978+
* vector of dimensions nnz x 1.
979+
*/
980+
std::unique_ptr<const Dense<ValueType>> create_const_value_view() const;
981+
881982
/**
882983
* Returns the column indexes of the matrix.
883984
*

0 commit comments

Comments
 (0)