Skip to content

Commit c984d91

Browse files
committed
add work estimate framework
1 parent 23688e2 commit c984d91

11 files changed

+336
-75
lines changed

core/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ target_sources(
4343
base/segmented_array.cpp
4444
base/timer.cpp
4545
base/version.cpp
46+
base/work_estimate.cpp
4647
components/range_minimum_query.cpp
4748
config/config.cpp
4849
config/config_helper.cpp

core/base/work_estimate.cpp

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <ginkgo/core/base/work_estimate.hpp>
6+
7+
8+
namespace gko {
9+
10+
11+
compute_bound_work_estimate operator+(compute_bound_work_estimate a,
12+
compute_bound_work_estimate b)
13+
{
14+
return {a.flops + b.flops};
15+
}
16+
17+
18+
compute_bound_work_estimate& compute_bound_work_estimate::operator+=(
19+
compute_bound_work_estimate other)
20+
{
21+
*this = *this + other;
22+
return *this;
23+
}
24+
25+
26+
memory_bound_work_estimate operator+(memory_bound_work_estimate a,
27+
memory_bound_work_estimate b)
28+
{
29+
return {a.bytes_read + b.bytes_read, a.bytes_written + b.bytes_written};
30+
}
31+
32+
33+
memory_bound_work_estimate& memory_bound_work_estimate::operator+=(
34+
memory_bound_work_estimate other)
35+
{
36+
*this = *this + other;
37+
return *this;
38+
}
39+
40+
41+
custom_work_estimate operator+(custom_work_estimate a, custom_work_estimate b)
42+
{
43+
GKO_ASSERT(a.operation_count_name == b.operation_count_name);
44+
return {a.operation_count_name, a.operations + b.operations};
45+
}
46+
47+
48+
custom_work_estimate& custom_work_estimate::operator+=(
49+
custom_work_estimate other)
50+
{
51+
*this = *this + other;
52+
return *this;
53+
}
54+
55+
56+
kernel_work_estimate operator+(kernel_work_estimate a, kernel_work_estimate b)
57+
{
58+
// this fails with std::bad_variant_access if the two estimates are of
59+
// different types
60+
return std::visit(
61+
[b](auto a) -> kernel_work_estimate {
62+
return a + std::get<decltype(a)>(b);
63+
},
64+
a);
65+
}
66+
67+
68+
} // namespace gko

core/components/prefix_sum_kernels.hpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -10,6 +10,7 @@
1010

1111
#include <ginkgo/core/base/executor.hpp>
1212
#include <ginkgo/core/base/types.hpp>
13+
#include <ginkgo/core/base/work_estimate.hpp>
1314

1415
#include "core/base/kernel_declaration.hpp"
1516

@@ -53,6 +54,19 @@ GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(components,
5354
#undef GKO_DECLARE_ALL_AS_TEMPLATES
5455

5556

57+
namespace work_estimate::components {
58+
59+
60+
template <typename IndexType>
61+
kernel_work_estimate prefix_sum_nonnegative(IndexType* counts,
62+
size_type num_entries)
63+
{
64+
return memory_bound_work_estimate{num_entries * sizeof(IndexType),
65+
num_entries * sizeof(IndexType)};
66+
}
67+
68+
69+
} // namespace work_estimate::components
5670
} // namespace kernels
5771
} // namespace gko
5872

