You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to port the classic GPU tree reduction via KernelAbstractions.jl.
See this for the direct CUDA implementation of what I'm trying to port from.
This is what I have implemented currently:
const TBSize =1024::Intconst DotBlocks =256::Int@kernelfunctiondot(@Const(a), @Const(b), size, partial)
local_i =@index(Local)
group_i =@index(Group)
tb_sum =@localmem T TBSize
@inbounds tb_sum[local_i] =0.0# do dot first
i =@index(Global)
while i <= size
@inbounds tb_sum[local_i] += a[i] * b[i]
i += TBSize * DotBlocks
end# then tree reduction
offset =@private Int64 (1,)
@inboundsbegin
offset[1] =@groupsize()[1] ÷2while offset[1] >0@synchronizeif (local_i -1) < offset[1]
tb_sum[local_i] += tb_sum[local_i+offset[1]]
end
offset[1] ÷=2endendif (local_i ==1)
@inbounds partial[group_i] = tb_sum[local_i]
endend# driverwait(dot(backendDevice, TBSize)(a, b, size, partial_sum, ndrange = TBSize * DotBlocks))
Removing the @synchronize macro in the while-loop makes the error go away but the answer becomes incorrect.
I've tried to do @print eltype(offset[1]), it prints the correct generic type (Float32 in the case) so I'm not sure what @synchronize is doing here.
I'm trying to port the classic GPU tree reduction via KernelAbstractions.jl.
See this for the direct CUDA implementation of what I'm trying to port from.
This is what I have implemented currently:
I was able to get correct results and performance seems mostly on par with our CUDA.jl(https://github.com/UoB-HPC/BabelStream/blob/7c1e04a42b9b03b0e5c5d0b07c0ef9f4bdd59353/JuliaStream.jl/src/CUDAStream.jl#L112) and AMDGPU.jl(https://github.com/UoB-HPC/BabelStream/blob/7c1e04a42b9b03b0e5c5d0b07c0ef9f4bdd59353/JuliaStream.jl/src/AMDGPUStream.jl#L135) implementation.
On CPU however, I got the following error:
Removing the
@synchronize
macro in the while-loop makes the error go away but the answer becomes incorrect.I've tried to do
@print eltype(offset[1])
, it prints the correct generic type (Float32
in the case) so I'm not sure what@synchronize
is doing here.For reference, here is what
pkg status
says:And the complete
Test.jl
reproducer:The text was updated successfully, but these errors were encountered: