Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupreduce API #559

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
13 changes: 11 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
@uniform
@groupsize
@ndrange
synchronize
allocate
```

### Reduction

```@docs
@groupreduce
@warp_groupreduce
KernelAbstractions.shfl_down
KernelAbstractions.supports_warp_reduction
```

## Host language

```@docs
synchronize
allocate
KernelAbstractions.zeros
```

Expand Down
2 changes: 2 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ argconvert(k::Kernel{T}, arg) where {T} =
supports_enzyme(::Backend) = false
function __fake_compiler_job end

include("groupreduction.jl")

###
# Extras
# - LoopInfo
Expand Down
120 changes: 120 additions & 0 deletions src/groupreduction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export @groupreduce, @warp_groupreduce

"""
@groupreduce op val neutral [groupsize]

Perform group reduction of `val` using `op`.

# Arguments

- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.

- `groupsize` specifies size of the workgroup.
If a kernel does not specifies `groupsize` statically, then it is required to
provide `groupsize`.
Also can be used to perform reduction accross first `groupsize` threads
(if `groupsize < @groupsize()`).

# Returns

Result of the reduction.
"""
macro groupreduce(op, val)
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__)))))))
end
macro groupreduce(op, val, groupsize)
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize)))))
end

function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize}
storage = @localmem T groupsize

local_idx = @index(Local)
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val)
@synchronize()

s::UInt64 = groupsize ÷ 0x02
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
s >>= 0x01
end

if local_idx == 0x01
@inbounds val = storage[local_idx]
end
return val
end

# Warp groupreduce.

"""
@warp_groupreduce op val neutral [groupsize]

Perform group reduction of `val` using `op`.
Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic,
followed by final reduction over results of individual warp reductions.

!!! note

Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction.
"""
macro warp_groupreduce(op, val, neutral)
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__)))))))
end
macro warp_groupreduce(op, val, neutral, groupsize)
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize)))))
end

"""
shfl_down(val::T, offset::Integer)::T where T

Read `val` from a lane with higher id given by `offset`.
"""
function shfl_down end
supports_warp_reduction() = false

"""
supports_warp_reduction(::Backend)

Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction.
"""
supports_warp_reduction(::Backend) = false

# Assume warp is 32 lanes.
const __warpsize = UInt32(32)
# Maximum number of warps (for a groupsize = 1024).
const __warp_bins = UInt32(32)

@inline function __warp_reduce(val, op)
offset::UInt32 = __warpsize ÷ 0x02
while offset > 0x00
val = op(val, shfl_down(val, offset))
offset >>= 0x01
end
return val
end

function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
storage = @localmem T __warp_bins

local_idx = @index(Local)
lane = (local_idx - 0x01) % __warpsize + 0x01
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01

# Each warp performs a reduction and writes results into its own bin in `storage`.
val = __warp_reduce(val, op)
@inbounds lane == 0x01 && (storage[warp_id] = val)
@synchronize()

# Final reduction of the `storage` on the first warp.
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize
@inbounds val = within_storage ? storage[lane] : neutral
warp_id == 0x01 && (val = __warp_reduce(val, op))
return val
end
70 changes: 70 additions & 0 deletions test/groupreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val)
i == 1 && (y[1] = res)
end

@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, groupsize)
i == 1 && (y[1] = res)
end

@kernel cpu=false function warp_groupreduce_1!(y, x, op, neutral)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @warp_groupreduce(op, val, neutral)
i == 1 && (y[1] = res)
end

@kernel cpu=false function warp_groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @warp_groupreduce(op, val, neutral, groupsize)
i == 1 && (y[1] = res)
end

function groupreduce_testsuite(backend, AT)
# TODO should be a better way of querying max groupsize
groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)

@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
x = AT(ones(T, n))
y = AT(zeros(T, 1))
neutral = zero(T)
op = +

groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
@test Array(y)[1] == n

for groupsize in (64, 128)
groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
@test Array(y)[1] == groupsize
end
end
end

if KernelAbstractions.supports_warp_reduction(backend())
@testset "@warp_groupreduce" begin
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes
x = AT(ones(T, n))
y = AT(zeros(T, 1))
neutral = zero(T)
op = +

warp_groupreduce_1!(backend(), n)(y, x, op, neutral; ndrange = n)
@test Array(y)[1] == n

for groupsize in (64, 128)
warp_groupreduce_2!(backend())(y, x, op, neutral, Val(groupsize); ndrange = n)
@test Array(y)[1] == groupsize
end
end
end
end
end
8 changes: 8 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ include("reflection.jl")
include("examples.jl")
include("convert.jl")
include("specialfunctions.jl")
include("groupreduce.jl")

function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
@conditional_testset "Unittests" skip_tests begin
Expand Down Expand Up @@ -92,6 +93,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
examples_testsuite(backend_str)
end

# TODO @index(Local) only works as a top-level expression on CPU.
if backend != CPU
@conditional_testset "@groupreduce" skip_tests begin
groupreduce_testsuite(backend, AT)
end
end

return
end

Expand Down
Loading