Skip to content

Commit d5da55e

Browse files
committed
adds advanced scatter operation
1 parent 9b9fa61 commit d5da55e

File tree

7 files changed

+175
-24
lines changed

7 files changed

+175
-24
lines changed

common/unified/matrix/row_scatterer.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
4040
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY);
4141

4242

43+
template <typename ValueType, typename OutputType, typename IndexType>
44+
void advanced_row_scatter(std::shared_ptr<const DefaultExecutor> exec,
45+
const array<IndexType>* row_idxs,
46+
const matrix::Dense<ValueType>* alpha,
47+
const matrix::Dense<ValueType>* orig,
48+
const matrix::Dense<OutputType>* beta,
49+
matrix::Dense<OutputType>* target,
50+
bit_packed_span<bool, IndexType, uint32> mask,
51+
bool& invalid_access) GKO_NOT_IMPLEMENTED;
52+
53+
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
54+
GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY);
55+
56+
4357
} // namespace row_scatter
4458
} // namespace GKO_DEVICE_NAMESPACE
4559
} // namespace kernels

core/device_hooks/common_kernels.inc.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,10 @@ namespace row_scatter {
11221122

11231123

11241124
GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY);
1125+
GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY);
11251126

11261127

1127-
}
1128+
} // namespace row_scatter
11281129
} // namespace GKO_HOOK_MODULE
11291130
} // namespace kernels
11301131
} // namespace gko

core/matrix/row_scatterer.cpp

+34-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ginkgo/core/matrix/dense.hpp>
99

1010
#include "core/base/dispatch_helper.hpp"
11+
#include "core/components/bit_packed_storage.hpp"
1112
#include "core/matrix/row_scatterer_kernels.hpp"
1213

1314
namespace gko {
@@ -16,9 +17,10 @@ namespace {
1617

1718

1819
GKO_REGISTER_OPERATION(row_scatter, row_scatter::row_scatter);
20+
GKO_REGISTER_OPERATION(advanced_row_scatter, row_scatter::advanced_row_scatter);
1921

2022

21-
}
23+
} // namespace
2224

2325

2426
template <typename IndexType>
@@ -41,20 +43,20 @@ template <typename IndexType>
4143
RowScatterer<IndexType>::RowScatterer(std::shared_ptr<const Executor> exec,
4244
array<IndexType> idxs, size_type to_size)
4345
: EnableLinOp<RowScatterer<IndexType>>(exec, {to_size, idxs.get_size()}),
44-
idxs_(exec, std::move(idxs))
46+
idxs_(exec, std::move(idxs)),
47+
mask_(exec,
48+
bit_packed_span<bool, IndexType, uint32>::storage_size(to_size, 1))
4549
{}
4650

4751

