-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement groupreduce API #559
base: main
Are you sure you want to change the base?
Changes from all commits
e1a110f
ff4097f
6a35eb8
224e8c8
4a8e707
7c923fb
a647992
cbc8bd5
bb77270
db5abc5
618c840
344d484
1cd2d2f
7c96e5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
export @groupreduce, @warp_groupreduce | ||
|
||
""" | ||
@groupreduce op val neutral [groupsize] | ||
|
||
Perform group reduction of `val` using `op`. | ||
|
||
# Arguments | ||
|
||
- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`. | ||
|
||
- `groupsize` specifies size of the workgroup. | ||
If a kernel does not specifies `groupsize` statically, then it is required to | ||
provide `groupsize`. | ||
Also can be used to perform reduction accross first `groupsize` threads | ||
(if `groupsize < @groupsize()`). | ||
|
||
# Returns | ||
|
||
Result of the reduction. | ||
""" | ||
macro groupreduce(op, val) | ||
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__))))))) | ||
end | ||
macro groupreduce(op, val, groupsize) | ||
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize))))) | ||
end | ||
|
||
function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize} | ||
storage = @localmem T groupsize | ||
|
||
local_idx = @index(Local) | ||
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val) | ||
@synchronize() | ||
|
||
s::UInt64 = groupsize ÷ 0x02 | ||
while s > 0x00 | ||
if (local_idx - 0x01) < s | ||
other_idx = local_idx + s | ||
if other_idx ≤ groupsize | ||
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) | ||
end | ||
end | ||
@synchronize() | ||
s >>= 0x01 | ||
end | ||
|
||
if local_idx == 0x01 | ||
@inbounds val = storage[local_idx] | ||
end | ||
return val | ||
end | ||
|
||
# Warp groupreduce. | ||
|
||
""" | ||
@warp_groupreduce op val neutral [groupsize] | ||
|
||
Perform group reduction of `val` using `op`. | ||
Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic, | ||
followed by final reduction over results of individual warp reductions. | ||
|
||
!!! note | ||
|
||
Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction. | ||
""" | ||
macro warp_groupreduce(op, val, neutral) | ||
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__))))))) | ||
end | ||
macro warp_groupreduce(op, val, neutral, groupsize) | ||
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize))))) | ||
end | ||
|
||
""" | ||
shfl_down(val::T, offset::Integer)::T where T | ||
|
||
Read `val` from a lane with higher id given by `offset`. | ||
""" | ||
function shfl_down end | ||
supports_warp_reduction() = false | ||
|
||
""" | ||
supports_warp_reduction(::Backend) | ||
|
||
Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction. | ||
""" | ||
supports_warp_reduction(::Backend) = false | ||
|
||
# Assume warp is 32 lanes. | ||
const __warpsize = UInt32(32) | ||
# Maximum number of warps (for a groupsize = 1024). | ||
const __warp_bins = UInt32(32) | ||
|
||
@inline function __warp_reduce(val, op) | ||
offset::UInt32 = __warpsize ÷ 0x02 | ||
while offset > 0x00 | ||
val = op(val, shfl_down(val, offset)) | ||
offset >>= 0x01 | ||
end | ||
return val | ||
end | ||
|
||
function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize} | ||
storage = @localmem T __warp_bins | ||
|
||
local_idx = @index(Local) | ||
lane = (local_idx - 0x01) % __warpsize + 0x01 | ||
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01 | ||
|
||
# Each warp performs a reduction and writes results into its own bin in `storage`. | ||
val = __warp_reduce(val, op) | ||
@inbounds lane == 0x01 && (storage[warp_id] = val) | ||
@synchronize() | ||
|
||
# Final reduction of the `storage` on the first warp. | ||
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize | ||
@inbounds val = within_storage ? storage[lane] : neutral | ||
warp_id == 0x01 && (val = __warp_reduce(val, op)) | ||
return val | ||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,70 @@ | ||||||||||||||||||||||
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
res = @groupreduce(op, val) | ||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} | ||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
res = @groupreduce(op, val, groupsize) | ||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
@kernel cpu=false function warp_groupreduce_1!(y, x, op, neutral) | ||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
res = @warp_groupreduce(op, val, neutral) | ||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
@kernel cpu=false function warp_groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} | ||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
res = @warp_groupreduce(op, val, neutral, groupsize) | ||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
function groupreduce_testsuite(backend, AT) | ||||||||||||||||||||||
# TODO should be a better way of querying max groupsize | ||||||||||||||||||||||
groupsizes = "$backend" == "oneAPIBackend" ? | ||||||||||||||||||||||
(256,) : | ||||||||||||||||||||||
(256, 512, 1024) | ||||||||||||||||||||||
|
||||||||||||||||||||||
@testset "@groupreduce" begin | ||||||||||||||||||||||
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes | ||||||||||||||||||||||
x = AT(ones(T, n)) | ||||||||||||||||||||||
y = AT(zeros(T, 1)) | ||||||||||||||||||||||
neutral = zero(T) | ||||||||||||||||||||||
op = + | ||||||||||||||||||||||
|
||||||||||||||||||||||
groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n) | ||||||||||||||||||||||
@test Array(y)[1] == n | ||||||||||||||||||||||
|
||||||||||||||||||||||
for groupsize in (64, 128) | ||||||||||||||||||||||
groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n) | ||||||||||||||||||||||
@test Array(y)[1] == groupsize | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
if KernelAbstractions.supports_warp_reduction(backend()) | ||||||||||||||||||||||
@testset "@warp_groupreduce" begin | ||||||||||||||||||||||
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes | ||||||||||||||||||||||
x = AT(ones(T, n)) | ||||||||||||||||||||||
y = AT(zeros(T, 1)) | ||||||||||||||||||||||
neutral = zero(T) | ||||||||||||||||||||||
op = + | ||||||||||||||||||||||
|
||||||||||||||||||||||
warp_groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n) | ||||||||||||||||||||||
@test Array(y)[1] == n | ||||||||||||||||||||||
|
||||||||||||||||||||||
for groupsize in (64, 128) | ||||||||||||||||||||||
warp_groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n) | ||||||||||||||||||||||
@test Array(y)[1] == groupsize | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.