Skip to content

Commit 6a991bc

Browse files
committed
update minres to new solver base
1 parent 16fc9af commit 6a991bc

File tree

6 files changed

+70
-116
lines changed

6 files changed

+70
-116
lines changed

common/unified/solver/minres_kernels.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void initialize(
7474
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
7575
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
7676
matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta,
77-
Array<stopping_status>* stop_status)
77+
array<stopping_status>* stop_status)
7878
{
7979
run_kernel(
8080
exec,
@@ -137,7 +137,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
137137
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
138138
matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next,
139139
typename matrix::Dense<ValueType>::absolute_type* tau,
140-
const Array<stopping_status>* stop_status)
140+
const array<stopping_status>* stop_status)
141141
{
142142
run_kernel(
143143
exec,
@@ -183,7 +183,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
183183
matrix::Dense<ValueType>* beta, matrix::Dense<ValueType>* gamma,
184184
matrix::Dense<ValueType>* delta, matrix::Dense<ValueType>* cos,
185185
matrix::Dense<ValueType>* eta,
186-
const Array<stopping_status>* stop_status)
186+
const array<stopping_status>* stop_status)
187187
{
188188
run_kernel_solver(
189189
exec,

core/solver/minres.cpp

+11-9
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ std::unique_ptr<LinOp> Minres<ValueType>::transpose() const
6464
return build()
6565
.with_generated_preconditioner(
6666
share(as<Transposable>(this->get_preconditioner())->transpose()))
67-
.with_criteria(this->stop_criterion_factory_)
67+
.with_criteria(this->get_stop_criterion_factory())
6868
.on(this->get_executor())
6969
->generate(
7070
share(as<Transposable>(this->get_system_matrix())->transpose()));
@@ -77,7 +77,7 @@ std::unique_ptr<LinOp> Minres<ValueType>::conj_transpose() const
7777
return build()
7878
.with_generated_preconditioner(share(
7979
as<Transposable>(this->get_preconditioner())->conj_transpose()))
80-
.with_criteria(this->stop_criterion_factory_)
80+
.with_criteria(this->get_stop_criterion_factory())
8181
.on(this->get_executor())
8282
->generate(share(
8383
as<Transposable>(this->get_system_matrix())->conj_transpose()));
@@ -157,21 +157,22 @@ void Minres<ValueType>::apply_dense_impl(
157157
auto sin = Vector::create_with_config_of(alpha.get());
158158

159159
bool one_changed{};
160-
Array<stopping_status> stop_status(alpha->get_executor(),
160+
array<stopping_status> stop_status(alpha->get_executor(),
161161
dense_b->get_size()[1]);
162162

163163
// r = dense_b
164164
r = gko::clone(dense_b);
165-
system_matrix_->apply(neg_one_op.get(), dense_x, one_op.get(), r.get());
166-
auto stop_criterion = stop_criterion_factory_->generate(
167-
system_matrix_,
165+
this->get_system_matrix()->apply(neg_one_op.get(), dense_x, one_op.get(),
166+
r.get());
167+
auto stop_criterion = this->get_stop_criterion_factory()->generate(
168+
this->get_system_matrix(),
168169
std::shared_ptr<const LinOp>(dense_b, [](const LinOp*) {}), dense_x,
169170
r.get());
170171

171172
// z = M^-1 * r
172173
// beta = <r, z>
173174
// tau = ||z||_2
174-
get_preconditioner()->apply(r.get(), z.get());
175+
this->get_preconditioner()->apply(r.get(), z.get());
175176
r->compute_conj_dot(z.get(), beta.get());
176177
z->compute_norm2(tau.get());
177178

@@ -214,10 +215,11 @@ void Minres<ValueType>::apply_dense_impl(
214215
// v = v - alpha * q
215216
// z_tilde = M * v
216217
// beta = <v, z_tilde>
217-
system_matrix_->apply(one_op.get(), z.get(), neg_one_op.get(), v.get());
218+
this->get_system_matrix()->apply(one_op.get(), z.get(),
219+
neg_one_op.get(), v.get());
218220
v->compute_conj_dot(z.get(), alpha.get());
219221
v->sub_scaled(alpha.get(), q.get());
220-
get_preconditioner()->apply(v.get(), z_tilde.get());
222+
this->get_preconditioner()->apply(v.get(), z_tilde.get());
221223
v->compute_conj_dot(z_tilde.get(), beta.get());
222224

223225
// Updates scalars (row vectors)

core/solver/minres_kernels.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ namespace cg {
6262
matrix::Dense<_type>* cos_prev, matrix::Dense<_type>* cos, \
6363
matrix::Dense<_type>* sin_prev, matrix::Dense<_type>* sin, \
6464
matrix::Dense<_type>* eta_next, matrix::Dense<_type>* eta, \
65-
Array<stopping_status>* stop_status)
65+
array<stopping_status>* stop_status)
6666

6767

6868
#define GKO_DECLARE_MINRES_STEP_1_KERNEL(_type) \
@@ -73,7 +73,7 @@ namespace cg {
7373
matrix::Dense<_type>* sin_prev, matrix::Dense<_type>* sin, \
7474
matrix::Dense<_type>* eta, matrix::Dense<_type>* eta_next, \
7575
typename matrix::Dense<_type>::absolute_type* tau, \
76-
const Array<stopping_status>* stop_status)
76+
const array<stopping_status>* stop_status)
7777

7878
#define GKO_DECLARE_MINRES_STEP_2_KERNEL(_type) \
7979
void step_2(std::shared_ptr<const DefaultExecutor> exec, \
@@ -84,7 +84,7 @@ namespace cg {
8484
matrix::Dense<_type>* alpha, matrix::Dense<_type>* beta, \
8585
matrix::Dense<_type>* gamma, matrix::Dense<_type>* delta, \
8686
matrix::Dense<_type>* cos, matrix::Dense<_type>* eta, \
87-
const Array<stopping_status>* stop_status)
87+
const array<stopping_status>* stop_status)
8888

8989

9090
#define GKO_DECLARE_ALL_AS_TEMPLATES \

include/ginkgo/core/solver/minres.hpp

+9-59
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4545
#include <ginkgo/core/log/logger.hpp>
4646
#include <ginkgo/core/matrix/dense.hpp>
4747
#include <ginkgo/core/matrix/identity.hpp>
48+
#include <ginkgo/core/solver/solver_base.hpp>
4849
#include <ginkgo/core/stop/combined.hpp>
4950
#include <ginkgo/core/stop/criterion.hpp>
5051

@@ -69,26 +70,17 @@ namespace solver {
6970
* @ingroup LinOp
7071
*/
7172
template <typename ValueType = default_precision>
72-
class Minres : public EnableLinOp<Minres<ValueType>>,
73-
public Preconditionable,
74-
public Transposable {
73+
class Minres
74+
: public EnableLinOp<Minres<ValueType>>,
75+
public EnablePreconditionedIterativeSolver<ValueType, Minres<ValueType>>,
76+
public Transposable {
7577
friend class EnableLinOp<Minres>;
7678
friend class EnablePolymorphicObject<Minres, LinOp>;
7779

7880
public:
7981
using value_type = ValueType;
8082
using transposed_type = Minres<ValueType>;
8183

82-
/**
83-
* Gets the system operator (matrix) of the linear system.
84-
*
85-
* @return the system operator (matrix)
86-
*/
87-
std::shared_ptr<const LinOp> get_system_matrix() const
88-
{
89-
return system_matrix_;
90-
}
91-
9284
std::unique_ptr<LinOp> transpose() const override;
9385

9486
std::unique_ptr<LinOp> conj_transpose() const override;
@@ -100,28 +92,6 @@ class Minres : public EnableLinOp<Minres<ValueType>>,
10092
*/
10193
bool apply_uses_initial_guess() const override { return true; }
10294

103-
/**
104-
* Gets the stopping criterion factory of the solver.
105-
*
106-
* @return the stopping criterion factory
107-
*/
108-
std::shared_ptr<const stop::CriterionFactory> get_stop_criterion_factory()
109-
const
110-
{
111-
return stop_criterion_factory_;
112-
}
113-
114-
/**
115-
* Sets the stopping criterion of the solver.
116-
*
117-
* @param other the new stopping criterion factory
118-
*/
119-
void set_stop_criterion_factory(
120-
std::shared_ptr<const stop::CriterionFactory> other)
121-
{
122-
stop_criterion_factory_ = std::move(other);
123-
}
124-
12595
GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
12696
{
12797
/**
@@ -163,30 +133,10 @@ class Minres : public EnableLinOp<Minres<ValueType>>,
163133
std::shared_ptr<const LinOp> system_matrix)
164134
: EnableLinOp<Minres>(factory->get_executor(),
165135
gko::transpose(system_matrix->get_size())),
166-
parameters_{factory->get_parameters()},
167-
system_matrix_{std::move(system_matrix)}
168-
{
169-
GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix_);
170-
if (parameters_.generated_preconditioner) {
171-
GKO_ASSERT_EQUAL_DIMENSIONS(parameters_.generated_preconditioner,
172-
this);
173-
Preconditionable::set_preconditioner(
174-
parameters_.generated_preconditioner);
175-
} else if (parameters_.preconditioner) {
176-
Preconditionable::set_preconditioner(
177-
parameters_.preconditioner->generate(system_matrix_));
178-
} else {
179-
Preconditionable::set_preconditioner(
180-
matrix::Identity<ValueType>::create(this->get_executor(),
181-
this->get_size()));
182-
}
183-
stop_criterion_factory_ =
184-
stop::combine(std::move(parameters_.criteria));
185-
}
186-
187-
private:
188-
std::shared_ptr<const LinOp> system_matrix_{};
189-
std::shared_ptr<const stop::CriterionFactory> stop_criterion_factory_{};
136+
EnablePreconditionedIterativeSolver<ValueType, Minres<ValueType>>{
137+
std::move(system_matrix), factory->get_parameters()},
138+
parameters_{factory->get_parameters()}
139+
{}
190140
};
191141

192142

reference/solver/minres_kernels.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void initialize(
6161
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
6262
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
6363
matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta,
64-
Array<stopping_status>* stop_status)
64+
array<stopping_status>* stop_status)
6565
{
6666
for (size_type j = 0; j < r->get_size()[1]; ++j) {
6767
delta->at(j) = gamma->at(j) = cos_prev->at(j) = sin_prev->at(j) =
@@ -110,7 +110,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
110110
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
111111
matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next,
112112
typename matrix::Dense<ValueType>::absolute_type* tau,
113-
const Array<stopping_status>* stop_status)
113+
const array<stopping_status>* stop_status)
114114
{
115115
for (size_type j = 0; j < alpha->get_size()[1]; ++j) {
116116
if (stop_status->get_const_data()[j].has_stopped()) {
@@ -149,7 +149,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
149149
matrix::Dense<ValueType>* beta, matrix::Dense<ValueType>* gamma,
150150
matrix::Dense<ValueType>* delta, matrix::Dense<ValueType>* cos,
151151
matrix::Dense<ValueType>* eta,
152-
const Array<stopping_status>* stop_status)
152+
const array<stopping_status>* stop_status)
153153
{
154154
for (size_type i = 0; i < x->get_size()[0]; ++i) {
155155
for (size_type j = 0; j < x->get_size()[1]; ++j) {

test/solver/minres_kernels.cpp

+41-39
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4949

5050
#include "core/solver/minres_kernels.hpp"
5151
#include "core/test/utils.hpp"
52-
#include "core/test/utils/matrix_utils.hpp"
52+
#include "core/utils/matrix_utils.hpp"
5353
#include "test/utils/executor.hpp"
5454

5555
namespace {
@@ -80,14 +80,19 @@ class Minres : public ::testing::Test {
8080
}
8181

8282
std::unique_ptr<Mtx> gen_mtx(gko::size_type num_rows,
83-
gko::size_type num_cols, gko::size_type stride)
83+
gko::size_type num_cols, gko::size_type stride,
84+
bool make_hermitian)
8485
{
85-
auto tmp_mtx = gko::test::generate_random_matrix<Mtx>(
86+
auto tmp_mtx = gko::test::generate_random_matrix_data<value_type , gko::int32>(
8687
num_rows, num_cols,
8788
std::uniform_int_distribution<>(num_cols, num_cols),
88-
std::normal_distribution<value_type>(-1.0, 1.0), rand_engine, ref);
89+
std::normal_distribution<value_type>(-1.0, 1.0), rand_engine);
90+
if (make_hermitian) {
91+
gko::utils::make_unit_diagonal(tmp_mtx);
92+
gko::utils::make_hermitian(tmp_mtx);
93+
}
8994
auto result = Mtx::create(ref, gko::dim<2>{num_rows, num_cols}, stride);
90-
result->copy_from(tmp_mtx.get());
95+
result->read(tmp_mtx);
9196
return result;
9297
}
9398

@@ -96,31 +101,31 @@ class Minres : public ::testing::Test {
96101
gko::size_type m = 597;
97102
gko::size_type n = 43;
98103
// all vectors need the same stride as b, except x
99-
b = gen_mtx(m, n, n + 2);
100-
r = gen_mtx(m, n, n + 2);
101-
z = gen_mtx(m, n, n + 2);
102-
z_tilde = gen_mtx(m, n, n + 2);
103-
p = gen_mtx(m, n, n + 2);
104-
p_prev = gen_mtx(m, n, n + 2);
105-
q = gen_mtx(m, n, n + 2);
106-
q_prev = gen_mtx(m, n, n + 2);
107-
v = gen_mtx(m, n, n + 2);
108-
x = gen_mtx(m, n, n + 3);
109-
alpha = gen_mtx(1, n, n);
110-
beta = gen_mtx(1, n, n)->compute_absolute();
111-
gamma = gen_mtx(1, n, n);
112-
delta = gen_mtx(1, n, n);
113-
cos_prev = gen_mtx(1, n, n);
114-
cos = gen_mtx(1, n, n);
115-
sin_prev = gen_mtx(1, n, n);
116-
sin = gen_mtx(1, n, n);
117-
eta_next = gen_mtx(1, n, n);
118-
eta = gen_mtx(1, n, n);
119-
tau = gen_mtx(1, n, n)->compute_absolute();
104+
b = gen_mtx(m, n, n + 2, false);
105+
r = gen_mtx(m, n, n + 2, false);
106+
z = gen_mtx(m, n, n + 2, false);
107+
z_tilde = gen_mtx(m, n, n + 2, false);
108+
p = gen_mtx(m, n, n + 2, false);
109+
p_prev = gen_mtx(m, n, n + 2, false);
110+
q = gen_mtx(m, n, n + 2, false);
111+
q_prev = gen_mtx(m, n, n + 2, false);
112+
v = gen_mtx(m, n, n + 2, false);
113+
x = gen_mtx(m, n, n + 3, false);
114+
alpha = gen_mtx(1, n, n, false);
115+
beta = gen_mtx(1, n, n, false)->compute_absolute();
116+
gamma = gen_mtx(1, n, n, false);
117+
delta = gen_mtx(1, n, n, false);
118+
cos_prev = gen_mtx(1, n, n, false);
119+
cos = gen_mtx(1, n, n, false);
120+
sin_prev = gen_mtx(1, n, n, false);
121+
sin = gen_mtx(1, n, n, false);
122+
eta_next = gen_mtx(1, n, n, false);
123+
eta = gen_mtx(1, n, n, false);
124+
tau = gen_mtx(1, n, n, false)->compute_absolute();
120125
// check correct handling for zero values
121126
beta->at(2) = gko::zero<value_type>();
122127
stop_status =
123-
std::make_unique<gko::Array<gko::stopping_status>>(ref, n);
128+
std::make_unique<gko::array<gko::stopping_status>>(ref, n);
124129
for (size_t i = 0; i < stop_status->get_num_elems(); ++i) {
125130
stop_status->get_data()[i].reset();
126131
}
@@ -148,7 +153,7 @@ class Minres : public ::testing::Test {
148153
d_cos = gko::clone(exec, cos);
149154
d_sin_prev = gko::clone(exec, sin_prev);
150155
d_sin = gko::clone(exec, sin);
151-
d_stop_status = std::make_unique<gko::Array<gko::stopping_status>>(
156+
d_stop_status = std::make_unique<gko::array<gko::stopping_status>>(
152157
exec, *stop_status);
153158
}
154159

@@ -203,8 +208,8 @@ class Minres : public ::testing::Test {
203208
std::unique_ptr<Mtx> d_sin_prev;
204209
std::unique_ptr<Mtx> d_sin;
205210

206-
std::unique_ptr<gko::Array<gko::stopping_status>> stop_status;
207-
std::unique_ptr<gko::Array<gko::stopping_status>> d_stop_status;
211+
std::unique_ptr<gko::array<gko::stopping_status>> stop_status;
212+
std::unique_ptr<gko::array<gko::stopping_status>> d_stop_status;
208213
};
209214

210215

@@ -296,10 +301,9 @@ TEST_F(Minres, MinresStep2IsEquivalentToStep2)
296301

297302
TEST_F(Minres, ApplyIsEquivalentToRef)
298303
{
299-
auto mtx = gen_mtx(50, 50, 53);
300-
gko::test::make_hermitian(mtx.get());
301-
auto x = gen_mtx(50, 1, 5);
302-
auto b = gen_mtx(50, 1, 4);
304+
auto mtx = gen_mtx(50, 50, 53, true);
305+
auto x = gen_mtx(50, 1, 5, false);
306+
auto b = gen_mtx(50, 1, 4, false);
303307
auto d_mtx = gko::clone(exec, mtx);
304308
auto d_x = gko::clone(exec, x);
305309
auto d_b = gko::clone(exec, b);
@@ -331,11 +335,9 @@ TEST_F(Minres, ApplyIsEquivalentToRef)
331335

332336
TEST_F(Minres, PreconditionedApplyIsEquivalentToRef)
333337
{
334-
335-
auto mtx = gen_mtx(50, 50, 53);
336-
gko::test::make_hpd(mtx.get());
337-
auto x = gen_mtx(50, 1, 5);
338-
auto b = gen_mtx(50, 1, 4);
338+
auto mtx = gen_mtx(50, 50, 53, true);
339+
auto x = gen_mtx(50, 1, 5, false);
340+
auto b = gen_mtx(50, 1, 4, false);
339341
auto d_mtx = gko::clone(exec, mtx);
340342
auto d_x = gko::clone(exec, x);
341343
auto d_b = gko::clone(exec, b);

0 commit comments

Comments
 (0)