|
18 | 18 |
|
19 | 19 | #include "core/test/utils.hpp"
|
20 | 20 |
|
| 21 | + |
21 | 22 | namespace {
|
22 | 23 |
|
| 24 | + |
23 | 25 | template <typename T>
|
24 | 26 | class Minres : public ::testing::Test {
|
25 | 27 | protected:
|
@@ -178,6 +180,32 @@ TYPED_TEST(Minres, KernelInitialize)
|
178 | 180 | }
|
179 | 181 |
|
180 | 182 |
|
| 183 | +TYPED_TEST(Minres, KernelInitializeWithSafeDivide) |
| 184 | +{ |
| 185 | + using Mtx = typename TestFixture::Mtx; |
| 186 | + using vt = typename TestFixture::value_type; |
| 187 | + this->small_r = |
| 188 | + gko::initialize<Mtx>(I<I<vt>>({{1, 2}, {3, 4}}), this->exec); |
| 189 | + this->small_z = gko::initialize<Mtx>(I<I<vt>>{{4, 3}, {2, 1}}, this->exec); |
| 190 | + auto zero = gko::zero<vt>(); |
| 191 | + this->beta = gko::initialize<Mtx>(I<I<vt>>{{zero, zero}}, this->exec); |
| 192 | + this->small_q->fill(1); |
| 193 | + |
| 194 | + gko::kernels::reference::minres::initialize( |
| 195 | + this->exec, this->small_r.get(), this->small_z.get(), |
| 196 | + this->small_p.get(), this->small_p_prev.get(), this->small_q.get(), |
| 197 | + this->small_q_prev.get(), this->small_v.get(), this->beta.get(), |
| 198 | + this->gamma.get(), this->delta.get(), this->cos_prev.get(), |
| 199 | + this->cos.get(), this->sin_prev.get(), this->sin.get(), |
| 200 | + this->eta_next.get(), this->eta.get(), &this->small_stop); |
| 201 | + |
| 202 | + GKO_ASSERT_MTX_NEAR(this->small_q, l({{0.0, 2. / 5}, {0.0, 4. / 5}}), |
| 203 | + r<vt>::value); |
| 204 | + GKO_ASSERT_MTX_NEAR(this->small_z, l({{0.0, 3. / 5}, {0.0, 1. / 5}}), |
| 205 | + r<vt>::value); |
| 206 | +} |
| 207 | + |
| 208 | + |
181 | 209 | TYPED_TEST(Minres, KernelStep1)
|
182 | 210 | {
|
183 | 211 | using Mtx = typename TestFixture::Mtx;
|
@@ -273,6 +301,36 @@ TYPED_TEST(Minres, KernelStep2)
|
273 | 301 | }
|
274 | 302 |
|
275 | 303 |
|
| 304 | +TYPED_TEST(Minres, KernelStep2WithSafeDivide) |
| 305 | +{ |
| 306 | + using Mtx = typename TestFixture::Mtx; |
| 307 | + using vt = typename TestFixture::value_type; |
| 308 | + this->small_q = gko::initialize<Mtx>(I<I<vt>>{{4, 9}, {7, 11}}, this->exec); |
| 309 | + this->small_p = gko::initialize<Mtx>(I<I<vt>>{{1, 2}, {3, 4}}, this->exec); |
| 310 | + this->small_z = gko::initialize<Mtx>(I<I<vt>>{{6, 1}, {7, 3}}, this->exec); |
| 311 | + auto zero = gko::zero<vt>(); |
| 312 | + this->alpha = gko::initialize<Mtx>(I<I<vt>>{{zero, zero}}, this->exec); |
| 313 | + this->beta = gko::initialize<Mtx>(I<I<vt>>{{zero, zero}}, this->exec); |
| 314 | + auto old_small_q = gko::clone(this->small_q); |
| 315 | + auto old_small_v = gko::clone(this->small_v); |
| 316 | + auto old_small_q_scaled = gko::clone(this->small_v); |
| 317 | + auto old_small_z_tilde_scaled = gko::clone(this->small_z_tilde); |
| 318 | + auto old_small_v_scaled = gko::clone(this->small_q); |
| 319 | + |
| 320 | + gko::kernels::reference::minres::step_2( |
| 321 | + this->exec, this->small_x.get(), this->small_p.get(), |
| 322 | + this->small_p_prev.get(), this->small_z.get(), |
| 323 | + this->small_z_tilde.get(), this->small_q.get(), |
| 324 | + this->small_q_prev.get(), this->small_v.get(), this->alpha.get(), |
| 325 | + this->beta.get(), this->gamma.get(), this->delta.get(), this->cos.get(), |
| 326 | + this->eta.get(), &this->small_stop); |
| 327 | + |
| 328 | + GKO_ASSERT_MTX_NEAR(this->small_q, l({{0.0, 0.0}, {0.0, 0.0}}), 0.); |
| 329 | + GKO_ASSERT_MTX_NEAR(this->small_z, l({{0.0, 0.0}, {0.0, 0.0}}), 0.); |
| 330 | + GKO_ASSERT_MTX_NEAR(this->small_p, l({{0.0, 0.0}, {0.0, 0.0}}), 0.); |
| 331 | +} |
| 332 | + |
| 333 | + |
276 | 334 | TYPED_TEST(Minres, SolvesSystem)
|
277 | 335 | {
|
278 | 336 | using Mtx = typename TestFixture::Mtx;
|
|
0 commit comments