Skip to content

Commit f7c8214

Browse files
committed
add zero-checks to axpy-like operations
This prevents NaNs from polluting the output
1 parent 4b31772 commit f7c8214

23 files changed

+173
-41
lines changed

common/cuda_hip/matrix/csr_kernels.hpp.inc

+5-2
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_merge_path_spmv(
380380
merge_path_spmv<items_per_thread>(
381381
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
382382
[&alpha_val](const type& x) { return alpha_val * x; },
383-
[&beta_val](const type& x) { return beta_val * x; });
383+
[&beta_val](const type& x) {
384+
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
385+
});
384386
}
385387

386388

@@ -480,7 +482,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
480482
device_classical_spmv<subwarp_size>(
481483
num_rows, val, col_idxs, row_ptrs, b, c,
482484
[&alpha_val, &beta_val](const type& x, const type& y) {
483-
return alpha_val * x + beta_val * y;
485+
return is_zero(beta_val) ? alpha_val * x
486+
: alpha_val * x + beta_val * y;
484487
});
485488
}
486489

common/cuda_hip/matrix/ell_kernels.hpp.inc

+4-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ __global__ __launch_bounds__(default_block_size) void spmv(
124124
num_stored_elements_per_row, b, c, c_stride,
125125
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
126126
return static_cast<OutputValueType>(
127-
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
127+
is_zero(beta_val)
128+
? alpha_val * x
129+
: alpha_val * x +
130+
static_cast<arithmetic_type>(beta_val * y));
128131
});
129132
}
130133
}

common/cuda_hip/matrix/sellp_kernels.hpp.inc

+3-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ __global__ __launch_bounds__(default_block_size) void advanced_spmv_kernel(
5151
}
5252
}
5353
c[row * c_stride + column_id] =
54-
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
54+
is_zero(beta[0])
55+
? alpha[0] * val
56+
: beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
5557
}
5658
}
5759

common/cuda_hip/matrix/sparsity_csr_kernels.hpp.inc

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
7474
device_classical_spmv<subwarp_size>(
7575
num_rows, val, col_idxs, row_ptrs, b, c,
7676
[&alpha_val, &beta_val](const type& x, const type& y) {
77-
return alpha_val * x + beta_val * y;
77+
return is_zero(beta_val) ? alpha_val * x
78+
: alpha_val * x + beta_val * y;
7879
});
7980
}
8081

common/unified/matrix/dense_kernels.template.cpp

+22-6
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,22 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
7777
run_kernel(
7878
exec,
7979
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
80-
x(row, col) *= alpha[col];
80+
if (is_zero(zero(alpha[col]))) {
81+
x(row, col) = zero(alpha[col]);
82+
} else {
83+
x(row, col) *= alpha[col];
84+
}
8185
},
8286
x->get_size(), alpha->get_const_values(), x);
8387
} else {
8488
run_kernel(
8589
exec,
8690
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
87-
x(row, col) *= alpha[0];
91+
if (is_zero(alpha[0])) {
92+
x(row, col) = zero(alpha[0]);
93+
} else {
94+
x(row, col) *= alpha[0];
95+
}
8896
},
8997
x->get_size(), alpha->get_const_values(), x);
9098
}
@@ -130,7 +138,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
130138
run_kernel(
131139
exec,
132140
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
133-
y(row, col) += alpha[0] * x(row, col);
141+
if (is_nonzero(alpha[0])) {
142+
y(row, col) += alpha[0] * x(row, col);
143+
}
134144
},
135145
x->get_size(), alpha->get_const_values(), x, y);
136146
}
@@ -153,7 +163,9 @@ void sub_scaled(std::shared_ptr<const DefaultExecutor> exec,
153163
run_kernel(
154164
exec,
155165
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
156-
y(row, col) -= alpha[0] * x(row, col);
166+
if (is_nonzero(alpha[0])) {
167+
y(row, col) -= alpha[0] * x(row, col);
168+
}
157169
},
158170
x->get_size(), alpha->get_const_values(), x, y);
159171
}
@@ -170,7 +182,9 @@ void add_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
170182
run_kernel(
171183
exec,
172184
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
173-
y(i, i) += alpha[0] * diag[i];
185+
if (is_nonzero(alpha[0])) {
186+
y(i, i) += alpha[0] * diag[i];
187+
}
174188
},
175189
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
176190
}
@@ -186,7 +200,9 @@ void sub_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
186200
run_kernel(
187201
exec,
188202
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
189-
y(i, i) -= alpha[0] * diag[i];
203+
if (is_nonzero(alpha[0])) {
204+
y(i, i) -= alpha[0] * diag[i];
205+
}
190206
},
191207
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
192208
}

dpcpp/matrix/csr_kernels.dp.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,10 @@ void abstract_merge_path_spmv(
490490
merge_path_spmv<items_per_thread>(
491491
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
492492
[&alpha_val](const type& x) { return alpha_val * x; },
493-
[&beta_val](const type& x) { return beta_val * x; }, item_ct1,
494-
shared_row_ptrs);
493+
[&beta_val](const type& x) {
494+
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
495+
},
496+
item_ct1, shared_row_ptrs);
495497
}
496498

