Skip to content

Commit 83bc786

Browse files
committed
wip
1 parent 45e5a95 commit 83bc786

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

core/matrix/row_scatterer.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,22 @@ void RowScatterer<IndexType>::apply_impl(const LinOp* b, LinOp* x) const
7878
template <typename IndexType>
7979
void RowScatterer<IndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
8080
const LinOp* beta, LinOp* x) const
81-
{}
81+
{
82+
auto x_copy = gko::clone(x);
83+
run<Dense,
84+
#if GINKGO_ENABLE_HALF
85+
gko::half, std::complex<gko::half>,
86+
#endif
87+
float, double, std::complex<float>, std::complex<double>>(
88+
x, [&](auto* target) {
89+
using dense_type = std::decay_t<decltype(*target)>;
90+
as<dense_type>(x_copy)->fill(
91+
gko::zero<typename dense_type::value_type>());
92+
this->apply_impl(b, x_copy);
93+
target->scale(beta);
94+
target->add_scaled(alpha, x_copy);
95+
});
96+
}
8297

8398

8499
#define GKO_DECLARE_ROW_SCATTER(_type) class RowScatterer<_type>

reference/matrix/row_scatterer_kernels.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ void row_scatter(std::shared_ptr<const ReferenceExecutor> exec,
2525
return;
2626
}
2727
for (size_type j = 0; j < orig->get_size()[1]; ++j) {
28-
target->at(rows[i], j) = orig->at(i, j);
28+
target->at(rows[i], j) = zero<OutputType>();
29+
}
30+
}
31+
32+
for (size_type i = 0; i < row_idxs->get_size(); ++i) {
33+
for (size_type j = 0; j < orig->get_size()[1]; ++j) {
34+
target->at(rows[i], j) += orig->at(i, j);
2935
}
3036
}
3137
}

reference/test/matrix/row_scatterer_kernels.cpp

+33-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class RowScatter : public ::testing::Test {
2121
using index_type = std::tuple_element_t<2, InValueOutValueIndexType>;
2222
using DenseIn = gko::matrix::Dense<in_value_type>;
2323
using DenseOut = gko::matrix::Dense<out_value_type>;
24+
using Scatterer = gko::matrix::RowScatterer<index_type>;
2425

2526

2627
std::shared_ptr<gko::ReferenceExecutor> exec =
@@ -32,6 +33,9 @@ class RowScatter : public ::testing::Test {
3233
I<I<out_value_type>>{{11, 22}, {33, 44}, {55, 66}, {77, 88}}, exec);
3334

3435
gko::array<index_type> idxs = {exec, {3, 1}};
36+
37+
std::unique_ptr<Scatterer> scatterer =
38+
Scatterer::create(exec, idxs, out->get_size()[0]);
3539
};
3640

3741
#ifdef GINKGO_MIXED_PRECISION
@@ -43,7 +47,7 @@ TYPED_TEST_SUITE(RowScatter, gko::test::MixedPresisionValueIndexTypes,
4347
#endif
4448

4549

46-
TYPED_TEST(RowScatter, CanScatter)
50+
TYPED_TEST(RowScatter, CanRowScatter)
4751
{
4852
bool invalid_access = false;
4953

@@ -70,3 +74,31 @@ TYPED_TEST(RowScatter, CanDetectInvalidAccess)
7074

7175
ASSERT_TRUE(invalid_access);
7276
}
77+
78+
79+
TYPED_TEST(RowScatter, CanRowScatterSimpleApply)
80+
{
81+
this->scatterer->apply(this->in.get(), this->out.get());
82+
83+
auto expected = gko::initialize<typename TestFixture::DenseOut>(
84+
I<I<typename TestFixture::out_value_type>>{
85+
{11, 22}, {3, 4}, {55, 66}, {1, 2}},
86+
this->exec);
87+
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
88+
}
89+
90+
91+
TYPED_TEST(RowScatter, CanRowScatterAdvancedApply)
92+
{
93+
auto alpha = gko::initialize<typename TestFixture::DenseIn>(
94+
{-gko::one<typename TestFixture::in_value_type>()}, this->exec);
95+
auto beta = gko::initialize<typename TestFixture::DenseOut>(
96+
{-2 * gko::one<typename TestFixture::out_value_type>()}, this->exec);
97+
this->scatterer->apply(this->in.get(), this->out.get());
98+
99+
auto expected = gko::initialize<typename TestFixture::DenseOut>(
100+
I<I<typename TestFixture::out_value_type>>{
101+
{-11, -22}, {3, 4}, {-55, -66}, {1, 2}},
102+
this->exec);
103+
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
104+
}

0 commit comments

Comments
 (0)