|
| 1 | +// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: BSD-3-Clause |
| 4 | + |
| 5 | +#include "core/solver/minres_kernels.hpp" |
| 6 | + |
| 7 | + |
| 8 | +#include <ginkgo/core/base/executor.hpp> |
| 9 | + |
| 10 | + |
| 11 | +#include "common/unified/base/kernel_launch_solver.hpp" |
| 12 | + |
| 13 | + |
| 14 | +namespace gko { |
| 15 | +namespace kernels { |
| 16 | +namespace GKO_DEVICE_NAMESPACE { |
| 17 | +/** |
| 18 | + * @brief The Minres solver namespace. |
| 19 | + * |
| 20 | + * @ingroup minres |
| 21 | + */ |
| 22 | +namespace minres { |
| 23 | +namespace detail { |
| 24 | + |
| 25 | + |
| 26 | +template <typename T, typename U> |
| 27 | +GKO_INLINE GKO_ATTRIBUTES void swap(T& a, U& b) |
| 28 | +{ |
| 29 | + U tmp{b}; |
| 30 | + b = a; |
| 31 | + a = tmp; |
| 32 | +} |
| 33 | + |
| 34 | + |
| 35 | +} // namespace detail |
| 36 | + |
| 37 | + |
| 38 | +template <typename ValueType> |
| 39 | +void initialize( |
| 40 | + std::shared_ptr<const DefaultExecutor> exec, |
| 41 | + const matrix::Dense<ValueType>* r, matrix::Dense<ValueType>* z, |
| 42 | + matrix::Dense<ValueType>* p, matrix::Dense<ValueType>* p_prev, |
| 43 | + matrix::Dense<ValueType>* q, matrix::Dense<ValueType>* q_prev, |
| 44 | + matrix::Dense<ValueType>* v, matrix::Dense<ValueType>* beta, |
| 45 | + matrix::Dense<ValueType>* gamma, matrix::Dense<ValueType>* delta, |
| 46 | + matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos, |
| 47 | + matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin, |
| 48 | + matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta, |
| 49 | + array<stopping_status>* stop_status) |
| 50 | +{ |
| 51 | + run_kernel( |
| 52 | + exec, |
| 53 | + [] GKO_KERNEL(auto col, auto beta, auto gamma, auto delta, |
| 54 | + auto cos_prev, auto cos, auto sin_prev, auto sin, |
| 55 | + auto eta_next, auto eta, auto stop) { |
| 56 | + delta[col] = gamma[col] = cos_prev[col] = sin_prev[col] = sin[col] = |
| 57 | + zero(*delta); |
| 58 | + cos[col] = one(*delta); |
| 59 | + eta_next[col] = eta[col] = beta[col] = sqrt(beta[col]); |
| 60 | + stop[col].reset(); |
| 61 | + }, |
| 62 | + beta->get_num_stored_elements(), row_vector(beta), row_vector(gamma), |
| 63 | + row_vector(delta), row_vector(cos_prev), row_vector(cos), |
| 64 | + row_vector(sin_prev), row_vector(sin), row_vector(eta_next), |
| 65 | + row_vector(eta), *stop_status); |
| 66 | + |
| 67 | + run_kernel_solver( |
| 68 | + exec, |
| 69 | + [] GKO_KERNEL(auto row, auto col, auto r, auto z, auto p, auto p_prev, |
| 70 | + auto q, auto q_prev, auto v, auto beta, auto stop) { |
| 71 | + q(row, col) = safe_divide(r(row, col), beta[col]); |
| 72 | + z(row, col) = safe_divide(z(row, col), beta[col]); |
| 73 | + p(row, col) = p_prev(row, col) = q_prev(row, col) = v(row, col) = |
| 74 | + zero(p(row, col)); |
| 75 | + }, |
| 76 | + r->get_size(), r->get_stride(), default_stride(r), default_stride(z), |
| 77 | + default_stride(p), default_stride(p_prev), default_stride(q), |
| 78 | + default_stride(q_prev), default_stride(v), row_vector(beta), |
| 79 | + *stop_status); |
| 80 | +} |
| 81 | + |
| 82 | +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_INITIALIZE_KERNEL); |
| 83 | + |
| 84 | + |
| 85 | +template <typename ValueType> |
| 86 | +GKO_KERNEL void update_givens_rotation(ValueType& alpha, const ValueType& beta, |
| 87 | + ValueType& cos, ValueType& sin) |
| 88 | +{ |
| 89 | + if (alpha == zero(alpha)) { |
| 90 | + cos = zero(cos); |
| 91 | + sin = one(sin); |
| 92 | + } else { |
| 93 | + const auto scale = abs(alpha) + abs(beta); |
| 94 | + const auto hypotenuse = |
| 95 | + scale * sqrt(abs(alpha / scale) * abs(alpha / scale) + |
| 96 | + abs(beta / scale) * abs(beta / scale)); |
| 97 | + cos = conj(alpha) / hypotenuse; |
| 98 | + sin = conj(beta) / hypotenuse; |
| 99 | + } |
| 100 | + alpha = cos * alpha + sin * beta; |
| 101 | +} |
| 102 | + |
| 103 | + |
| 104 | +template <typename ValueType> |
| 105 | +void step_1(std::shared_ptr<const DefaultExecutor> exec, |
| 106 | + matrix::Dense<ValueType>* alpha, matrix::Dense<ValueType>* beta, |
| 107 | + matrix::Dense<ValueType>* gamma, matrix::Dense<ValueType>* delta, |
| 108 | + matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos, |
| 109 | + matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin, |
| 110 | + matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next, |
| 111 | + typename matrix::Dense<ValueType>::absolute_type* tau, |
| 112 | + const array<stopping_status>* stop_status) |
| 113 | +{ |
| 114 | + run_kernel( |
| 115 | + exec, |
| 116 | + [] GKO_KERNEL(auto col, auto alpha, auto beta, auto gamma, auto delta, |
| 117 | + auto cos_prev, auto cos, auto sin_prev, auto sin, |
| 118 | + auto eta_next, auto eta, auto tau, auto stop) { |
| 119 | + if (!stop[col].has_stopped()) { |
| 120 | + beta[col] = sqrt(beta[col]); |
| 121 | + delta[col] = sin_prev[col] * gamma[col]; |
| 122 | + const auto tmp_d = gamma[col]; |
| 123 | + const auto tmp_a = alpha[col]; |
| 124 | + gamma[col] = |
| 125 | + cos_prev[col] * cos[col] * tmp_d + sin[col] * tmp_a; |
| 126 | + alpha[col] = |
| 127 | + -conj(sin[col]) * cos_prev[col] * tmp_d + cos[col] * tmp_a; |
| 128 | + |
| 129 | + detail::swap(cos[col], cos_prev[col]); |
| 130 | + detail::swap(sin[col], sin_prev[col]); |
| 131 | + update_givens_rotation(alpha[col], beta[col], cos[col], |
| 132 | + sin[col]); |
| 133 | + |
| 134 | + tau[col] = abs(sin[col]) * tau[col]; |
| 135 | + eta[col] = eta_next[col]; |
| 136 | + eta_next[col] = -conj(sin[col]) * eta[col]; |
| 137 | + } |
| 138 | + }, |
| 139 | + alpha->get_num_stored_elements(), row_vector(alpha), row_vector(beta), |
| 140 | + row_vector(gamma), row_vector(delta), row_vector(cos_prev), |
| 141 | + row_vector(cos), row_vector(sin_prev), row_vector(sin), |
| 142 | + row_vector(eta_next), row_vector(eta), row_vector(tau), *stop_status); |
| 143 | +} |
| 144 | + |
| 145 | +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_1_KERNEL); |
| 146 | + |
| 147 | + |
| 148 | +template <typename ValueType> |
| 149 | +void step_2(std::shared_ptr<const DefaultExecutor> exec, |
| 150 | + matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* p, |
| 151 | + const matrix::Dense<ValueType>* p_prev, matrix::Dense<ValueType>* z, |
| 152 | + const matrix::Dense<ValueType>* z_tilde, |
| 153 | + matrix::Dense<ValueType>* q, matrix::Dense<ValueType>* q_prev, |
| 154 | + matrix::Dense<ValueType>* v, const matrix::Dense<ValueType>* alpha, |
| 155 | + const matrix::Dense<ValueType>* beta, |
| 156 | + const matrix::Dense<ValueType>* gamma, |
| 157 | + const matrix::Dense<ValueType>* delta, |
| 158 | + const matrix::Dense<ValueType>* cos, |
| 159 | + const matrix::Dense<ValueType>* eta, |
| 160 | + const array<stopping_status>* stop_status) |
| 161 | +{ |
| 162 | + run_kernel_solver( |
| 163 | + exec, |
| 164 | + [] GKO_KERNEL(auto row, auto col, auto x, auto p, auto p_prev, auto q, |
| 165 | + auto q_prev, auto v, auto z, auto z_tilde, auto alpha, |
| 166 | + auto beta, auto gamma, auto delta, auto cos, auto eta, |
| 167 | + auto stop) { |
| 168 | + if (!stop[col].has_stopped()) { |
| 169 | + p(row, col) = |
| 170 | + safe_divide(z(row, col) - gamma[col] * p_prev(row, col) - |
| 171 | + delta[col] * p(row, col), |
| 172 | + alpha[col]); |
| 173 | + x(row, col) = x(row, col) + cos[col] * eta[col] * p(row, col); |
| 174 | + |
| 175 | + q_prev(row, col) = v(row, col); |
| 176 | + const auto tmp = q(row, col); |
| 177 | + z(row, col) = safe_divide(z_tilde(row, col), beta[col]); |
| 178 | + q(row, col) = safe_divide(v(row, col), beta[col]); |
| 179 | + v(row, col) = tmp * beta[col]; |
| 180 | + } |
| 181 | + }, |
| 182 | + x->get_size(), p->get_stride(), x, default_stride(p), |
| 183 | + default_stride(p_prev), default_stride(q), default_stride(q_prev), |
| 184 | + default_stride(v), default_stride(z), default_stride(z_tilde), |
| 185 | + row_vector(alpha), row_vector(beta), row_vector(gamma), |
| 186 | + row_vector(delta), row_vector(cos), row_vector(eta), *stop_status); |
| 187 | +} |
| 188 | + |
| 189 | +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MINRES_STEP_2_KERNEL); |
| 190 | + |
| 191 | + |
| 192 | +} // namespace minres |
| 193 | +} // namespace GKO_DEVICE_NAMESPACE |
| 194 | +} // namespace kernels |
| 195 | +} // namespace gko |
0 commit comments