Skip to content

Commit 777448c

Browse files
MarcelKochSlaedrhartwiganzt
committed
add minres implementation
based on 'Iterative Methods for Solving Linear Systems' (https://doi.org/10.1137/1.9781611970937) and 'Iterative Methods for Singular Linear Equations and Least-Squares Problems' (PhD thesis, Stanford University) Co-authored-by: Aditya Kashi <[email protected]> Co-authored-by: Hartwig Anzt <[email protected]>
1 parent a47745f commit 777448c

File tree

16 files changed

+1655
-11
lines changed

16 files changed

+1655
-11
lines changed

benchmark/solver/solver_common.hpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ DEFINE_bool(
3434
rel_residual, false,
3535
"Use relative residual instead of residual reduction stopping criterion");
3636

37-
DEFINE_string(solvers, "cg",
38-
"A comma-separated list of solvers to run. "
39-
"Supported values are: bicgstab, bicg, cb_gmres_keep, "
40-
"cb_gmres_reduce1, cb_gmres_reduce2, cb_gmres_integer, "
41-
"cb_gmres_ireduce1, cb_gmres_ireduce2, cg, cgs, fcg, gmres, idr, "
42-
"lower_trs, upper_trs, spd_direct, symm_direct, "
43-
"near_symm_direct, direct, overhead");
37+
DEFINE_string(
38+
solvers, "cg",
39+
"A comma-separated list of solvers to run. "
40+
"Supported values are: bicgstab, bicg, cb_gmres_keep, "
41+
"cb_gmres_reduce1, cb_gmres_reduce2, cb_gmres_integer, "
42+
"cb_gmres_ireduce1, cb_gmres_ireduce2, cg, cgs, direct, fcg, gmres, idr, "
43+
"lower_trs, minres, near_symm_direct, upper_trs, spd_direct, symm_direct, "
44+
"overhead");
4445

4546
DEFINE_uint32(
4647
nrhs, 1,

common/unified/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(UNIFIED_SOURCES
2828
solver/gcr_kernels.cpp
2929
solver/gmres_kernels.cpp
3030
solver/ir_kernels.cpp
31+
solver/minres_kernels.cpp
3132
)
3233
list(TRANSFORM UNIFIED_SOURCES PREPEND ${CMAKE_CURRENT_SOURCE_DIR}/)
33-
set(GKO_UNIFIED_COMMON_SOURCES ${UNIFIED_SOURCES} PARENT_SCOPE)
34+
set(GKO_UNIFIED_COMMON_SOURCES ${UNIFIED_SOURCES} PARENT_SCOPE)
+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include "core/solver/minres_kernels.hpp"
6+
7+
8+
#include <ginkgo/core/base/executor.hpp>
9+
10+
11+
#include "common/unified/base/kernel_launch_solver.hpp"
12+
13+
14+
namespace gko {
15+
namespace kernels {
16+
namespace GKO_DEVICE_NAMESPACE {
17+
/**
18+
* @brief The Minres solver namespace.
19+
*
20+
* @ingroup minres
21+
*/
22+
namespace minres {
23+
namespace detail {
24+
25+
26+
template <typename T, typename U>
27+
GKO_INLINE GKO_ATTRIBUTES void swap(T& a, U& b)
28+
{
29+
U tmp{b};
30+
b = a;
31+
a = tmp;
32+
}
33+
34+
35+
} // namespace detail
36+
37+
38+
template <typename ValueType>
39+
void initialize(
40+
std::shared_ptr<const DefaultExecutor> exec,
41+
const matrix::Dense<ValueType>* r, matrix::Dense<ValueType>* z,
42+
matrix::Dense<ValueType>* p, matrix::Dense<ValueType>* p_prev,
43+
matrix::Dense<ValueType>* q, matrix::Dense<ValueType>* q_prev,
44+
matrix::Dense<ValueType>* v, matrix::Dense<ValueType>* beta,
45+
matrix::Dense<ValueType>* gamma, matrix::Dense<ValueType>* delta,
46+
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
47+
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
48+
matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta,
49+
array<stopping_status>* stop_status)
50+
{
51+
run_kernel(
52+
exec,
53+
[] GKO_KERNEL(auto col, auto beta, auto gamma, auto delta,
54+
auto cos_prev, auto cos, auto sin_prev, auto sin,
55+
auto eta_next, auto eta, auto stop) {
56+
delta[col] = gamma[col] = cos_prev[col] = sin_prev[col] = sin[col] =
57+
zero(*delta);
58+
cos[col] = one(*delta);
59+
eta_next[col] = eta[col] = beta[col] = sqrt(beta[col]);
60+
stop[col].reset();
61+
},
62+
beta->get_num_stored_elements(), row_vector(beta), row_vector(gamma),
63+
row_vector(delta), row_vector(cos_prev), row_vector(cos),
64+
row_vector(sin_prev), row_vector(sin), row_vector(eta_next),
65+
row_vector(eta), *stop_status);
66+
67+
run_kernel_solver(
68+
exec,
69+
[] GKO_KERNEL(auto row, auto col, auto r, auto z, auto p, auto p_prev,
70+
auto q, auto q_prev, auto v, auto beta, auto stop) {
71+
q(row, col) = safe_divide(r(row, col), beta[col]);
72+
z(row, col) = safe_divide(z(row, col), beta[col]);
73+
p(row, col) = p_prev(row, col) = q_prev(row, col) = v(row, col) =
74+
zero(p(row, col));
75+
},
76+
r->get_size(), r->get_stride(), default_stride(r), default_stride(z),
77+
default_stride(p), default_stride(p_prev), default_stride(q),
78+
default_stride(q_prev), default_stride(v), row_vector(beta),
79+
*stop_status);
80+
}
81+
82+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_INITIALIZE_KERNEL);
83+
84+
85+
template <typename ValueType>
86+
GKO_KERNEL void update_givens_rotation(ValueType& alpha, const ValueType& beta,
87+
ValueType& cos, ValueType& sin)
88+
{
89+
if (alpha == zero(alpha)) {
90+
cos = zero(cos);
91+
sin = one(sin);
92+
} else {
93+
const auto scale = abs(alpha) + abs(beta);
94+
const auto hypotenuse =
95+
scale * sqrt(abs(alpha / scale) * abs(alpha / scale) +
96+
abs(beta / scale) * abs(beta / scale));
97+
cos = conj(alpha) / hypotenuse;
98+
sin = conj(beta) / hypotenuse;
99+
}
100+
alpha = cos * alpha + sin * beta;
101+
}
102+
103+
104+
template <typename ValueType>
105+
void step_1(std::shared_ptr<const DefaultExecutor> exec,
106+
matrix::Dense<ValueType>* alpha, matrix::Dense<ValueType>* beta,
107+
matrix::Dense<ValueType>* gamma, matrix::Dense<ValueType>* delta,
108+
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
109+
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
110+
matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next,
111+
typename matrix::Dense<ValueType>::absolute_type* tau,
112+
const array<stopping_status>* stop_status)
113+
{
114+
run_kernel(
115+
exec,
116+
[] GKO_KERNEL(auto col, auto alpha, auto beta, auto gamma, auto delta,
117+
auto cos_prev, auto cos, auto sin_prev, auto sin,
118+
auto eta_next, auto eta, auto tau, auto stop) {
119+
if (!stop[col].has_stopped()) {
120+
beta[col] = sqrt(beta[col]);
121+
delta[col] = sin_prev[col] * gamma[col];
122+
const auto tmp_d = gamma[col];
123+
const auto tmp_a = alpha[col];
124+
gamma[col] =
125+
cos_prev[col] * cos[col] * tmp_d + sin[col] * tmp_a;
126+
alpha[col] =
127+
-conj(sin[col]) * cos_prev[col] * tmp_d + cos[col] * tmp_a;
128+
129+
detail::swap(cos[col], cos_prev[col]);
130+
detail::swap(sin[col], sin_prev[col]);
131+
update_givens_rotation(alpha[col], beta[col], cos[col],
132+
sin[col]);
133+
134+
tau[col] = abs(sin[col]) * tau[col];
135+
eta[col] = eta_next[col];
136+
eta_next[col] = -conj(sin[col]) * eta[col];
137+
}
138+
},
139+
alpha->get_num_stored_elements(), row_vector(alpha), row_vector(beta),
140+
row_vector(gamma), row_vector(delta), row_vector(cos_prev),
141+
row_vector(cos), row_vector(sin_prev), row_vector(sin),
142+
row_vector(eta_next), row_vector(eta), row_vector(tau), *stop_status);
143+
}
144+
145+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_1_KERNEL);
146+
147+
148+
template <typename ValueType>
149+
void step_2(std::shared_ptr<const DefaultExecutor> exec,
150+
matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* p,
151+
const matrix::Dense<ValueType>* p_prev, matrix::Dense<ValueType>* z,
152+
const matrix::Dense<ValueType>* z_tilde,
153+
matrix::Dense<ValueType>* q, matrix::Dense<ValueType>* q_prev,
154+
matrix::Dense<ValueType>* v, const matrix::Dense<ValueType>* alpha,
155+
const matrix::Dense<ValueType>* beta,
156+
const matrix::Dense<ValueType>* gamma,
157+
const matrix::Dense<ValueType>* delta,
158+
const matrix::Dense<ValueType>* cos,
159+
const matrix::Dense<ValueType>* eta,
160+
const array<stopping_status>* stop_status)
161+
{
162+
run_kernel_solver(
163+
exec,
164+
[] GKO_KERNEL(auto row, auto col, auto x, auto p, auto p_prev, auto q,
165+
auto q_prev, auto v, auto z, auto z_tilde, auto alpha,
166+
auto beta, auto gamma, auto delta, auto cos, auto eta,
167+
auto stop) {
168+
if (!stop[col].has_stopped()) {
169+
p(row, col) =
170+
safe_divide(z(row, col) - gamma[col] * p_prev(row, col) -
171+
delta[col] * p(row, col),
172+
alpha[col]);
173+
x(row, col) = x(row, col) + cos[col] * eta[col] * p(row, col);
174+
175+
q_prev(row, col) = v(row, col);
176+
const auto tmp = q(row, col);
177+
z(row, col) = safe_divide(z_tilde(row, col), beta[col]);
178+
q(row, col) = safe_divide(v(row, col), beta[col]);
179+
v(row, col) = tmp * beta[col];
180+
}
181+
},
182+
x->get_size(), p->get_stride(), x, default_stride(p),
183+
default_stride(p_prev), default_stride(q), default_stride(q_prev),
184+
default_stride(v), default_stride(z), default_stride(z_tilde),
185+
row_vector(alpha), row_vector(beta), row_vector(gamma),
186+
row_vector(delta), row_vector(cos), row_vector(eta), *stop_status);
187+
}
188+
189+
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_2_KERNEL);
190+
191+
192+
} // namespace minres
193+
} // namespace GKO_DEVICE_NAMESPACE
194+
} // namespace kernels
195+
} // namespace gko