core/matrix/csr.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ GKO_REGISTER_OPERATION(is_sorted_by_column_index,
8585
csr::is_sorted_by_column_index);
8686
GKO_REGISTER_OPERATION(extract_diagonal, csr::extract_diagonal);
8787
GKO_REGISTER_OPERATION(fill_array, components::fill_array);
88-
GKO_REGISTER_OPERATION(convert_precision, components::convert_precision);
89-
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
90-
components::prefix_sum_nonnegative);
88+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(prefix_sum_nonnegative,
89+
components::prefix_sum_nonnegative);
9190
GKO_REGISTER_OPERATION(inplace_absolute_array,
9291
components::inplace_absolute_array);
9392
GKO_REGISTER_OPERATION(outplace_absolute_array,

core/matrix/dense.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -40,17 +40,18 @@ namespace dense {
4040
namespace {
4141

4242

43-
GKO_REGISTER_OPERATION(simple_apply, dense::simple_apply);
43+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(simple_apply, dense::simple_apply);
4444
GKO_REGISTER_OPERATION(apply, dense::apply);
45-
GKO_REGISTER_OPERATION(copy, dense::copy);
46-
GKO_REGISTER_OPERATION(fill, dense::fill);
45+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(copy, dense::copy);
46+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(fill, dense::fill);
4747
GKO_REGISTER_OPERATION(scale, dense::scale);
4848
GKO_REGISTER_OPERATION(inv_scale, dense::inv_scale);
4949
GKO_REGISTER_OPERATION(add_scaled, dense::add_scaled);
5050
GKO_REGISTER_OPERATION(sub_scaled, dense::sub_scaled);
5151
GKO_REGISTER_OPERATION(add_scaled_diag, dense::add_scaled_diag);
5252
GKO_REGISTER_OPERATION(sub_scaled_diag, dense::sub_scaled_diag);
53-
GKO_REGISTER_OPERATION(compute_dot, dense::compute_dot_dispatch);
53+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(compute_dot,
54+
dense::compute_dot_dispatch);
5455
GKO_REGISTER_OPERATION(compute_conj_dot, dense::compute_conj_dot_dispatch);
5556
GKO_REGISTER_OPERATION(compute_norm2, dense::compute_norm2_dispatch);
5657
GKO_REGISTER_OPERATION(compute_norm1, dense::compute_norm1);

core/matrix/dense_kernels.hpp

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -10,6 +10,7 @@
1010

1111
#include <ginkgo/core/base/math.hpp>
1212
#include <ginkgo/core/base/types.hpp>
13+
#include <ginkgo/core/base/work_estimate.hpp>
1314
#include <ginkgo/core/matrix/dense.hpp>
1415
#include <ginkgo/core/matrix/diagonal.hpp>
1516

@@ -476,6 +477,53 @@ GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(dense, GKO_DECLARE_ALL_AS_TEMPLATES);
476477
#undef GKO_DECLARE_ALL_AS_TEMPLATES
477478

478479

480+
namespace work_estimate {
481+
namespace dense {
482+
483+
484+
template <typename ValueType>
485+
kernel_work_estimate simple_apply(const matrix::Dense<ValueType>* a,
486+
const matrix::Dense<ValueType>* b,
487+
matrix::Dense<ValueType>* c)
488+
{
489+
const auto a_rows = a->get_size()[0];
490+
const auto a_cols = a->get_size()[1];
491+
const auto b_cols = b->get_size()[1];
492+
return compute_bound_work_estimate{2 * a_rows * a_cols * b_cols};
493+
}
494+
495+
496+
template <typename InValueType, typename OutValueType>
497+
kernel_work_estimate copy(const matrix::Dense<InValueType>* input,
498+
matrix::Dense<OutValueType>* output)
499+
{
500+
const auto memsize = input->get_size()[0] * input->get_size()[1];
501+
return memory_bound_work_estimate{memsize * sizeof(InValueType),
502+
memsize * sizeof(OutValueType)};
503+
}
504+
505+
506+
template <typename ValueType>
507+
kernel_work_estimate fill(matrix::Dense<ValueType>* mat, ValueType value)
508+
{
509+
return memory_bound_work_estimate{
510+
0, mat->get_size()[0] * mat->get_size()[1] * sizeof(ValueType)};
511+
}
512+
513+
514+
template <typename ValueType>
515+
kernel_work_estimate compute_dot_dispatch(const matrix::Dense<ValueType>* x,
516+
const matrix::Dense<ValueType>* y,
517+
matrix::Dense<ValueType>* result,
518+
array<char>& tmp)
519+
{
520+
const auto num_elements = x->get_size()[0] * x->get_size()[1];
521+
return memory_bound_work_estimate{2 * num_elements * sizeof(ValueType), 0};
522+
}
523+
524+
525+
} // namespace dense
526+
} // namespace work_estimate
479527
} // namespace kernels
480528
} // namespace gko
481529

core/matrix/ell.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -42,8 +42,8 @@ GKO_REGISTER_OPERATION(convert_to_csr, ell::convert_to_csr);
4242
GKO_REGISTER_OPERATION(count_nonzeros_per_row, ell::count_nonzeros_per_row);
4343
GKO_REGISTER_OPERATION(extract_diagonal, ell::extract_diagonal);
4444
GKO_REGISTER_OPERATION(fill_array, components::fill_array);
45-
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
46-
components::prefix_sum_nonnegative);
45+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(prefix_sum_nonnegative,
46+
components::prefix_sum_nonnegative);
4747
GKO_REGISTER_OPERATION(inplace_absolute_array,
4848
components::inplace_absolute_array);
4949
GKO_REGISTER_OPERATION(outplace_absolute_array,

core/matrix/hybrid.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -42,8 +42,8 @@ GKO_REGISTER_OPERATION(compute_coo_row_ptrs, hybrid::compute_coo_row_ptrs);
4242
GKO_REGISTER_OPERATION(convert_idxs_to_ptrs, components::convert_idxs_to_ptrs);
4343
GKO_REGISTER_OPERATION(convert_to_csr, hybrid::convert_to_csr);
4444
GKO_REGISTER_OPERATION(fill_array, components::fill_array);
45-
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
46-
components::prefix_sum_nonnegative);
45+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(prefix_sum_nonnegative,
46+
components::prefix_sum_nonnegative);
4747
GKO_REGISTER_OPERATION(inplace_absolute_array,
4848
components::inplace_absolute_array);
4949
GKO_REGISTER_OPERATION(outplace_absolute_array,

core/matrix/sellp.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -31,8 +31,8 @@ namespace {
3131
GKO_REGISTER_OPERATION(spmv, sellp::spmv);
3232
GKO_REGISTER_OPERATION(advanced_spmv, sellp::advanced_spmv);
3333
GKO_REGISTER_OPERATION(convert_idxs_to_ptrs, components::convert_idxs_to_ptrs);
34-
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
35-
components::prefix_sum_nonnegative);
34+
GKO_REGISTER_OPERATION_WITH_WORK_ESTIMATE(prefix_sum_nonnegative,
35+
components::prefix_sum_nonnegative);
3636
GKO_REGISTER_OPERATION(compute_slice_sets, sellp::compute_slice_sets);
3737
GKO_REGISTER_OPERATION(fill_in_matrix_data, sellp::fill_in_matrix_data);
3838
GKO_REGISTER_OPERATION(fill_in_dense, sellp::fill_in_dense);

0 commit comments

Comments
 (0)