4852
template <typename IndexType>
4953
void RowScatterer<IndexType>::apply_impl(const LinOp* b, LinOp* x) const
5054
{
5155
auto impl = [&](const auto* orig, auto* target) {
52-
auto exec = orig->get_executor();
56+
auto exec = this->get_executor();
5357
bool invalid_access = false;
5458

55-
exec->run(make_row_scatter(
56-
make_temporary_clone(exec, &idxs_).get(), orig,
57-
make_temporary_clone(exec, target).get(), invalid_access));
59+
exec->run(make_row_scatter(&idxs_, orig, target, invalid_access));
5860

5961
// TODO: find a uniform way to handle device-side errors
6062
if (invalid_access) {
@@ -79,19 +81,37 @@ template <typename IndexType>
7981
void RowScatterer<IndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
8082
const LinOp* beta, LinOp* x) const
8183
{
82-
auto x_copy = gko::clone(x);
84+
auto impl = [&](const auto* orig, auto* target) {
85+
auto exec = this->get_executor();
86+
bool invalid_access = false;
87+
88+
auto dense_alpha = make_temporary_conversion<
89+
typename std::decay_t<decltype(*orig)>::value_type>(alpha);
90+
auto dense_beta = make_temporary_conversion<
91+
typename std::decay_t<decltype(*target)>::value_type>(beta);
92+
93+
exec->run(make_advanced_row_scatter(
94+
&idxs_, dense_alpha.get(), orig, dense_beta.get(), target,
95+
bit_packed_span<bool, IndexType, uint32>(mask_.get_data(), 1,
96+
this->get_size()[0]),
97+
invalid_access));
98+
99+
if (invalid_access) {
100+
GKO_INVALID_STATE("Out-of-bounds scatter index detected.");
101+
}
102+
};
103+
104+
mask_.fill(uint32{});
105+
83106
run<Dense,
84107
#if GINKGO_ENABLE_HALF
85108
gko::half, std::complex<gko::half>,
86109
#endif
87110
float, double, std::complex<float>, std::complex<double>>(
88-
x, [&](auto* target) {
89-
using dense_type = std::decay_t<decltype(*target)>;
90-
as<dense_type>(x_copy)->fill(
91-
gko::zero<typename dense_type::value_type>());
92-
this->apply_impl(b, x_copy);
93-
target->scale(beta);
94-
target->add_scaled(alpha, x_copy);
111+
b, [&](auto* orig) {
112+
using value_type =
113+
typename std::decay_t<decltype(*orig)>::value_type;
114+
mixed_precision_dispatch_real_complex<value_type>(impl, orig, x);
95115
});
96116
}
97117

core/matrix/row_scatterer_kernels.hpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ginkgo/core/matrix/dense.hpp>
99

1010
#include "core/base/kernel_declaration.hpp"
11+
#include "core/components/bit_packed_storage.hpp"
1112

1213

1314
namespace gko {
@@ -19,10 +20,20 @@ namespace kernels {
1920
const matrix::Dense<_vtype>* orig, \
2021
matrix::Dense<_otype>* target, bool& invalid_access)
2122

23+
#define GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY(_vtype, _otype, _itype) \
24+
void advanced_row_scatter( \
25+
std::shared_ptr<const DefaultExecutor> exec, \
26+
const array<_itype>* row_idxs, const matrix::Dense<_vtype>* alpha, \
27+
const matrix::Dense<_vtype>* orig, const matrix::Dense<_otype>* beta, \
28+
matrix::Dense<_otype>* target, \
29+
bit_packed_span<bool, _itype, uint32> mask, bool& invalid_access)
2230

23-
#define GKO_DECLARE_ALL_AS_TEMPLATES \
24-
template <typename ValueType, typename OutputType, typename IndexType> \
25-
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY(ValueType, OutputType, IndexType)
31+
32+
#define GKO_DECLARE_ALL_AS_TEMPLATES \
33+
template <typename ValueType, typename OutputType, typename IndexType> \
34+
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY(ValueType, OutputType, IndexType); \
35+
template <typename ValueType, typename OutputType, typename IndexType> \
36+
GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY(ValueType, OutputType, IndexType)
2637

2738

2839
GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(row_scatter,

include/ginkgo/core/matrix/row_scatterer.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@
77
#include <ginkgo/core/base/lin_op.hpp>
88

99
namespace gko {
10+
11+
template <typename ValueType, typename IndexType, typename StorageType>
12+
class bit_packed_span;
13+
14+
1015
namespace matrix {
1116

1217

18+
/**
19+
*
20+
* @tparam IndexType type for defining the scatter-to indices
21+
*/
1322
template <typename IndexType = int32>
1423
class RowScatterer : public EnableLinOp<RowScatterer<IndexType>> {
1524
friend class EnablePolymorphicObject<RowScatterer<IndexType>, LinOp>;
@@ -32,6 +41,7 @@ class RowScatterer : public EnableLinOp<RowScatterer<IndexType>> {
3241
array<IndexType> idxs, size_type to_size);
3342

3443
array<IndexType> idxs_;
44+
mutable array<uint32> mask_;
3545
};
3646

3747

reference/matrix/row_scatterer_kernels.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,40 @@ void row_scatter(std::shared_ptr<const ReferenceExecutor> exec,
3939
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
4040
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY);
4141

42+
template <typename ValueType, typename OutputType, typename IndexType>
43+
void advanced_row_scatter(std::shared_ptr<const ReferenceExecutor> exec,
44+
const array<IndexType>* row_idxs,
45+
const matrix::Dense<ValueType>* alpha,
46+
const matrix::Dense<ValueType>* orig,
47+
const matrix::Dense<OutputType>* beta,
48+
matrix::Dense<OutputType>* target,
49+
bit_packed_span<bool, IndexType, uint32> mask,
50+
bool& invalid_access)
51+
{
52+
using type = highest_precision<ValueType, OutputType>;
53+
auto scalar_alpha = alpha->at(0, 0);
54+
auto scalar_beta = beta->at(0, 0);
55+
auto rows = row_idxs->get_const_data();
56+
for (size_type i = 0; i < row_idxs->get_size(); ++i) {
57+
if (rows[i] >= target->get_size()[0]) {
58+
invalid_access = true;
59+
return;
60+
}
61+
62+
bool scaled = mask.get(rows[i]);
63+
mask.set(rows[i], true);
64+
for (size_type j = 0; j < orig->get_size()[1]; ++j) {
65+
target->at(rows[i], j) =
66+
static_cast<type>(scalar_alpha * orig->at(i, j)) +
67+
(scaled ? type{1.0} : static_cast<type>(scalar_beta)) *
68+
static_cast<type>(target->at(rows[i], j));
69+
}
70+
}
71+
}
72+
73+
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
74+
GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY);
75+
4276

4377
} // namespace row_scatter
4478
} // namespace reference

reference/test/matrix/row_scatterer_kernels.cpp

+67-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <ginkgo/core/matrix/row_scatterer.hpp>
1212

13+
#include "core/components/bit_packed_storage.hpp"
1314
#include "core/test/utils.hpp"
1415

1516

@@ -22,20 +23,25 @@ class RowScatter : public ::testing::Test {
2223
using DenseIn = gko::matrix::Dense<in_value_type>;
2324
using DenseOut = gko::matrix::Dense<out_value_type>;
2425
using Scatterer = gko::matrix::RowScatterer<index_type>;
25-
26+
using mask_type = gko::bit_packed_span<bool, index_type, gko::uint32>;
2627

2728
std::shared_ptr<gko::ReferenceExecutor> exec =
2829
gko::ReferenceExecutor::create();
2930

3031
std::unique_ptr<DenseIn> in =
3132
gko::initialize<DenseIn>(I<I<in_value_type>>{{1, 2}, {3, 4}}, exec);
32-
std::unique_ptr<DenseIn> in_repeated =
33-
gko::initialize<DenseIn>(I<I<in_value_type>>{{1, 2}, {3, 4}, {5, 6}}, exec);
33+
std::unique_ptr<DenseIn> in_repeated = gko::initialize<DenseIn>(
34+
I<I<in_value_type>>{{1, 2}, {3, 4}, {5, 6}}, exec);
3435
std::unique_ptr<DenseOut> out = gko::initialize<DenseOut>(
3536
I<I<out_value_type>>{{11, 22}, {33, 44}, {55, 66}, {77, 88}}, exec);
37+
std::unique_ptr<DenseIn> alpha =
38+
gko::initialize<DenseIn>({in_value_type{-1.0}}, this->exec);
39+
std::unique_ptr<DenseOut> beta =
40+
gko::initialize<DenseOut>({out_value_type{-2.0}}, this->exec);
3641

3742
gko::array<index_type> idxs = {exec, {3, 1}};
3843
gko::array<index_type> idxs_repeated = {exec, {3, 3, 1}};
44+
gko::array<gko::uint32> mask = {exec, mask_type::storage_size(4, 1)};
3945

4046
std::unique_ptr<Scatterer> scatterer =
4147
Scatterer::create(exec, idxs, out->get_size()[0]);
@@ -50,7 +56,7 @@ TYPED_TEST_SUITE(RowScatter, gko::test::MixedPresisionValueIndexTypes,
5056
#endif
5157

5258

53-
TYPED_TEST(RowScatter, CanRowScatter)
59+
TYPED_TEST(RowScatter, CanSimpleRowScatter)
5460
{
5561
bool invalid_access = false;
5662

@@ -72,8 +78,8 @@ TYPED_TEST(RowScatter, SimpleRowScatterIsAdditive)
7278
bool invalid_access = false;
7379

7480
gko::kernels::reference::row_scatter::row_scatter(
75-
this->exec, &this->idxs_repeated, this->in_repeated.get(), this->out.get(),
76-
invalid_access);
81+
this->exec, &this->idxs_repeated, this->in_repeated.get(),
82+
this->out.get(), invalid_access);
7783

7884
ASSERT_FALSE(invalid_access);
7985
auto expected = gko::initialize<typename TestFixture::DenseOut>(
@@ -84,6 +90,48 @@ TYPED_TEST(RowScatter, SimpleRowScatterIsAdditive)
8490
}
8591

8692

93+
TYPED_TEST(RowScatter, CanAdvancedRowScatter)
94+
{
95+
using mask_type = typename TestFixture::mask_type;
96+
bool invalid_access = false;
97+
this->mask.fill(0);
98+
99+
gko::kernels::reference::row_scatter::advanced_row_scatter(
100+
this->exec, &this->idxs, this->alpha.get(), this->in.get(),
101+
this->beta.get(), this->out.get(),
102+
mask_type(this->mask.get_data(), 1, this->out->get_size()[0]),
103+
invalid_access);
104+
105+
ASSERT_FALSE(invalid_access);
106+
auto expected = gko::initialize<typename TestFixture::DenseOut>(
107+
I<I<typename TestFixture::out_value_type>>{
108+
{11, 22}, {-69, -92}, {55, 66}, {-155, -178}},
109+
this->exec);
110+
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
111+
}
112+
113+
114+
TYPED_TEST(RowScatter, AdvancedRowScatterIsAdditive)
115+
{
116+
using mask_type = typename TestFixture::mask_type;
117+
bool invalid_access = false;
118+
this->mask.fill(0);
119+
120+
gko::kernels::reference::row_scatter::advanced_row_scatter(
121+
this->exec, &this->idxs_repeated, this->alpha.get(),
122+
this->in_repeated.get(), this->beta.get(), this->out.get(),
123+
mask_type(this->mask.get_data(), 1, this->out->get_size()[0]),
124+
invalid_access);
125+
126+
ASSERT_FALSE(invalid_access);
127+
auto expected = gko::initialize<typename TestFixture::DenseOut>(
128+
I<I<typename TestFixture::out_value_type>>{
129+
{11, 22}, {-71, -94}, {55, 66}, {-158, -182}},
130+
this->exec);
131+
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
132+
}
133+
134+
87135
TYPED_TEST(RowScatter, CanDetectInvalidAccess)
88136
{
89137
bool invalid_access = false;
@@ -106,3 +154,16 @@ TYPED_TEST(RowScatter, CanRowScatterSimpleApply)
106154
this->exec);
107155
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
108156
}
157+
158+
159+
TYPED_TEST(RowScatter, CanRowScatterAdvancedApply)
160+
{
161+
this->scatterer->apply(this->alpha, this->in.get(), this->beta,
162+
this->out.get());
163+
164+
auto expected = gko::initialize<typename TestFixture::DenseOut>(
165+
I<I<typename TestFixture::out_value_type>>{
166+
{11, 22}, {-69, -92}, {55, 66}, {-155, -178}},
167+
this->exec);
168+
GKO_ASSERT_MTX_NEAR(this->out, expected, 0.0);
169+
}

0 commit comments

Comments
 (0)