@@ -56,6 +56,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
56
56
#include " dpcpp/base/onemkl_bindings.hpp"
57
57
#include " dpcpp/components/cooperative_groups.dp.hpp"
58
58
#include " dpcpp/components/reduction.dp.hpp"
59
+ #include " dpcpp/components/searching.dp.hpp"
59
60
#include " dpcpp/components/thread_ids.dp.hpp"
60
61
#include " dpcpp/components/uninitialized_array.hpp"
61
62
#include " dpcpp/synthesizer/implementation_selection.hpp"
@@ -606,7 +607,45 @@ template <typename ValueType, typename OutputType, typename IndexType>
606
607
void row_scatter (std::shared_ptr<const DefaultExecutor> exec,
607
608
const index_set<IndexType>* row_idxs,
608
609
const matrix::Dense<ValueType>* orig,
609
- matrix::Dense<OutputType>* target) GKO_NOT_IMPLEMENTED;
610
+ matrix::Dense<OutputType>* target)
611
+ {
612
+ const auto num_sets = row_idxs->get_num_subsets ();
613
+ const auto num_rows = row_idxs->get_num_elems ();
614
+ const auto num_cols = orig->get_size ()[1 ];
615
+
616
+ const auto * row_set_begins = row_idxs->get_subsets_begin ();
617
+ const auto * row_set_offsets = row_idxs->get_superset_indices ();
618
+
619
+ const auto orig_stride = orig->get_stride ();
620
+ const auto * orig_values = orig->get_const_values ();
621
+
622
+ const auto target_stride = target->get_stride ();
623
+ auto * target_values = target->get_values ();
624
+
625
+ exec->get_queue ()->submit ([&](sycl::handler& cgh) {
626
+ cgh.parallel_for (
627
+ static_cast <size_type>(num_rows * num_cols),
628
+ [=](sycl::item<1 > item) {
629
+ const auto row = static_cast <size_type>(item[0 ]) / num_cols;
630
+ const auto col = static_cast <size_type>(item[0 ]) % num_cols;
631
+
632
+ if (row >= num_rows) {
633
+ return ;
634
+ }
635
+
636
+ auto set_id =
637
+ binary_search<size_type>(
638
+ 0 , num_sets + 1 ,
639
+ [=](auto i) { return row < row_set_offsets[i]; }) -
640
+ 1 ;
641
+ auto set_local_row = row - row_set_offsets[set_id];
642
+ auto target_row = set_local_row + row_set_begins[set_id];
643
+
644
+ target_values[target_row * target_stride + col] =
645
+ orig_values[row * orig_stride + col];
646
+ });
647
+ });
648
+ }
610
649
611
650
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2 (
612
651
GKO_DECLARE_DENSE_ROW_SCATTER_INDEX_SET_KERNEL);
0 commit comments