497499
template <int items_per_thread, typename matrix_accessor,
@@ -713,7 +715,8 @@ void abstract_classical_spmv(
713715
device_classical_spmv<subgroup_size>(
714716
num_rows, val, col_idxs, row_ptrs, b, c,
715717
[&alpha_val, &beta_val](const type& x, const type& y) {
716-
return alpha_val * x + beta_val * y;
718+
return is_zero(beta_val) ? alpha_val * x
719+
: alpha_val * x + beta_val * y;
717720
},
718721
item_ct1);
719722
}

dpcpp/matrix/ell_kernels.dp.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,10 @@ void spmv(
239239
num_stored_elements_per_row, b, c, c_stride,
240240
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
241241
return static_cast<OutputValueType>(
242-
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
242+
is_zero(beta_val)
243+
? alpha_val * x
244+
: alpha_val * x +
245+
static_cast<arithmetic_type>(beta_val * y));
243246
},
244247
item_ct1, storage);
245248
}

dpcpp/matrix/sellp_kernels.dp.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ void advanced_spmv_kernel(size_type num_rows, size_type num_right_hand_sides,
9696
}
9797
}
9898
c[row * c_stride + column_id] =
99-
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
99+
is_zero(beta[0])
100+
? alpha[0] * val
101+
: alpha[0] * val + beta[0] * c[row * c_stride + column_id];
100102
}
101103
}
102104

dpcpp/matrix/sparsity_csr_kernels.dp.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ void abstract_classical_spmv(
132132
device_classical_spmv<subgroup_size>(
133133
num_rows, val, col_idxs, row_ptrs, b, c,
134134
[&alpha_val, &beta_val](const type& x, const type& y) {
135-
return alpha_val * x + beta_val * y;
135+
return is_zero(beta_val) ? alpha_val * x
136+
: alpha_val * x + beta_val * y;
136137
},
137138
item_ct1);
138139
}

