@@ -423,16 +423,16 @@ __global__ __launch_bounds__(default_block_size) void fill_in_sellp(
423
423
template <typename ValueType, typename OutputType, typename IndexType>
424
424
__global__ __launch_bounds__(default_block_size) void row_scatter(
425
425
size_type num_sets, IndexType* __restrict__ row_set_begins,
426
- IndexType* __restrict__ row_set_offsets, size_type orig_num_rows ,
427
- size_type num_cols, size_type orig_stride,
426
+ IndexType* __restrict__ row_set_offsets, size_type target_num_rows ,
427
+ size_type num_cols, size_type orig_num_rows, size_type orig_stride,
428
428
const ValueType* __restrict__ orig_values, size_type target_stride,
429
- OutputType* __restrict__ target_values)
429
+ OutputType* __restrict__ target_values, bool* __restrict__ invalid_access )
430
430
{
431
431
auto id = thread::get_thread_id_flat();
432
432
auto row = id / num_cols;
433
433
auto col = id % num_cols;
434
434
435
- if (row >= orig_num_rows) {
435
+ if (row >= orig_num_rows || *invalid_access ) {
436
436
return ;
437
437
}
438
438
@@ -443,6 +443,11 @@ __global__ __launch_bounds__(default_block_size) void row_scatter(
443
443
auto set_local_row = row - row_set_offsets[set_id];
444
444
auto target_row = set_local_row + row_set_begins[set_id];
445
445
446
+ if (target_row >= target_num_rows) {
447
+ *invalid_access = true ;
448
+ return ;
449
+ }
450
+
446
451
target_values[target_row * target_stride + col] =
447
452
orig_values[row * orig_stride + col];
448
453
}
@@ -681,19 +686,28 @@ template <typename ValueType, typename OutputType, typename IndexType>
681
686
void row_scatter(std::shared_ptr<const DefaultExecutor> exec,
682
687
const index_set<IndexType>* row_idxs,
683
688
const matrix::Dense<ValueType>* orig,
684
- matrix::Dense<OutputType>* target)
689
+ matrix::Dense<OutputType>* target, bool& invalid_access )
685
690
{
686
- auto size = orig->get_size();
687
- if (size) {
691
+ auto orig_size = orig->get_size();
692
+ auto target_size = target->get_size();
693
+
694
+ array<bool> invalid_access_arr(exec, {false });
695
+
696
+ if (orig_size) {
688
697
constexpr auto block_size = default_block_size;
689
- auto num_blocks = ceildiv(size [0] * size [1], block_size);
698
+ auto num_blocks = ceildiv(orig_size [0] * orig_size [1], block_size);
690
699
kernel::row_scatter<<<num_blocks, block_size, 0, exec->get_stream()>>>(
691
700
row_idxs->get_num_subsets(),
692
701
as_device_type(row_idxs->get_subsets_begin()),
693
- as_device_type(row_idxs->get_superset_indices()), size[0], size[1],
694
- orig->get_stride(), as_device_type(orig->get_const_values()),
695
- target->get_stride(), as_device_type(target->get_values()));
702
+ as_device_type(row_idxs->get_superset_indices()), target_size[0],
703
+ target_size[1], orig_size[0], orig->get_stride(),
704
+ as_device_type(orig->get_const_values()), target->get_stride(),
705
+ as_device_type(target->get_values()),
706
+ as_device_type(invalid_access_arr.get_data()));
696
707
}
708
+
709
+ invalid_access =
710
+ exec->copy_val_to_host(invalid_access_arr.get_const_data());
697
711
}
698
712
699
713
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
0 commit comments