Skip to content

Commit dcaabd9

Browse files
committedJul 3, 2023
adds dpcpp kernel
1 parent 26b4525 commit dcaabd9

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed
 

‎dpcpp/matrix/dense_kernels.dp.cpp

+40-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5656
#include "dpcpp/base/onemkl_bindings.hpp"
5757
#include "dpcpp/components/cooperative_groups.dp.hpp"
5858
#include "dpcpp/components/reduction.dp.hpp"
59+
#include "dpcpp/components/searching.dp.hpp"
5960
#include "dpcpp/components/thread_ids.dp.hpp"
6061
#include "dpcpp/components/uninitialized_array.hpp"
6162
#include "dpcpp/synthesizer/implementation_selection.hpp"
@@ -606,7 +607,45 @@ template <typename ValueType, typename OutputType, typename IndexType>
606607
void row_scatter(std::shared_ptr<const DefaultExecutor> exec,
607608
const index_set<IndexType>* row_idxs,
608609
const matrix::Dense<ValueType>* orig,
609-
matrix::Dense<OutputType>* target) GKO_NOT_IMPLEMENTED;
610+
matrix::Dense<OutputType>* target)
611+
{
612+
const auto num_sets = row_idxs->get_num_subsets();
613+
const auto num_rows = row_idxs->get_num_elems();
614+
const auto num_cols = orig->get_size()[1];
615+
616+
const auto* row_set_begins = row_idxs->get_subsets_begin();
617+
const auto* row_set_offsets = row_idxs->get_superset_indices();
618+
619+
const auto orig_stride = orig->get_stride();
620+
const auto* orig_values = orig->get_const_values();
621+
622+
const auto target_stride = target->get_stride();
623+
auto* target_values = target->get_values();
624+
625+
exec->get_queue()->submit([&](sycl::handler& cgh) {
626+
cgh.parallel_for(
627+
static_cast<size_type>(num_rows * num_cols),
628+
[=](sycl::item<1> item) {
629+
const auto row = static_cast<size_type>(item[0]) / num_cols;
630+
const auto col = static_cast<size_type>(item[0]) % num_cols;
631+
632+
if (row >= num_rows) {
633+
return;
634+
}
635+
636+
auto set_id =
637+
binary_search<size_type>(
638+
0, num_sets + 1,
639+
[=](auto i) { return row < row_set_offsets[i]; }) -
640+
1;
641+
auto set_local_row = row - row_set_offsets[set_id];
642+
auto target_row = set_local_row + row_set_begins[set_id];
643+
644+
target_values[target_row * target_stride + col] =
645+
orig_values[row * orig_stride + col];
646+
});
647+
});
648+
}
610649

611650
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
612651
GKO_DECLARE_DENSE_ROW_SCATTER_INDEX_SET_KERNEL);

0 commit comments

Comments
 (0)
Please sign in to comment.