Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero-checks to axpy-like operations #1573

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions common/cuda_hip/matrix/csr_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,17 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_merge_path_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = alpha[0];
const type beta_val = beta[0];
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; });
if (is_zero(beta_val)) {
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[](const type& x) { return zero<type>(); });
} else {
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; });
}
}


Expand Down Expand Up @@ -562,11 +569,19 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = alpha[0];
const type beta_val = beta[0];
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
});
if (is_zero(beta_val)) {
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val](const type& x, const type& y) {
return alpha_val * x;
});
} else {
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
});
}
}


Expand Down
27 changes: 19 additions & 8 deletions common/cuda_hip/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -192,13 +192,24 @@ __global__ __launch_bounds__(default_block_size) void spmv(
return static_cast<OutputValueType>(alpha_val * x);
});
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
});
if (is_zero(beta_val)) {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(alpha_val * x);
});
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x,
const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
});
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions common/cuda_hip/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -82,7 +82,9 @@ __global__ __launch_bounds__(default_block_size) void advanced_spmv_kernel(
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
}
}

Expand Down
20 changes: 14 additions & 6 deletions common/cuda_hip/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -121,11 +121,19 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = alpha[0];
const type beta_val = beta[0];
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
});
if (is_zero(beta_val)) {
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val](const type& x, const type& y) {
return alpha_val * x;
});
} else {
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
});
}
}


Expand Down
24 changes: 18 additions & 6 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -83,7 +83,11 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
x(row, col) *= alpha[0];
if (is_zero(alpha[0])) {
x(row, col) = zero(alpha[0]);
} else {
x(row, col) *= alpha[0];
}
},
x->get_size(), alpha->get_const_values(), x);
}
Expand Down Expand Up @@ -129,7 +133,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) += alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) += alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -152,7 +158,9 @@ void sub_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) -= alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) -= alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -169,7 +177,9 @@ void add_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) += alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) += alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand All @@ -185,7 +195,9 @@ void sub_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) -= alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) -= alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand Down
41 changes: 29 additions & 12 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -486,11 +486,19 @@ void abstract_merge_path_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = static_cast<type>(alpha[0]);
const type beta_val = static_cast<type>(beta[0]);
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; }, item_ct1,
shared_row_ptrs);
if (is_zero(beta_val)) {
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[](const type& x) { return zero<type>(); }, item_ct1,
shared_row_ptrs);
} else {
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; }, item_ct1,
shared_row_ptrs);
}
}

template <int items_per_thread, typename matrix_accessor,
Expand Down Expand Up @@ -701,12 +709,21 @@ void abstract_classical_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = static_cast<type>(alpha[0]);
const type beta_val = static_cast<type>(beta[0]);
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
},
item_ct1);
if (is_zero(beta_val)) {
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val](const type& x, const type& y) {
return alpha_val * x;
},
item_ct1);
} else {
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
},
item_ct1);
}
}

template <size_type subgroup_size, typename matrix_accessor,
Expand Down
30 changes: 21 additions & 9 deletions dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -233,14 +233,26 @@ void spmv(
},
item_ct1, storage);
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
},
item_ct1, storage);
if (is_zero(beta_val)) {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(alpha_val * x);
},
item_ct1, storage);
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x,
const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
},
item_ct1, storage);
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions dpcpp/matrix/sellp_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -93,7 +93,9 @@ void advanced_spmv_kernel(size_type num_rows, size_type num_right_hand_sides,
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: alpha[0] * val + beta[0] * c[row * c_stride + column_id];
}
}

Expand Down
23 changes: 16 additions & 7 deletions dpcpp/matrix/sparsity_csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -129,12 +129,21 @@ void abstract_classical_spmv(
using type = typename output_accessor::arithmetic_type;
const type alpha_val = static_cast<type>(alpha[0]);
const type beta_val = static_cast<type>(beta[0]);
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
},
item_ct1);
if (is_zero(beta_val)) {
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val](const type& x, const type& y) {
return alpha_val * x;
},
item_ct1);
} else {
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
},
item_ct1);
}
}

template <size_type subgroup_size, typename MatrixValueType,
Expand Down
1 change: 1 addition & 0 deletions extensions/test/config/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
find_package(yaml-cpp 0.8.0 QUIET)
if(NOT yaml-cpp_FOUND)
message(STATUS "Fetching external yaml-cpp")
include(FetchContent)
FetchContent_Declare(
yaml-cpp
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
Expand Down
4 changes: 2 additions & 2 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -106,7 +106,7 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < a->get_size()[0]; ++row) {
for (size_type j = 0; j < c->get_size()[1]; ++j) {
auto sum = c_vals(row, j) * vbeta;
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
for (size_type k = row_ptrs[row];
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
arithmetic_type val = a_vals(k);
Expand Down
4 changes: 2 additions & 2 deletions omp/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -120,7 +120,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < c->get_size()[0]; ++row) {
for (size_type col = 0; col < c->get_size()[1]; ++col) {
c->at(row, col) *= zero<ValueType>();
c->at(row, col) = zero<ValueType>();
}
}
}
Expand Down
Loading