-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
groupreduction and subgroupreduction #421
Changes from 12 commits
cc510ac
1c1e459
dd3a0ca
546e8c9
3602808
42a7960
c96a24a
d2d65be
128a5f0
b899685
1cdb6d6
1fea4cc
41356d3
88662f8
45844ce
c5dc356
e2c8f84
700d5f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
export @groupreduce, @subgroupreduce | ||
|
||
""" | ||
|
||
@subgroupreduce(op, val) | ||
|
||
reduce values across a subgroup. This operation is only supported if subgroups are supported by the backend. | ||
""" | ||
macro subgroupreduce(op, val) | ||
quote | ||
$__subgroupreduce($(esc(op)),$(esc(val))) | ||
end | ||
end | ||
|
||
function __subgroupreduce(op, val) | ||
error("@subgroupreduce used outside kernel, not captured, or not supported") | ||
end | ||
|
||
""" | ||
|
||
@groupreduce(op, val, neutral, use_subgroups) | ||
|
||
Reduce values across a block | ||
- `op`: the operator of the reduction | ||
- `val`: value that each thread contibutes to the values that need to be reduced | ||
- `netral`: value of the operator, so that `op(netural, neutral) = neutral`` | ||
- `use_subgroups`: make use of the subgroupreduction of the groupreduction | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the value of having a common implementation. So I would define:
And then maybe:
And then we could define:
as you have here.
|
||
""" | ||
macro groupreduce(op, val, neutral, use_subgroups) | ||
quote | ||
$__groupreduce($(esc(:__ctx__)),$(esc(op)), $(esc(val)), $(esc(neutral)), $(esc(typeof(val))), Val(use_subgroups)) | ||
end | ||
end | ||
|
||
@inline function __groupreduce(__ctx__, op, val, neutral, ::Type{T}, ::Val{true}) where {T} | ||
idx_in_group = @index(Local) | ||
groupsize = @groupsize()[1] | ||
subgroupsize = @subgroupsize() | ||
|
||
localmem = @localmem(T, subgroupsize) | ||
|
||
idx_subgroup, idx_in_subgroup = fldmod1(idx_in_group, subgroupsize) | ||
|
||
# first subgroup reduction | ||
val = @subgroupreduce(op, val) | ||
|
||
# store partial results in local memory | ||
if idx_in_subgroup == 1 | ||
@inbounds localmem[idx_in_subgroup] = val | ||
end | ||
|
||
@synchronize() | ||
|
||
val = if idx_in_subgroup <= fld1(groupsize, subgroupsize) | ||
@inbounds localmem[idx_in_subgroup] | ||
else | ||
neutral | ||
end | ||
|
||
# second subgroup reduction to reduce partial results | ||
if idx_in_subgroup == 1 | ||
val = @subgroupreduce(op, val) | ||
end | ||
|
||
return val | ||
end | ||
|
||
@inline function __groupreduce(__ctx__, op, val, neutral, ::Type{T}, ::Val{false}) where {T} | ||
idx_in_group = @index(Local) | ||
groupsize = @groupsize()[1] | ||
|
||
localmem = @localmem(T, groupsize) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see if we can do a subgroupreduce the memory we need here is much reduced. |
||
|
||
@inbounds localmem[idx_in_group] = val | ||
|
||
# perform the reduction | ||
d = 1 | ||
while d < groupsize | ||
@synchronize() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Workaround? |
||
index = 2 * d * (idx_in_group-1) + 1 | ||
@inbounds if index <= groupsize | ||
other_val = if index + d <= groupsize | ||
localmem[index+d] | ||
else | ||
neutral | ||
end | ||
localmem[index] = op(localmem[index], other_val) | ||
end | ||
d *= 2 | ||
end | ||
|
||
# load the final value on the first thread | ||
if idx_in_group == 1 | ||
val = @inbounds localmem[idx_in_group] | ||
end | ||
|
||
return val | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using KernelAbstractions, Test | ||
|
||
|
||
|
||
|
||
@kernel function reduce(a, b, op, neutral) | ||
idx_in_group = @index(Local) | ||
|
||
val = a[idx_in_group] | ||
|
||
val = @groupreduce(op, val, netral, false) | ||
|
||
b[1] = val | ||
end | ||
|
||
function(backend, ArrayT) | ||
@testset "groupreduce one group" begin | ||
@testset for op in (+,*,max,min) | ||
@testset for type in (Int32, Float32, Float64) | ||
@test test_1group_groupreduce(backend, ArrayT ,op, type, op(neutral)) | ||
end | ||
end | ||
end | ||
end | ||
|
||
function test_1group_groupreduce(backend,ArrayT, op, type, neutral) | ||
a = rand(type, 32) | ||
b = ArrayT(a) | ||
|
||
c = similar(b,1) | ||
reduce(a, c, op, neutral) | ||
|
||
expected = mapreduce(x->x^2, +, a) | ||
actual = c[1] | ||
return expected = actual | ||
end | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.