core/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ target_sources(ginkgo
7878
solver/gmres.cpp
7979
solver/idr.cpp
8080
solver/ir.cpp
81+
solver/minres.cpp
8182
solver/lower_trs.cpp
8283
solver/multigrid.cpp
8384
solver/upper_trs.cpp

core/device_hooks/common_kernels.inc.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#include "core/solver/idr_kernels.hpp"
6262
#include "core/solver/ir_kernels.hpp"
6363
#include "core/solver/lower_trs_kernels.hpp"
64+
#include "core/solver/minres_kernels.hpp"
6465
#include "core/solver/multigrid_kernels.hpp"
6566
#include "core/solver/upper_trs_kernels.hpp"
6667
#include "core/stop/criterion_kernels.hpp"
@@ -562,6 +563,17 @@ GKO_STUB_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_MULTIGRID_KCYCLE_CHECK_STOP_KERNEL);
562563
} // namespace multigrid
563564

564565

566+
namespace minres {
567+
568+
569+
GKO_STUB_VALUE_TYPE(GKO_DECLARE_MINRES_INITIALIZE_KERNEL);
570+
GKO_STUB_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_1_KERNEL);
571+
GKO_STUB_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_2_KERNEL);
572+
573+
574+
} // namespace minres
575+
576+
565577
namespace sparsity_csr {
566578

567579

0 commit comments

Comments
 (0)