10
10
11
11
#include < ginkgo/core/matrix/row_scatterer.hpp>
12
12
13
+ #include " core/components/bit_packed_storage.hpp"
13
14
#include " core/test/utils.hpp"
14
15
15
16
@@ -22,20 +23,25 @@ class RowScatter : public ::testing::Test {
22
23
using DenseIn = gko::matrix::Dense<in_value_type>;
23
24
using DenseOut = gko::matrix::Dense<out_value_type>;
24
25
using Scatterer = gko::matrix::RowScatterer<index_type>;
25
-
26
+ using mask_type = gko::bit_packed_span< bool , index_type, gko::uint32>;
26
27
27
28
std::shared_ptr<gko::ReferenceExecutor> exec =
28
29
gko::ReferenceExecutor::create ();
29
30
30
31
std::unique_ptr<DenseIn> in =
31
32
gko::initialize<DenseIn>(I<I<in_value_type>>{{1 , 2 }, {3 , 4 }}, exec);
32
- std::unique_ptr<DenseIn> in_repeated =
33
- gko::initialize<DenseIn>( I<I<in_value_type>>{{1 , 2 }, {3 , 4 }, {5 , 6 }}, exec);
33
+ std::unique_ptr<DenseIn> in_repeated = gko::initialize<DenseIn>(
34
+ I<I<in_value_type>>{{1 , 2 }, {3 , 4 }, {5 , 6 }}, exec);
34
35
std::unique_ptr<DenseOut> out = gko::initialize<DenseOut>(
35
36
I<I<out_value_type>>{{11 , 22 }, {33 , 44 }, {55 , 66 }, {77 , 88 }}, exec);
37
+ std::unique_ptr<DenseIn> alpha =
38
+ gko::initialize<DenseIn>({in_value_type{-1.0 }}, this ->exec);
39
+ std::unique_ptr<DenseOut> beta =
40
+ gko::initialize<DenseOut>({out_value_type{-2.0 }}, this ->exec);
36
41
37
42
gko::array<index_type> idxs = {exec, {3 , 1 }};
38
43
gko::array<index_type> idxs_repeated = {exec, {3 , 3 , 1 }};
44
+ gko::array<gko::uint32> mask = {exec, mask_type::storage_size (4 , 1 )};
39
45
40
46
std::unique_ptr<Scatterer> scatterer =
41
47
Scatterer::create (exec, idxs, out->get_size ()[0]);
@@ -50,7 +56,7 @@ TYPED_TEST_SUITE(RowScatter, gko::test::MixedPresisionValueIndexTypes,
50
56
#endif
51
57
52
58
53
- TYPED_TEST (RowScatter, CanRowScatter )
59
+ TYPED_TEST (RowScatter, CanSimpleRowScatter )
54
60
{
55
61
bool invalid_access = false ;
56
62
@@ -72,8 +78,8 @@ TYPED_TEST(RowScatter, SimpleRowScatterIsAdditive)
72
78
bool invalid_access = false ;
73
79
74
80
gko::kernels::reference::row_scatter::row_scatter (
75
- this ->exec , &this ->idxs_repeated , this ->in_repeated .get (), this -> out . get (),
76
- invalid_access);
81
+ this ->exec , &this ->idxs_repeated , this ->in_repeated .get (),
82
+ this -> out . get (), invalid_access);
77
83
78
84
ASSERT_FALSE (invalid_access);
79
85
auto expected = gko::initialize<typename TestFixture::DenseOut>(
@@ -84,6 +90,48 @@ TYPED_TEST(RowScatter, SimpleRowScatterIsAdditive)
84
90
}
85
91
86
92
93
+ TYPED_TEST (RowScatter, CanAdvancedRowScatter)
94
+ {
95
+ using mask_type = typename TestFixture::mask_type;
96
+ bool invalid_access = false ;
97
+ this ->mask .fill (0 );
98
+
99
+ gko::kernels::reference::row_scatter::advanced_row_scatter (
100
+ this ->exec , &this ->idxs , this ->alpha .get (), this ->in .get (),
101
+ this ->beta .get (), this ->out .get (),
102
+ mask_type (this ->mask .get_data (), 1 , this ->out ->get_size ()[0 ]),
103
+ invalid_access);
104
+
105
+ ASSERT_FALSE (invalid_access);
106
+ auto expected = gko::initialize<typename TestFixture::DenseOut>(
107
+ I<I<typename TestFixture::out_value_type>>{
108
+ {11 , 22 }, {-69 , -92 }, {55 , 66 }, {-155 , -178 }},
109
+ this ->exec );
110
+ GKO_ASSERT_MTX_NEAR (this ->out , expected, 0.0 );
111
+ }
112
+
113
+
114
+ TYPED_TEST (RowScatter, AdvancedRowScatterIsAdditive)
115
+ {
116
+ using mask_type = typename TestFixture::mask_type;
117
+ bool invalid_access = false ;
118
+ this ->mask .fill (0 );
119
+
120
+ gko::kernels::reference::row_scatter::advanced_row_scatter (
121
+ this ->exec , &this ->idxs_repeated , this ->alpha .get (),
122
+ this ->in_repeated .get (), this ->beta .get (), this ->out .get (),
123
+ mask_type (this ->mask .get_data (), 1 , this ->out ->get_size ()[0 ]),
124
+ invalid_access);
125
+
126
+ ASSERT_FALSE (invalid_access);
127
+ auto expected = gko::initialize<typename TestFixture::DenseOut>(
128
+ I<I<typename TestFixture::out_value_type>>{
129
+ {11 , 22 }, {-71 , -94 }, {55 , 66 }, {-158 , -182 }},
130
+ this ->exec );
131
+ GKO_ASSERT_MTX_NEAR (this ->out , expected, 0.0 );
132
+ }
133
+
134
+
87
135
TYPED_TEST (RowScatter, CanDetectInvalidAccess)
88
136
{
89
137
bool invalid_access = false ;
@@ -106,3 +154,16 @@ TYPED_TEST(RowScatter, CanRowScatterSimpleApply)
106
154
this ->exec );
107
155
GKO_ASSERT_MTX_NEAR (this ->out , expected, 0.0 );
108
156
}
157
+
158
+
159
+ TYPED_TEST (RowScatter, CanRowScatterAdvancedApply)
160
+ {
161
+ this ->scatterer ->apply (this ->alpha , this ->in .get (), this ->beta ,
162
+ this ->out .get ());
163
+
164
+ auto expected = gko::initialize<typename TestFixture::DenseOut>(
165
+ I<I<typename TestFixture::out_value_type>>{
166
+ {11 , 22 }, {-69 , -92 }, {55 , 66 }, {-155 , -178 }},
167
+ this ->exec );
168
+ GKO_ASSERT_MTX_NEAR (this ->out , expected, 0.0 );
169
+ }
0 commit comments