@@ -395,16 +395,16 @@ __global__ __launch_bounds__(default_block_size) void fill_in_sellp(
395
395
template <typename ValueType, typename OutputType, typename IndexType>
396
396
__global__ __launch_bounds__(default_block_size) void row_scatter(
397
397
size_type num_sets, IndexType* __restrict__ row_set_begins,
398
- IndexType* __restrict__ row_set_offsets, size_type orig_num_rows ,
399
- size_type num_cols, size_type orig_stride,
398
+ IndexType* __restrict__ row_set_offsets, size_type target_num_rows ,
399
+ size_type num_cols, size_type orig_num_rows, size_type orig_stride,
400
400
const ValueType* __restrict__ orig_values, size_type target_stride,
401
- OutputType* __restrict__ target_values)
401
+ OutputType* __restrict__ target_values, bool* __restrict__ invalid_access )
402
402
{
403
403
auto id = thread::get_thread_id_flat();
404
404
auto row = id / num_cols;
405
405
auto col = id % num_cols;
406
406
407
- if (row >= orig_num_rows) {
407
+ if (row >= orig_num_rows || *invalid_access ) {
408
408
return ;
409
409
}
410
410
@@ -415,6 +415,11 @@ __global__ __launch_bounds__(default_block_size) void row_scatter(
415
415
auto set_local_row = row - row_set_offsets[set_id];
416
416
auto target_row = set_local_row + row_set_begins[set_id];
417
417
418
+ if (target_row >= target_num_rows) {
419
+ *invalid_access = true ;
420
+ return ;
421
+ }
422
+
418
423
target_values[target_row * target_stride + col] =
419
424
orig_values[row * orig_stride + col];
420
425
}
@@ -653,19 +658,28 @@ template <typename ValueType, typename OutputType, typename IndexType>
653
658
void row_scatter(std::shared_ptr<const DefaultExecutor> exec,
654
659
const index_set<IndexType>* row_idxs,
655
660
const matrix::Dense<ValueType>* orig,
656
- matrix::Dense<OutputType>* target)
661
+ matrix::Dense<OutputType>* target, bool& invalid_access )
657
662
{
658
- auto size = orig->get_size();
659
- if (size) {
663
+ auto orig_size = orig->get_size();
664
+ auto target_size = target->get_size();
665
+
666
+ array<bool> invalid_access_arr(exec, {false });
667
+
668
+ if (orig_size) {
660
669
constexpr auto block_size = default_block_size;
661
- auto num_blocks = ceildiv(size [0] * size [1], block_size);
670
+ auto num_blocks = ceildiv(orig_size [0] * orig_size [1], block_size);
662
671
kernel::row_scatter<<<num_blocks, block_size, 0, exec->get_stream()>>>(
663
672
row_idxs->get_num_subsets(),
664
673
as_device_type(row_idxs->get_subsets_begin()),
665
- as_device_type(row_idxs->get_superset_indices()), size[0], size[1],
666
- orig->get_stride(), as_device_type(orig->get_const_values()),
667
- target->get_stride(), as_device_type(target->get_values()));
674
+ as_device_type(row_idxs->get_superset_indices()), target_size[0],
675
+ target_size[1], orig_size[0], orig->get_stride(),
676
+ as_device_type(orig->get_const_values()), target->get_stride(),
677
+ as_device_type(target->get_values()),
678
+ as_device_type(invalid_access_arr.get_data()));
668
679
}
680
+
681
+ invalid_access =
682
+ exec->copy_val_to_host(invalid_access_arr.get_const_data());
669
683
}
670
684
671
685
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
0 commit comments