1
- // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1
+ // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2
2
//
3
3
// SPDX-License-Identifier: BSD-3-Clause
4
4
26
26
#include " core/components/absolute_array_kernels.hpp"
27
27
#include " core/components/fill_array_kernels.hpp"
28
28
#include " core/components/format_conversion_kernels.hpp"
29
+ #include " core/components/precision_conversion_kernels.hpp"
29
30
#include " core/components/prefix_sum_kernels.hpp"
30
31
#include " core/matrix/csr_kernels.hpp"
31
32
#include " core/matrix/ell_kernels.hpp"
@@ -48,6 +49,7 @@ GKO_REGISTER_OPERATION(spgeam, csr::spgeam);
48
49
GKO_REGISTER_OPERATION (convert_idxs_to_ptrs, components::convert_idxs_to_ptrs);
49
50
GKO_REGISTER_OPERATION (convert_ptrs_to_idxs, components::convert_ptrs_to_idxs);
50
51
GKO_REGISTER_OPERATION (fill_in_dense, csr::fill_in_dense);
52
+ GKO_REGISTER_OPERATION (fill_seq_array, components::fill_seq_array);
51
53
GKO_REGISTER_OPERATION (compute_slice_sets, sellp::compute_slice_sets);
52
54
GKO_REGISTER_OPERATION (convert_to_sellp, csr::convert_to_sellp);
53
55
GKO_REGISTER_OPERATION (compute_max_row_nnz, ell::compute_max_row_nnz);
@@ -83,6 +85,7 @@ GKO_REGISTER_OPERATION(is_sorted_by_column_index,
83
85
csr::is_sorted_by_column_index);
84
86
GKO_REGISTER_OPERATION (extract_diagonal, csr::extract_diagonal);
85
87
GKO_REGISTER_OPERATION (fill_array, components::fill_array);
88
+ GKO_REGISTER_OPERATION (convert_precision, components::convert_precision);
86
89
GKO_REGISTER_OPERATION (prefix_sum_nonnegative,
87
90
components::prefix_sum_nonnegative);
88
91
GKO_REGISTER_OPERATION (inplace_absolute_array,
@@ -618,6 +621,92 @@ void Csr<ValueType, IndexType>::write(mat_data& data) const
618
621
}
619
622
620
623
624
+ template <typename ValueType, typename IndexType, typename TransformClosure>
625
+ std::pair<std::unique_ptr<Csr<ValueType, IndexType>>,
626
+ typename Csr<ValueType, IndexType>::permuting_reuse_info>
627
+ transform_reusable (const Csr<ValueType, IndexType>* input, gko::dim<2 > out_size,
628
+ size_type nnz, TransformClosure closure)
629
+ {
630
+ using FloatIndexType =
631
+ std::conditional_t <std::is_same_v<IndexType, int32>, float , double >;
632
+ static_assert (sizeof (FloatIndexType) == sizeof (IndexType));
633
+ static_assert (alignof (FloatIndexType) == alignof (IndexType));
634
+ auto exec = input->get_executor ();
635
+ auto in_size = input->get_size ();
636
+ auto transformed = Csr<ValueType, IndexType>::create (exec, out_size, nnz);
637
+ // transform matrix with integer values from 0 to nnz - 1 reinterpret_cast
638
+ // as float
639
+ array<IndexType> iota_values{exec, nnz};
640
+ exec->run (csr::make_fill_seq_array (iota_values.get_data (), nnz));
641
+ auto iota_float_view = make_array_view (
642
+ exec, nnz, reinterpret_cast <FloatIndexType*>(iota_values.get_data ()));
643
+ auto iota_mtx = Csr<FloatIndexType, IndexType>::create_const (
644
+ exec, input->get_size (), iota_float_view.as_const_view (),
645
+ make_const_array_view (exec, nnz, input->get_const_col_idxs ()),
646
+ make_const_array_view (exec, in_size[0 ] + 1 ,
647
+ input->get_const_row_ptrs ()),
648
+ std::make_shared<typename Csr<FloatIndexType, IndexType>::sparselib>());
649
+ auto transformed_iota = closure (iota_mtx.get ());
650
+ exec->copy (out_size[0 ] + 1 , transformed_iota->get_const_row_ptrs (),
651
+ transformed->get_row_ptrs ());
652
+ exec->copy (nnz, transformed_iota->get_const_col_idxs (),
653
+ transformed->get_col_idxs ());
654
+ exec->copy (nnz,
655
+ reinterpret_cast <const IndexType*>(
656
+ transformed_iota->get_const_values ()),
657
+ iota_values.get_data ());
658
+ auto transform_permutation =
659
+ Permutation<IndexType>::create (exec, std::move (iota_values));
660
+ transformed->set_strategy (input->get_strategy ());
661
+ // permute values into output matrix
662
+ input->create_const_value_view ()->permute (transform_permutation,
663
+ transformed->create_value_view (),
664
+ permute_mode::rows);
665
+
666
+ return std::make_pair (
667
+ std::move (transformed),
668
+ typename Csr<ValueType, IndexType>::permuting_reuse_info{
669
+ std::move (transform_permutation)});
670
+ }
671
+
672
+
673
+ template <typename ValueType, typename IndexType>
674
+ Csr<ValueType, IndexType>::permuting_reuse_info::permuting_reuse_info()
675
+ : permuting_reuse_info{nullptr }
676
+ {}
677
+
678
+
679
+ template <typename ValueType, typename IndexType>
680
+ Csr<ValueType, IndexType>::permuting_reuse_info::permuting_reuse_info(
681
+ std::unique_ptr<Permutation<index_type>> value_permutation)
682
+ : value_permutation{std::move (value_permutation)}
683
+ {}
684
+
685
+
686
+ template <typename ValueType, typename IndexType>
687
+ void Csr<ValueType, IndexType>::permuting_reuse_info::update_values(
688
+ ptr_param<const Csr> input, ptr_param<Csr> output) const
689
+ {
690
+ if (!value_permutation) {
691
+ GKO_NOT_SUPPORTED (value_permutation);
692
+ }
693
+ input->create_const_value_view ()->permute (
694
+ value_permutation, output->create_value_view (), permute_mode::rows);
695
+ }
696
+
697
+
698
+ template <typename ValueType, typename IndexType>
699
+ auto Csr<ValueType, IndexType>::transpose_reuse() const
700
+ -> std::pair<std::unique_ptr<Csr>, Csr::permuting_reuse_info>
701
+ {
702
+ return transform_reusable (
703
+ this , gko::transpose (this ->get_size ()), this ->get_num_stored_elements (),
704
+ [](auto mtx) {
705
+ return as<gko::detail::pointee<decltype (mtx)>>(mtx->transpose ());
706
+ });
707
+ }
708
+
709
+
621
710
template <typename ValueType, typename IndexType>
622
711
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::transpose() const
623
712
{
@@ -733,6 +822,31 @@ std::unique_ptr<Csr<ValueType, IndexType>> Csr<ValueType, IndexType>::permute(
733
822
}
734
823
735
824
825
+ template <typename ValueType, typename IndexType>
826
+ auto Csr<ValueType, IndexType>::permute_reuse(
827
+ ptr_param<const Permutation<index_type>> permutation,
828
+ permute_mode mode) const
829
+ -> std::pair<std::unique_ptr<Csr>, permuting_reuse_info>
830
+ {
831
+ return transform_reusable (
832
+ this , this ->get_size (), this ->get_num_stored_elements (),
833
+ [&](auto mtx) { return mtx->permute (permutation, mode); });
834
+ }
835
+
836
+
837
+ template <typename ValueType, typename IndexType>
838
+ auto Csr<ValueType, IndexType>::permute_reuse(
839
+ ptr_param<const Permutation<index_type>> row_permutation,
840
+ ptr_param<const Permutation<index_type>> column_permutation,
841
+ bool invert) const -> std::pair<std::unique_ptr<Csr>, permuting_reuse_info>
842
+ {
843
+ return transform_reusable (
844
+ this , this ->get_size (), this ->get_num_stored_elements (), [&](auto mtx) {
845
+ return mtx->permute (row_permutation, column_permutation, invert);
846
+ });
847
+ }
848
+
849
+
736
850
template <typename ValueType, typename IndexType>
737
851
std::unique_ptr<Csr<ValueType, IndexType>>
738
852
Csr<ValueType, IndexType>::scale_permute(
@@ -986,6 +1100,29 @@ Csr<ValueType, IndexType>::create_submatrix(
986
1100
}
987
1101
988
1102
1103
+ template <typename ValueType, typename IndexType>
1104
+ std::unique_ptr<Dense<ValueType>> Csr<ValueType, IndexType>::create_value_view()
1105
+ {
1106
+ const auto nnz = this ->get_num_stored_elements ();
1107
+ const auto exec = this ->get_executor ();
1108
+ return Dense<ValueType>::create (
1109
+ exec, gko::dim<2 >{nnz, 1 },
1110
+ make_array_view (exec, nnz, this ->get_values ()), 1 );
1111
+ }
1112
+
1113
+
1114
+ template <typename ValueType, typename IndexType>
1115
+ std::unique_ptr<const Dense<ValueType>>
1116
+ Csr<ValueType, IndexType>::create_const_value_view() const
1117
+ {
1118
+ const auto nnz = this ->get_num_stored_elements ();
1119
+ const auto exec = this ->get_executor ();
1120
+ return Dense<ValueType>::create_const (
1121
+ exec, gko::dim<2 >{nnz, 1 },
1122
+ make_const_array_view (exec, nnz, this ->get_const_values ()), 1 );
1123
+ }
1124
+
1125
+
989
1126
template <typename ValueType, typename IndexType>
990
1127
std::unique_ptr<Diagonal<ValueType>>
991
1128
Csr<ValueType, IndexType>::extract_diagonal() const
0 commit comments