omp/matrix/csr_kernels.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
110110
#pragma omp parallel for
111111
for (size_type row = 0; row < a->get_size()[0]; ++row) {
112112
for (size_type j = 0; j < c->get_size()[1]; ++j) {
113-
auto sum = c_vals(row, j) * vbeta;
113+
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
114114
for (size_type k = row_ptrs[row];
115115
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
116116
arithmetic_type val = a_vals(k);

omp/matrix/dense_kernels.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
124124
#pragma omp parallel for
125125
for (size_type row = 0; row < c->get_size()[0]; ++row) {
126126
for (size_type col = 0; col < c->get_size()[1]; ++col) {
127-
c->at(row, col) *= zero<ValueType>();
127+
c->at(row, col) = zero<ValueType>();
128128
}
129129
}
130130
}

omp/matrix/ell_kernels.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
211211
const auto alpha_val = arithmetic_type{alpha->at(0, 0)};
212212
const auto beta_val = arithmetic_type{beta->at(0, 0)};
213213
auto out = [&](auto i, auto j, auto value) {
214-
return alpha_val * value + beta_val * arithmetic_type{c->at(i, j)};
214+
return is_zero(beta_val) ? alpha_val * value
215+
: alpha_val * value +
216+
beta_val * arithmetic_type{c->at(i, j)};
215217
};
216218
if (num_rhs == 1) {
217219
spmv_small_rhs<1>(exec, a, b, c, out);

omp/matrix/fbcsr_kernels.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
105105
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
106106
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
107107
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
108-
c->at(row, rhs) *= vbeta;
108+
if (is_zero(vbeta)) {
109+
c->at(row, rhs) = zero(vbeta);
110+
} else {
111+
c->at(row, rhs) *= vbeta;
112+
}
109113
}
110114
}
111115
for (IndexType inz = row_ptrs[ibrow]; inz < row_ptrs[ibrow + 1];

omp/matrix/sellp_kernels.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
176176
const auto alpha_val = alpha->at(0, 0);
177177
const auto beta_val = beta->at(0, 0);
178178
auto out = [&](auto i, auto j, auto value) {
179-
return alpha_val * value + beta_val * c->at(i, j);
179+
return is_zero(beta_val) ? alpha_val * value
180+
: alpha_val * value + beta_val * c->at(i, j);
180181
};
181182
if (num_rhs == 1) {
182183
spmv_small_rhs<1>(exec, a, b, c, out);

omp/matrix/sparsity_csr_kernels.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
9393
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
9494
}
9595
c->at(row, j) = static_cast<OutputValueType>(
96-
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
96+
(is_zero(vbeta)
97+
? zero(vbeta)
98+
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
9799
valpha * temp_val);
98100
}
99101
}

reference/matrix/csr_kernels.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
106106
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
107107
for (size_type row = 0; row < a->get_size()[0]; ++row) {
108108
for (size_type j = 0; j < c->get_size()[1]; ++j) {
109-
auto sum = c_vals(row, j) * vbeta;
109+
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
110110
for (size_type k = row_ptrs[row];
111111
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
112112
arithmetic_type val = a_vals(k);

reference/matrix/dense_kernels.cpp

+24-12
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void apply(std::shared_ptr<const ReferenceExecutor> exec,
7777
} else {
7878
for (size_type row = 0; row < c->get_size()[0]; ++row) {
7979
for (size_type col = 0; col < c->get_size()[1]; ++col) {
80-
c->at(row, col) *= zero<ValueType>();
80+
c->at(row, col) = zero<ValueType>();
8181
}
8282
}
8383
}
@@ -133,7 +133,11 @@ void scale(std::shared_ptr<const ReferenceExecutor> exec,
133133
if (alpha->get_size()[1] == 1) {
134134
for (size_type i = 0; i < x->get_size()[0]; ++i) {
135135
for (size_type j = 0; j < x->get_size()[1]; ++j) {
136-
x->at(i, j) *= alpha->at(0, 0);
136+
if (is_zero(alpha->at(0, 0))) {
137+
x->at(i, j) = zero<ValueType>();
138+
} else {
139+
x->at(i, j) *= alpha->at(0, 0);
140+
}
137141
}
138142
}
139143
} else {
@@ -178,9 +182,11 @@ void add_scaled(std::shared_ptr<const ReferenceExecutor> exec,
178182
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
179183
{
180184
if (alpha->get_size()[1] == 1) {
181-
for (size_type i = 0; i < x->get_size()[0]; ++i) {
182-
for (size_type j = 0; j < x->get_size()[1]; ++j) {
183-
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
185+
if (is_nonzero(alpha->at(0, 0))) {
186+
for (size_type i = 0; i < x->get_size()[0]; ++i) {
187+
for (size_type j = 0; j < x->get_size()[1]; ++j) {
188+
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
189+
}
184190
}
185191
}
186192
} else {
@@ -202,9 +208,11 @@ void sub_scaled(std::shared_ptr<const ReferenceExecutor> exec,
202208
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
203209
{
204210
if (alpha->get_size()[1] == 1) {
205-
for (size_type i = 0; i < x->get_size()[0]; ++i) {
206-
for (size_type j = 0; j < x->get_size()[1]; ++j) {
207-
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
211+
if (is_nonzero(alpha->at(0, 0))) {
212+
for (size_type i = 0; i < x->get_size()[0]; ++i) {
213+
for (size_type j = 0; j < x->get_size()[1]; ++j) {
214+
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
215+
}
208216
}
209217
}
210218
} else {
@@ -227,8 +235,10 @@ void add_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
227235
matrix::Dense<ValueType>* y)
228236
{
229237
const auto diag_values = x->get_const_values();
230-
for (size_type i = 0; i < x->get_size()[0]; i++) {
231-
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
238+
if (is_nonzero(alpha->at(0, 0))) {
239+
for (size_type i = 0; i < x->get_size()[0]; i++) {
240+
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
241+
}
232242
}
233243
}
234244

@@ -242,8 +252,10 @@ void sub_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
242252
matrix::Dense<ValueType>* y)
243253
{
244254
const auto diag_values = x->get_const_values();
245-
for (size_type i = 0; i < x->get_size()[0]; i++) {
246-
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
255+
if (is_nonzero(alpha->at(0, 0))) {
256+
for (size_type i = 0; i < x->get_size()[0]; i++) {
257+
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
258+
}
247259
}
248260
}
249261

reference/matrix/ell_kernels.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
109109

110110
for (size_type j = 0; j < c->get_size()[1]; j++) {
111111
for (size_type row = 0; row < a->get_size()[0]; row++) {
112-
arithmetic_type result = c->at(row, j);
113-
result *= beta_val;
112+
arithmetic_type result =
113+
is_zero(beta_val) ? zero(beta_val) : beta_val * c->at(row, j);
114114
for (size_type i = 0; i < num_stored_elements_per_row; i++) {
115115
arithmetic_type val = a_vals(row + i * stride);
116116
auto col = a->col_at(row, i);

reference/matrix/fbcsr_kernels.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ void advanced_spmv(const std::shared_ptr<const ReferenceExecutor>,
103103
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
104104
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
105105
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
106-
c->at(row, rhs) *= vbeta;
106+
if (is_zero(vbeta)) {
107+
c->at(row, rhs) = zero(vbeta);
108+
} else {
109+
c->at(row, rhs) *= vbeta;
110+
}
107111
}
108112
}
109113

reference/matrix/sellp_kernels.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
8383
break;
8484
}
8585
for (size_type j = 0; j < c->get_size()[1]; j++) {
86-
c->at(global_row, j) *= vbeta;
86+
if (is_nonzero(vbeta)) {
87+
c->at(global_row, j) *= vbeta;
88+
} else {
89+
c->at(global_row, j) = zero<ValueType>();
90+
}
8791
}
8892
for (size_type i = 0; i < slice_lengths[slice]; i++) {
8993
auto val = a->val_at(row, slice_sets[slice], i);

reference/matrix/sparsity_csr_kernels.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
8989
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
9090
}
9191
c->at(row, j) = static_cast<OutputValueType>(
92-
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
92+
(is_zero(vbeta)
93+
? zero(vbeta)
94+
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
9395
valpha * temp_val);
9496
}
9597
}

0 commit comments

Comments
 (0)