Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupreduce API #559

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft

Implement groupreduce API #559

wants to merge 14 commits into from

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Jan 30, 2025

Implement reduction API. Supports two types of algorithms:

  • thread: reduction performed by threads: uses shmem of length groupsize, no bank conflict, no divergence.
  • warp: reduction performed by shlf_down within warps: uses shmem of length 32, reduction within warps storing results in shmem, followed by final warp reduction using values stored in shmem. Backends are required to only implement shlf_down intrinsic which AMDGPU/CUDA/Metal have (no sure about other backends).
  • query function to check if backend supports warp reduction KA.__supports_warp_reduction().
res = @groupreduce op val neutral
  • Optionally limit number of threads that participate in reduction.
res = @groupreduce op val neutral 128 # first 128 threads will perform reduction

Copy link
Contributor

github-actions bot commented Jan 30, 2025

Benchmark Results

main 7c96e5a... main/7c96e5a4ab5554...
saxpy/default/Float16/1024 0.0586 ± 0.026 ms 0.742 ± 0.0059 μs 79
saxpy/default/Float16/1048576 0.89 ± 0.025 ms 0.173 ± 0.0045 ms 5.15
saxpy/default/Float16/16384 0.063 ± 0.027 ms 3.33 ± 0.023 μs 18.9
saxpy/default/Float16/2048 0.0591 ± 0.023 ms 0.92 ± 0.012 μs 64.2
saxpy/default/Float16/256 0.0623 ± 0.026 ms 0.597 ± 0.0047 μs 104
saxpy/default/Float16/262144 0.269 ± 0.026 ms 0.0439 ± 0.00033 ms 6.14
saxpy/default/Float16/32768 0.0749 ± 0.027 ms 6.01 ± 0.047 μs 12.5
saxpy/default/Float16/4096 0.0626 ± 0.025 ms 1.31 ± 0.023 μs 47.9
saxpy/default/Float16/512 0.0605 ± 0.026 ms 0.657 ± 0.0054 μs 92.2
saxpy/default/Float16/64 0.0633 ± 0.026 ms 0.566 ± 0.0048 μs 112
saxpy/default/Float16/65536 0.103 ± 0.027 ms 11.6 ± 0.13 μs 8.88
saxpy/default/Float32/1024 0.0616 ± 0.026 ms 0.65 ± 0.011 μs 94.9
saxpy/default/Float32/1048576 0.476 ± 0.024 ms 0.234 ± 0.015 ms 2.04
saxpy/default/Float32/16384 0.055 ± 0.026 ms 2.77 ± 0.12 μs 19.9
saxpy/default/Float32/2048 0.0531 ± 0.023 ms 0.757 ± 0.043 μs 70.1
saxpy/default/Float32/256 0.0604 ± 0.026 ms 0.576 ± 0.0062 μs 105
saxpy/default/Float32/262144 0.16 ± 0.034 ms 0.0577 ± 0.0033 ms 2.78
saxpy/default/Float32/32768 0.0601 ± 0.026 ms 5.31 ± 0.29 μs 11.3
saxpy/default/Float32/4096 0.061 ± 0.024 ms 1.13 ± 0.086 μs 53.8
saxpy/default/Float32/512 0.0607 ± 0.026 ms 0.612 ± 0.01 μs 99.1
saxpy/default/Float32/64 0.0623 ± 0.026 ms 0.563 ± 0.006 μs 111
saxpy/default/Float32/65536 0.0745 ± 0.029 ms 12.5 ± 0.66 μs 5.98
saxpy/default/Float64/1024 0.057 ± 0.026 ms 0.751 ± 0.058 μs 75.9
saxpy/default/Float64/1048576 0.504 ± 0.049 ms 0.502 ± 0.022 ms 1
saxpy/default/Float64/16384 0.0548 ± 0.025 ms 5.24 ± 0.19 μs 10.4
saxpy/default/Float64/2048 0.0519 ± 0.023 ms 1.13 ± 0.078 μs 45.9
saxpy/default/Float64/256 0.0619 ± 0.026 ms 0.582 ± 0.0068 μs 106
saxpy/default/Float64/262144 0.17 ± 0.029 ms 0.0886 ± 0.0076 ms 1.92
saxpy/default/Float64/32768 0.0626 ± 0.025 ms 11.9 ± 0.92 μs 5.25
saxpy/default/Float64/4096 0.0605 ± 0.024 ms 1.68 ± 0.11 μs 36
saxpy/default/Float64/512 0.0612 ± 0.026 ms 0.635 ± 0.01 μs 96.5
saxpy/default/Float64/64 0.0621 ± 0.025 ms 0.555 ± 0.0061 μs 112
saxpy/default/Float64/65536 0.0847 ± 0.026 ms 23.1 ± 1.8 μs 3.67
saxpy/static workgroup=(1024,)/Float16/1024 0.056 ± 0.026 ms 2.16 ± 0.027 μs 25.9
saxpy/static workgroup=(1024,)/Float16/1048576 0.899 ± 0.028 ms 0.158 ± 0.0069 ms 5.7
saxpy/static workgroup=(1024,)/Float16/16384 0.0594 ± 0.025 ms 4.41 ± 0.079 μs 13.5
saxpy/static workgroup=(1024,)/Float16/2048 0.0571 ± 0.023 ms 2.33 ± 0.027 μs 24.5
saxpy/static workgroup=(1024,)/Float16/256 0.0603 ± 0.025 ms 2.81 ± 0.033 μs 21.5
saxpy/static workgroup=(1024,)/Float16/262144 0.268 ± 0.027 ms 0.0428 ± 0.0018 ms 6.26
saxpy/static workgroup=(1024,)/Float16/32768 0.0724 ± 0.025 ms 6.81 ± 0.15 μs 10.6
saxpy/static workgroup=(1024,)/Float16/4096 0.0619 ± 0.026 ms 2.67 ± 0.035 μs 23.2
saxpy/static workgroup=(1024,)/Float16/512 0.0585 ± 0.026 ms 3.25 ± 0.035 μs 18
saxpy/static workgroup=(1024,)/Float16/64 0.0598 ± 0.025 ms 2.51 ± 0.22 μs 23.9
saxpy/static workgroup=(1024,)/Float16/65536 0.101 ± 0.025 ms 12.6 ± 0.38 μs 8.06
saxpy/static workgroup=(1024,)/Float32/1024 0.0588 ± 0.026 ms 2.23 ± 0.03 μs 26.4
saxpy/static workgroup=(1024,)/Float32/1048576 0.46 ± 0.025 ms 0.201 ± 0.024 ms 2.29
saxpy/static workgroup=(1024,)/Float32/16384 0.0518 ± 0.024 ms 4.4 ± 0.25 μs 11.8
saxpy/static workgroup=(1024,)/Float32/2048 0.0519 ± 0.022 ms 2.4 ± 0.04 μs 21.7
saxpy/static workgroup=(1024,)/Float32/256 0.0605 ± 0.025 ms 2.68 ± 0.043 μs 22.6
saxpy/static workgroup=(1024,)/Float32/262144 0.159 ± 0.035 ms 0.0485 ± 0.0037 ms 3.28
saxpy/static workgroup=(1024,)/Float32/32768 0.0573 ± 0.025 ms 7.49 ± 0.42 μs 7.66
saxpy/static workgroup=(1024,)/Float32/4096 0.0556 ± 0.025 ms 2.66 ± 0.065 μs 20.9
saxpy/static workgroup=(1024,)/Float32/512 0.0581 ± 0.026 ms 2.69 ± 0.031 μs 21.6
saxpy/static workgroup=(1024,)/Float32/64 0.0604 ± 0.025 ms 2.7 ± 5.6 μs 22.4
saxpy/static workgroup=(1024,)/Float32/65536 0.0714 ± 0.028 ms 14.6 ± 1.3 μs 4.89
saxpy/static workgroup=(1024,)/Float64/1024 0.056 ± 0.025 ms 2.32 ± 0.048 μs 24.1
saxpy/static workgroup=(1024,)/Float64/1048576 0.499 ± 0.044 ms 0.499 ± 0.051 ms 0.999
saxpy/static workgroup=(1024,)/Float64/16384 0.0541 ± 0.025 ms 7.41 ± 0.49 μs 7.3
saxpy/static workgroup=(1024,)/Float64/2048 0.0507 ± 0.023 ms 2.61 ± 0.067 μs 19.4
saxpy/static workgroup=(1024,)/Float64/256 0.0605 ± 0.025 ms 2.66 ± 0.061 μs 22.7
saxpy/static workgroup=(1024,)/Float64/262144 0.168 ± 0.029 ms 0.0992 ± 0.0086 ms 1.7
saxpy/static workgroup=(1024,)/Float64/32768 0.0613 ± 0.025 ms 14.5 ± 1.4 μs 4.23
saxpy/static workgroup=(1024,)/Float64/4096 0.0538 ± 0.025 ms 3.15 ± 0.14 μs 17.1
saxpy/static workgroup=(1024,)/Float64/512 0.059 ± 0.025 ms 2.66 ± 0.062 μs 22.2
saxpy/static workgroup=(1024,)/Float64/64 0.0618 ± 0.025 ms 2.62 ± 0.065 μs 23.6
saxpy/static workgroup=(1024,)/Float64/65536 0.0834 ± 0.027 ms 26.5 ± 2.2 μs 3.15
time_to_load 1.09 ± 0.0082 s 0.304 ± 0.0042 s 3.6

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@pxl-th
Copy link
Member Author

