@@ -1288,38 +1288,39 @@ void Dense<ValueType>::row_gather_impl(const Dense<ValueType>* alpha,
1288
1288
}
1289
1289
1290
1290
1291
- template <typename ValueType>
1292
- template <typename OutputType, typename IndexType>
1293
- void Dense<ValueType>::row_scatter_impl(const array<IndexType>* row_idxs,
1294
- Dense<OutputType>* target) const
1291
+ template <typename IndexType>
1292
+ size_type get_size (const array<IndexType>* arr)
1295
1293
{
1296
- auto exec = this ->get_executor ();
1297
- dim<2 > expected_dim{row_idxs->get_num_elems (), this ->get_size ()[1 ]};
1298
- GKO_ASSERT_EQUAL_DIMENSIONS (expected_dim, this );
1299
- GKO_ASSERT_EQUAL_COLS (this , target);
1300
- // @todo check that indices are inbounds for target
1294
+ return arr->get_size ();
1295
+ }
1301
1296
1302
- exec->run (dense::make_row_scatter (
1303
- make_temporary_clone (exec, row_idxs).get (), this ,
1304
- make_temporary_clone (exec, target).get ()));
1297
+ template <typename IndexType>
1298
+ size_type get_size (const index_set<IndexType>* is)
1299
+ {
1300
+ return is->get_num_elems ();
1305
1301
}
1306
1302
1307
1303
1308
- template <typename ValueType>
1309
- template <typename OutputType, typename IndexType>
1310
- void Dense<ValueType>::row_scatter_impl(const index_set<IndexType>* row_idxs,
1311
- Dense<OutputType>* target) const
1304
+ template <typename ValueType, typename OutputType, typename IndexContainer>
1305
+ void row_scatter_impl (const IndexContainer* row_idxs,
1306
+ const Dense<ValueType>* orig, Dense<OutputType>* target)
1312
1307
{
1313
- auto exec = this ->get_executor ();
1314
- dim<2 > expected_dim{static_cast <size_type>(row_idxs->get_num_elems ()),
1315
- this ->get_size ()[1 ]};
1316
- GKO_ASSERT_EQUAL_DIMENSIONS (expected_dim, this );
1317
- GKO_ASSERT_EQUAL_COLS (this , target);
1318
- // @todo check that indices are inbounds for target
1308
+ auto exec = orig->get_executor ();
1309
+ dim<2 > expected_dim{static_cast <size_type>(get_size (row_idxs)),
1310
+ orig->get_size ()[1 ]};
1311
+ GKO_ASSERT_EQUAL_DIMENSIONS (expected_dim, orig);
1312
+ GKO_ASSERT_EQUAL_COLS (orig, target);
1313
+
1314
+ bool invalid_access = false ;
1319
1315
1320
1316
exec->run (dense::make_row_scatter (
1321
- make_temporary_clone (exec, row_idxs).get (), this ,
1322
- make_temporary_clone (exec, target).get ()));
1317
+ make_temporary_clone (exec, row_idxs).get (), orig,
1318
+ make_temporary_clone (exec, target).get (), invalid_access));
1319
+
1320
+ if (invalid_access) {
1321
+ GKO_INVALID_STATE (
1322
+ " Out-of-bounds access detected during kernel execution." );
1323
+ }
1323
1324
}
1324
1325
1325
1326
@@ -1633,7 +1634,7 @@ void Dense<ValueType>::row_scatter(const array<IndexType>* row_idxs,
1633
1634
ptr_param<LinOp> row_collection) const
1634
1635
{
1635
1636
gather_mixed_real_complex<ValueType>(
1636
- [&](auto dense) { this -> row_scatter_impl (row_idxs, dense); },
1637
+ [&](auto dense) { row_scatter_impl (row_idxs, this , dense); },
1637
1638
row_collection.get ());
1638
1639
}
1639
1640
@@ -1644,7 +1645,7 @@ void Dense<ValueType>::row_scatter(const index_set<IndexType>* row_idxs,
1644
1645
ptr_param<LinOp> row_collection) const
1645
1646
{
1646
1647
gather_mixed_real_complex<ValueType>(
1647
- [&](auto dense) { this -> row_scatter_impl (row_idxs, dense); },
1648
+ [&](auto dense) { row_scatter_impl (row_idxs, this , dense); },
1648
1649
row_collection.get ());
1649
1650
}
1650
1651
0 commit comments