pxl-th commented Jan 30, 2025

@vchuravy not sure about CPU errors (regarding @index(Local)). Any idea?

UPD: #218 (comment)

end

function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

end

function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

src/reduce.jl Outdated
Comment on lines 15 to 22
- `algo` specifies which reduction algorithm to use:
- `Reduction.thread`:
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
Available accross all backends.
- `Reduction.warp`:
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
Potentially faster, since requires fewer writes to shared memory.
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Shouldn't the backend go and use warp reductions if it can?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now doing an auto-selection of the algorithm based on device function __supports_warp_reduction().

src/reduce.jl Outdated
Comment on lines 70 to 77
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not legal.

#262 might need to wait until #556

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume this code is GPU only anyways)

src/reduce.jl Outdated
Comment on lines 89 to 93
macro shfl_down(val, offset)
return quote
$__shfl_down($(esc(val)), $(esc(offset)))
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it isn't user-facing or needs special CPU handling you don't need to introduce a new macro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, removed macro.

Comment on lines 1 to 13
@kernel function groupreduce_1!(y, x, op, neutral, algo)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo)
i == 1 && (y[1] = res)
end

@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo, groupsize)
i == 1 && (y[1] = res)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be cpu=false since you are using non-top-level @synchronize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@vchuravy vchuravy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This currently doesn't fully work since calling a function with a __ctx__ argument is GPU only.

#558 is also rearing it's ugly head. I suspect we will need a macro-free kernel language in KA for writing this correctly.

Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

@pxl-th pxl-th requested a review from vchuravy February 3, 2025 22:56
Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@pxl-th pxl-th marked this pull request as draft February 5, 2025 22:48
@vchuravy vchuravy mentioned this pull request Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants